prithivMLmods commited on
Commit
a791c81
·
verified ·
1 Parent(s): 50c3bc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +298 -282
app.py CHANGED
@@ -1,313 +1,329 @@
1
  import os
2
- import gradio as gr
 
3
  import json
 
 
 
 
 
 
4
  import spaces
5
- import logging
6
  import torch
 
7
  from PIL import Image
8
- import random
9
- import time
10
- from hi_diffusers import HiDreamImagePipeline
11
- from hi_diffusers import HiDreamImageTransformer2DModel
12
- from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
13
- from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
14
- from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
15
- from huggingface_hub import ModelCard
16
-
17
- # Constants
18
- MODEL_PREFIX = "HiDream-ai"
19
- LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
20
-
21
- FAST_MODEL_CONFIG = {
22
- "path": f"{MODEL_PREFIX}/HiDream-I1-Full",
23
- "guidance_scale": 5.0,
24
- "num_inference_steps": 50,
25
- "shift": 3.0,
26
- "scheduler": FlowUniPCMultistepScheduler
27
- }
28
-
29
- RESOLUTION_OPTIONS = [
30
- "1024 × 1024 (Square)",
31
- "768 × 1360 (Portrait)",
32
- "1360 × 768 (Landscape)",
33
- "880 × 1168 (Portrait)",
34
- "1168 × 880 (Landscape)",
35
- "1248 × 832 (Landscape)",
36
- "832 × 1248 (Portrait)"
37
- ]
38
-
39
- # Load LoRAs from JSON file (assumed to be compatible with Hi-Dream)
40
- with open('loras.json', 'r') as f:
41
- loras = json.load(f)
42
 
43
- device = "cuda" if torch.cuda.is_available() else "cpu"
44
- MAX_SEED = 2**32 - 1
 
 
 
 
 
45
 
46
- # Parse resolution string to height and width
47
- def parse_resolution(res_str):
48
- mapping = {
49
- "1024 × 1024": (1024, 1024),
50
- "768 × 1360": (768, 1360),
51
- "1360 × 768": (1360, 768),
52
- "880 × 1168": (880, 1168),
53
- "1168 × 880": (1168, 880),
54
- "1248 × 832": (1248, 832),
55
- "832 × 1248": (832, 1248)
56
- }
57
- for key, (h, w) in mapping.items():
58
- if key in res_str:
59
- return h, w
60
- return 1024, 1024 # fallback
61
 
62
- # Load the Hi-Dream Fast Model pipeline
63
- pipe, MODEL_CONFIG = None, None
64
 
65
- def load_fast_model():
66
- global pipe, MODEL_CONFIG
67
- config = FAST_MODEL_CONFIG
68
- scheduler = config["scheduler"](
69
- num_train_timesteps=1000,
70
- shift=config["shift"],
71
- use_dynamic_shifting=False
72
- )
 
 
 
 
 
 
 
 
73
 
74
- tokenizer = PreTrainedTokenizerFast.from_pretrained(
75
- LLAMA_MODEL_NAME,
76
- use_fast=False
77
- )
78
- text_encoder = LlamaForCausalLM.from_pretrained(
79
- LLAMA_MODEL_NAME,
80
- output_hidden_states=True,
81
- output_attentions=True,
82
- torch_dtype=torch.bfloat16
83
- ).to(device)
84
 
85
- transformer = HiDreamImageTransformer2DModel.from_pretrained(
86
- config["path"],
87
- subfolder="transformer",
88
- torch_dtype=torch.bfloat16
89
- ).to(device)
 
 
90
 
91
- pipe = HiDreamImagePipeline.from_pretrained(
92
- config["path"],
93
- scheduler=scheduler,
94
- tokenizer_4=tokenizer,
95
- text_encoder_4=text_encoder,
96
- torch_dtype=torch.bfloat16
97
- ).to(device, torch.bfloat16)
98
 
99
- pipe.transformer = transformer
100
- MODEL_CONFIG = config
101
- return pipe, config
102
 
103
- # Generate image
104
- @spaces.GPU
105
- def generate_image(prompt, resolution, seed, guidance_scale, num_inference_steps):
106
- global pipe, MODEL_CONFIG
107
- if pipe is None:
108
- pipe, MODEL_CONFIG = load_fast_model()
 
 
109
 
110
- height, width = parse_resolution(resolution)
111
- if seed == -1 or seed is None:
112
  seed = random.randint(0, MAX_SEED)
113
- generator = torch.Generator(device=device).manual_seed(int(seed))
114
 
115
- result = pipe(
116
- prompt=prompt,
117
- height=height,
118
- width=width,
119
- guidance_scale=guidance_scale,
120
- num_inference_steps=num_inference_steps,
121
- num_images_per_prompt=1,
122
- generator=generator
123
- )
124
 
125
- return result.images[0], seed
126
-
127
- class calculateDuration:
128
- def __init__(self, activity_name=""):
129
- self.activity_name = activity_name
130
-
131
- def __enter__(self):
132
- self.start_time = time.time()
133
- return self
134
-
135
- def __exit__(self, exc_type, exc_value, traceback):
136
- self.end_time = time.time()
137
- self.elapsed_time = self.end_time - self.start_time
138
- if self.activity_name:
139
- print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
140
- else:
141
- print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
142
 
143
- def update_selection(evt: gr.SelectData, resolution):
144
- selected_lora = loras[evt.index]
145
- new_placeholder = f"Type a prompt for {selected_lora['title']}"
146
- lora_repo = selected_lora["repo"]
147
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
148
- if "aspect" in selected_lora:
149
- if selected_lora["aspect"] == "portrait":
150
- resolution = "768 × 1360 (Portrait)"
151
- elif selected_lora["aspect"] == "landscape":
152
- resolution = "1360 × 768 (Landscape)"
153
- else:
154
- resolution = "1024 × 1024 (Square)"
155
- return (
156
- gr.update(placeholder=new_placeholder),
157
- updated_text,
158
- evt.index,
159
- resolution,
160
- )
161
 
162
- def run_lora(prompt, resolution, cfg_scale, steps, selected_index, randomize_seed, seed):
163
- global pipe
164
- if pipe is None:
165
- pipe, _ = load_fast_model()
166
 
167
- if selected_index is not None:
168
- selected_lora = loras[selected_index]
169
- lora_path = selected_lora["repo"]
170
- weight_name = selected_lora.get("weights", None)
171
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
172
- pipe.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True)
173
- trigger_word = selected_lora.get("trigger_word", "")
174
- if trigger_word:
175
- if "trigger_position" in selected_lora and selected_lora["trigger_position"] == "prepend":
176
- prompt = f"{trigger_word} {prompt}"
177
- else:
178
- prompt = f"{prompt} {trigger_word}"
179
 
180
- if randomize_seed:
181
- seed = random.randint(0, MAX_SEED)
182
 
183
- with calculateDuration("Generating image"):
184
- final_image, used_seed = generate_image(prompt, resolution, seed, cfg_scale, steps)
185
- return final_image, used_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- def check_custom_model(link):
188
- split_link = link.split("/")
189
- if len(split_link) != 2:
190
- raise Exception("Invalid Hugging Face repository link format.")
191
- model_card = ModelCard.load(link)
192
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
193
- trigger_word = model_card.data.get("instance_prompt", "")
194
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
195
- safetensors_name = None # Simplified; assumes a safetensors file exists
196
- return split_link[1], link, safetensors_name, trigger_word, image_url
197
 
198
- def add_custom_lora(custom_lora):
199
- global loras
200
- if custom_lora:
201
- try:
202
- title, repo, path, trigger_word, image = check_custom_model(custom_lora)
203
- card = f'''
204
- <div class="custom_lora_card">
205
- <span>Loaded custom LoRA:</span>
206
- <div class="card_internal">
207
- <img src="{image}" />
208
- <div>
209
- <h3>{title}</h3>
210
- <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found."}</small>
211
- </div>
212
- </div>
213
- </div>
214
- '''
215
- existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
216
- if not existing_item_index:
217
- new_item = {
218
- "image": image,
219
- "title": title,
220
- "repo": repo,
221
- "weights": path,
222
- "trigger_word": trigger_word
223
- }
224
- existing_item_index = len(loras)
225
- loras.append(new_item)
226
- return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
227
- except Exception as e:
228
- gr.Warning(f"Invalid LoRA: {str(e)}")
229
- return gr.update(visible=True, value=f"Invalid LoRA: {str(e)}"), gr.update(visible=True), gr.update(), "", None, ""
230
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
231
 
232
- def remove_custom_lora():
233
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- css = '''
236
- #gen_btn{height: 100%}
237
- #gen_column{align-self: stretch}
238
- #title{text-align: center}
239
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
240
- #title img{width: 100px; margin-right: 0.5em}
241
- #gallery .grid-wrap{height: 10vh}
242
- #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
243
- .card_internal{display: flex;height: 100px;margin-top: .5em}
244
- .card_internal img{margin-right: 1em}
245
- .styler{--form-gap-width: 0px !important}
246
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- font = [gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
249
- with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60)) as app:
250
- title = gr.HTML(
251
- """<h1>Hi-Dream Full LoRA DLC 🤩</h1>""",
252
- elem_id="title",
253
- )
254
- selected_index = gr.State(None)
255
- with gr.Row():
256
- with gr.Column(scale=3):
257
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
258
- with gr.Column(scale=1, elem_id="gen_column"):
259
- generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
260
- with gr.Row():
261
- with gr.Column():
262
- selected_info = gr.Markdown("")
263
- gallery = gr.Gallery(
264
- [(item["image"], item["title"]) for item in loras],
265
- label="LoRA Gallery",
266
- allow_preview=False,
267
- columns=3,
268
- elem_id="gallery",
269
- show_share_button=False
270
- )
271
- with gr.Group():
272
- custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="linoyts/HiDream-yarn-art-LoRA")
273
- gr.Markdown("[Check the list of Hi-Dream LoRAs]", elem_id="lora_list")
274
- custom_lora_info = gr.HTML(visible=False)
275
- custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
276
- with gr.Column():
277
- result = gr.Image(label="Generated Image")
278
 
279
- with gr.Row():
280
- with gr.Accordion("Advanced Settings", open=False):
281
- cfg_scale = gr.Slider(label="Guidance Scale", minimum=0, maximum=20, step=0.1, value=FAST_MODEL_CONFIG["guidance_scale"])
282
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=FAST_MODEL_CONFIG["num_inference_steps"])
283
- resolution = gr.Radio(
284
- choices=RESOLUTION_OPTIONS,
285
- value=RESOLUTION_OPTIONS[0],
286
- label="Resolution"
287
- )
288
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
289
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
290
 
291
- gallery.select(
292
- update_selection,
293
- inputs=[resolution],
294
- outputs=[prompt, selected_info, selected_index, resolution]
295
- )
296
- custom_lora.input(
297
- add_custom_lora,
298
- inputs=[custom_lora],
299
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
300
- )
301
- custom_lora_button.click(
302
- remove_custom_lora,
303
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
304
- )
305
- gr.on(
306
- triggers=[generate_button.click, prompt.submit],
307
- fn=run_lora,
308
- inputs=[prompt, resolution, cfg_scale, steps, selected_index, randomize_seed, seed],
309
- outputs=[result, seed]
310
- )
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
- app.queue()
313
- app.launch()
 
1
  import os
2
+ import random
3
+ import uuid
4
  import json
5
+ import time
6
+ import asyncio
7
+ import re
8
+ from threading import Thread
9
+
10
+ import gradio as gr
11
  import spaces
 
12
  import torch
13
+ import numpy as np
14
  from PIL import Image
15
+ import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ from transformers import (
18
+ AutoProcessor,
19
+ Gemma3ForConditionalGeneration,
20
+ Qwen2VLForConditionalGeneration,
21
+ TextIteratorStreamer,
22
+ )
23
+ from transformers.image_utils import load_image
24
 
25
+ # Constants
26
+ MAX_MAX_NEW_TOKENS = 2048
27
+ DEFAULT_MAX_NEW_TOKENS = 1024
28
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
+ MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
30
 
31
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
32
 
33
+ # Helper function to return a progress bar HTML snippet.
34
+ def progress_bar_html(label: str) -> str:
35
+ return f'''
36
+ <div style="display: flex; align-items: center;">
37
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
38
+ <div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
39
+ <div style="width: 100%; height: 100%; background-color: #00FF00; animation: loading 1.5s linear infinite;"></div>
40
+ </div>
41
+ </div>
42
+ <style>
43
+ @keyframes loading {{
44
+ 0% {{ transform: translateX(-100%); }}
45
+ 100% {{ transform: translateX(100%); }}
46
+ }}
47
+ </style>
48
+ '''
49
 
50
+ # Qwen2-VL (for optional image inference)
 
 
 
 
 
 
 
 
 
51
 
52
+ MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
53
+ processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
54
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
55
+ MODEL_ID_VL,
56
+ trust_remote_code=True,
57
+ torch_dtype=torch.float16
58
+ ).to("cuda").eval()
59
 
60
+ def clean_chat_history(chat_history):
61
+ cleaned = []
62
+ for msg in chat_history:
63
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
64
+ cleaned.append(msg)
65
+ return cleaned
 
66
 
67
+ bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
68
+ bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
69
+ default_negative = os.getenv("default_negative", "")
70
 
71
+ def check_text(prompt, negative=""):
72
+ for i in bad_words:
73
+ if i in prompt:
74
+ return True
75
+ for i in bad_words_negative:
76
+ if i in negative:
77
+ return True
78
+ return False
79
 
80
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
81
+ if randomize_seed:
82
  seed = random.randint(0, MAX_SEED)
83
+ return seed
84
 
85
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
86
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
87
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
88
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
 
 
 
 
 
89
 
90
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ # Gemma3 Model (default for text, image, & video inference)
 
 
 
94
 
95
+ gemma3_model_id = "google/gemma-3-4b-it" #[or] Duplicate the space to use 12b
96
+ gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
97
+ gemma3_model_id, device_map="auto"
98
+ ).eval()
99
+ gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
 
 
 
 
 
 
 
100
 
101
+ # VIDEO PROCESSING HELPER
 
102
 
103
+ def downsample_video(video_path):
104
+ vidcap = cv2.VideoCapture(video_path)
105
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
106
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
107
+ frames = []
108
+ # Sample 10 evenly spaced frames.
109
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
110
+ for i in frame_indices:
111
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
112
+ success, image = vidcap.read()
113
+ if success:
114
+ # Convert from BGR to RGB and then to PIL Image.
115
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
116
+ pil_image = Image.fromarray(image)
117
+ timestamp = round(i / fps, 2)
118
+ frames.append((pil_image, timestamp))
119
+ vidcap.release()
120
+ return frames
121
 
122
+ # MAIN GENERATION FUNCTION
 
 
 
 
 
 
 
 
 
123
 
124
+ @spaces.GPU
125
+ def generate(
126
+ input_dict: dict,
127
+ chat_history: list[dict],
128
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
129
+ temperature: float = 0.6,
130
+ top_p: float = 0.9,
131
+ top_k: int = 50,
132
+ repetition_penalty: float = 1.2,
133
+ ):
134
+ text = input_dict["text"]
135
+ files = input_dict.get("files", [])
136
+ lower_text = text.lower().strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # ----- Qwen2-VL branch (triggered with @qwen2-vl) -----
139
+ if lower_text.startswith("@qwen2-vl"):
140
+ prompt_clean = re.sub(r"@qwen2-vl", "", text, flags=re.IGNORECASE).strip().strip('"')
141
+ if files:
142
+ images = [load_image(f) for f in files]
143
+ messages = [{
144
+ "role": "user",
145
+ "content": [
146
+ *[{"type": "image", "image": image} for image in images],
147
+ {"type": "text", "text": prompt_clean},
148
+ ]
149
+ }]
150
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
151
+ inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
152
+ else:
153
+ messages = [
154
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
155
+ {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
156
+ ]
157
+ inputs = processor.apply_chat_template(
158
+ messages, add_generation_prompt=True, tokenize=True,
159
+ return_dict=True, return_tensors="pt"
160
+ ).to("cuda", dtype=torch.float16)
161
+ streamer = TextIteratorStreamer(processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
162
+ generation_kwargs = {
163
+ **inputs,
164
+ "streamer": streamer,
165
+ "max_new_tokens": max_new_tokens,
166
+ "do_sample": True,
167
+ "temperature": temperature,
168
+ "top_p": top_p,
169
+ "top_k": top_k,
170
+ "repetition_penalty": repetition_penalty,
171
+ }
172
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
173
+ thread.start()
174
+ buffer = ""
175
+ yield progress_bar_html("Processing with Qwen2VL")
176
+ for new_text in streamer:
177
+ buffer += new_text
178
+ buffer = buffer.replace("<|im_end|>", "")
179
+ time.sleep(0.01)
180
+ yield buffer
181
+ return
182
 
183
+ # ----- Default branch: Gemma3 (for text, image, & video inference) -----
184
+ if files:
185
+ # Check if any provided file is a video based on extension.
186
+ video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm")
187
+ if any(str(f).lower().endswith(video_extensions) for f in files):
188
+ # Video inference branch.
189
+ prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
190
+ video_path = files[0]
191
+ frames = downsample_video(video_path)
192
+ messages = [
193
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
194
+ {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
195
+ ]
196
+ # Append each frame (with its timestamp) to the conversation.
197
+ for frame in frames:
198
+ image, timestamp = frame
199
+ image_path = f"video_frame_{uuid.uuid4().hex}.png"
200
+ image.save(image_path)
201
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
202
+ messages[1]["content"].append({"type": "image", "url": image_path})
203
+ inputs = gemma3_processor.apply_chat_template(
204
+ messages, add_generation_prompt=True, tokenize=True,
205
+ return_dict=True, return_tensors="pt"
206
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
207
+ streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
208
+ generation_kwargs = {
209
+ **inputs,
210
+ "streamer": streamer,
211
+ "max_new_tokens": max_new_tokens,
212
+ "do_sample": True,
213
+ "temperature": temperature,
214
+ "top_p": top_p,
215
+ "top_k": top_k,
216
+ "repetition_penalty": repetition_penalty,
217
+ }
218
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
219
+ thread.start()
220
+ buffer = ""
221
+ yield progress_bar_html("Processing video with Gemma3")
222
+ for new_text in streamer:
223
+ buffer += new_text
224
+ time.sleep(0.01)
225
+ yield buffer
226
+ return
227
+ else:
228
+ # Image inference branch.
229
+ prompt_clean = re.sub(r"@gemma3", "", text, flags=re.IGNORECASE).strip().strip('"')
230
+ images = [load_image(f) for f in files]
231
+ messages = [{
232
+ "role": "user",
233
+ "content": [
234
+ *[{"type": "image", "image": image} for image in images],
235
+ {"type": "text", "text": prompt_clean},
236
+ ]
237
+ }]
238
+ inputs = gemma3_processor.apply_chat_template(
239
+ messages, tokenize=True, add_generation_prompt=True,
240
+ return_dict=True, return_tensors="pt"
241
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
242
+ streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
243
+ generation_kwargs = {
244
+ **inputs,
245
+ "streamer": streamer,
246
+ "max_new_tokens": max_new_tokens,
247
+ "do_sample": True,
248
+ "temperature": temperature,
249
+ "top_p": top_p,
250
+ "top_k": top_k,
251
+ "repetition_penalty": repetition_penalty,
252
+ }
253
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
254
+ thread.start()
255
+ buffer = ""
256
+ yield progress_bar_html("Processing with Gemma3")
257
+ for new_text in streamer:
258
+ buffer += new_text
259
+ time.sleep(0.01)
260
+ yield buffer
261
+ return
262
+ else:
263
+ # Text-only inference branch.
264
+ messages = [
265
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
266
+ {"role": "user", "content": [{"type": "text", "text": text}]}
267
+ ]
268
+ inputs = gemma3_processor.apply_chat_template(
269
+ messages, add_generation_prompt=True, tokenize=True,
270
+ return_dict=True, return_tensors="pt"
271
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
272
+ streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
273
+ generation_kwargs = {
274
+ **inputs,
275
+ "streamer": streamer,
276
+ "max_new_tokens": max_new_tokens,
277
+ "do_sample": True,
278
+ "temperature": temperature,
279
+ "top_p": top_p,
280
+ "top_k": top_k,
281
+ "repetition_penalty": repetition_penalty,
282
+ }
283
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
284
+ thread.start()
285
+ outputs = []
286
+ for new_text in streamer:
287
+ outputs.append(new_text)
288
+ yield "".join(outputs)
289
+ final_response = "".join(outputs)
290
+ yield final_response
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ # Gradio Interface
 
 
 
 
 
 
 
 
 
 
294
 
295
+ demo = gr.ChatInterface(
296
+ fn=generate,
297
+ additional_inputs=[
298
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
299
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
300
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
301
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
302
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
303
+ ],
304
+ examples=[
305
+ [{"text": "Create a short story based on the image.","files": ["examples/1111.jpg"]}],
306
+ [{"text": "Explain the Image", "files": ["examples/3.jpg"]}],
307
+ [{"text": "Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}],
308
+ [{"text": "Which movie character is this?", "files": ["examples/9999.jpg"]}],
309
+ ["Explain Critical Temperature of Substance"],
310
+ [{"text": "@qwen2-vl Transcription of the letter", "files": ["examples/222.png"]}],
311
+ [{"text": "Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}],
312
+ [{"text": "Describe the video", "files": ["examples/Missing.mp4"]}],
313
+ [{"text": "Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
314
+ [{"text": "Summarize the events in this video", "files": ["examples/sky.mp4"]}],
315
+ [{"text": "What is in the video ?", "files": ["examples/redlight.mp4"]}],
316
+ ["Python Program for Array Rotation"],
317
+ ["Explain Critical Temperature of Substance"]
318
+ ],
319
+ cache_examples=False,
320
+ type="messages",
321
+ description="# **Gemma 3 Multimodal** \n`Use @qwen2-vl to switch to Qwen2-VL OCR for image inference and @video-infer for video input`",
322
+ fill_height=True,
323
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Tag with @qwen2-vl for Qwen2-VL inference if needed."),
324
+ stop_btn="Stop Generation",
325
+ multimodal=True,
326
+ )
327
 
328
+ if __name__ == "__main__":
329
+ demo.queue(max_size=20).launch(share=True)