chengzeyi commited on
Commit
7ea41f5
·
verified ·
1 Parent(s): 4bead13

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +56 -5
app.py CHANGED
@@ -184,7 +184,7 @@ class GenerationManager:
184
  }
185
  payload = {
186
  "prompt": prompt,
187
- "enable_safety_checker": False,
188
  "enable_base64_output": True, # Enable base64 output
189
  "size": "1024*1024",
190
  "seed": -1,
@@ -295,6 +295,7 @@ async def poll_once(manager, backend, request_id):
295
 
296
  # Handle base64 output
297
  output = data["outputs"][0]
 
298
 
299
  # Check if it's a base64 string or URL
300
  if isinstance(output, str) and output.startswith("http"):
@@ -308,7 +309,7 @@ async def poll_once(manager, backend, request_id):
308
  output, str
309
  ) and not output.startswith("data:image"):
310
  # Convert raw base64 to data URI format
311
- return f"data:image/png;base64,{output}"
312
  else:
313
  # Already in data URI format
314
  return output
@@ -329,23 +330,45 @@ async def poll_once(manager, backend, request_id):
329
  raise Exception(f"Poll error: {response.status}")
330
 
331
 
 
 
 
 
 
 
 
 
 
 
 
332
  # Use a state variable to store session ID
333
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
334
  session_id = gr.State(None) # Add this to store session ID
335
 
336
- gr.Markdown("# 🌊 HiDream Arena powered by WaveSpeed AI Image Generator")
337
 
338
  # Add the introduction with link to WaveSpeedAI
339
  gr.Markdown(
340
- "[WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation."
 
 
 
341
  )
342
  gr.Markdown(
343
- "Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries."
 
 
344
  )
345
 
346
  with gr.Row():
347
  with gr.Column(scale=3):
 
 
 
 
 
348
  input_text = gr.Textbox(
 
349
  label="Enter your prompt",
350
  placeholder="Type here...",
351
  lines=3,
@@ -353,6 +376,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
353
  with gr.Column(scale=1):
354
  generate_btn = gr.Button("Generate", variant="primary")
355
 
 
 
356
  # Two status boxes - small (default) and big (during generation)
357
  small_status_box = gr.Markdown("Ready to generate images",
358
  elem_id="small-status")
@@ -374,6 +399,24 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
374
 
375
  performance_plot = gr.Plot(label="Performance Metrics")
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  # Add custom CSS for the big status box
378
  css = """
379
  #big-status-row {
@@ -442,6 +485,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
442
  None,
443
  None,
444
  session_id, # Return the session ID
 
445
  )
446
  return
447
 
@@ -459,6 +503,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
459
  None,
460
  None,
461
  session_id, # Return the session ID
 
462
  )
463
 
464
  # For production mode:
@@ -514,6 +559,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
514
  (manager.get_performance_plot()
515
  if any(completed_backends) else None),
516
  session_id,
 
517
  )
518
  except Exception as e:
519
  print(f"Error polling {backend}: {str(e)}")
@@ -526,6 +572,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
526
  if len(completed_backends) == 3 else
527
  "⚠️ Some generations timed out")
528
 
 
 
529
  # Final yield
530
  yield (
531
  final_status,
@@ -537,6 +585,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
537
  results["hidream-full"],
538
  manager.get_performance_plot(),
539
  session_id,
 
540
  )
541
 
542
  except Exception as e:
@@ -552,6 +601,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
552
  None,
553
  None,
554
  session_id,
 
555
  )
556
 
557
  # Schedule periodic cleanup of old sessions
@@ -577,6 +627,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
577
  best_output,
578
  performance_plot,
579
  session_id, # Update the session ID
 
580
  ],
581
  api_name="generate",
582
  max_batch_size=10, # Process up to 10 requests at once
 
184
  }
185
  payload = {
186
  "prompt": prompt,
187
+ "enable_safety_checker": True,
188
  "enable_base64_output": True, # Enable base64 output
189
  "size": "1024*1024",
190
  "seed": -1,
 
295
 
296
  # Handle base64 output
297
  output = data["outputs"][0]
298
+ # has_nsfw_content = data["has_nsfw_contents"][0]
299
 
300
  # Check if it's a base64 string or URL
301
  if isinstance(output, str) and output.startswith("http"):
 
309
  output, str
310
  ) and not output.startswith("data:image"):
311
  # Convert raw base64 to data URI format
312
+ return f"data:image/jpeg;base64,{output}"
313
  else:
314
  # Already in data URI format
315
  return output
 
330
  raise Exception(f"Poll error: {response.status}")
331
 
332
 
333
+ # Store recent generations
334
+ recent_generations = []
335
+
336
+ # Example prompts
337
+ example_prompts = [
338
+ "A deep sea diver exploring an underwater city ruins, using a palette of deep blues and silvers."
339
+ "A Martian greenhouse complex that uses genetically modified crops designed to thrive in low gravity. Outside, rovers fitted with AI guidance systems patrol dusty red plains, ensuring each pressurized dome remains airtight.",
340
+ "A sleek, futuristic sports car with glowing blue accents, racing through a virtual reality landscape, 3D render",
341
+ ]
342
+
343
+
344
  # Use a state variable to store session ID
345
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
346
  session_id = gr.State(None) # Add this to store session ID
347
 
348
+ gr.Markdown("# 🌊 WaveSpeed AI HiDream Arena")
349
 
350
  # Add the introduction with link to WaveSpeedAI
351
  gr.Markdown(
352
+ """
353
+ [WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation.
354
+ "Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries."
355
+ """
356
  )
357
  gr.Markdown(
358
+ """
359
+ This demo showcases the performance and outputs of leading image generation models, including HiDream and Flux, on our accelerated inference platform.
360
+ """
361
  )
362
 
363
  with gr.Row():
364
  with gr.Column(scale=3):
365
+ example_dropdown = gr.Dropdown(
366
+ choices=example_prompts,
367
+ label="Choose an example prompt",
368
+ interactive=True
369
+ )
370
  input_text = gr.Textbox(
371
+ example_prompts[0],
372
  label="Enter your prompt",
373
  placeholder="Type here...",
374
  lines=3,
 
376
  with gr.Column(scale=1):
377
  generate_btn = gr.Button("Generate", variant="primary")
378
 
379
+ example_dropdown.change(lambda ex: ex, inputs=[example_dropdown], outputs=[input_text])
380
+
381
  # Two status boxes - small (default) and big (during generation)
382
  small_status_box = gr.Markdown("Ready to generate images",
383
  elem_id="small-status")
 
399
 
400
  performance_plot = gr.Plot(label="Performance Metrics")
401
 
402
+ with gr.Accordion("Recent Generations (last 32)", open=False):
403
+ recent_gallery = gr.Gallery(label="Prompt and Output")
404
+
405
+ def update_recent_gallery(prompt, results):
406
+ recent_generations.append({
407
+ "prompt": prompt,
408
+ "flux-dev": results["flux-dev"],
409
+ "hidream-dev": results["hidream-dev"],
410
+ "hidream-full": results["hidream-full"],
411
+ })
412
+ if len(recent_generations) > 32:
413
+ recent_generations.pop(0)
414
+ gallery_items = [
415
+ (r["prompt"], r["flux-dev"], r["hidream-dev"], r["hidream-full"])
416
+ for r in reversed(recent_generations)
417
+ ]
418
+ return gr.update(value=gallery_items)
419
+
420
  # Add custom CSS for the big status box
421
  css = """
422
  #big-status-row {
 
485
  None,
486
  None,
487
  session_id, # Return the session ID
488
+ None,
489
  )
490
  return
491
 
 
503
  None,
504
  None,
505
  session_id, # Return the session ID
506
+ None,
507
  )
508
 
509
  # For production mode:
 
559
  (manager.get_performance_plot()
560
  if any(completed_backends) else None),
561
  session_id,
562
+ None,
563
  )
564
  except Exception as e:
565
  print(f"Error polling {backend}: {str(e)}")
 
572
  if len(completed_backends) == 3 else
573
  "⚠️ Some generations timed out")
574
 
575
+ gallery_update = update_recent_gallery(prompt, results)
576
+
577
  # Final yield
578
  yield (
579
  final_status,
 
585
  results["hidream-full"],
586
  manager.get_performance_plot(),
587
  session_id,
588
+ gallery_update,
589
  )
590
 
591
  except Exception as e:
 
601
  None,
602
  None,
603
  session_id,
604
+ None,
605
  )
606
 
607
  # Schedule periodic cleanup of old sessions
 
627
  best_output,
628
  performance_plot,
629
  session_id, # Update the session ID
630
+ recent_gallery, # Update the gallery
631
  ],
632
  api_name="generate",
633
  max_batch_size=10, # Process up to 10 requests at once