Spaces:
Running
on
Zero
Running
on
Zero
Update optimized.py
Browse files- optimized.py +25 -2
optimized.py
CHANGED
@@ -5,11 +5,18 @@ import diffusers
|
|
5 |
import PIL
|
6 |
from diffusers.utils import load_image
|
7 |
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
|
|
|
|
|
8 |
import gradio as gr
|
9 |
from accelerate import dispatch_model, infer_auto_device_map
|
10 |
from PIL import Image
|
|
|
|
|
11 |
import gc
|
12 |
# Corrected and optimized FluxControlNet implementation
|
|
|
|
|
|
|
13 |
|
14 |
def self_attention_slicing(module, slice_size=3):
|
15 |
"""Modified from Diffusers' original for Flux compatibility"""
|
@@ -35,9 +42,23 @@ def self_attention_slicing(module, slice_size=3):
|
|
35 |
|
36 |
return output
|
37 |
return sliced_attention
|
38 |
-
device = "cuda"
|
39 |
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
good_vae = AutoencoderKL.from_pretrained(
|
42 |
"black-forest-labs/FLUX.1-dev",
|
43 |
subfolder="vae",
|
@@ -54,6 +75,8 @@ pipe = FluxControlNetPipeline.from_pretrained(
|
|
54 |
torch_dtype=torch.bfloat16
|
55 |
),
|
56 |
vae=good_vae, # Now defined in scope
|
|
|
|
|
57 |
torch_dtype=torch.bfloat16,
|
58 |
use_safetensors=True,
|
59 |
device_map=None,
|
|
|
5 |
import PIL
|
6 |
from diffusers.utils import load_image
|
7 |
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
|
8 |
+
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
9 |
+
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
10 |
import gradio as gr
|
11 |
from accelerate import dispatch_model, infer_auto_device_map
|
12 |
from PIL import Image
|
13 |
+
from diffusers import FluxTransformer2DModel
|
14 |
+
from transformers import T5EncoderModel
|
15 |
import gc
|
16 |
# Corrected and optimized FluxControlNet implementation
|
17 |
+
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
|
18 |
+
device = "cuda"
|
19 |
+
torch_dtype = torch.bfloat16
|
20 |
|
21 |
def self_attention_slicing(module, slice_size=3):
|
22 |
"""Modified from Diffusers' original for Flux compatibility"""
|
|
|
42 |
|
43 |
return output
|
44 |
return sliced_attention
|
|
|
45 |
|
46 |
+
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
|
47 |
+
text_encoder_2_8bit = T5EncoderModel.from_pretrained(
|
48 |
+
"LPX55/FLUX.1-merged_uncensored",
|
49 |
+
subfolder="text_encoder_2",
|
50 |
+
quantization_config=quant_config,
|
51 |
+
torch_dtype=torch.bfloat16,
|
52 |
+
token=huggingface_token
|
53 |
+
)
|
54 |
+
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
|
55 |
+
transformer_8bit = FluxTransformer2DModel.from_pretrained(
|
56 |
+
"LPX55/FLUX.1-merged_uncensored",
|
57 |
+
subfolder="transformer",
|
58 |
+
quantization_config=quant_config,
|
59 |
+
torch_dtype=torch.bfloat16,
|
60 |
+
token=huggingface_token
|
61 |
+
)
|
62 |
good_vae = AutoencoderKL.from_pretrained(
|
63 |
"black-forest-labs/FLUX.1-dev",
|
64 |
subfolder="vae",
|
|
|
75 |
torch_dtype=torch.bfloat16
|
76 |
),
|
77 |
vae=good_vae, # Now defined in scope
|
78 |
+
transformer=transformer_8bit,
|
79 |
+
text_encoder_2=text_encoder_2_8bit,
|
80 |
torch_dtype=torch.bfloat16,
|
81 |
use_safetensors=True,
|
82 |
device_map=None,
|