LPX55 commited on
Commit
93afc0b
·
verified ·
1 Parent(s): 2b6a07b

Update optimized.py

Browse files
Files changed (1) hide show
  1. optimized.py +25 -2
optimized.py CHANGED
@@ -5,11 +5,18 @@ import diffusers
5
  import PIL
6
  from diffusers.utils import load_image
7
  from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
 
 
8
  import gradio as gr
9
  from accelerate import dispatch_model, infer_auto_device_map
10
  from PIL import Image
 
 
11
  import gc
12
  # Corrected and optimized FluxControlNet implementation
 
 
 
13
 
14
  def self_attention_slicing(module, slice_size=3):
15
  """Modified from Diffusers' original for Flux compatibility"""
@@ -35,9 +42,23 @@ def self_attention_slicing(module, slice_size=3):
35
 
36
  return output
37
  return sliced_attention
38
- device = "cuda"
39
 
40
- huggingface_token = os.getenv("HUGGINFACE_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  good_vae = AutoencoderKL.from_pretrained(
42
  "black-forest-labs/FLUX.1-dev",
43
  subfolder="vae",
@@ -54,6 +75,8 @@ pipe = FluxControlNetPipeline.from_pretrained(
54
  torch_dtype=torch.bfloat16
55
  ),
56
  vae=good_vae, # Now defined in scope
 
 
57
  torch_dtype=torch.bfloat16,
58
  use_safetensors=True,
59
  device_map=None,
 
5
  import PIL
6
  from diffusers.utils import load_image
7
  from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
8
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
9
+ from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
10
  import gradio as gr
11
  from accelerate import dispatch_model, infer_auto_device_map
12
  from PIL import Image
13
+ from diffusers import FluxTransformer2DModel
14
+ from transformers import T5EncoderModel
15
  import gc
16
  # Corrected and optimized FluxControlNet implementation
17
+ huggingface_token = os.getenv("HUGGINFACE_TOKEN")
18
+ device = "cuda"
19
+ torch_dtype = torch.bfloat16
20
 
21
  def self_attention_slicing(module, slice_size=3):
22
  """Modified from Diffusers' original for Flux compatibility"""
 
42
 
43
  return output
44
  return sliced_attention
 
45
 
46
+ quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
47
+ text_encoder_2_8bit = T5EncoderModel.from_pretrained(
48
+ "LPX55/FLUX.1-merged_uncensored",
49
+ subfolder="text_encoder_2",
50
+ quantization_config=quant_config,
51
+ torch_dtype=torch.bfloat16,
52
+ token=huggingface_token
53
+ )
54
+ quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
55
+ transformer_8bit = FluxTransformer2DModel.from_pretrained(
56
+ "LPX55/FLUX.1-merged_uncensored",
57
+ subfolder="transformer",
58
+ quantization_config=quant_config,
59
+ torch_dtype=torch.bfloat16,
60
+ token=huggingface_token
61
+ )
62
  good_vae = AutoencoderKL.from_pretrained(
63
  "black-forest-labs/FLUX.1-dev",
64
  subfolder="vae",
 
75
  torch_dtype=torch.bfloat16
76
  ),
77
  vae=good_vae, # Now defined in scope
78
+ transformer=transformer_8bit,
79
+ text_encoder_2=text_encoder_2_8bit,
80
  torch_dtype=torch.bfloat16,
81
  use_safetensors=True,
82
  device_map=None,