import spaces
import gradio as gr
import numpy as np
import os
import torch
import random
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
from PIL import Image
from data.data_utils import add_special_tokens, pil_img2rgb
from data.transforms import ImageTransform
from inferencer import InterleaveInferencer
from modeling.autoencoder import load_ae
from modeling.bagel.qwen2_navit import NaiveCache
from modeling.bagel import (
BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer
from huggingface_hub import snapshot_download
save_dir = "./model"
repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
cache_dir = save_dir + "/cache"
snapshot_download(cache_dir=cache_dir,
local_dir=save_dir,
repo_id=repo_id,
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
)
# Model Initialization
model_path = "./model" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"
vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
vit_config.rope = False
vit_config.num_hidden_layers -= 1
vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
config = BagelConfig(
visual_gen=True,
visual_und=True,
llm_config=llm_config,
vit_config=vit_config,
vae_config=vae_config,
vit_max_num_patch_per_side=70,
connector_act='gelu_pytorch_tanh',
latent_patch_size=2,
max_latent_size=64,
)
with init_empty_weights():
language_model = Qwen2ForCausalLM(llm_config)
vit_model = SiglipVisionModel(vit_config)
model = Bagel(language_model, vit_model, config)
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
vae_transform = ImageTransform(1024, 512, 16)
vit_transform = ImageTransform(980, 224, 14)
# Model Loading and Multi GPU Infernece Preparing
device_map = infer_auto_device_map(
model,
max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
)
same_device_modules = [
'language_model.model.embed_tokens',
'time_embedder',
'latent_pos_embed',
'vae2llm',
'llm2vae',
'connector',
'vit_pos_embed'
]
if torch.cuda.device_count() == 1:
first_device = device_map.get(same_device_modules[0], "cuda:0")
for k in same_device_modules:
if k in device_map:
device_map[k] = first_device
else:
device_map[k] = "cuda:0"
else:
first_device = device_map.get(same_device_modules[0])
for k in same_device_modules:
if k in device_map:
device_map[k] = first_device
model = load_checkpoint_and_dispatch(
model,
checkpoint=os.path.join(model_path, "ema.safetensors"),
device_map=device_map,
offload_buffers=True,
dtype=torch.bfloat16,
force_hooks=True,
).eval()
# Inferencer Preparing
inferencer = InterleaveInferencer(
model=model,
vae_model=vae_model,
tokenizer=tokenizer,
vae_transform=vae_transform,
vit_transform=vit_transform,
new_token_ids=new_token_ids,
)
def set_seed(seed):
"""Set random seeds for reproducibility"""
if seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
return seed
# Text to Image function with thinking option and hyperparameters
@spaces.GPU(duration=90)
def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
timestep_shift=3.0, num_timesteps=50,
cfg_renorm_min=1.0, cfg_renorm_type="global",
max_think_token_n=1024, do_sample=False, text_temperature=0.3,
seed=0, image_ratio="1:1"):
# Set seed for reproducibility
set_seed(seed)
if image_ratio == "1:1":
image_shapes = (1024, 1024)
elif image_ratio == "4:3":
image_shapes = (768, 1024)
elif image_ratio == "3:4":
image_shapes = (1024, 768)
elif image_ratio == "16:9":
image_shapes = (576, 1024)
elif image_ratio == "9:16":
image_shapes = (1024, 576)
# Set hyperparameters
inference_hyper = dict(
max_think_token_n=max_think_token_n if show_thinking else 1024,
do_sample=do_sample if show_thinking else False,
temperature=text_temperature if show_thinking else 0.3,
cfg_text_scale=cfg_text_scale,
cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
timestep_shift=timestep_shift,
num_timesteps=num_timesteps,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
image_shapes=image_shapes,
)
result = {"text": "", "image": None}
# Call inferencer with or without think parameter based on user choice
for i in inferencer(text=prompt, think=show_thinking, understanding_output=False, **inference_hyper):
# print(type(i)) # For debugging stream
if type(i) == str:
result["text"] += i
else:
result["image"] = i
yield result["image"], result.get("text", "")
# Image Understanding function with thinking option and hyperparameters
@spaces.GPU(duration=90)
def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
do_sample=False, text_temperature=0.3, max_new_tokens=512):
if image is None:
yield "Please upload an image for understanding."
return
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = pil_img2rgb(image)
# Set hyperparameters
inference_hyper = dict(
do_sample=do_sample,
temperature=text_temperature,
max_think_token_n=max_new_tokens, # Set max_length for text generation
)
result_text = ""
# Use show_thinking parameter to control thinking process
for i in inferencer(image=image, text=prompt, think=show_thinking,
understanding_output=True, **inference_hyper):
if type(i) == str:
result_text += i
yield result_text
# else: This branch seems unused in original, as understanding_output=True typically yields text.
# If it yielded image, it would be an intermediate. For final output, it's text.
# For now, we assume it only yields text.
yield result_text # Ensure final text is yielded
# Image Editing function with thinking option and hyperparameters
@spaces.GPU(duration=90)
def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
cfg_img_scale=2.0, cfg_interval=0.0,
timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
cfg_renorm_type="text_channel", max_think_token_n=1024,
do_sample=False, text_temperature=0.3, seed=0):
# Set seed for reproducibility
set_seed(seed)
if image is None:
yield None, "Please upload an image for editing." # Yield tuple for image/text
return
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = pil_img2rgb(image)
# Set hyperparameters
inference_hyper = dict(
max_think_token_n=max_think_token_n if show_thinking else 1024,
do_sample=do_sample if show_thinking else False,
temperature=text_temperature if show_thinking else 0.3,
cfg_text_scale=cfg_text_scale,
cfg_img_scale=cfg_img_scale,
cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
timestep_shift=timestep_shift,
num_timesteps=num_timesteps,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
)
# Include thinking parameter based on user choice
result = {"text": "", "image": None}
for i in inferencer(image=image, text=prompt, think=show_thinking, understanding_output=False, **inference_hyper):
if type(i) == str:
result["text"] += i
else:
result["image"] = i
yield result["image"], result.get("text", "") # Yield tuple for image/text
# Helper function to load example images
def load_example_image(image_path):
try:
return Image.open(image_path)
except Exception as e:
print(f"Error loading example image: {e}")
return None
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("""
# BAGEL Multimodal Chatbot
Interact with BAGEL to generate images from text, edit existing images, or understand image content.
""")
# Chatbot display area
chatbot = gr.Chatbot(label="Chat History", height=500, avatar_images=(None, "https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/BAGEL_favicon.png"))
# Input area
with gr.Row():
image_input = gr.Image(type="pil", label="Optional: Upload an Image (for Image Understanding/Edit)", scale=0.5, value=None)
with gr.Column(scale=1.5):
user_prompt = gr.Textbox(label="Your Message", placeholder="Type your prompt here...", lines=3)
with gr.Row():
mode_selector = gr.Radio(
choices=["Text to Image", "Image Understanding", "Image Edit"],
value="Text to Image",
label="Select Mode",
interactive=True
)
submit_btn = gr.Button("Send", variant="primary")
# Global/Shared Hyperparameters
with gr.Accordion("General Settings & Hyperparameters", open=False) as general_accordion:
with gr.Row():
show_thinking_global = gr.Checkbox(label="Show Thinking Process", value=False, info="Enable to see model's intermediate thinking text.")
seed_global = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, label="Seed", info="0 for random seed, positive for reproducible results.")
# Container for thinking-specific parameters, visibility controlled by show_thinking_global
thinking_params_container = gr.Group(visible=False)
with thinking_params_container:
gr.Markdown("#### Thinking Process Parameters (affect text generation)")
with gr.Row():
common_do_sample = gr.Checkbox(label="Enable Sampling", value=False, info="Enable sampling for text generation (otherwise greedy).")
common_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, label="Text Temperature", info="Controls randomness in text generation (higher = more random).")
common_max_think_token_n = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Think Tokens / Max New Tokens", info="Maximum number of tokens for thinking (T2I/Edit) or generated text (Understanding).")
# T2I Hyperparameters
t2i_params_accordion = gr.Accordion("Text to Image Specific Parameters", open=False)
with t2i_params_accordion:
gr.Markdown("#### Text to Image Parameters")
with gr.Row():
t2i_image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"], value="1:1", label="Image Ratio", info="The longer size is fixed to 1024 pixels.")
with gr.Row():
t2i_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0 recommended).")
t2i_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, label="CFG Interval", info="Start of Classifier-Free Guidance application interval (end is fixed at 1.0).")
with gr.Row():
t2i_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"], value="global", label="CFG Renorm Type", info="Normalization type for CFG. Use 'global' if the generated image is blurry.")
t2i_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="CFG Renorm Min", info="Minimum value for CFG Renormalization (1.0 disables CFG-Renorm).")
with gr.Row():
t2i_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, label="Timesteps", info="Total denoising steps for image generation.")
t2i_timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, label="Timestep Shift", info="Higher values for layout control, lower for fine details.")
# Image Edit Hyperparameters
edit_params_accordion = gr.Accordion("Image Edit Specific Parameters", open=False)
with edit_params_accordion:
gr.Markdown("#### Image Edit Parameters")
with gr.Row():
edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, label="CFG Text Scale", info="Controls how strongly the model follows the text prompt for editing.")
edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, label="CFG Image Scale", info="Controls how much the model preserves input image details during editing.")
with gr.Row():
edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="CFG Interval", info="Start of CFG application interval for editing (end is fixed at 1.0).")
edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"], value="text_channel", label="CFG Renorm Type", info="Normalization type for CFG during editing. Use 'global' if output is blurry.")
with gr.Row():
edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="CFG Renorm Min", info="Minimum value for CFG Renormalization during editing (1.0 disables CFG-Renorm).")
with gr.Row():
edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, label="Timesteps", info="Total denoising steps for image editing.")
edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, label="Timestep Shift", info="Higher values for layout control, lower for fine details during editing.")
# Main chat processing function
@spaces.GPU(duration=90) # Apply GPU decorator to the combined function
def process_chat_message(history, prompt, uploaded_image, mode,
show_thinking_global_val, seed_global_val,
common_do_sample_val, common_text_temperature_val, common_max_think_token_n_val,
t2i_cfg_text_scale_val, t2i_cfg_interval_val, t2i_timestep_shift_val,
t2i_num_timesteps_val, t2i_cfg_renorm_min_val, t2i_cfg_renorm_type_val,
t2i_image_ratio_val,
edit_cfg_text_scale_val, edit_cfg_img_scale_val, edit_cfg_interval_val,
edit_timestep_shift_val, edit_num_timesteps_val, edit_cfg_renorm_min_val,
edit_cfg_renorm_type_val):
# Append user message to history
history.append([prompt, None])
# Define common parameters for inference functions
common_infer_params = dict(
show_thinking=show_thinking_global_val,
do_sample=common_do_sample_val,
text_temperature=common_text_temperature_val,
)
try:
if mode == "Text to Image":
# Add T2I specific parameters, including max_think_token_n and seed
t2i_params = {
**common_infer_params,
"max_think_token_n": common_max_think_token_n_val,
"seed": seed_global_val,
"cfg_text_scale": t2i_cfg_text_scale_val,
"cfg_interval": t2i_cfg_interval_val,
"timestep_shift": t2i_timestep_shift_val,
"num_timesteps": t2i_num_timesteps_val,
"cfg_renorm_min": t2i_cfg_renorm_min_val,
"cfg_renorm_type": t2i_cfg_renorm_type_val,
"image_ratio": t2i_image_ratio_val,
}
for img, txt in text_to_image(
prompt=prompt,
**t2i_params
):
# For Text to Image, yield image first, then thinking text (if available)
if img is not None:
history[-1] = [prompt, (img, txt)]
elif txt: # Only update text if image is not ready yet
history[-1] = [prompt, txt]
yield history, gr.update(value="") # Update chatbot and clear input
elif mode == "Image Understanding":
if uploaded_image is None:
history[-1] = [prompt, "Please upload an image for Image Understanding."]
yield history, gr.update(value="")
return
# Add Understanding specific parameters (max_new_tokens maps to common_max_think_token_n)
# Note: seed is not used in image_understanding
understand_params = {
**common_infer_params,
"max_new_tokens": common_max_think_token_n_val,
}
# Remove seed from parameters as it's not used by image_understanding
understand_params.pop('seed', None)
for txt in image_understanding(
image=uploaded_image,
prompt=prompt,
**understand_params
):
history[-1] = [prompt, txt]
yield history, gr.update(value="")
elif mode == "Image Edit":
if uploaded_image is None:
history[-1] = [prompt, "Please upload an image for Image Editing."]
yield history, gr.update(value="")
return
# Add Edit specific parameters, including max_think_token_n and seed
edit_params = {
**common_infer_params,
"max_think_token_n": common_max_think_token_n_val,
"seed": seed_global_val,
"cfg_text_scale": edit_cfg_text_scale_val,
"cfg_img_scale": edit_cfg_img_scale_val,
"cfg_interval": edit_cfg_interval_val,
"timestep_shift": edit_timestep_shift_val,
"num_timesteps": edit_num_timesteps_val,
"cfg_renorm_min": edit_cfg_renorm_min_val,
"cfg_renorm_type": edit_cfg_renorm_type_val,
}
for img, txt in edit_image(
image=uploaded_image,
prompt=prompt,
**edit_params
):
# For Image Edit, yield image first, then thinking text (if available)
if img is not None:
history[-1] = [prompt, (img, txt)]
elif txt: # Only update text if image is not ready yet
history[-1] = [prompt, txt]
yield history, gr.update(value="")
except Exception as e:
history[-1] = [prompt, f"An error occurred: {e}"]
yield history, gr.update(value="") # Update history with error and clear input
# Event handlers for dynamic UI updates and submission
# Control visibility of thinking parameters
show_thinking_global.change(
fn=lambda x: gr.update(visible=x),
inputs=[show_thinking_global],
outputs=[thinking_params_container]
)
# Clear image input if mode switches to Text to Image
mode_selector.change(
fn=lambda mode: gr.update(value=None) if mode == "Text to Image" else gr.update(),
inputs=[mode_selector],
outputs=[image_input]
)
# List of all input components whose values are passed to process_chat_message
inputs_list = [
chatbot, user_prompt, image_input, mode_selector,
show_thinking_global, seed_global,
common_do_sample, common_text_temperature, common_max_think_token_n,
t2i_cfg_text_scale, t2i_cfg_interval, t2i_timestep_shift,
t2i_num_timesteps, t2i_cfg_renorm_min, t2i_cfg_renorm_type,
t2i_image_ratio,
edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
edit_timestep_shift, edit_num_timesteps, edit_cfg_renorm_min,
edit_cfg_renorm_type
]
# Link submit button and text input 'Enter' key to the processing function
submit_btn.click(
fn=process_chat_message,
inputs=inputs_list,
outputs=[chatbot, user_prompt],
scroll_to_output=True,
queue=False, # Set to True if long generation times cause issues, but might affect responsiveness
)
user_prompt.submit( # Allows pressing Enter in textbox to submit
fn=process_chat_message,
inputs=inputs_list,
outputs=[chatbot, user_prompt],
scroll_to_output=True,
queue=False,
)
demo.launch()