SanaSprint / app.py
multimodalart's picture
Update app.py
b9cc117 verified
raw
history blame
4.46 kB
import gradio as gr
import spaces
import numpy as np
import random
import spaces
import torch
from diffusers import SanaSprintPipeline
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
torch_dtype=torch.bfloat16
)
pipe2 = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
torch_dtype=torch.bfloat16
)
pipe.to(device)
pipe2.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU(duration=5)
def infer(prompt, model_size, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
# Choose the appropriate model based on selected model size
selected_pipe = pipe if model_size == "0.6B" else pipe2
img = selected_pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
output_type="pil"
)
print(img)
return img.images[0], seed
examples = [
["a tiny astronaut hatching from an egg on the moon", "0.6B"],
["a cat holding a sign that says hello world", "1.6B"],
["an anime illustration of a wiener schnitzel", "0.6B"],
["a photorealistic landscape of mountains at sunset", "1.6B"],
]
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# Sana Sprint""")
gr.Markdown("Demo for the real-time [Sana Sprint](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) model")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
model_size = gr.Radio(
label="Model Size",
choices=["0.6B", "1.6B"],
value="1.6B",
interactive=True
)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=1,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=2,
)
gr.Examples(
examples = examples,
fn = infer,
inputs = [prompt, model_size], # Add model_size to inputs
outputs = [result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [prompt, model_size, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], # Add model_size to inputs
outputs = [result, seed]
)
demo.launch()