Commit
·
99b1671
1
Parent(s):
367e06e
(wip)modify models
Browse files
app.py
CHANGED
@@ -129,6 +129,34 @@ CACHE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, CACHE_AUDIO_SUBDIR)
|
|
129 |
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
|
130 |
os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache subdir exists
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
# Store active TTS sessions
|
134 |
app.tts_sessions = {}
|
@@ -382,8 +410,13 @@ def generate_and_save_tts(text, model_id, output_dir):
|
|
382 |
temp_audio_path = None # Initialize to None
|
383 |
try:
|
384 |
app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'")
|
|
|
|
|
|
|
|
|
|
|
385 |
# If predict_tts saves file itself and returns path:
|
386 |
-
temp_audio_path = predict_tts(text, model_id)
|
387 |
app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}")
|
388 |
|
389 |
if not temp_audio_path or not os.path.exists(temp_audio_path):
|
@@ -396,7 +429,7 @@ def generate_and_save_tts(text, model_id, output_dir):
|
|
396 |
# Move the file generated by predict_tts to the target cache directory
|
397 |
shutil.move(temp_audio_path, dest_path)
|
398 |
app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}")
|
399 |
-
return dest_path
|
400 |
|
401 |
except Exception as e:
|
402 |
app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}")
|
@@ -407,7 +440,7 @@ def generate_and_save_tts(text, model_id, output_dir):
|
|
407 |
os.remove(temp_audio_path)
|
408 |
except OSError:
|
409 |
pass # Ignore error if file couldn't be removed
|
410 |
-
return None
|
411 |
|
412 |
|
413 |
def _generate_cache_entry_task(sentence):
|
@@ -445,8 +478,8 @@ def _generate_cache_entry_task(sentence):
|
|
445 |
future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR)
|
446 |
|
447 |
timeout_seconds = 120
|
448 |
-
audio_a_path = future_a.result(timeout=timeout_seconds)
|
449 |
-
audio_b_path = future_b.result(timeout=timeout_seconds)
|
450 |
|
451 |
if audio_a_path and audio_b_path:
|
452 |
with tts_cache_lock:
|
@@ -458,6 +491,8 @@ def _generate_cache_entry_task(sentence):
|
|
458 |
"model_b": model_b_id,
|
459 |
"audio_a": audio_a_path,
|
460 |
"audio_b": audio_b_path,
|
|
|
|
|
461 |
"created_at": datetime.utcnow(),
|
462 |
}
|
463 |
app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'")
|
@@ -1112,7 +1147,7 @@ def setup_periodic_tasks():
|
|
1112 |
|
1113 |
db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "instance/") # Get relative path
|
1114 |
preferences_repo_id = "kemuriririn/arena-preferences"
|
1115 |
-
database_repo_id = "kemuriririn/database-arena
|
1116 |
votes_dir = "./votes"
|
1117 |
|
1118 |
def sync_database():
|
@@ -1318,10 +1353,27 @@ def toggle_leaderboard_visibility():
|
|
1318 |
|
1319 |
@app.route("/api/tts/cached-sentences")
|
1320 |
def get_cached_sentences():
|
1321 |
-
"""Returns a list of sentences currently available in the TTS cache."""
|
1322 |
with tts_cache_lock:
|
1323 |
-
|
1324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1325 |
|
1326 |
|
1327 |
def get_weighted_random_models(
|
@@ -1414,6 +1466,7 @@ if __name__ == "__main__":
|
|
1414 |
except Exception as e:
|
1415 |
print(f"Error downloading database from HF dataset: {str(e)} ⚠️")
|
1416 |
|
|
|
1417 |
|
1418 |
db.create_all() # Create tables if they don't exist
|
1419 |
insert_initial_models()
|
|
|
129 |
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
|
130 |
os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache subdir exists
|
131 |
|
132 |
+
# --- 参考音色下载与管理 ---
|
133 |
+
REFERENCE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, "reference_audios")
|
134 |
+
REFERENCE_AUDIO_DATASET = os.getenv("REFERENCE_AUDIO_DATASET", "kemuriririn/arena-files")
|
135 |
+
REFERENCE_AUDIO_PATTERN = os.getenv("REFERENCE_AUDIO_PATTERN", "reference_audios/")
|
136 |
+
reference_audio_files = []
|
137 |
+
|
138 |
+
def download_reference_audios():
|
139 |
+
"""从 Hugging Face dataset 下载参考音频到本地目录,并生成文件列表"""
|
140 |
+
global reference_audio_files
|
141 |
+
os.makedirs(REFERENCE_AUDIO_DIR, exist_ok=True)
|
142 |
+
try:
|
143 |
+
api = HfApi(token=os.getenv("HF_TOKEN"))
|
144 |
+
files = api.list_repo_files(repo_id=REFERENCE_AUDIO_DATASET, repo_type="dataset")
|
145 |
+
# 只下载 wav 文件
|
146 |
+
wav_files = [f for f in files if f.startswith(REFERENCE_AUDIO_PATTERN) and f.endswith(".wav")]
|
147 |
+
for f in wav_files:
|
148 |
+
local_path = hf_hub_download(
|
149 |
+
repo_id=REFERENCE_AUDIO_DATASET,
|
150 |
+
filename=f,
|
151 |
+
repo_type="dataset",
|
152 |
+
local_dir=REFERENCE_AUDIO_DIR,
|
153 |
+
token=os.getenv("HF_TOKEN"),
|
154 |
+
)
|
155 |
+
reference_audio_files.append(local_path)
|
156 |
+
print(f"Downloaded {len(reference_audio_files)} reference audios.")
|
157 |
+
except Exception as e:
|
158 |
+
print(f"Error downloading reference audios: {e}")
|
159 |
+
reference_audio_files = []
|
160 |
|
161 |
# Store active TTS sessions
|
162 |
app.tts_sessions = {}
|
|
|
410 |
temp_audio_path = None # Initialize to None
|
411 |
try:
|
412 |
app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'")
|
413 |
+
# 随机选一个参考音频
|
414 |
+
reference_audio_path = None
|
415 |
+
if reference_audio_files:
|
416 |
+
reference_audio_path = random.choice(reference_audio_files)
|
417 |
+
app.logger.debug(f"[TTS Gen {model_id}] Using reference audio: {reference_audio_path}")
|
418 |
# If predict_tts saves file itself and returns path:
|
419 |
+
temp_audio_path = predict_tts(text, model_id, reference_audio_path=reference_audio_path)
|
420 |
app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}")
|
421 |
|
422 |
if not temp_audio_path or not os.path.exists(temp_audio_path):
|
|
|
429 |
# Move the file generated by predict_tts to the target cache directory
|
430 |
shutil.move(temp_audio_path, dest_path)
|
431 |
app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}")
|
432 |
+
return dest_path, reference_audio_path
|
433 |
|
434 |
except Exception as e:
|
435 |
app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}")
|
|
|
440 |
os.remove(temp_audio_path)
|
441 |
except OSError:
|
442 |
pass # Ignore error if file couldn't be removed
|
443 |
+
return None, None
|
444 |
|
445 |
|
446 |
def _generate_cache_entry_task(sentence):
|
|
|
478 |
future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR)
|
479 |
|
480 |
timeout_seconds = 120
|
481 |
+
audio_a_path, ref_a = future_a.result(timeout=timeout_seconds)
|
482 |
+
audio_b_path, ref_b = future_b.result(timeout=timeout_seconds)
|
483 |
|
484 |
if audio_a_path and audio_b_path:
|
485 |
with tts_cache_lock:
|
|
|
491 |
"model_b": model_b_id,
|
492 |
"audio_a": audio_a_path,
|
493 |
"audio_b": audio_b_path,
|
494 |
+
"ref_a": ref_a,
|
495 |
+
"ref_b": ref_b,
|
496 |
"created_at": datetime.utcnow(),
|
497 |
}
|
498 |
app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'")
|
|
|
1147 |
|
1148 |
db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "instance/") # Get relative path
|
1149 |
preferences_repo_id = "kemuriririn/arena-preferences"
|
1150 |
+
database_repo_id = "kemuriririn/database-arena"
|
1151 |
votes_dir = "./votes"
|
1152 |
|
1153 |
def sync_database():
|
|
|
1353 |
|
1354 |
@app.route("/api/tts/cached-sentences")
|
1355 |
def get_cached_sentences():
|
1356 |
+
"""Returns a list of sentences currently available in the TTS cache, with reference audio."""
|
1357 |
with tts_cache_lock:
|
1358 |
+
cached = [
|
1359 |
+
{
|
1360 |
+
"sentence": k,
|
1361 |
+
"model_a": v["model_a"],
|
1362 |
+
"model_b": v["model_b"],
|
1363 |
+
"ref_a": os.path.relpath(v["ref_a"], start=REFERENCE_AUDIO_DIR) if v.get("ref_a") else None,
|
1364 |
+
"ref_b": os.path.relpath(v["ref_b"], start=REFERENCE_AUDIO_DIR) if v.get("ref_b") else None,
|
1365 |
+
}
|
1366 |
+
for k, v in tts_cache.items()
|
1367 |
+
]
|
1368 |
+
return jsonify(cached)
|
1369 |
+
|
1370 |
+
@app.route("/api/tts/reference-audio/<filename>")
|
1371 |
+
def get_reference_audio(filename):
|
1372 |
+
"""试听参考音频"""
|
1373 |
+
file_path = os.path.join(REFERENCE_AUDIO_DIR, filename)
|
1374 |
+
if not os.path.exists(file_path):
|
1375 |
+
return jsonify({"error": "Reference audio not found"}), 404
|
1376 |
+
return send_file(file_path, mimetype="audio/wav")
|
1377 |
|
1378 |
|
1379 |
def get_weighted_random_models(
|
|
|
1466 |
except Exception as e:
|
1467 |
print(f"Error downloading database from HF dataset: {str(e)} ⚠️")
|
1468 |
|
1469 |
+
download_reference_audios()
|
1470 |
|
1471 |
db.create_all() # Create tables if they don't exist
|
1472 |
insert_initial_models()
|
models.py
CHANGED
@@ -446,7 +446,7 @@ def insert_initial_models():
|
|
446 |
name="Spark TTS",
|
447 |
model_type=ModelType.TTS,
|
448 |
is_open=False,
|
449 |
-
is_active=
|
450 |
model_url="https://github.com/SparkAudio/Spark-TTS",
|
451 |
),
|
452 |
# Model(
|
|
|
446 |
name="Spark TTS",
|
447 |
model_type=ModelType.TTS,
|
448 |
is_open=False,
|
449 |
+
is_active=True, # API stopped working
|
450 |
model_url="https://github.com/SparkAudio/Spark-TTS",
|
451 |
),
|
452 |
# Model(
|