Spaces:
Running
Running
add rate limit
Browse files
app.py
CHANGED
@@ -64,10 +64,9 @@ class BackendStatus:
|
|
64 |
self.status = "completed"
|
65 |
self.progress = 100
|
66 |
self.end_time = time.time()
|
67 |
-
self.history.append(
|
68 |
-
"timestamp": datetime.now(),
|
69 |
-
|
70 |
-
})
|
71 |
|
72 |
def fail(self):
|
73 |
self.status = "failed"
|
@@ -94,8 +93,10 @@ class SessionManager:
|
|
94 |
with cls._lock:
|
95 |
to_remove = []
|
96 |
for session_id, manager in cls._instances.items():
|
97 |
-
if (
|
98 |
-
|
|
|
|
|
99 |
to_remove.append(session_id)
|
100 |
|
101 |
for session_id in to_remove:
|
@@ -105,15 +106,24 @@ class SessionManager:
|
|
105 |
class GenerationManager:
|
106 |
|
107 |
def __init__(self):
|
108 |
-
self.backend_statuses = {
|
109 |
-
backend: BackendStatus()
|
110 |
-
for backend in BACKENDS
|
111 |
-
}
|
112 |
self.last_activity = time.time()
|
|
|
113 |
|
114 |
def update_activity(self):
|
115 |
self.last_activity = time.time()
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
def get_performance_plot(self):
|
118 |
fig = go.Figure()
|
119 |
|
@@ -127,21 +137,24 @@ class GenerationManager:
|
|
127 |
# Use bar chart instead of box plot
|
128 |
fig.add_trace(
|
129 |
go.Bar(
|
130 |
-
y=[avg_duration], #
|
131 |
x=[BACKENDS[backend]["name"]], # Backend name
|
132 |
name=BACKENDS[backend]["name"],
|
133 |
marker_color=BACKENDS[backend]["color"],
|
134 |
text=[f"{avg_duration:.2f}s"], # Show time in seconds
|
135 |
textposition="auto",
|
136 |
width=[0.5], # Make bars narrower
|
137 |
-
)
|
|
|
138 |
|
139 |
# Set a minimum y-axis range if we have data
|
140 |
if has_data:
|
141 |
-
max_duration = max(
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
145 |
# Add 20% padding to the top
|
146 |
y_max = max_duration * 1.2
|
147 |
# Ensure the y-axis always starts at 0
|
@@ -196,19 +209,15 @@ class GenerationManager:
|
|
196 |
|
197 |
# Use aiohttp instead of requests for async
|
198 |
async with aiohttp.ClientSession() as session:
|
199 |
-
async with session.post(url, headers=headers,
|
200 |
-
json=payload) as response:
|
201 |
if response.status == 200:
|
202 |
result = await response.json()
|
203 |
request_id = result["data"]["id"]
|
204 |
-
print(
|
205 |
-
f"Task submitted successfully. Request ID: {request_id}"
|
206 |
-
)
|
207 |
return request_id
|
208 |
else:
|
209 |
text = await response.text()
|
210 |
-
raise Exception(
|
211 |
-
f"API error: {response.status}, {text}")
|
212 |
|
213 |
except Exception as e:
|
214 |
status.fail()
|
@@ -296,9 +305,9 @@ async def poll_once(manager, backend, request_id):
|
|
296 |
# It's base64 data - format it as a data URI if needed
|
297 |
try:
|
298 |
# Format as data URI for Gradio to display directly
|
299 |
-
if isinstance(
|
300 |
-
|
301 |
-
)
|
302 |
# Convert raw base64 to data URI format
|
303 |
return f"data:image/jpeg;base64,{output}"
|
304 |
else:
|
@@ -306,8 +315,7 @@ async def poll_once(manager, backend, request_id):
|
|
306 |
return output
|
307 |
except Exception as e:
|
308 |
print(f"Error processing base64 image: {e}")
|
309 |
-
raise Exception(
|
310 |
-
f"Failed to process base64 image: {str(e)}")
|
311 |
|
312 |
elif current_status == "failed":
|
313 |
manager.backend_statuses[backend].fail()
|
@@ -338,19 +346,25 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
338 |
gr.Markdown("# 🌊 WaveSpeedAI HiDream Arena")
|
339 |
|
340 |
# Add the introduction with link to WaveSpeedAI
|
341 |
-
gr.Markdown(
|
|
|
342 |
[WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation.
|
343 |
Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries.
|
344 |
-
"""
|
345 |
-
|
|
|
|
|
346 |
This demo showcases the performance and outputs of leading image generation models, including HiDream and Flux, on our accelerated inference platform.
|
347 |
-
"""
|
|
|
348 |
|
349 |
with gr.Row():
|
350 |
with gr.Column(scale=3):
|
351 |
-
example_dropdown = gr.Dropdown(
|
352 |
-
|
353 |
-
|
|
|
|
|
354 |
input_text = gr.Textbox(
|
355 |
example_prompts[0],
|
356 |
label="Enter your prompt",
|
@@ -360,20 +374,18 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
360 |
with gr.Column(scale=1):
|
361 |
generate_btn = gr.Button("Generate", variant="primary")
|
362 |
|
363 |
-
example_dropdown.change(
|
364 |
-
|
365 |
-
|
366 |
|
367 |
# Two status boxes - small (default) and big (during generation)
|
368 |
-
small_status_box = gr.Markdown("Ready to generate images",
|
369 |
-
elem_id="small-status")
|
370 |
|
371 |
# Big status box in its own row with styling
|
372 |
with gr.Row(elem_id="big-status-row"):
|
373 |
-
big_status_box = gr.Markdown(
|
374 |
-
|
375 |
-
|
376 |
-
elem_classes="big-status-box")
|
377 |
|
378 |
with gr.Row():
|
379 |
with gr.Column():
|
@@ -386,27 +398,27 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
386 |
performance_plot = gr.Plot(label="Performance Metrics")
|
387 |
|
388 |
with gr.Accordion("Recent Generations (last 16)", open=False):
|
389 |
-
recent_gallery = gr.Gallery(
|
390 |
-
|
391 |
-
|
392 |
|
393 |
def get_recent_gallery_items():
|
394 |
gallery_items = []
|
395 |
for r in reversed(recent_generations):
|
396 |
gallery_items.append((r["flux-dev"], f"FLUX-dev: {r['prompt']}"))
|
397 |
-
gallery_items.append(
|
398 |
-
|
399 |
-
gallery_items.append(
|
400 |
-
(r["hidream-full"], f"HiDream-full: {r['prompt']}"))
|
401 |
return gallery_items
|
402 |
|
403 |
def update_recent_gallery(prompt, results):
|
404 |
-
recent_generations.append(
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
|
|
|
|
410 |
if len(recent_generations) > 16:
|
411 |
recent_generations.pop(0)
|
412 |
gallery_items = get_recent_gallery_items()
|
@@ -457,13 +469,34 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
457 |
gr.HTML(f"<style>{css}</style>")
|
458 |
|
459 |
# Update the generation function to use session manager
|
460 |
-
async def generate_all_backends_with_status_boxes(prompt,
|
461 |
-
current_session_id):
|
462 |
"""Generate images with big status box during generation"""
|
463 |
# Get or create a session manager
|
464 |
session_id, manager = SessionManager.get_manager(current_session_id)
|
465 |
manager.update_activity()
|
466 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
# IMPORTANT: Reset history when starting a new generation
|
468 |
if prompt and prompt.strip() != "":
|
469 |
manager.reset_history() # Clear previous performance metrics
|
@@ -523,8 +556,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
523 |
poll_attempt = 0
|
524 |
|
525 |
# Main polling loop
|
526 |
-
while len(completed_backends
|
527 |
-
) < 3 and poll_attempt < max_poll_attempts:
|
528 |
poll_attempt += 1
|
529 |
|
530 |
# Poll each pending backend
|
@@ -536,8 +568,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
536 |
# Only do actual API calls every few attempts to reduce load
|
537 |
if poll_attempt % 2 == 0 or backend == "flux-dev":
|
538 |
# Use the session manager instead of global manager
|
539 |
-
result = await poll_once(
|
540 |
-
|
|
|
541 |
if result: # Backend completed
|
542 |
results[backend] = result
|
543 |
completed_backends.add(backend)
|
@@ -551,8 +584,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
551 |
results["flux-dev"],
|
552 |
results["hidream-dev"],
|
553 |
results["hidream-full"],
|
554 |
-
(
|
555 |
-
|
|
|
|
|
|
|
556 |
session_id,
|
557 |
None,
|
558 |
)
|
@@ -563,9 +599,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
563 |
await asyncio.sleep(0.1)
|
564 |
|
565 |
# Final status
|
566 |
-
final_status = (
|
567 |
-
|
568 |
-
|
|
|
|
|
569 |
|
570 |
gallery_update = update_recent_gallery(prompt, results)
|
571 |
|
|
|
64 |
self.status = "completed"
|
65 |
self.progress = 100
|
66 |
self.end_time = time.time()
|
67 |
+
self.history.append(
|
68 |
+
{"timestamp": datetime.now(), "duration": self.end_time - self.start_time}
|
69 |
+
)
|
|
|
70 |
|
71 |
def fail(self):
|
72 |
self.status = "failed"
|
|
|
93 |
with cls._lock:
|
94 |
to_remove = []
|
95 |
for session_id, manager in cls._instances.items():
|
96 |
+
if (
|
97 |
+
hasattr(manager, "last_activity")
|
98 |
+
and current_time - manager.last_activity > max_age
|
99 |
+
):
|
100 |
to_remove.append(session_id)
|
101 |
|
102 |
for session_id in to_remove:
|
|
|
106 |
class GenerationManager:
|
107 |
|
108 |
def __init__(self):
|
109 |
+
self.backend_statuses = {backend: BackendStatus() for backend in BACKENDS}
|
|
|
|
|
|
|
110 |
self.last_activity = time.time()
|
111 |
+
self.request_timestamps = [] # Track timestamps of requests
|
112 |
|
113 |
def update_activity(self):
|
114 |
self.last_activity = time.time()
|
115 |
|
116 |
+
def add_request_timestamp(self):
|
117 |
+
self.request_timestamps.append(time.time())
|
118 |
+
|
119 |
+
def has_exceeded_limit(self, limit=10): # Default limit: 10 requests per hour
|
120 |
+
current_time = time.time()
|
121 |
+
# Filter timestamps to only include those within the last hour
|
122 |
+
self.request_timestamps = [
|
123 |
+
ts for ts in self.request_timestamps if current_time - ts <= 3600
|
124 |
+
]
|
125 |
+
return len(self.request_timestamps) >= limit
|
126 |
+
|
127 |
def get_performance_plot(self):
|
128 |
fig = go.Figure()
|
129 |
|
|
|
137 |
# Use bar chart instead of box plot
|
138 |
fig.add_trace(
|
139 |
go.Bar(
|
140 |
+
y=[avg_duration], #
|
141 |
x=[BACKENDS[backend]["name"]], # Backend name
|
142 |
name=BACKENDS[backend]["name"],
|
143 |
marker_color=BACKENDS[backend]["color"],
|
144 |
text=[f"{avg_duration:.2f}s"], # Show time in seconds
|
145 |
textposition="auto",
|
146 |
width=[0.5], # Make bars narrower
|
147 |
+
)
|
148 |
+
)
|
149 |
|
150 |
# Set a minimum y-axis range if we have data
|
151 |
if has_data:
|
152 |
+
max_duration = max(
|
153 |
+
[
|
154 |
+
max([h["duration"] for h in status.history] or [0])
|
155 |
+
for status in self.backend_statuses.values()
|
156 |
+
]
|
157 |
+
)
|
158 |
# Add 20% padding to the top
|
159 |
y_max = max_duration * 1.2
|
160 |
# Ensure the y-axis always starts at 0
|
|
|
209 |
|
210 |
# Use aiohttp instead of requests for async
|
211 |
async with aiohttp.ClientSession() as session:
|
212 |
+
async with session.post(url, headers=headers, json=payload) as response:
|
|
|
213 |
if response.status == 200:
|
214 |
result = await response.json()
|
215 |
request_id = result["data"]["id"]
|
216 |
+
print(f"Task submitted successfully. Request ID: {request_id}")
|
|
|
|
|
217 |
return request_id
|
218 |
else:
|
219 |
text = await response.text()
|
220 |
+
raise Exception(f"API error: {response.status}, {text}")
|
|
|
221 |
|
222 |
except Exception as e:
|
223 |
status.fail()
|
|
|
305 |
# It's base64 data - format it as a data URI if needed
|
306 |
try:
|
307 |
# Format as data URI for Gradio to display directly
|
308 |
+
if isinstance(output, str) and not output.startswith(
|
309 |
+
"data:image"
|
310 |
+
):
|
311 |
# Convert raw base64 to data URI format
|
312 |
return f"data:image/jpeg;base64,{output}"
|
313 |
else:
|
|
|
315 |
return output
|
316 |
except Exception as e:
|
317 |
print(f"Error processing base64 image: {e}")
|
318 |
+
raise Exception(f"Failed to process base64 image: {str(e)}")
|
|
|
319 |
|
320 |
elif current_status == "failed":
|
321 |
manager.backend_statuses[backend].fail()
|
|
|
346 |
gr.Markdown("# 🌊 WaveSpeedAI HiDream Arena")
|
347 |
|
348 |
# Add the introduction with link to WaveSpeedAI
|
349 |
+
gr.Markdown(
|
350 |
+
"""
|
351 |
[WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation.
|
352 |
Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries.
|
353 |
+
"""
|
354 |
+
)
|
355 |
+
gr.Markdown(
|
356 |
+
"""
|
357 |
This demo showcases the performance and outputs of leading image generation models, including HiDream and Flux, on our accelerated inference platform.
|
358 |
+
"""
|
359 |
+
)
|
360 |
|
361 |
with gr.Row():
|
362 |
with gr.Column(scale=3):
|
363 |
+
example_dropdown = gr.Dropdown(
|
364 |
+
choices=example_prompts,
|
365 |
+
label="Choose an example prompt",
|
366 |
+
interactive=True,
|
367 |
+
)
|
368 |
input_text = gr.Textbox(
|
369 |
example_prompts[0],
|
370 |
label="Enter your prompt",
|
|
|
374 |
with gr.Column(scale=1):
|
375 |
generate_btn = gr.Button("Generate", variant="primary")
|
376 |
|
377 |
+
example_dropdown.change(
|
378 |
+
lambda ex: ex, inputs=[example_dropdown], outputs=[input_text]
|
379 |
+
)
|
380 |
|
381 |
# Two status boxes - small (default) and big (during generation)
|
382 |
+
small_status_box = gr.Markdown("Ready to generate images", elem_id="small-status")
|
|
|
383 |
|
384 |
# Big status box in its own row with styling
|
385 |
with gr.Row(elem_id="big-status-row"):
|
386 |
+
big_status_box = gr.Markdown(
|
387 |
+
"", elem_id="big-status", visible=False, elem_classes="big-status-box"
|
388 |
+
)
|
|
|
389 |
|
390 |
with gr.Row():
|
391 |
with gr.Column():
|
|
|
398 |
performance_plot = gr.Plot(label="Performance Metrics")
|
399 |
|
400 |
with gr.Accordion("Recent Generations (last 16)", open=False):
|
401 |
+
recent_gallery = gr.Gallery(
|
402 |
+
label="Prompt and Output", columns=3, interactive=False
|
403 |
+
)
|
404 |
|
405 |
def get_recent_gallery_items():
|
406 |
gallery_items = []
|
407 |
for r in reversed(recent_generations):
|
408 |
gallery_items.append((r["flux-dev"], f"FLUX-dev: {r['prompt']}"))
|
409 |
+
gallery_items.append((r["hidream-dev"], f"HiDream-dev: {r['prompt']}"))
|
410 |
+
gallery_items.append((r["hidream-full"], f"HiDream-full: {r['prompt']}"))
|
|
|
|
|
411 |
return gallery_items
|
412 |
|
413 |
def update_recent_gallery(prompt, results):
|
414 |
+
recent_generations.append(
|
415 |
+
{
|
416 |
+
"prompt": prompt,
|
417 |
+
"flux-dev": results["flux-dev"],
|
418 |
+
"hidream-dev": results["hidream-dev"],
|
419 |
+
"hidream-full": results["hidream-full"],
|
420 |
+
}
|
421 |
+
)
|
422 |
if len(recent_generations) > 16:
|
423 |
recent_generations.pop(0)
|
424 |
gallery_items = get_recent_gallery_items()
|
|
|
469 |
gr.HTML(f"<style>{css}</style>")
|
470 |
|
471 |
# Update the generation function to use session manager
|
472 |
+
async def generate_all_backends_with_status_boxes(prompt, current_session_id):
|
|
|
473 |
"""Generate images with big status box during generation"""
|
474 |
# Get or create a session manager
|
475 |
session_id, manager = SessionManager.get_manager(current_session_id)
|
476 |
manager.update_activity()
|
477 |
|
478 |
+
# Check if the user has exceeded the request limit
|
479 |
+
if manager.has_exceeded_limit(
|
480 |
+
limit=10
|
481 |
+
): # Set the limit to 10 requests per hour
|
482 |
+
error_message = "❌ You have exceeded the limit of 10 requests per hour. Please try again later."
|
483 |
+
yield (
|
484 |
+
error_message,
|
485 |
+
error_message,
|
486 |
+
gr.update(visible=False),
|
487 |
+
gr.update(visible=True),
|
488 |
+
None,
|
489 |
+
None,
|
490 |
+
None,
|
491 |
+
None,
|
492 |
+
session_id,
|
493 |
+
None,
|
494 |
+
)
|
495 |
+
return
|
496 |
+
|
497 |
+
# Add the current request timestamp
|
498 |
+
manager.add_request_timestamp()
|
499 |
+
|
500 |
# IMPORTANT: Reset history when starting a new generation
|
501 |
if prompt and prompt.strip() != "":
|
502 |
manager.reset_history() # Clear previous performance metrics
|
|
|
556 |
poll_attempt = 0
|
557 |
|
558 |
# Main polling loop
|
559 |
+
while len(completed_backends) < 3 and poll_attempt < max_poll_attempts:
|
|
|
560 |
poll_attempt += 1
|
561 |
|
562 |
# Poll each pending backend
|
|
|
568 |
# Only do actual API calls every few attempts to reduce load
|
569 |
if poll_attempt % 2 == 0 or backend == "flux-dev":
|
570 |
# Use the session manager instead of global manager
|
571 |
+
result = await poll_once(
|
572 |
+
manager, backend, request_ids[backend]
|
573 |
+
)
|
574 |
if result: # Backend completed
|
575 |
results[backend] = result
|
576 |
completed_backends.add(backend)
|
|
|
584 |
results["flux-dev"],
|
585 |
results["hidream-dev"],
|
586 |
results["hidream-full"],
|
587 |
+
(
|
588 |
+
manager.get_performance_plot()
|
589 |
+
if any(completed_backends)
|
590 |
+
else None
|
591 |
+
),
|
592 |
session_id,
|
593 |
None,
|
594 |
)
|
|
|
599 |
await asyncio.sleep(0.1)
|
600 |
|
601 |
# Final status
|
602 |
+
final_status = (
|
603 |
+
"✅ All generations completed!"
|
604 |
+
if len(completed_backends) == 3
|
605 |
+
else "⚠️ Some generations timed out"
|
606 |
+
)
|
607 |
|
608 |
gallery_update = update_recent_gallery(prompt, results)
|
609 |
|