LPX55 commited on
Commit
4af365d
·
verified ·
1 Parent(s): b1c8464

Update optimized.py

Browse files
Files changed (1) hide show
  1. optimized.py +53 -40
optimized.py CHANGED
@@ -8,70 +8,83 @@ from accelerate import init_empty_weights
8
 
9
  huggingface_token = os.getenv("HUGGINFACE_TOKEN")
10
 
 
 
 
 
 
 
 
 
11
 
12
-
13
- good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae",
14
- torch_dtype=torch.bfloat16,
15
- # variant="4bit",
16
- device_map="balanced",
17
- use_safetensors=True,
18
- token=huggingface_token).to("cuda")
19
-
20
- # Load pipeline
21
  controlnet = FluxControlNetModel.from_pretrained(
22
  "jasperai/Flux.1-dev-Controlnet-Upscaler",
23
  torch_dtype=torch.bfloat16
24
  )
25
- #with init_empty_weights():
 
26
  pipe = FluxControlNetPipeline.from_pretrained(
27
  "LPX55/FLUX.1-merged_uncensored",
28
  controlnet=controlnet,
29
- torch_dtype=torch.bfloat16,
30
- device_map="balanced",
31
  vae=good_vae,
32
- use_safetensors=True,
 
 
33
  token=huggingface_token
34
  )
35
- pipe.enable_model_cpu_offload(device="cuda")
36
- # Add to your pipeline initialization:
37
- # pipe.enable_xformers_memory_efficient_attention()
38
- # pipe.enable_vae_slicing() # Batch processing of VAE
39
- # pipe.enable_model_cpu_offload() # Use with accelerate
 
 
40
  try:
41
  import xformers
42
  pipe.enable_xformers_memory_efficient_attention()
43
  except ImportError:
44
  print("XFormers missing! Using PyTorch attention instead")
45
- # Fallback to PyTorch 2.0+ memory efficient attention
46
  pipe.enable_sdp_attention()
47
  torch.backends.cuda.enable_flash_sdp(True)
48
- # Convert all models to memory-efficient format
49
- #pipe.to(memory_format=torch.channels_last)
50
- pipe.to("cuda")
51
 
 
 
 
52
  @spaces.GPU
53
  def generate_image(prompt, scale, steps, control_image, controlnet_conditioning_scale, guidance_scale):
54
- # Load control image
55
- control_image = control_image.resize((int(w * scale), int(h * scale)), PIL.Image.BICUBIC)
56
- # control_image = load_image(control_image)
57
  w, h = control_image.size
58
- # Upscale x1
59
- control_image = control_image.resize((int(w * scale), int(h * scale)))
60
- print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1]))
61
- image = pipe(
62
- prompt=prompt,
63
- control_image=control_image,
64
- controlnet_conditioning_scale=controlnet_conditioning_scale,
65
- num_inference_steps=steps,
66
- guidance_scale=guidance_scale,
67
- height=control_image.size[1],
68
- width=control_image.size[0],
69
- torch_dtype=torch.bfloat16,
70
- device_map="balanced"
71
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  torch.cuda.empty_cache()
 
 
73
  return image
74
-
75
  # Create Gradio interface
76
  iface = gr.Interface(
77
  fn=generate_image,
 
8
 
9
  huggingface_token = os.getenv("HUGGINFACE_TOKEN")
10
 
11
+ good_vae = AutoencoderKL.from_pretrained(
12
+ "black-forest-labs/FLUX.1-dev",
13
+ subfolder="vae",
14
+ torch_dtype=torch.bfloat16,
15
+ use_safetensors=True,
16
+ device_map=None, # Disable automatic mapping
17
+ token=huggingface_token
18
+ )
19
 
 
 
 
 
 
 
 
 
 
20
  controlnet = FluxControlNetModel.from_pretrained(
21
  "jasperai/Flux.1-dev-Controlnet-Upscaler",
22
  torch_dtype=torch.bfloat16
23
  )
24
+
25
+ # Initialize pipeline without automatic device mapping
26
  pipe = FluxControlNetPipeline.from_pretrained(
27
  "LPX55/FLUX.1-merged_uncensored",
28
  controlnet=controlnet,
 
 
29
  vae=good_vae,
30
+ torch_dtype=torch.bfloat16,
31
+ use_safetensors=True,
32
+ device_map=None, # Disable automatic device mapping
33
  token=huggingface_token
34
  )
35
+ print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
36
+ # Proper CPU offloading sequence
37
+ pipe.enable_model_cpu_offload(device="cuda") # First enable offloading
38
+ pipe.enable_vae_slicing() # Then enable memory optimizations
39
+ pipe.enable_attention_slicing(1)
40
+
41
+ # Handle xformers/SDP attention after offloading
42
  try:
43
  import xformers
44
  pipe.enable_xformers_memory_efficient_attention()
45
  except ImportError:
46
  print("XFormers missing! Using PyTorch attention instead")
 
47
  pipe.enable_sdp_attention()
48
  torch.backends.cuda.enable_flash_sdp(True)
 
 
 
49
 
50
+ # Memory format optimization (only after other configs)
51
+ pipe.to(memory_format=torch.channels_last)
52
+ print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
53
  @spaces.GPU
54
  def generate_image(prompt, scale, steps, control_image, controlnet_conditioning_scale, guidance_scale):
55
+ # Clean up input handling
 
 
56
  w, h = control_image.size
57
+ scale = min(scale, 2.0) # Cap scale factor
58
+
59
+ # Size calculation with safety limits
60
+ max_dim = 1536 # Set based on your VRAM
61
+ target_w = min(int(w * scale), max_dim)
62
+ target_h = min(int(h * scale), max_dim)
63
+
64
+ control_image = control_image.resize(
65
+ (target_w, target_h),
66
+ PIL.Image.BICUBIC
67
+ )
68
+
69
+ # Generation with memory-friendly parameters
70
+ with torch.autocast("cuda"): # Mixed precision
71
+ image = pipe(
72
+ prompt=prompt,
73
+ control_image=control_image,
74
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
75
+ num_inference_steps=steps,
76
+ guidance_scale=guidance_scale,
77
+ height=target_h,
78
+ width=target_w,
79
+ output_type="pil", # Avoid extra latent decoding steps
80
+ generator=torch.Generator(device="cuda").manual_seed(0)
81
+ ).images[0]
82
+ print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
83
+ # Aggressive memory cleanup
84
  torch.cuda.empty_cache()
85
+ torch.cuda.ipc_collect()
86
+ print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
87
  return image
 
88
  # Create Gradio interface
89
  iface = gr.Interface(
90
  fn=generate_image,