Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
@@ -184,7 +184,7 @@ class GenerationManager:
|
|
184 |
}
|
185 |
payload = {
|
186 |
"prompt": prompt,
|
187 |
-
"enable_safety_checker":
|
188 |
"enable_base64_output": True, # Enable base64 output
|
189 |
"size": "1024*1024",
|
190 |
"seed": -1,
|
@@ -295,6 +295,7 @@ async def poll_once(manager, backend, request_id):
|
|
295 |
|
296 |
# Handle base64 output
|
297 |
output = data["outputs"][0]
|
|
|
298 |
|
299 |
# Check if it's a base64 string or URL
|
300 |
if isinstance(output, str) and output.startswith("http"):
|
@@ -308,7 +309,7 @@ async def poll_once(manager, backend, request_id):
|
|
308 |
output, str
|
309 |
) and not output.startswith("data:image"):
|
310 |
# Convert raw base64 to data URI format
|
311 |
-
return f"data:image/
|
312 |
else:
|
313 |
# Already in data URI format
|
314 |
return output
|
@@ -329,23 +330,45 @@ async def poll_once(manager, backend, request_id):
|
|
329 |
raise Exception(f"Poll error: {response.status}")
|
330 |
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
# Use a state variable to store session ID
|
333 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
334 |
session_id = gr.State(None) # Add this to store session ID
|
335 |
|
336 |
-
gr.Markdown("# 🌊
|
337 |
|
338 |
# Add the introduction with link to WaveSpeedAI
|
339 |
gr.Markdown(
|
340 |
-
"
|
|
|
|
|
|
|
341 |
)
|
342 |
gr.Markdown(
|
343 |
-
"
|
|
|
|
|
344 |
)
|
345 |
|
346 |
with gr.Row():
|
347 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
|
348 |
input_text = gr.Textbox(
|
|
|
349 |
label="Enter your prompt",
|
350 |
placeholder="Type here...",
|
351 |
lines=3,
|
@@ -353,6 +376,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
353 |
with gr.Column(scale=1):
|
354 |
generate_btn = gr.Button("Generate", variant="primary")
|
355 |
|
|
|
|
|
356 |
# Two status boxes - small (default) and big (during generation)
|
357 |
small_status_box = gr.Markdown("Ready to generate images",
|
358 |
elem_id="small-status")
|
@@ -374,6 +399,24 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
374 |
|
375 |
performance_plot = gr.Plot(label="Performance Metrics")
|
376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
# Add custom CSS for the big status box
|
378 |
css = """
|
379 |
#big-status-row {
|
@@ -442,6 +485,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
442 |
None,
|
443 |
None,
|
444 |
session_id, # Return the session ID
|
|
|
445 |
)
|
446 |
return
|
447 |
|
@@ -459,6 +503,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
459 |
None,
|
460 |
None,
|
461 |
session_id, # Return the session ID
|
|
|
462 |
)
|
463 |
|
464 |
# For production mode:
|
@@ -514,6 +559,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
514 |
(manager.get_performance_plot()
|
515 |
if any(completed_backends) else None),
|
516 |
session_id,
|
|
|
517 |
)
|
518 |
except Exception as e:
|
519 |
print(f"Error polling {backend}: {str(e)}")
|
@@ -526,6 +572,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
526 |
if len(completed_backends) == 3 else
|
527 |
"⚠️ Some generations timed out")
|
528 |
|
|
|
|
|
529 |
# Final yield
|
530 |
yield (
|
531 |
final_status,
|
@@ -537,6 +585,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
537 |
results["hidream-full"],
|
538 |
manager.get_performance_plot(),
|
539 |
session_id,
|
|
|
540 |
)
|
541 |
|
542 |
except Exception as e:
|
@@ -552,6 +601,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
552 |
None,
|
553 |
None,
|
554 |
session_id,
|
|
|
555 |
)
|
556 |
|
557 |
# Schedule periodic cleanup of old sessions
|
@@ -577,6 +627,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
577 |
best_output,
|
578 |
performance_plot,
|
579 |
session_id, # Update the session ID
|
|
|
580 |
],
|
581 |
api_name="generate",
|
582 |
max_batch_size=10, # Process up to 10 requests at once
|
|
|
184 |
}
|
185 |
payload = {
|
186 |
"prompt": prompt,
|
187 |
+
"enable_safety_checker": True,
|
188 |
"enable_base64_output": True, # Enable base64 output
|
189 |
"size": "1024*1024",
|
190 |
"seed": -1,
|
|
|
295 |
|
296 |
# Handle base64 output
|
297 |
output = data["outputs"][0]
|
298 |
+
# has_nsfw_content = data["has_nsfw_contents"][0]
|
299 |
|
300 |
# Check if it's a base64 string or URL
|
301 |
if isinstance(output, str) and output.startswith("http"):
|
|
|
309 |
output, str
|
310 |
) and not output.startswith("data:image"):
|
311 |
# Convert raw base64 to data URI format
|
312 |
+
return f"data:image/jpeg;base64,{output}"
|
313 |
else:
|
314 |
# Already in data URI format
|
315 |
return output
|
|
|
330 |
raise Exception(f"Poll error: {response.status}")
|
331 |
|
332 |
|
333 |
+
# Store recent generations
|
334 |
+
recent_generations = []
|
335 |
+
|
336 |
+
# Example prompts
|
337 |
+
example_prompts = [
|
338 |
+
"A deep sea diver exploring an underwater city ruins, using a palette of deep blues and silvers."
|
339 |
+
"A Martian greenhouse complex that uses genetically modified crops designed to thrive in low gravity. Outside, rovers fitted with AI guidance systems patrol dusty red plains, ensuring each pressurized dome remains airtight.",
|
340 |
+
"A sleek, futuristic sports car with glowing blue accents, racing through a virtual reality landscape, 3D render",
|
341 |
+
]
|
342 |
+
|
343 |
+
|
344 |
# Use a state variable to store session ID
|
345 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
346 |
session_id = gr.State(None) # Add this to store session ID
|
347 |
|
348 |
+
gr.Markdown("# 🌊 WaveSpeed AI HiDream Arena")
|
349 |
|
350 |
# Add the introduction with link to WaveSpeedAI
|
351 |
gr.Markdown(
|
352 |
+
"""
|
353 |
+
[WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation.
|
354 |
+
"Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries."
|
355 |
+
"""
|
356 |
)
|
357 |
gr.Markdown(
|
358 |
+
"""
|
359 |
+
This demo showcases the performance and outputs of leading image generation models, including HiDream and Flux, on our accelerated inference platform.
|
360 |
+
"""
|
361 |
)
|
362 |
|
363 |
with gr.Row():
|
364 |
with gr.Column(scale=3):
|
365 |
+
example_dropdown = gr.Dropdown(
|
366 |
+
choices=example_prompts,
|
367 |
+
label="Choose an example prompt",
|
368 |
+
interactive=True
|
369 |
+
)
|
370 |
input_text = gr.Textbox(
|
371 |
+
example_prompts[0],
|
372 |
label="Enter your prompt",
|
373 |
placeholder="Type here...",
|
374 |
lines=3,
|
|
|
376 |
with gr.Column(scale=1):
|
377 |
generate_btn = gr.Button("Generate", variant="primary")
|
378 |
|
379 |
+
example_dropdown.change(lambda ex: ex, inputs=[example_dropdown], outputs=[input_text])
|
380 |
+
|
381 |
# Two status boxes - small (default) and big (during generation)
|
382 |
small_status_box = gr.Markdown("Ready to generate images",
|
383 |
elem_id="small-status")
|
|
|
399 |
|
400 |
performance_plot = gr.Plot(label="Performance Metrics")
|
401 |
|
402 |
+
with gr.Accordion("Recent Generations (last 32)", open=False):
|
403 |
+
recent_gallery = gr.Gallery(label="Prompt and Output")
|
404 |
+
|
405 |
+
def update_recent_gallery(prompt, results):
|
406 |
+
recent_generations.append({
|
407 |
+
"prompt": prompt,
|
408 |
+
"flux-dev": results["flux-dev"],
|
409 |
+
"hidream-dev": results["hidream-dev"],
|
410 |
+
"hidream-full": results["hidream-full"],
|
411 |
+
})
|
412 |
+
if len(recent_generations) > 32:
|
413 |
+
recent_generations.pop(0)
|
414 |
+
gallery_items = [
|
415 |
+
(r["prompt"], r["flux-dev"], r["hidream-dev"], r["hidream-full"])
|
416 |
+
for r in reversed(recent_generations)
|
417 |
+
]
|
418 |
+
return gr.update(value=gallery_items)
|
419 |
+
|
420 |
# Add custom CSS for the big status box
|
421 |
css = """
|
422 |
#big-status-row {
|
|
|
485 |
None,
|
486 |
None,
|
487 |
session_id, # Return the session ID
|
488 |
+
None,
|
489 |
)
|
490 |
return
|
491 |
|
|
|
503 |
None,
|
504 |
None,
|
505 |
session_id, # Return the session ID
|
506 |
+
None,
|
507 |
)
|
508 |
|
509 |
# For production mode:
|
|
|
559 |
(manager.get_performance_plot()
|
560 |
if any(completed_backends) else None),
|
561 |
session_id,
|
562 |
+
None,
|
563 |
)
|
564 |
except Exception as e:
|
565 |
print(f"Error polling {backend}: {str(e)}")
|
|
|
572 |
if len(completed_backends) == 3 else
|
573 |
"⚠️ Some generations timed out")
|
574 |
|
575 |
+
gallery_update = update_recent_gallery(prompt, results)
|
576 |
+
|
577 |
# Final yield
|
578 |
yield (
|
579 |
final_status,
|
|
|
585 |
results["hidream-full"],
|
586 |
manager.get_performance_plot(),
|
587 |
session_id,
|
588 |
+
gallery_update,
|
589 |
)
|
590 |
|
591 |
except Exception as e:
|
|
|
601 |
None,
|
602 |
None,
|
603 |
session_id,
|
604 |
+
None,
|
605 |
)
|
606 |
|
607 |
# Schedule periodic cleanup of old sessions
|
|
|
627 |
best_output,
|
628 |
performance_plot,
|
629 |
session_id, # Update the session ID
|
630 |
+
recent_gallery, # Update the gallery
|
631 |
],
|
632 |
api_name="generate",
|
633 |
max_batch_size=10, # Process up to 10 requests at once
|