jiandan1998 commited on
Commit
9b22478
·
verified ·
1 Parent(s): 9a748ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -21
app.py CHANGED
@@ -99,6 +99,7 @@ def image_to_base64(file_path):
99
  def generate_video(
100
  context_scale,
101
  enable_safety_checker,
 
102
  flow_shift,
103
  guidance_scale,
104
  images,
@@ -120,7 +121,7 @@ def generate_video(
120
 
121
  if not rate_limiter.check(session_id):
122
  error_img = create_error_image("每小时限制20次请求")
123
- yield "❌ 请求过于频繁,请稍后再试", error_img
124
  return
125
 
126
  session = session_manager.get_session(session_id)
@@ -129,13 +130,13 @@ def generate_video(
129
 
130
  API_KEY = os.getenv("WAVESPEED_API_KEY")
131
  if not API_KEY:
132
- error_img = create_error_image("API密钥缺失")
133
  yield "❌ Error: Missing API Key", error_img
134
  return
135
 
136
  try:
137
  if not images or len(images) < 2:
138
- raise ValueError("需要上传至少两张图片")
139
 
140
  base64_images = []
141
  for img_path in images[:2]:
@@ -144,7 +145,7 @@ def generate_video(
144
 
145
  except Exception as e:
146
  error_img = create_error_image(str(e))
147
- yield f"❌ 文件处理失败: {str(e)}", error_img
148
  return
149
 
150
  video_payload = ""
@@ -156,7 +157,7 @@ def generate_video(
156
 
157
  payload = {
158
  "context_scale": context_scale,
159
- "enable_fast_mode": False,
160
  "enable_safety_checker": enable_safety_checker,
161
  "flow_shift": flow_shift,
162
  "guidance_scale": guidance_scale,
@@ -170,7 +171,7 @@ def generate_video(
170
  "video": str(video_payload) if video_payload else "",
171
  }
172
 
173
- logging.debug(f"API请求payload: {json.dumps(payload, indent=2)}")
174
 
175
  headers = {
176
  "Content-Type": "application/json",
@@ -187,14 +188,14 @@ def generate_video(
187
 
188
  if response.status_code != 200:
189
  error_img = create_error_image(response.text)
190
- yield f"❌ API错误 ({response.status_code}): {response.text}", error_img
191
  return
192
 
193
  request_id = response.json()["data"]["id"]
194
- yield f"✅ 任务已提交 (ID: {request_id})", None
195
  except Exception as e:
196
  error_img = create_error_image(str(e))
197
- yield f"❌ 连接错误: {str(e)}", error_img
198
  return
199
 
200
  result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result"
@@ -246,6 +247,11 @@ with gr.Blocks(
246
  .safe { background: #e8f5e9; border: 1px solid #a5d6a7; }
247
  .warning { background: #fff3e0; border: 1px solid #ffcc80; }
248
  .error { background: #ffebee; border: 1px solid #ef9a9a; }
 
 
 
 
 
249
  """
250
  ) as app:
251
 
@@ -256,19 +262,28 @@ with gr.Blocks(
256
 
257
  with gr.Row():
258
  with gr.Column(scale=1):
259
- images = gr.File(label="upload image", file_count="multiple", file_types=["image"], type="filepath", elem_id="image-uploader")
260
- video = gr.Video(label="Input Video", format="mp4", sources=["upload"])
261
- prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Prompt...")
 
 
 
262
  negative_prompt = gr.Textbox(label="Negative Prompt", lines=2)
263
- size = gr.Dropdown(["832*480", "480*832"], value="832*480", label="Size")
264
- context_scale = gr.Slider(0, 2, value=1, step=0.1, label="Context Scale")
265
- num_inference_steps = gr.Slider(1, 100, value=20, step=1, label="Inference Steps")
266
- task = gr.Dropdown(["depth", "pose"], value="depth", label="Task")
267
- seed = gr.Number(-1, label="Seed")
268
- random_seed_btn = gr.Button("Random🎲Seed", variant="secondary")
269
- guidance = gr.Slider(1, 20, value=7.5, step=0.1, label="Guidance_Scale")
270
- flow_shift = gr.Slider(1, 20, value=16, step=1, label="Shift")
271
- enable_safety_checker = gr.Checkbox(True, label="Enable Safety Checker", interactive=True)
 
 
 
 
 
 
272
  with gr.Column(scale=1):
273
  video_output = gr.Video(label="Video Output", format="mp4", interactive=False, elem_classes=["video-preview"])
274
  generate_btn = gr.Button("Generate", variant="primary")
@@ -296,6 +311,7 @@ with gr.Blocks(
296
  inputs=[
297
  context_scale,
298
  enable_safety_checker,
 
299
  flow_shift,
300
  guidance,
301
  images,
 
99
  def generate_video(
100
  context_scale,
101
  enable_safety_checker,
102
+ enable_fast_mode,
103
  flow_shift,
104
  guidance_scale,
105
  images,
 
121
 
122
  if not rate_limiter.check(session_id):
123
  error_img = create_error_image("每小时限制20次请求")
124
+ yield "❌ rate limit exceeded", error_img
125
  return
126
 
127
  session = session_manager.get_session(session_id)
 
130
 
131
  API_KEY = os.getenv("WAVESPEED_API_KEY")
132
  if not API_KEY:
133
+ error_img = create_error_image("API key not found")
134
  yield "❌ Error: Missing API Key", error_img
135
  return
136
 
137
  try:
138
  if not images or len(images) < 2:
139
+ raise ValueError("must provide at least 2 images")
140
 
141
  base64_images = []
142
  for img_path in images[:2]:
 
145
 
146
  except Exception as e:
147
  error_img = create_error_image(str(e))
148
+ yield f"❌failed to upload images: {str(e)}", error_img
149
  return
150
 
151
  video_payload = ""
 
157
 
158
  payload = {
159
  "context_scale": context_scale,
160
+ "enable_fast_mode": enable_fast_mode,
161
  "enable_safety_checker": enable_safety_checker,
162
  "flow_shift": flow_shift,
163
  "guidance_scale": guidance_scale,
 
171
  "video": str(video_payload) if video_payload else "",
172
  }
173
 
174
+ logging.debug(f"API request payload: {json.dumps(payload, indent=2)}")
175
 
176
  headers = {
177
  "Content-Type": "application/json",
 
188
 
189
  if response.status_code != 200:
190
  error_img = create_error_image(response.text)
191
+ yield f"❌ API Error ({response.status_code}): {response.text}", error_img
192
  return
193
 
194
  request_id = response.json()["data"]["id"]
195
+ yield f"✅ Task ID (ID: {request_id})", None
196
  except Exception as e:
197
  error_img = create_error_image(str(e))
198
+ yield f"❌ Connection Error: {str(e)}", error_img
199
  return
200
 
201
  result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result"
 
247
  .safe { background: #e8f5e9; border: 1px solid #a5d6a7; }
248
  .warning { background: #fff3e0; border: 1px solid #ffcc80; }
249
  .error { background: #ffebee; border: 1px solid #ef9a9a; }
250
+ #centered_button {
251
+ align-self: center !important;
252
+ height: fit-content !important;
253
+ margin-top: 22px !important; # 根据输入框高度微调
254
+ }
255
  """
256
  ) as app:
257
 
 
262
 
263
  with gr.Row():
264
  with gr.Column(scale=1):
265
+ with gr.Row():
266
+ images = gr.File(label="upload image", file_count="multiple", file_types=["image"], type="filepath", elem_id="image-uploader",
267
+ scale=1)
268
+ video = gr.Video(label="Input Video", format="mp4", sources=["upload"],
269
+ scale=1)
270
+ prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Prompt")
271
  negative_prompt = gr.Textbox(label="Negative Prompt", lines=2)
272
+ with gr.Row():
273
+ size = gr.Dropdown(["832*480", "480*832"], value="832*480", label="Size")
274
+ task = gr.Dropdown(["depth", "pose"], value="depth", label="Task")
275
+ with gr.Row():
276
+ num_inference_steps = gr.Slider(1, 100, value=20, step=1, label="Inference Steps")
277
+ context_scale = gr.Slider(0, 2, value=1, step=0.1, label="Context Scale")
278
+ with gr.Row():
279
+ guidance = gr.Slider(1, 20, value=7.5, step=0.1, label="Guidance_Scale")
280
+ flow_shift = gr.Slider(1, 20, value=16, step=1, label="Shift")
281
+ with gr.Row():
282
+ seed = gr.Number(-1, label="Seed")
283
+ random_seed_btn = gr.Button("Random🎲Seed", variant="secondary", elem_id="centered_button")
284
+ with gr.Row():
285
+ enable_safety_checker = gr.Checkbox(True, label="Enable Safety Checker", interactive=True)
286
+ enable_fast_mode = gr.Checkbox(False, label="To enable the fast mode, please visit Wave Speed AI", interactive=False)
287
  with gr.Column(scale=1):
288
  video_output = gr.Video(label="Video Output", format="mp4", interactive=False, elem_classes=["video-preview"])
289
  generate_btn = gr.Button("Generate", variant="primary")
 
311
  inputs=[
312
  context_scale,
313
  enable_safety_checker,
314
+ enable_fast_mode,
315
  flow_shift,
316
  guidance,
317
  images,