File size: 4,458 Bytes
5e673fa
f9854a2
5e673fa
 
 
 
c7c8e9e
5e673fa
 
 
 
 
 
 
 
c7c8e9e
 
 
 
5e673fa
c7c8e9e
5e673fa
 
 
 
58e0bd4
5e673fa
 
 
 
c7c8e9e
 
 
 
5e673fa
 
 
 
 
 
9975067
5e673fa
9975067
 
5e673fa
5b6d50e
284fff8
 
5b6d50e
284fff8
5e673fa
 
 
 
 
 
 
 
 
 
 
 
c7c8e9e
b9cc117
5e673fa
 
 
 
 
 
 
 
 
 
 
 
 
 
2b10f87
 
 
3e6b373
2b10f87
 
 
5e673fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58e0bd4
5e673fa
 
 
 
 
 
 
5aebc40
5e673fa
 
5b6d50e
 
 
 
 
 
 
5e673fa
 
 
 
c7c8e9e
5e673fa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
import spaces
import numpy as np
import random
import spaces
import torch
from diffusers import SanaSprintPipeline

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = SanaSprintPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
    torch_dtype=torch.bfloat16
)
pipe2 = SanaSprintPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
    torch_dtype=torch.bfloat16
)
pipe.to(device)
pipe2.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

@spaces.GPU(duration=5)
def infer(prompt, model_size, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4.5, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    
    # Choose the appropriate model based on selected model size
    selected_pipe = pipe if model_size == "0.6B" else pipe2
    
    img = selected_pipe(
            prompt=prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            width=width,
            height=height,
            generator=generator,
            output_type="pil"
    )
    print(img)
    return img.images[0], seed
    
examples = [
    ["a tiny astronaut hatching from an egg on the moon", "1.6B"],
    ["๐Ÿถ Wearing ๐Ÿ•ถ flying on the ๐ŸŒˆ", "1.6B"],
    ["an anime illustration of a wiener schnitzel", "0.6B"],
    ["a photorealistic landscape of mountains at sunset", "0.6B"],
]

css="""
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# Sana Sprint""")
        gr.Markdown("Demo for the real-time [Sana Sprint](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) model")
        with gr.Row():
            
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Result", show_label=False)
        
        model_size = gr.Radio(
            label="Model Size",
            choices=["0.6B", "1.6B"],
            value="1.6B",
            interactive=True
        )
        
        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            
            with gr.Row():

                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=1,
                    maximum=15,
                    step=0.1,
                    value=4.5,
                )
  
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=2,
                )
        
        gr.Examples(
            examples = examples,
            fn = infer,
            inputs = [prompt, model_size],  # Add model_size to inputs
            outputs = [result, seed],
            cache_examples="lazy"
        )

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = infer,
        inputs = [prompt, model_size, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],  # Add model_size to inputs
        outputs = [result, seed]
    )

demo.launch()