LPX55 commited on
Commit
2310622
·
verified ·
1 Parent(s): 4af365d

Update optimized.py

Browse files
Files changed (1) hide show
  1. optimized.py +37 -2
optimized.py CHANGED
@@ -6,6 +6,31 @@ from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
6
  import gradio as gr
7
  from accelerate import init_empty_weights
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  huggingface_token = os.getenv("HUGGINFACE_TOKEN")
10
 
11
  good_vae = AutoencoderKL.from_pretrained(
@@ -35,8 +60,18 @@ pipe = FluxControlNetPipeline.from_pretrained(
35
  print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
36
  # Proper CPU offloading sequence
37
  pipe.enable_model_cpu_offload(device="cuda") # First enable offloading
38
- pipe.enable_vae_slicing() # Then enable memory optimizations
39
- pipe.enable_attention_slicing(1)
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Handle xformers/SDP attention after offloading
42
  try:
 
6
  import gradio as gr
7
  from accelerate import init_empty_weights
8
 
9
+ def self_attention_slicing(module, slice_size=3):
10
+ """Modified from Diffusers' original for Flux compatibility"""
11
+ def sliced_attention(*args, **kwargs):
12
+ if "dim" in kwargs:
13
+ dim = kwargs["dim"]
14
+ else:
15
+ dim = 1
16
+
17
+ if slice_size == "auto":
18
+ # Automatic slicing based on Flux architecture
19
+ return module(*args, **kwargs)
20
+
21
+ output = torch.cat([
22
+ module(
23
+ *[arg[:, :, i:i+slice_size] if i == dim else arg
24
+ for arg in args],
25
+ **{k: v[:, :, i:i+slice_size] if k == dim else v
26
+ for k,v in kwargs.items()}
27
+ )
28
+ for i in range(0, args[0].shape[dim], slice_size)
29
+ ], dim=dim)
30
+
31
+ return output
32
+ return sliced_attention
33
+
34
  huggingface_token = os.getenv("HUGGINFACE_TOKEN")
35
 
36
  good_vae = AutoencoderKL.from_pretrained(
 
60
  print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
61
  # Proper CPU offloading sequence
62
  pipe.enable_model_cpu_offload(device="cuda") # First enable offloading
63
+
64
+ # 2. Then apply custom VAE slicing
65
+ if getattr(pipe, "vae", None) is not None:
66
+ # Method 1: Use official implementation if available
67
+ try:
68
+ pipe.vae.enable_slicing()
69
+ except AttributeError:
70
+ # Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py
71
+ pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2)
72
+
73
+ # 3. Attention optimizations
74
+ pipe.enable_attention_slicing(1) # Mandatory for Flux
75
 
76
  # Handle xformers/SDP attention after offloading
77
  try: