kemuriririn commited on
Commit
782d74b
·
1 Parent(s): 9e032ec

(wip)debug

Browse files
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. CosyVoice2-0.5B +1 -0
  3. requirements.txt +2 -1
  4. tts.py +31 -23
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "CosyVoice2-0.5B"]
2
+ path = CosyVoice2-0.5B
3
+ url = [email protected]:spaces/FunAudioLLM/CosyVoice2-0.5B
CosyVoice2-0.5B ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit b1769de266d0f5f94ea6e8dfb1df1519a2407be2
requirements.txt CHANGED
@@ -12,4 +12,5 @@ gunicorn
12
  waitress
13
  fal-client
14
  gradio_client==1.7.0
15
- git+https://github.com/playht/pyht
 
 
12
  waitress
13
  fal-client
14
  gradio_client==1.7.0
15
+ git+https://github.com/playht/pyht
16
+ modelscope
tts.py CHANGED
@@ -2,6 +2,8 @@
2
  # Currently just use current TTS router.
3
  import os
4
  import json
 
 
5
  from dotenv import load_dotenv
6
  import fal_client
7
  import requests
@@ -232,35 +234,41 @@ def predict_spark_tts(text, reference_audio_path=None):
232
 
233
 
234
  def predict_cosyvoice_tts(text, reference_audio_path=None):
235
- from gradio_client import Client, file, handle_file
236
- client = Client("https://iic-cosyvoice2-0-5b.ms.show/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  if not reference_audio_path:
238
  raise ValueError("cosyvoice-2.0 需要 reference_audio_path")
239
- prompt_wav = handle_file(reference_audio_path)
240
- # 先识别参考音频文本
241
- recog_result = client.predict(
242
- prompt_wav=file(reference_audio_path),
243
- api_name="/prompt_wav_recognition"
244
- )
245
- print("cosyvoice-2.0 prompt_wav_recognition result:", recog_result)
246
- prompt_text = recog_result if isinstance(recog_result, str) else str(recog_result)
247
- result = client.predict(
248
- tts_text=text,
249
- mode_checkbox_group="3s极速复刻",
250
- prompt_text=prompt_text,
251
- prompt_wav_upload=prompt_wav,
252
- prompt_wav_record=prompt_wav,
253
- instruct_text="",
254
- seed=0,
255
- api_name="/generate_audio"
256
- )
257
- print("cosyvoice-2.0 result:", result)
258
- return result
259
 
260
 
261
  def predict_maskgct(text, reference_audio_path=None):
262
  from gradio_client import Client, handle_file
263
- client = Client("amphion/maskgct")
264
  if not reference_audio_path:
265
  raise ValueError("maskgct 需要 reference_audio_path")
266
  prompt_wav = handle_file(reference_audio_path)
 
2
  # Currently just use current TTS router.
3
  import os
4
  import json
5
+ import sys
6
+
7
  from dotenv import load_dotenv
8
  import fal_client
9
  import requests
 
234
 
235
 
236
  def predict_cosyvoice_tts(text, reference_audio_path=None):
237
+ import tempfile
238
+ import soundfile as sf
239
+ from modelscope import snapshot_download
240
+ model_dir = os.path.join(os.path.dirname(__file__), "CosyVoice2-0.5B", "pretrained_models", "CosyVoice2-0.5B")
241
+ if not os.path.exists(model_dir) or not os.listdir(model_dir):
242
+ snapshot_download('iic/CosyVoice2-0.5B', local_dir=model_dir)
243
+ sys.path.append(os.path.join(os.path.dirname(__file__), "CosyVoice2-0.5B"))
244
+ from cosyvoice.cli.cosyvoice import CosyVoice2
245
+ from cosyvoice.utils.file_utils import load_wav
246
+
247
+ # 全局模型初始化
248
+ global _cosyvoice_model
249
+ if '_cosyvoice_model' not in globals() or _cosyvoice_model is None:
250
+ _cosyvoice_model = CosyVoice2(model_dir)
251
+ model = _cosyvoice_model
252
+
253
  if not reference_audio_path:
254
  raise ValueError("cosyvoice-2.0 需要 reference_audio_path")
255
+ # 读取参考音频
256
+ prompt_speech_16k = load_wav(reference_audio_path, 16000)
257
+ # 参考文本可选,这里不做ASR,直接传空字符串
258
+ prompt_text = ""
259
+ # 推理
260
+ result = None
261
+ for i in model.inference_zero_shot(text, prompt_text, prompt_speech_16k):
262
+ result = i['tts_speech'].numpy().flatten()
263
+ # 保存为临时wav
264
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
265
+ sf.write(temp_file.name, result, 24000)
266
+ return temp_file.name
 
 
 
 
 
 
 
 
267
 
268
 
269
  def predict_maskgct(text, reference_audio_path=None):
270
  from gradio_client import Client, handle_file
271
+ client = Client("cocktailpeanut/maskgct")
272
  if not reference_audio_path:
273
  raise ValueError("maskgct 需要 reference_audio_path")
274
  prompt_wav = handle_file(reference_audio_path)