kemuriririn commited on
Commit
99b1671
·
1 Parent(s): 367e06e

(wip)modify models

Browse files
Files changed (2) hide show
  1. app.py +62 -9
  2. models.py +1 -1
app.py CHANGED
@@ -129,6 +129,34 @@ CACHE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, CACHE_AUDIO_SUBDIR)
129
  os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
130
  os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache subdir exists
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  # Store active TTS sessions
134
  app.tts_sessions = {}
@@ -382,8 +410,13 @@ def generate_and_save_tts(text, model_id, output_dir):
382
  temp_audio_path = None # Initialize to None
383
  try:
384
  app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'")
 
 
 
 
 
385
  # If predict_tts saves file itself and returns path:
386
- temp_audio_path = predict_tts(text, model_id)
387
  app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}")
388
 
389
  if not temp_audio_path or not os.path.exists(temp_audio_path):
@@ -396,7 +429,7 @@ def generate_and_save_tts(text, model_id, output_dir):
396
  # Move the file generated by predict_tts to the target cache directory
397
  shutil.move(temp_audio_path, dest_path)
398
  app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}")
399
- return dest_path
400
 
401
  except Exception as e:
402
  app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}")
@@ -407,7 +440,7 @@ def generate_and_save_tts(text, model_id, output_dir):
407
  os.remove(temp_audio_path)
408
  except OSError:
409
  pass # Ignore error if file couldn't be removed
410
- return None
411
 
412
 
413
  def _generate_cache_entry_task(sentence):
@@ -445,8 +478,8 @@ def _generate_cache_entry_task(sentence):
445
  future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR)
446
 
447
  timeout_seconds = 120
448
- audio_a_path = future_a.result(timeout=timeout_seconds)
449
- audio_b_path = future_b.result(timeout=timeout_seconds)
450
 
451
  if audio_a_path and audio_b_path:
452
  with tts_cache_lock:
@@ -458,6 +491,8 @@ def _generate_cache_entry_task(sentence):
458
  "model_b": model_b_id,
459
  "audio_a": audio_a_path,
460
  "audio_b": audio_b_path,
 
 
461
  "created_at": datetime.utcnow(),
462
  }
463
  app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'")
@@ -1112,7 +1147,7 @@ def setup_periodic_tasks():
1112
 
1113
  db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "instance/") # Get relative path
1114
  preferences_repo_id = "kemuriririn/arena-preferences"
1115
- database_repo_id = "kemuriririn/database-arena-v2"
1116
  votes_dir = "./votes"
1117
 
1118
  def sync_database():
@@ -1318,10 +1353,27 @@ def toggle_leaderboard_visibility():
1318
 
1319
  @app.route("/api/tts/cached-sentences")
1320
  def get_cached_sentences():
1321
- """Returns a list of sentences currently available in the TTS cache."""
1322
  with tts_cache_lock:
1323
- cached_keys = list(tts_cache.keys())
1324
- return jsonify(cached_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1325
 
1326
 
1327
  def get_weighted_random_models(
@@ -1414,6 +1466,7 @@ if __name__ == "__main__":
1414
  except Exception as e:
1415
  print(f"Error downloading database from HF dataset: {str(e)} ⚠️")
1416
 
 
1417
 
1418
  db.create_all() # Create tables if they don't exist
1419
  insert_initial_models()
 
129
  os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
130
  os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache subdir exists
131
 
132
+ # --- 参考音色下载与管理 ---
133
+ REFERENCE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, "reference_audios")
134
+ REFERENCE_AUDIO_DATASET = os.getenv("REFERENCE_AUDIO_DATASET", "kemuriririn/arena-files")
135
+ REFERENCE_AUDIO_PATTERN = os.getenv("REFERENCE_AUDIO_PATTERN", "reference_audios/")
136
+ reference_audio_files = []
137
+
138
+ def download_reference_audios():
139
+ """从 Hugging Face dataset 下载参考音频到本地目录,并生成文件列表"""
140
+ global reference_audio_files
141
+ os.makedirs(REFERENCE_AUDIO_DIR, exist_ok=True)
142
+ try:
143
+ api = HfApi(token=os.getenv("HF_TOKEN"))
144
+ files = api.list_repo_files(repo_id=REFERENCE_AUDIO_DATASET, repo_type="dataset")
145
+ # 只下载 wav 文件
146
+ wav_files = [f for f in files if f.startswith(REFERENCE_AUDIO_PATTERN) and f.endswith(".wav")]
147
+ for f in wav_files:
148
+ local_path = hf_hub_download(
149
+ repo_id=REFERENCE_AUDIO_DATASET,
150
+ filename=f,
151
+ repo_type="dataset",
152
+ local_dir=REFERENCE_AUDIO_DIR,
153
+ token=os.getenv("HF_TOKEN"),
154
+ )
155
+ reference_audio_files.append(local_path)
156
+ print(f"Downloaded {len(reference_audio_files)} reference audios.")
157
+ except Exception as e:
158
+ print(f"Error downloading reference audios: {e}")
159
+ reference_audio_files = []
160
 
161
  # Store active TTS sessions
162
  app.tts_sessions = {}
 
410
  temp_audio_path = None # Initialize to None
411
  try:
412
  app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'")
413
+ # 随机选一个参考音频
414
+ reference_audio_path = None
415
+ if reference_audio_files:
416
+ reference_audio_path = random.choice(reference_audio_files)
417
+ app.logger.debug(f"[TTS Gen {model_id}] Using reference audio: {reference_audio_path}")
418
  # If predict_tts saves file itself and returns path:
419
+ temp_audio_path = predict_tts(text, model_id, reference_audio_path=reference_audio_path)
420
  app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}")
421
 
422
  if not temp_audio_path or not os.path.exists(temp_audio_path):
 
429
  # Move the file generated by predict_tts to the target cache directory
430
  shutil.move(temp_audio_path, dest_path)
431
  app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}")
432
+ return dest_path, reference_audio_path
433
 
434
  except Exception as e:
435
  app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}")
 
440
  os.remove(temp_audio_path)
441
  except OSError:
442
  pass # Ignore error if file couldn't be removed
443
+ return None, None
444
 
445
 
446
  def _generate_cache_entry_task(sentence):
 
478
  future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR)
479
 
480
  timeout_seconds = 120
481
+ audio_a_path, ref_a = future_a.result(timeout=timeout_seconds)
482
+ audio_b_path, ref_b = future_b.result(timeout=timeout_seconds)
483
 
484
  if audio_a_path and audio_b_path:
485
  with tts_cache_lock:
 
491
  "model_b": model_b_id,
492
  "audio_a": audio_a_path,
493
  "audio_b": audio_b_path,
494
+ "ref_a": ref_a,
495
+ "ref_b": ref_b,
496
  "created_at": datetime.utcnow(),
497
  }
498
  app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'")
 
1147
 
1148
  db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "instance/") # Get relative path
1149
  preferences_repo_id = "kemuriririn/arena-preferences"
1150
+ database_repo_id = "kemuriririn/database-arena"
1151
  votes_dir = "./votes"
1152
 
1153
  def sync_database():
 
1353
 
1354
  @app.route("/api/tts/cached-sentences")
1355
  def get_cached_sentences():
1356
+ """Returns a list of sentences currently available in the TTS cache, with reference audio."""
1357
  with tts_cache_lock:
1358
+ cached = [
1359
+ {
1360
+ "sentence": k,
1361
+ "model_a": v["model_a"],
1362
+ "model_b": v["model_b"],
1363
+ "ref_a": os.path.relpath(v["ref_a"], start=REFERENCE_AUDIO_DIR) if v.get("ref_a") else None,
1364
+ "ref_b": os.path.relpath(v["ref_b"], start=REFERENCE_AUDIO_DIR) if v.get("ref_b") else None,
1365
+ }
1366
+ for k, v in tts_cache.items()
1367
+ ]
1368
+ return jsonify(cached)
1369
+
1370
+ @app.route("/api/tts/reference-audio/<filename>")
1371
+ def get_reference_audio(filename):
1372
+ """试听参考音频"""
1373
+ file_path = os.path.join(REFERENCE_AUDIO_DIR, filename)
1374
+ if not os.path.exists(file_path):
1375
+ return jsonify({"error": "Reference audio not found"}), 404
1376
+ return send_file(file_path, mimetype="audio/wav")
1377
 
1378
 
1379
  def get_weighted_random_models(
 
1466
  except Exception as e:
1467
  print(f"Error downloading database from HF dataset: {str(e)} ⚠️")
1468
 
1469
+ download_reference_audios()
1470
 
1471
  db.create_all() # Create tables if they don't exist
1472
  insert_initial_models()
models.py CHANGED
@@ -446,7 +446,7 @@ def insert_initial_models():
446
  name="Spark TTS",
447
  model_type=ModelType.TTS,
448
  is_open=False,
449
- is_active=False, # API stopped working
450
  model_url="https://github.com/SparkAudio/Spark-TTS",
451
  ),
452
  # Model(
 
446
  name="Spark TTS",
447
  model_type=ModelType.TTS,
448
  is_open=False,
449
+ is_active=True, # API stopped working
450
  model_url="https://github.com/SparkAudio/Spark-TTS",
451
  ),
452
  # Model(