LPX55 commited on
Commit
30ad131
·
verified ·
1 Parent(s): b16d959

Update optimized.py

Browse files
Files changed (1) hide show
  1. optimized.py +21 -4
optimized.py CHANGED
@@ -1,10 +1,11 @@
1
  import torch
2
  import spaces
3
  import os
 
4
  from diffusers.utils import load_image
5
  from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
6
  import gradio as gr
7
- from accelerate import init_empty_weights
8
 
9
  def self_attention_slicing(module, slice_size=3):
10
  """Modified from Diffusers' original for Flux compatibility"""
@@ -59,11 +60,27 @@ pipe = FluxControlNetPipeline.from_pretrained(
59
  )
60
  print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
61
  # Proper CPU offloading sequence
62
- pipe.enable_sequential_cpu_offload(
63
- device=torch.device("cuda:0"),
64
- execution_device="cuda"
 
 
 
 
 
 
 
 
 
65
  )
66
 
 
 
 
 
 
 
 
67
  # # 2. Then apply custom VAE slicing
68
  # if getattr(pipe, "vae", None) is not None:
69
  # # Method 1: Use official implementation if available
 
1
  import torch
2
  import spaces
3
  import os
4
+ import diffusers
5
  from diffusers.utils import load_image
6
  from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
7
  import gradio as gr
8
+ from accelerate import dispatch_model, infer_auto_device_map
9
 
10
  def self_attention_slicing(module, slice_size=3):
11
  """Modified from Diffusers' original for Flux compatibility"""
 
60
  )
61
  print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
62
  # Proper CPU offloading sequence
63
+ # device_map = infer_auto_device_map(pipe)
64
+ # pipe = dispatch_model(pipe, device_map=device_map, main_device="cuda")
65
+
66
+ device_map = infer_auto_device_map(
67
+ pipe,
68
+ max_memory={0:"38GB", "cpu":"64GB"},
69
+ device_types=["cuda", "cpu"]
70
+ )
71
+ pipe = dispatch_model(
72
+ pipe,
73
+ device_map=device_map,
74
+ main_device="cuda"
75
  )
76
 
77
+ # For Diffusers v0.20+
78
+ pipe.enable_sequential_cpu_offload()
79
+ # (No parameters needed)
80
+
81
+ pipe.unet.to(dtype=torch.bfloat16)
82
+ pipe.controlnet.to(dtype=torch.bfloat16)
83
+ pipe.vae.to(dtype=torch.bfloat16)
84
  # # 2. Then apply custom VAE slicing
85
  # if getattr(pipe, "vae", None) is not None:
86
  # # Method 1: Use official implementation if available