liujie31
change default model
fa4c25e
raw
history blame
8.73 kB
import gradio as gr
import numpy as np
import random
from PIL import Image
import os
# import spaces
from diffusers import StableDiffusion3Pipeline
import torch
from peft import PeftModel
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "frankjoshua/stable-diffusion-3.5-medium"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = StableDiffusion3Pipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
lora_models = {
"None": None,
"GenEval": "jieliu/SD3.5M-FlowGRPO-GenEval",
"Text Rendering": "jieliu/SD3.5M-FlowGRPO-Text",
"Human Prefer": "jieliu/SD3.5M-FlowGRPO-PickScore",
}
lora_prompts = {
"GenEval": os.path.join(os.getcwd(), "prompts/geneval.txt"),
"Text Rendering": os.path.join(os.getcwd(), "prompts/ocr.txt"),
"Human Prefer": os.path.join(os.getcwd(), "prompts/pickscore.txt"),
}
pipe.transformer = PeftModel.from_pretrained(pipe.transformer, lora_models["GenEval"], adapter_name="GenEval")
pipe.transformer.load_adapter(lora_models["Text Rendering"], adapter_name="Text Rendering")
pipe.transformer.load_adapter(lora_models["Human Prefer"], adapter_name="Human Prefer")
pipe = pipe.to(device)
COUNTER_FILE = os.path.join(os.getcwd(),"model_call_counter.txt")
def get_call_count():
if not os.path.exists(COUNTER_FILE):
return 0
try:
with open(COUNTER_FILE, 'r') as f:
return int(f.read().strip())
except:
return 0
def update_call_count():
count = get_call_count() + 1
with open(COUNTER_FILE, 'w') as f:
f.write(str(count))
return count
def sample_prompt(lora_model):
if lora_model in lora_models and lora_model != "None":
file_path = f"{lora_prompts[lora_model]}"
try:
with open(file_path, 'r') as file:
prompts = file.readlines()
if lora_model=='GenEval':
total_lines = len(prompts)
if total_lines > 0:
weights = [1/(i+1) for i in range(total_lines)]
sum_weights = sum(weights)
normalized_weights = [w/sum_weights for w in weights]
return random.choices(prompts, weights=normalized_weights, k=1)[0].strip()
return "No prompts found in file."
else:
return random.choice(prompts).strip()
except FileNotFoundError:
return "Prompt file not found."
return ""
def create_grid_image(images):
# Create a 2x2 grid from the 4 images
width, height = images[0].size
grid_image = Image.new('RGB', (width * 2, height * 2))
# Paste images in a 2x2 grid
grid_image.paste(images[0], (0, 0))
grid_image.paste(images[1], (width, 0))
grid_image.paste(images[2], (0, height))
grid_image.paste(images[3], (width, height))
return grid_image
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lora_model,
progress=gr.Progress(track_tqdm=True),
):
call_count = update_call_count()
images = []
seeds = []
# Generate 4 images
for i in range(4):
if randomize_seed:
current_seed = random.randint(0, MAX_SEED)
else:
current_seed = seed + i # Use sequential seeds if not randomizing
seeds.append(current_seed)
generator = torch.Generator().manual_seed(current_seed)
sampled_prompt = sample_prompt(lora_model)
final_prompt = prompt if prompt else sampled_prompt
if lora_model == "None":
with pipe.transformer.disable_adapter():
image = pipe(
prompt=final_prompt,
negative_prompt="",
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
else:
pipe.transformer.set_adapter(lora_model)
image = pipe(
prompt=final_prompt,
negative_prompt="",
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
images.append(image)
# Create a 2x2 grid from the 4 images
grid_image = create_grid_image(images)
return grid_image, ", ".join(map(str, seeds)), f"Model has been called {call_count} times"
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# SD3.5 Medium + Flow-GRPO
Our model is trained separately for different tasks, so it’s best to use the corresponding prompt format for each task.
**User Guide:**
1. Select a LoRA model (choose “None” to use the base model)
2. Click “Sample Prompt” to randomly select from ~1000 task-specific prompts, or write your own
3. Click “Run” to generate images (a 2×2 grid of 4 images will be produced)
**Note:**
- For the *Text Rendering* task, please enclose the text to be displayed in **double quotes (`"`)**, not single quotes (`'`)
""")
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
with gr.Row():
lora_model = gr.Dropdown(
label="LoRA Model",
choices=list(lora_models.keys()),
value="GenEval"
)
sample_prompt_button = gr.Button("Sample Prompt", scale=0, variant="secondary")
def update_sampled_prompt(lora_model):
return sample_prompt(lora_model)
sample_prompt_button.click(
fn=update_sampled_prompt,
inputs=[lora_model],
outputs=[prompt]
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Results (2x2 Grid)", show_label=True)
seed_display = gr.Textbox(label="Seeds Used", show_label=True)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Starting Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seeds", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=4.5, # Replace with defaults that work for your model
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=40, # Replace with defaults that work for your model
)
call_count_display = gr.Textbox(
label="Model Call Count",
value=f"Model has been called {get_call_count()} times",
interactive=False
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lora_model,
],
outputs=[result, seed_display, call_count_display],
)
if __name__ == "__main__":
demo.launch()