File size: 1,793 Bytes
c80eda9
 
 
ec75dde
c80eda9
 
 
 
ec75dde
c80eda9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f725d52
c80eda9
 
 
 
 
 
 
 
f725d52
c80eda9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# model_loader.py
import os
import torch
import spaces
from diffusers import FluxControlNetPipeline
from transformers import T5EncoderModel
from moondream import vl

@spaces.GPU()
def safe_model_load():
    """Load models in a single GPU invocation to keep them warm"""
    try:
        # Set max memory usage for ZeroGPU
        torch.cuda.set_per_process_memory_fraction(1.0)
        torch.set_float32_matmul_precision("high")
        
        # Load models
        huggingface_token = os.getenv("HUGGINFACE_TOKEN")
        md_api_key = os.getenv("MD_KEY")
        
        text_encoder = T5EncoderModel.from_pretrained(
            "LPX55/FLUX.1-merged_uncensored",
            subfolder="text_encoder_2",
            torch_dtype=torch.bfloat16,
            token=huggingface_token
        )
        
        pipe = FluxControlNetPipeline.from_pretrained(
            "LPX55/FLUX.1M-8step_upscaler-cnet",
            torch_dtype=torch.bfloat16,
            text_encoder_2=text_encoder,
            token=huggingface_token
        )
        
        # Apply memory optimizations
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except Exception as e:
            print(f"XFormers not available: {e}")
            
        pipe.enable_attention_slicing()
        # pipe.enable_sequential_cpu_offload()
        pipe.to("cuda")
        
        # For memory-sensitive environments
        try:
            torch.multiprocessing.set_sharing_strategy('file_system')
        except Exception as e:
            print(f"Exception raised (torch.multiprocessing): {e}")
        
        return pipe
        
    except Exception as e:
        print(f"Model loading failed: {e}")
        # Return placeholder to handle gracefully in UI
        return {"error": str(e)}