FluxM-Lightning-Upscaler / optimized.py
LPX55's picture
Update optimized.py
a102a01 verified
raw
history blame
4.8 kB
import torch
import spaces
import os
import diffusers
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
import gradio as gr
from accelerate import dispatch_model, infer_auto_device_map
# Corrected and optimized FluxControlNet implementation
def self_attention_slicing(module, slice_size=3):
"""Modified from Diffusers' original for Flux compatibility"""
def sliced_attention(*args, **kwargs):
if "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 1
if slice_size == "auto":
# Automatic slicing based on Flux architecture
return module(*args, **kwargs)
output = torch.cat([
module(
*[arg[:, :, i:i+slice_size] if i == dim else arg
for arg in args],
**{k: v[:, :, i:i+slice_size] if k == dim else v
for k,v in kwargs.items()}
)
for i in range(0, args[0].shape[dim], slice_size)
], dim=dim)
return output
return sliced_attention
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="vae",
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=None, # Disable automatic mapping
token=huggingface_token
)
# 2. Main Pipeline Initialization WITH VAE SCOPE
pipe = FluxControlNetPipeline.from_pretrained(
"LPX55/FLUX.1-merged_uncensored",
controlnet=FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler",
torch_dtype=torch.bfloat16
),
vae=good_vae, # Now defined in scope
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=None,
token=huggingface_token # Note corrected env var name
)
# 3. Strict Order for Optimization Steps
# A. Apply CPU Offloading FIRST
pipe.enable_sequential_cpu_offload() # No arguments for new API
# 2. Then apply custom VAE slicing
if getattr(pipe, "vae", None) is not None:
# Method 1: Use official implementation if available
try:
pipe.vae.enable_slicing()
except AttributeError:
# Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py
pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2)
pipe.enable_attention_slicing(1)
# B. Enable Memory Optimizations
# pipe.enable_vae_tiling()
# pipe.enable_xformers_memory_efficient_attention()
# C. Unified Precision Handling
# for comp in [pipe.unet, pipe.vae, pipe.controlnet]:
# comp.to(dtype=torch.bfloat16)
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
@spaces.GPU
def generate_image(prompt, scale, steps, control_image, controlnet_conditioning_scale, guidance_scale):
# Clean up input handling
w, h = control_image.size
scale = min(scale, 2.0) # Cap scale factor
# Size calculation with safety limits
max_dim = 1536 # Set based on your VRAM
target_w = min(int(w * scale), max_dim)
target_h = min(int(h * scale), max_dim)
control_image = control_image.resize(
(target_w, target_h),
PIL.Image.BICUBIC
)
# Generation with memory-friendly parameters
with torch.autocast("cuda"): # Mixed precision
image = pipe(
prompt=prompt,
control_image=control_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=steps,
guidance_scale=guidance_scale,
height=target_h,
width=target_w,
output_type="pil", # Avoid extra latent decoding steps
generator=torch.Generator(device="cuda").manual_seed(0)
).images[0]
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
# Aggressive memory cleanup
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
return image
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
gr.Slider(1, 3, value=1, label="Scale"),
gr.Slider(6, 30, value=8, label="Steps"),
gr.Image(type="pil", label="Control Image"),
gr.Slider(0, 1, value=0.6, label="ControlNet Scale"),
gr.Slider(1, 20, value=3.5, label="Guidance Scale"),
],
outputs=[
gr.Image(type="pil", label="Generated Image", format="png"),
],
title="FLUX ControlNet Image Generation",
description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.",
)
# Launch the app
iface.launch()