chengzeyi commited on
Commit
3cf3832
·
1 Parent(s): 2eb2f52

add rate limit

Browse files
Files changed (1) hide show
  1. app.py +106 -68
app.py CHANGED
@@ -64,10 +64,9 @@ class BackendStatus:
64
  self.status = "completed"
65
  self.progress = 100
66
  self.end_time = time.time()
67
- self.history.append({
68
- "timestamp": datetime.now(),
69
- "duration": self.end_time - self.start_time
70
- })
71
 
72
  def fail(self):
73
  self.status = "failed"
@@ -94,8 +93,10 @@ class SessionManager:
94
  with cls._lock:
95
  to_remove = []
96
  for session_id, manager in cls._instances.items():
97
- if (hasattr(manager, "last_activity")
98
- and current_time - manager.last_activity > max_age):
 
 
99
  to_remove.append(session_id)
100
 
101
  for session_id in to_remove:
@@ -105,15 +106,24 @@ class SessionManager:
105
  class GenerationManager:
106
 
107
  def __init__(self):
108
- self.backend_statuses = {
109
- backend: BackendStatus()
110
- for backend in BACKENDS
111
- }
112
  self.last_activity = time.time()
 
113
 
114
  def update_activity(self):
115
  self.last_activity = time.time()
116
 
 
 
 
 
 
 
 
 
 
 
 
117
  def get_performance_plot(self):
118
  fig = go.Figure()
119
 
@@ -127,21 +137,24 @@ class GenerationManager:
127
  # Use bar chart instead of box plot
128
  fig.add_trace(
129
  go.Bar(
130
- y=[avg_duration], #
131
  x=[BACKENDS[backend]["name"]], # Backend name
132
  name=BACKENDS[backend]["name"],
133
  marker_color=BACKENDS[backend]["color"],
134
  text=[f"{avg_duration:.2f}s"], # Show time in seconds
135
  textposition="auto",
136
  width=[0.5], # Make bars narrower
137
- ))
 
138
 
139
  # Set a minimum y-axis range if we have data
140
  if has_data:
141
- max_duration = max([
142
- max([h["duration"] for h in status.history] or [0])
143
- for status in self.backend_statuses.values()
144
- ])
 
 
145
  # Add 20% padding to the top
146
  y_max = max_duration * 1.2
147
  # Ensure the y-axis always starts at 0
@@ -196,19 +209,15 @@ class GenerationManager:
196
 
197
  # Use aiohttp instead of requests for async
198
  async with aiohttp.ClientSession() as session:
199
- async with session.post(url, headers=headers,
200
- json=payload) as response:
201
  if response.status == 200:
202
  result = await response.json()
203
  request_id = result["data"]["id"]
204
- print(
205
- f"Task submitted successfully. Request ID: {request_id}"
206
- )
207
  return request_id
208
  else:
209
  text = await response.text()
210
- raise Exception(
211
- f"API error: {response.status}, {text}")
212
 
213
  except Exception as e:
214
  status.fail()
@@ -296,9 +305,9 @@ async def poll_once(manager, backend, request_id):
296
  # It's base64 data - format it as a data URI if needed
297
  try:
298
  # Format as data URI for Gradio to display directly
299
- if isinstance(
300
- output, str
301
- ) and not output.startswith("data:image"):
302
  # Convert raw base64 to data URI format
303
  return f"data:image/jpeg;base64,{output}"
304
  else:
@@ -306,8 +315,7 @@ async def poll_once(manager, backend, request_id):
306
  return output
307
  except Exception as e:
308
  print(f"Error processing base64 image: {e}")
309
- raise Exception(
310
- f"Failed to process base64 image: {str(e)}")
311
 
312
  elif current_status == "failed":
313
  manager.backend_statuses[backend].fail()
@@ -338,19 +346,25 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
338
  gr.Markdown("# 🌊 WaveSpeedAI HiDream Arena")
339
 
340
  # Add the introduction with link to WaveSpeedAI
341
- gr.Markdown("""
 
342
  [WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation.
343
  Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries.
344
- """)
345
- gr.Markdown("""
 
 
346
  This demo showcases the performance and outputs of leading image generation models, including HiDream and Flux, on our accelerated inference platform.
347
- """)
 
348
 
349
  with gr.Row():
350
  with gr.Column(scale=3):
351
- example_dropdown = gr.Dropdown(choices=example_prompts,
352
- label="Choose an example prompt",
353
- interactive=True)
 
 
354
  input_text = gr.Textbox(
355
  example_prompts[0],
356
  label="Enter your prompt",
@@ -360,20 +374,18 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
360
  with gr.Column(scale=1):
361
  generate_btn = gr.Button("Generate", variant="primary")
362
 
363
- example_dropdown.change(lambda ex: ex,
364
- inputs=[example_dropdown],
365
- outputs=[input_text])
366
 
367
  # Two status boxes - small (default) and big (during generation)
368
- small_status_box = gr.Markdown("Ready to generate images",
369
- elem_id="small-status")
370
 
371
  # Big status box in its own row with styling
372
  with gr.Row(elem_id="big-status-row"):
373
- big_status_box = gr.Markdown("",
374
- elem_id="big-status",
375
- visible=False,
376
- elem_classes="big-status-box")
377
 
378
  with gr.Row():
379
  with gr.Column():
@@ -386,27 +398,27 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
386
  performance_plot = gr.Plot(label="Performance Metrics")
387
 
388
  with gr.Accordion("Recent Generations (last 16)", open=False):
389
- recent_gallery = gr.Gallery(label="Prompt and Output",
390
- columns=3,
391
- interactive=False)
392
 
393
  def get_recent_gallery_items():
394
  gallery_items = []
395
  for r in reversed(recent_generations):
396
  gallery_items.append((r["flux-dev"], f"FLUX-dev: {r['prompt']}"))
397
- gallery_items.append(
398
- (r["hidream-dev"], f"HiDream-dev: {r['prompt']}"))
399
- gallery_items.append(
400
- (r["hidream-full"], f"HiDream-full: {r['prompt']}"))
401
  return gallery_items
402
 
403
  def update_recent_gallery(prompt, results):
404
- recent_generations.append({
405
- "prompt": prompt,
406
- "flux-dev": results["flux-dev"],
407
- "hidream-dev": results["hidream-dev"],
408
- "hidream-full": results["hidream-full"],
409
- })
 
 
410
  if len(recent_generations) > 16:
411
  recent_generations.pop(0)
412
  gallery_items = get_recent_gallery_items()
@@ -457,13 +469,34 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
457
  gr.HTML(f"<style>{css}</style>")
458
 
459
  # Update the generation function to use session manager
460
- async def generate_all_backends_with_status_boxes(prompt,
461
- current_session_id):
462
  """Generate images with big status box during generation"""
463
  # Get or create a session manager
464
  session_id, manager = SessionManager.get_manager(current_session_id)
465
  manager.update_activity()
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  # IMPORTANT: Reset history when starting a new generation
468
  if prompt and prompt.strip() != "":
469
  manager.reset_history() # Clear previous performance metrics
@@ -523,8 +556,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
523
  poll_attempt = 0
524
 
525
  # Main polling loop
526
- while len(completed_backends
527
- ) < 3 and poll_attempt < max_poll_attempts:
528
  poll_attempt += 1
529
 
530
  # Poll each pending backend
@@ -536,8 +568,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
536
  # Only do actual API calls every few attempts to reduce load
537
  if poll_attempt % 2 == 0 or backend == "flux-dev":
538
  # Use the session manager instead of global manager
539
- result = await poll_once(manager, backend,
540
- request_ids[backend])
 
541
  if result: # Backend completed
542
  results[backend] = result
543
  completed_backends.add(backend)
@@ -551,8 +584,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
551
  results["flux-dev"],
552
  results["hidream-dev"],
553
  results["hidream-full"],
554
- (manager.get_performance_plot()
555
- if any(completed_backends) else None),
 
 
 
556
  session_id,
557
  None,
558
  )
@@ -563,9 +599,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
563
  await asyncio.sleep(0.1)
564
 
565
  # Final status
566
- final_status = ("✅ All generations completed!"
567
- if len(completed_backends) == 3 else
568
- "⚠️ Some generations timed out")
 
 
569
 
570
  gallery_update = update_recent_gallery(prompt, results)
571
 
 
64
  self.status = "completed"
65
  self.progress = 100
66
  self.end_time = time.time()
67
+ self.history.append(
68
+ {"timestamp": datetime.now(), "duration": self.end_time - self.start_time}
69
+ )
 
70
 
71
  def fail(self):
72
  self.status = "failed"
 
93
  with cls._lock:
94
  to_remove = []
95
  for session_id, manager in cls._instances.items():
96
+ if (
97
+ hasattr(manager, "last_activity")
98
+ and current_time - manager.last_activity > max_age
99
+ ):
100
  to_remove.append(session_id)
101
 
102
  for session_id in to_remove:
 
106
  class GenerationManager:
107
 
108
  def __init__(self):
109
+ self.backend_statuses = {backend: BackendStatus() for backend in BACKENDS}
 
 
 
110
  self.last_activity = time.time()
111
+ self.request_timestamps = [] # Track timestamps of requests
112
 
113
  def update_activity(self):
114
  self.last_activity = time.time()
115
 
116
+ def add_request_timestamp(self):
117
+ self.request_timestamps.append(time.time())
118
+
119
+ def has_exceeded_limit(self, limit=10): # Default limit: 10 requests per hour
120
+ current_time = time.time()
121
+ # Filter timestamps to only include those within the last hour
122
+ self.request_timestamps = [
123
+ ts for ts in self.request_timestamps if current_time - ts <= 3600
124
+ ]
125
+ return len(self.request_timestamps) >= limit
126
+
127
  def get_performance_plot(self):
128
  fig = go.Figure()
129
 
 
137
  # Use bar chart instead of box plot
138
  fig.add_trace(
139
  go.Bar(
140
+ y=[avg_duration], #
141
  x=[BACKENDS[backend]["name"]], # Backend name
142
  name=BACKENDS[backend]["name"],
143
  marker_color=BACKENDS[backend]["color"],
144
  text=[f"{avg_duration:.2f}s"], # Show time in seconds
145
  textposition="auto",
146
  width=[0.5], # Make bars narrower
147
+ )
148
+ )
149
 
150
  # Set a minimum y-axis range if we have data
151
  if has_data:
152
+ max_duration = max(
153
+ [
154
+ max([h["duration"] for h in status.history] or [0])
155
+ for status in self.backend_statuses.values()
156
+ ]
157
+ )
158
  # Add 20% padding to the top
159
  y_max = max_duration * 1.2
160
  # Ensure the y-axis always starts at 0
 
209
 
210
  # Use aiohttp instead of requests for async
211
  async with aiohttp.ClientSession() as session:
212
+ async with session.post(url, headers=headers, json=payload) as response:
 
213
  if response.status == 200:
214
  result = await response.json()
215
  request_id = result["data"]["id"]
216
+ print(f"Task submitted successfully. Request ID: {request_id}")
 
 
217
  return request_id
218
  else:
219
  text = await response.text()
220
+ raise Exception(f"API error: {response.status}, {text}")
 
221
 
222
  except Exception as e:
223
  status.fail()
 
305
  # It's base64 data - format it as a data URI if needed
306
  try:
307
  # Format as data URI for Gradio to display directly
308
+ if isinstance(output, str) and not output.startswith(
309
+ "data:image"
310
+ ):
311
  # Convert raw base64 to data URI format
312
  return f"data:image/jpeg;base64,{output}"
313
  else:
 
315
  return output
316
  except Exception as e:
317
  print(f"Error processing base64 image: {e}")
318
+ raise Exception(f"Failed to process base64 image: {str(e)}")
 
319
 
320
  elif current_status == "failed":
321
  manager.backend_statuses[backend].fail()
 
346
  gr.Markdown("# 🌊 WaveSpeedAI HiDream Arena")
347
 
348
  # Add the introduction with link to WaveSpeedAI
349
+ gr.Markdown(
350
+ """
351
  [WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation.
352
  Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries.
353
+ """
354
+ )
355
+ gr.Markdown(
356
+ """
357
  This demo showcases the performance and outputs of leading image generation models, including HiDream and Flux, on our accelerated inference platform.
358
+ """
359
+ )
360
 
361
  with gr.Row():
362
  with gr.Column(scale=3):
363
+ example_dropdown = gr.Dropdown(
364
+ choices=example_prompts,
365
+ label="Choose an example prompt",
366
+ interactive=True,
367
+ )
368
  input_text = gr.Textbox(
369
  example_prompts[0],
370
  label="Enter your prompt",
 
374
  with gr.Column(scale=1):
375
  generate_btn = gr.Button("Generate", variant="primary")
376
 
377
+ example_dropdown.change(
378
+ lambda ex: ex, inputs=[example_dropdown], outputs=[input_text]
379
+ )
380
 
381
  # Two status boxes - small (default) and big (during generation)
382
+ small_status_box = gr.Markdown("Ready to generate images", elem_id="small-status")
 
383
 
384
  # Big status box in its own row with styling
385
  with gr.Row(elem_id="big-status-row"):
386
+ big_status_box = gr.Markdown(
387
+ "", elem_id="big-status", visible=False, elem_classes="big-status-box"
388
+ )
 
389
 
390
  with gr.Row():
391
  with gr.Column():
 
398
  performance_plot = gr.Plot(label="Performance Metrics")
399
 
400
  with gr.Accordion("Recent Generations (last 16)", open=False):
401
+ recent_gallery = gr.Gallery(
402
+ label="Prompt and Output", columns=3, interactive=False
403
+ )
404
 
405
  def get_recent_gallery_items():
406
  gallery_items = []
407
  for r in reversed(recent_generations):
408
  gallery_items.append((r["flux-dev"], f"FLUX-dev: {r['prompt']}"))
409
+ gallery_items.append((r["hidream-dev"], f"HiDream-dev: {r['prompt']}"))
410
+ gallery_items.append((r["hidream-full"], f"HiDream-full: {r['prompt']}"))
 
 
411
  return gallery_items
412
 
413
  def update_recent_gallery(prompt, results):
414
+ recent_generations.append(
415
+ {
416
+ "prompt": prompt,
417
+ "flux-dev": results["flux-dev"],
418
+ "hidream-dev": results["hidream-dev"],
419
+ "hidream-full": results["hidream-full"],
420
+ }
421
+ )
422
  if len(recent_generations) > 16:
423
  recent_generations.pop(0)
424
  gallery_items = get_recent_gallery_items()
 
469
  gr.HTML(f"<style>{css}</style>")
470
 
471
  # Update the generation function to use session manager
472
+ async def generate_all_backends_with_status_boxes(prompt, current_session_id):
 
473
  """Generate images with big status box during generation"""
474
  # Get or create a session manager
475
  session_id, manager = SessionManager.get_manager(current_session_id)
476
  manager.update_activity()
477
 
478
+ # Check if the user has exceeded the request limit
479
+ if manager.has_exceeded_limit(
480
+ limit=10
481
+ ): # Set the limit to 10 requests per hour
482
+ error_message = "❌ You have exceeded the limit of 10 requests per hour. Please try again later."
483
+ yield (
484
+ error_message,
485
+ error_message,
486
+ gr.update(visible=False),
487
+ gr.update(visible=True),
488
+ None,
489
+ None,
490
+ None,
491
+ None,
492
+ session_id,
493
+ None,
494
+ )
495
+ return
496
+
497
+ # Add the current request timestamp
498
+ manager.add_request_timestamp()
499
+
500
  # IMPORTANT: Reset history when starting a new generation
501
  if prompt and prompt.strip() != "":
502
  manager.reset_history() # Clear previous performance metrics
 
556
  poll_attempt = 0
557
 
558
  # Main polling loop
559
+ while len(completed_backends) < 3 and poll_attempt < max_poll_attempts:
 
560
  poll_attempt += 1
561
 
562
  # Poll each pending backend
 
568
  # Only do actual API calls every few attempts to reduce load
569
  if poll_attempt % 2 == 0 or backend == "flux-dev":
570
  # Use the session manager instead of global manager
571
+ result = await poll_once(
572
+ manager, backend, request_ids[backend]
573
+ )
574
  if result: # Backend completed
575
  results[backend] = result
576
  completed_backends.add(backend)
 
584
  results["flux-dev"],
585
  results["hidream-dev"],
586
  results["hidream-full"],
587
+ (
588
+ manager.get_performance_plot()
589
+ if any(completed_backends)
590
+ else None
591
+ ),
592
  session_id,
593
  None,
594
  )
 
599
  await asyncio.sleep(0.1)
600
 
601
  # Final status
602
+ final_status = (
603
+ "✅ All generations completed!"
604
+ if len(completed_backends) == 3
605
+ else "⚠️ Some generations timed out"
606
+ )
607
 
608
  gallery_update = update_recent_gallery(prompt, results)
609