import spaces import gradio as gr import numpy as np import os import torch import random import subprocess subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights from PIL import Image from data.data_utils import add_special_tokens, pil_img2rgb from data.transforms import ImageTransform from inferencer import InterleaveInferencer from modeling.autoencoder import load_ae from modeling.bagel.qwen2_navit import NaiveCache from modeling.bagel import ( BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel ) from modeling.qwen2 import Qwen2Tokenizer from huggingface_hub import snapshot_download save_dir = "./model" repo_id = "ByteDance-Seed/BAGEL-7B-MoT" cache_dir = save_dir + "/cache" snapshot_download(cache_dir=cache_dir, local_dir=save_dir, repo_id=repo_id, local_dir_use_symlinks=False, resume_download=True, allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"], ) # Model Initialization model_path = "./model" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json")) llm_config.qk_norm = True llm_config.tie_word_embeddings = False llm_config.layer_module = "Qwen2MoTDecoderLayer" vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json")) vit_config.rope = False vit_config.num_hidden_layers -= 1 vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors")) config = BagelConfig( visual_gen=True, visual_und=True, llm_config=llm_config, vit_config=vit_config, vae_config=vae_config, vit_max_num_patch_per_side=70, connector_act='gelu_pytorch_tanh', latent_patch_size=2, max_latent_size=64, ) with init_empty_weights(): language_model = Qwen2ForCausalLM(llm_config) vit_model = SiglipVisionModel(vit_config) model = Bagel(language_model, vit_model, config) model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True) tokenizer = Qwen2Tokenizer.from_pretrained(model_path) tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) vae_transform = ImageTransform(1024, 512, 16) vit_transform = ImageTransform(980, 224, 14) # Model Loading and Multi GPU Infernece Preparing device_map = infer_auto_device_map( model, max_memory={i: "80GiB" for i in range(torch.cuda.device_count())}, no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], ) same_device_modules = [ 'language_model.model.embed_tokens', 'time_embedder', 'latent_pos_embed', 'vae2llm', 'llm2vae', 'connector', 'vit_pos_embed' ] if torch.cuda.device_count() == 1: first_device = device_map.get(same_device_modules[0], "cuda:0") for k in same_device_modules: if k in device_map: device_map[k] = first_device else: device_map[k] = "cuda:0" else: first_device = device_map.get(same_device_modules[0]) for k in same_device_modules: if k in device_map: device_map[k] = first_device model = load_checkpoint_and_dispatch( model, checkpoint=os.path.join(model_path, "ema.safetensors"), device_map=device_map, offload_buffers=True, dtype=torch.bfloat16, force_hooks=True, ).eval() # Inferencer Preparing inferencer = InterleaveInferencer( model=model, vae_model=vae_model, tokenizer=tokenizer, vae_transform=vae_transform, vit_transform=vit_transform, new_token_ids=new_token_ids, ) def set_seed(seed): """Set random seeds for reproducibility""" if seed > 0: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False return seed # Text to Image function with thinking option and hyperparameters @spaces.GPU(duration=90) def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4, timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0, cfg_renorm_type="global", max_think_token_n=1024, do_sample=False, text_temperature=0.3, seed=0, image_ratio="1:1"): # Set seed for reproducibility set_seed(seed) if image_ratio == "1:1": image_shapes = (1024, 1024) elif image_ratio == "4:3": image_shapes = (768, 1024) elif image_ratio == "3:4": image_shapes = (1024, 768) elif image_ratio == "16:9": image_shapes = (576, 1024) elif image_ratio == "9:16": image_shapes = (1024, 576) # Set hyperparameters inference_hyper = dict( max_think_token_n=max_think_token_n if show_thinking else 1024, do_sample=do_sample if show_thinking else False, temperature=text_temperature if show_thinking else 0.3, cfg_text_scale=cfg_text_scale, cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0 timestep_shift=timestep_shift, num_timesteps=num_timesteps, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, image_shapes=image_shapes, ) result = {"text": "", "image": None} # Call inferencer with or without think parameter based on user choice for i in inferencer(text=prompt, think=show_thinking, understanding_output=False, **inference_hyper): # print(type(i)) # For debugging stream if type(i) == str: result["text"] += i else: result["image"] = i yield result["image"], result.get("text", "") # Image Understanding function with thinking option and hyperparameters @spaces.GPU(duration=90) def image_understanding(image: Image.Image, prompt: str, show_thinking=False, do_sample=False, text_temperature=0.3, max_new_tokens=512): if image is None: yield "Please upload an image for understanding." return if isinstance(image, np.ndarray): image = Image.fromarray(image) image = pil_img2rgb(image) # Set hyperparameters inference_hyper = dict( do_sample=do_sample, temperature=text_temperature, max_think_token_n=max_new_tokens, # Set max_length for text generation ) result_text = "" # Use show_thinking parameter to control thinking process for i in inferencer(image=image, text=prompt, think=show_thinking, understanding_output=True, **inference_hyper): if type(i) == str: result_text += i yield result_text # else: This branch seems unused in original, as understanding_output=True typically yields text. # If it yielded image, it would be an intermediate. For final output, it's text. # For now, we assume it only yields text. yield result_text # Ensure final text is yielded # Image Editing function with thinking option and hyperparameters @spaces.GPU(duration=90) def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0, cfg_img_scale=2.0, cfg_interval=0.0, timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0, cfg_renorm_type="text_channel", max_think_token_n=1024, do_sample=False, text_temperature=0.3, seed=0): # Set seed for reproducibility set_seed(seed) if image is None: yield None, "Please upload an image for editing." # Yield tuple for image/text return if isinstance(image, np.ndarray): image = Image.fromarray(image) image = pil_img2rgb(image) # Set hyperparameters inference_hyper = dict( max_think_token_n=max_think_token_n if show_thinking else 1024, do_sample=do_sample if show_thinking else False, temperature=text_temperature if show_thinking else 0.3, cfg_text_scale=cfg_text_scale, cfg_img_scale=cfg_img_scale, cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0 timestep_shift=timestep_shift, num_timesteps=num_timesteps, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, ) # Include thinking parameter based on user choice result = {"text": "", "image": None} for i in inferencer(image=image, text=prompt, think=show_thinking, understanding_output=False, **inference_hyper): if type(i) == str: result["text"] += i else: result["image"] = i yield result["image"], result.get("text", "") # Yield tuple for image/text # Helper function to load example images def load_example_image(image_path): try: return Image.open(image_path) except Exception as e: print(f"Error loading example image: {e}") return None # Gradio UI with gr.Blocks() as demo: gr.Markdown("""
BAGEL
# BAGEL Multimodal Chatbot Interact with BAGEL to generate images from text, edit existing images, or understand image content. """) # Chatbot display area chatbot = gr.Chatbot(label="Chat History", height=500, avatar_images=(None, "https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/BAGEL_favicon.png")) # Input area with gr.Row(): image_input = gr.Image(type="pil", label="Optional: Upload an Image (for Image Understanding/Edit)", scale=0.5, value=None) with gr.Column(scale=1.5): user_prompt = gr.Textbox(label="Your Message", placeholder="Type your prompt here...", lines=3) with gr.Row(): mode_selector = gr.Radio( choices=["Text to Image", "Image Understanding", "Image Edit"], value="Text to Image", label="Select Mode", interactive=True ) submit_btn = gr.Button("Send", variant="primary") # Global/Shared Hyperparameters with gr.Accordion("General Settings & Hyperparameters", open=False) as general_accordion: with gr.Row(): show_thinking_global = gr.Checkbox(label="Show Thinking Process", value=False, info="Enable to see model's intermediate thinking text.") seed_global = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, label="Seed", info="0 for random seed, positive for reproducible results.") # Container for thinking-specific parameters, visibility controlled by show_thinking_global thinking_params_container = gr.Group(visible=False) with thinking_params_container: gr.Markdown("#### Thinking Process Parameters (affect text generation)") with gr.Row(): common_do_sample = gr.Checkbox(label="Enable Sampling", value=False, info="Enable sampling for text generation (otherwise greedy).") common_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, label="Text Temperature", info="Controls randomness in text generation (higher = more random).") common_max_think_token_n = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Think Tokens / Max New Tokens", info="Maximum number of tokens for thinking (T2I/Edit) or generated text (Understanding).") # T2I Hyperparameters t2i_params_accordion = gr.Accordion("Text to Image Specific Parameters", open=False) with t2i_params_accordion: gr.Markdown("#### Text to Image Parameters") with gr.Row(): t2i_image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"], value="1:1", label="Image Ratio", info="The longer size is fixed to 1024 pixels.") with gr.Row(): t2i_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0 recommended).") t2i_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, label="CFG Interval", info="Start of Classifier-Free Guidance application interval (end is fixed at 1.0).") with gr.Row(): t2i_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"], value="global", label="CFG Renorm Type", info="Normalization type for CFG. Use 'global' if the generated image is blurry.") t2i_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="CFG Renorm Min", info="Minimum value for CFG Renormalization (1.0 disables CFG-Renorm).") with gr.Row(): t2i_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, label="Timesteps", info="Total denoising steps for image generation.") t2i_timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, label="Timestep Shift", info="Higher values for layout control, lower for fine details.") # Image Edit Hyperparameters edit_params_accordion = gr.Accordion("Image Edit Specific Parameters", open=False) with edit_params_accordion: gr.Markdown("#### Image Edit Parameters") with gr.Row(): edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, label="CFG Text Scale", info="Controls how strongly the model follows the text prompt for editing.") edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, label="CFG Image Scale", info="Controls how much the model preserves input image details during editing.") with gr.Row(): edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="CFG Interval", info="Start of CFG application interval for editing (end is fixed at 1.0).") edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"], value="text_channel", label="CFG Renorm Type", info="Normalization type for CFG during editing. Use 'global' if output is blurry.") with gr.Row(): edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="CFG Renorm Min", info="Minimum value for CFG Renormalization during editing (1.0 disables CFG-Renorm).") with gr.Row(): edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, label="Timesteps", info="Total denoising steps for image editing.") edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, label="Timestep Shift", info="Higher values for layout control, lower for fine details during editing.") # Main chat processing function @spaces.GPU(duration=90) # Apply GPU decorator to the combined function def process_chat_message(history, prompt, uploaded_image, mode, show_thinking_global_val, seed_global_val, common_do_sample_val, common_text_temperature_val, common_max_think_token_n_val, t2i_cfg_text_scale_val, t2i_cfg_interval_val, t2i_timestep_shift_val, t2i_num_timesteps_val, t2i_cfg_renorm_min_val, t2i_cfg_renorm_type_val, t2i_image_ratio_val, edit_cfg_text_scale_val, edit_cfg_img_scale_val, edit_cfg_interval_val, edit_timestep_shift_val, edit_num_timesteps_val, edit_cfg_renorm_min_val, edit_cfg_renorm_type_val): # Append user message to history history.append([prompt, None]) # Define common parameters for inference functions common_infer_params = dict( show_thinking=show_thinking_global_val, do_sample=common_do_sample_val, text_temperature=common_text_temperature_val, ) try: if mode == "Text to Image": # Add T2I specific parameters, including max_think_token_n and seed t2i_params = { **common_infer_params, "max_think_token_n": common_max_think_token_n_val, "seed": seed_global_val, "cfg_text_scale": t2i_cfg_text_scale_val, "cfg_interval": t2i_cfg_interval_val, "timestep_shift": t2i_timestep_shift_val, "num_timesteps": t2i_num_timesteps_val, "cfg_renorm_min": t2i_cfg_renorm_min_val, "cfg_renorm_type": t2i_cfg_renorm_type_val, "image_ratio": t2i_image_ratio_val, } for img, txt in text_to_image( prompt=prompt, **t2i_params ): # For Text to Image, yield image first, then thinking text (if available) if img is not None: history[-1] = [prompt, (img, txt)] elif txt: # Only update text if image is not ready yet history[-1] = [prompt, txt] yield history, gr.update(value="") # Update chatbot and clear input elif mode == "Image Understanding": if uploaded_image is None: history[-1] = [prompt, "Please upload an image for Image Understanding."] yield history, gr.update(value="") return # Add Understanding specific parameters (max_new_tokens maps to common_max_think_token_n) # Note: seed is not used in image_understanding understand_params = { **common_infer_params, "max_new_tokens": common_max_think_token_n_val, } # Remove seed from parameters as it's not used by image_understanding understand_params.pop('seed', None) for txt in image_understanding( image=uploaded_image, prompt=prompt, **understand_params ): history[-1] = [prompt, txt] yield history, gr.update(value="") elif mode == "Image Edit": if uploaded_image is None: history[-1] = [prompt, "Please upload an image for Image Editing."] yield history, gr.update(value="") return # Add Edit specific parameters, including max_think_token_n and seed edit_params = { **common_infer_params, "max_think_token_n": common_max_think_token_n_val, "seed": seed_global_val, "cfg_text_scale": edit_cfg_text_scale_val, "cfg_img_scale": edit_cfg_img_scale_val, "cfg_interval": edit_cfg_interval_val, "timestep_shift": edit_timestep_shift_val, "num_timesteps": edit_num_timesteps_val, "cfg_renorm_min": edit_cfg_renorm_min_val, "cfg_renorm_type": edit_cfg_renorm_type_val, } for img, txt in edit_image( image=uploaded_image, prompt=prompt, **edit_params ): # For Image Edit, yield image first, then thinking text (if available) if img is not None: history[-1] = [prompt, (img, txt)] elif txt: # Only update text if image is not ready yet history[-1] = [prompt, txt] yield history, gr.update(value="") except Exception as e: history[-1] = [prompt, f"An error occurred: {e}"] yield history, gr.update(value="") # Update history with error and clear input # Event handlers for dynamic UI updates and submission # Control visibility of thinking parameters show_thinking_global.change( fn=lambda x: gr.update(visible=x), inputs=[show_thinking_global], outputs=[thinking_params_container] ) # Clear image input if mode switches to Text to Image mode_selector.change( fn=lambda mode: gr.update(value=None) if mode == "Text to Image" else gr.update(), inputs=[mode_selector], outputs=[image_input] ) # List of all input components whose values are passed to process_chat_message inputs_list = [ chatbot, user_prompt, image_input, mode_selector, show_thinking_global, seed_global, common_do_sample, common_text_temperature, common_max_think_token_n, t2i_cfg_text_scale, t2i_cfg_interval, t2i_timestep_shift, t2i_num_timesteps, t2i_cfg_renorm_min, t2i_cfg_renorm_type, t2i_image_ratio, edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval, edit_timestep_shift, edit_num_timesteps, edit_cfg_renorm_min, edit_cfg_renorm_type ] # Link submit button and text input 'Enter' key to the processing function submit_btn.click( fn=process_chat_message, inputs=inputs_list, outputs=[chatbot, user_prompt], scroll_to_output=True, queue=False, # Set to True if long generation times cause issues, but might affect responsiveness ) user_prompt.submit( # Allows pressing Enter in textbox to submit fn=process_chat_message, inputs=inputs_list, outputs=[chatbot, user_prompt], scroll_to_output=True, queue=False, ) demo.launch()