FluxM-Lightning-Upscaler / model_loader.py
LPX
Remove sequential CPU offload from model loading and simplify return value in safe_model_load function
f725d52
raw
history blame
1.79 kB
# model_loader.py
import os
import torch
import spaces
from diffusers import FluxControlNetPipeline
from transformers import T5EncoderModel
from moondream import vl
@spaces.GPU()
def safe_model_load():
"""Load models in a single GPU invocation to keep them warm"""
try:
# Set max memory usage for ZeroGPU
torch.cuda.set_per_process_memory_fraction(1.0)
torch.set_float32_matmul_precision("high")
# Load models
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
md_api_key = os.getenv("MD_KEY")
text_encoder = T5EncoderModel.from_pretrained(
"LPX55/FLUX.1-merged_uncensored",
subfolder="text_encoder_2",
torch_dtype=torch.bfloat16,
token=huggingface_token
)
pipe = FluxControlNetPipeline.from_pretrained(
"LPX55/FLUX.1M-8step_upscaler-cnet",
torch_dtype=torch.bfloat16,
text_encoder_2=text_encoder,
token=huggingface_token
)
# Apply memory optimizations
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print(f"XFormers not available: {e}")
pipe.enable_attention_slicing()
# pipe.enable_sequential_cpu_offload()
pipe.to("cuda")
# For memory-sensitive environments
try:
torch.multiprocessing.set_sharing_strategy('file_system')
except Exception as e:
print(f"Exception raised (torch.multiprocessing): {e}")
return pipe
except Exception as e:
print(f"Model loading failed: {e}")
# Return placeholder to handle gracefully in UI
return {"error": str(e)}