kevinwang676 commited on
Commit
a35e94c
Β·
verified Β·
1 Parent(s): d4069e8

Update speech_edit.py

Browse files
Files changed (1) hide show
  1. speech_edit.py +255 -181
speech_edit.py CHANGED
@@ -1,183 +1,257 @@
1
- import os
2
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
- import torch.nn.functional as F
5
- import torchaudio
6
- from vocos import Vocos
7
-
8
- from model import CFM, UNetT, DiT, MMDiT
9
- from model.utils import (
10
- load_checkpoint,
11
- get_tokenizer,
12
- convert_char_to_pinyin,
13
- save_spectrogram,
14
- )
15
-
16
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
17
-
18
-
19
- # --------------------- Dataset Settings -------------------- #
20
-
21
- target_sample_rate = 24000
22
- n_mel_channels = 100
23
- hop_length = 256
24
- target_rms = 0.1
25
-
26
- tokenizer = "pinyin"
27
- dataset_name = "Emilia_ZH_EN"
28
-
29
-
30
- # ---------------------- infer setting ---------------------- #
31
-
32
- seed = None # int | None
33
-
34
- exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
35
- ckpt_step = 1200000
36
-
37
- nfe_step = 32 # 16, 32
38
- cfg_strength = 2.
39
- ode_method = 'euler' # euler | midpoint
40
- sway_sampling_coef = -1.
41
- speed = 1.
42
-
43
- if exp_name == "F5TTS_Base":
44
- model_cls = DiT
45
- model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
46
-
47
- elif exp_name == "E2TTS_Base":
48
- model_cls = UNetT
49
- model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
50
-
51
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
52
- output_dir = "tests"
53
-
54
- # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
55
- # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
56
- # [write the origin_text into a file, e.g. tests/test_edit.txt]
57
- # ctc-forced-aligner --audio_path "tests/ref_audio/test_en_1_ref_short.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
58
- # [result will be saved at same path of audio file]
59
- # [--language "zho" for Chinese, "eng" for English]
60
- # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
61
-
62
- audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
63
- origin_text = "Some call me nature, others call me mother nature."
64
- target_text = "Some call me optimist, others call me realist."
65
- parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds
66
- fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds
67
-
68
- # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
69
- # origin_text = "ε―ΉοΌŒθΏ™ε°±ζ˜―ζˆ‘οΌŒδΈ‡δΊΊζ•¬δ»°ηš„ε€ͺδΉ™ηœŸδΊΊγ€‚"
70
- # target_text = "ε―ΉοΌŒι‚£ε°±ζ˜―δ½ οΌŒδΈ‡δΊΊζ•¬δ»°ηš„ε€ͺη™½ι‡‘ζ˜Ÿγ€‚"
71
- # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
72
- # fix_duration = None # use origin text duration
73
-
74
-
75
- # -------------------------------------------------#
76
-
77
- use_ema = True
78
-
79
- if not os.path.exists(output_dir):
80
- os.makedirs(output_dir)
81
-
82
- # Vocoder model
83
- local = False
84
- if local:
85
- vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
86
- vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
87
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
88
- vocos.load_state_dict(state_dict)
89
-
90
- vocos.eval()
91
- else:
92
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
93
-
94
- # Tokenizer
95
- vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
96
-
97
- # Model
98
- model = CFM(
99
- transformer = model_cls(
100
- **model_cfg,
101
- text_num_embeds = vocab_size,
102
- mel_dim = n_mel_channels
103
- ),
104
- mel_spec_kwargs = dict(
105
- target_sample_rate = target_sample_rate,
106
- n_mel_channels = n_mel_channels,
107
- hop_length = hop_length,
108
- ),
109
- odeint_kwargs = dict(
110
- method = ode_method,
111
- ),
112
- vocab_char_map = vocab_char_map,
113
- ).to(device)
114
-
115
- model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
116
-
117
- # Audio
118
- audio, sr = torchaudio.load(audio_to_edit)
119
- if audio.shape[0] > 1:
120
- audio = torch.mean(audio, dim=0, keepdim=True)
121
- rms = torch.sqrt(torch.mean(torch.square(audio)))
122
- if rms < target_rms:
123
- audio = audio * target_rms / rms
124
- if sr != target_sample_rate:
125
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
126
- audio = resampler(audio)
127
- offset = 0
128
- audio_ = torch.zeros(1, 0)
129
- edit_mask = torch.zeros(1, 0, dtype=torch.bool)
130
- for part in parts_to_edit:
131
- start, end = part
132
- part_dur = end - start if fix_duration is None else fix_duration.pop(0)
133
- part_dur = part_dur * target_sample_rate
134
- start = start * target_sample_rate
135
- audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
136
- edit_mask = torch.cat((edit_mask,
137
- torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool),
138
- torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
139
- ), dim = -1)
140
- offset = end * target_sample_rate
141
- # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
142
- edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
143
- audio = audio.to(device)
144
- edit_mask = edit_mask.to(device)
145
-
146
- # Text
147
- text_list = [target_text]
148
- if tokenizer == "pinyin":
149
- final_text_list = convert_char_to_pinyin(text_list)
150
- else:
151
- final_text_list = [text_list]
152
- print(f"text : {text_list}")
153
- print(f"pinyin: {final_text_list}")
154
-
155
- # Duration
156
- ref_audio_len = 0
157
- duration = audio.shape[-1] // hop_length
158
-
159
- # Inference
160
- with torch.inference_mode():
161
- generated, trajectory = model.sample(
162
- cond = audio,
163
- text = final_text_list,
164
- duration = duration,
165
- steps = nfe_step,
166
- cfg_strength = cfg_strength,
167
- sway_sampling_coef = sway_sampling_coef,
168
- seed = seed,
169
- edit_mask = edit_mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
- print(f"Generated mel: {generated.shape}")
172
-
173
- # Final result
174
- generated = generated.to(torch.float32)
175
- generated = generated[:, ref_audio_len:, :]
176
- generated_mel_spec = generated.permute(0, 2, 1)
177
- generated_wave = vocos.decode(generated_mel_spec.cpu())
178
- if rms < target_rms:
179
- generated_wave = generated_wave * rms / target_rms
180
-
181
- save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
182
- torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
183
- print(f"Generated wav: {generated_wave.shape}")
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding: utf‑8
3
+ """
4
+ CosyVoice gRPC back‑end – updated to mirror the FastAPI logic
5
+ * loads CosyVoice2 with TRT / FP16 first (falls back to CosyVoice)
6
+ * inference_zero_shot ➜ adds stream=False + speed
7
+ * inference_instruct ➜ keeps original β€œspeaker‑ID” path
8
+ * inference_instruct2 ➜ new: prompt‑audio + speed (no speaker‑ID)
9
+ """
10
+
11
+ import io, os, tempfile, requests, soundfile as sf, torchaudio
12
+ import sys
13
+ from concurrent import futures
14
+ import argparse
15
+ import logging
16
+ import grpc
17
+ import numpy as np
18
  import torch
19
+
20
+ import cosyvoice_pb2
21
+ import cosyvoice_pb2_grpc
22
+
23
+ # ────────────────────────────────────────────────────────────────────────────────
24
+ # set‑up
25
+ # ────────────────────────────────────────────────────────────────────────────────
26
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
27
+ logging.basicConfig(level=logging.INFO,
28
+ format="%(asctime)s %(levelname)s %(message)s")
29
+
30
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
31
+ sys.path.extend([
32
+ f"{ROOT_DIR}/../../..",
33
+ f"{ROOT_DIR}/../../../third_party/Matcha-TTS",
34
+ ])
35
+
36
+ from cosyvoice.cli.cosyvoice import CosyVoice2 # noqa: E402
37
+
38
+
39
+ # ────────────────────────────────────────────────────────────────────────────────
40
+ # helpers
41
+ # ────────────────────────────────────────────────────────────────────────────────
42
+ def _bytes_to_tensor(wav_bytes: bytes) -> torch.Tensor:
43
+ """
44
+ Convert int16 little‑endian PCM bytes β†’ torch.FloatTensor in range [‑1,1]
45
+ """
46
+ speech = torch.from_numpy(
47
+ np.frombuffer(wav_bytes, dtype=np.int16)
48
+ ).unsqueeze(0).float() / (2 ** 15)
49
+ return speech # [1,β€―T]
50
+
51
+
52
+ def _yield_audio(model_output):
53
+ """
54
+ Generator that converts CosyVoice output β†’ protobuf Response messages.
55
+ """
56
+ for seg in model_output:
57
+ pcm16 = (seg["tts_speech"].numpy() * (2 ** 15)).astype(np.int16)
58
+ resp = cosyvoice_pb2.Response(tts_audio=pcm16.tobytes())
59
+ yield resp
60
+
61
+
62
+ # ────────────────────────────────────────────────────────────────────────────────
63
+ # gRPC service
64
+ # ────────────────────────────────────────────────────────────────────────────────
65
+ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
66
+ def __init__(self, args):
67
+ # try CosyVoice2 first (preferred runtime: TRT / FP16)
68
+ try:
69
+ self.cosyvoice = CosyVoice2(args.model_dir,
70
+ load_jit=False,
71
+ load_trt=True,
72
+ fp16=True)
73
+ logging.info("Loaded CosyVoice2 (TRT / FP16).")
74
+ except Exception:
75
+ raise TypeError("No valid CosyVoice model found!")
76
+
77
+ # ---------------------------------------------------------------------
78
+ # single bi‑di streaming RPC
79
+ # ---------------------------------------------------------------------
80
+ def Inference(self, request, context):
81
+ """Route to the correct model call based on the oneof field present."""
82
+ # 1. Supervised fine‑tuning
83
+ if request.HasField("sft_request"):
84
+ logging.info("Received SFT inference request")
85
+ mo = self.cosyvoice.inference_sft(
86
+ request.sft_request.tts_text,
87
+ request.sft_request.spk_id
88
+ )
89
+ yield from _yield_audio(mo)
90
+ return
91
+
92
+ # 2. Zero‑shot speaker cloning (bytes OR S3 URL)
93
+ if request.HasField("zero_shot_request"):
94
+ logging.info("Received zero‑shot inference request")
95
+ zr = request.zero_shot_request
96
+ tmp_path = None # initialise so we can delete later
97
+
98
+ try:
99
+ # ───── determine payload type ──────────────────────────────────────
100
+ if zr.prompt_audio.startswith(b'http'):
101
+ # β€”β€” remote URL β€”β€” ---------------------------------------------
102
+ url = zr.prompt_audio.decode('utf‑8')
103
+ logging.info("Downloading prompt audio from %s", url)
104
+ resp = requests.get(url, timeout=10)
105
+ resp.raise_for_status()
106
+
107
+ # save to a temp file
108
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
109
+ f.write(resp.content)
110
+ tmp_path = f.name
111
+
112
+ # load, mono‑ise, resample β†’ tensor [1,β€―T]
113
+ wav, sr = sf.read(tmp_path, dtype="float32")
114
+ if wav.ndim > 1:
115
+ wav = wav.mean(axis=1)
116
+ if sr != 16_000:
117
+ wav = torchaudio.functional.resample(
118
+ torch.from_numpy(wav).unsqueeze(0), sr, 16_000
119
+ )[0].numpy()
120
+ prompt = torch.from_numpy(wav).unsqueeze(0)
121
+
122
+ else:
123
+ # β€”β€” legacy raw PCM bytes β€”β€” -----------------------------------
124
+ prompt = _bytes_to_tensor(zr.prompt_audio)
125
+
126
+ # ───── call the model ──────────────────────────────────────────────
127
+ speed = getattr(zr, "speed", 1.0)
128
+ mo = self.cosyvoice.inference_zero_shot(
129
+ zr.tts_text,
130
+ zr.prompt_text,
131
+ prompt,
132
+ stream=False,
133
+ speed=speed,
134
+ )
135
+
136
+ finally:
137
+ # clean up any temporary file we created
138
+ if tmp_path and os.path.exists(tmp_path):
139
+ try:
140
+ os.remove(tmp_path)
141
+ except Exception as e:
142
+ logging.warning("Could not remove temp file %s: %s", tmp_path, e)
143
+
144
+ yield from _yield_audio(mo)
145
+ return
146
+
147
+ # 3. Cross‑lingual
148
+ if request.HasField("cross_lingual_request"):
149
+ logging.info("Received cross‑lingual inference request")
150
+ cr = request.cross_lingual_request
151
+ prompt = _bytes_to_tensor(cr.prompt_audio)
152
+ mo = self.cosyvoice.inference_cross_lingual(
153
+ cr.tts_text,
154
+ prompt
155
+ )
156
+ yield from _yield_audio(mo)
157
+ return
158
+
159
+ # 4. Instruction‑TTS (two flavours)
160
+ if request.HasField("instruct_request"):
161
+ ir = request.instruct_request
162
+
163
+ # ──────────────────────────────────────────────────────────────────
164
+ # 4‑a) instruct‑2 (has prompt_audio β†’ bytes OR S3 URL)
165
+ # ──────────────────────────────────────────────────────────────────
166
+ if ir.HasField("prompt_audio"):
167
+ logging.info("Received instruct‑2 inference request")
168
+
169
+ tmp_path = None
170
+ try:
171
+ if ir.prompt_audio.startswith(b'http'):
172
+ # treat as URL, download then load
173
+ url = ir.prompt_audio.decode('utf‑8')
174
+ logging.info("Downloading prompt audio from %s", url)
175
+ resp = requests.get(url, timeout=10)
176
+ resp.raise_for_status()
177
+
178
+ with tempfile.NamedTemporaryFile(delete=False,
179
+ suffix=".wav") as f:
180
+ f.write(resp.content)
181
+ tmp_path = f.name
182
+
183
+ wav, sr = sf.read(tmp_path, dtype='float32')
184
+ if wav.ndim > 1:
185
+ wav = wav.mean(axis=1)
186
+ if sr != 16_000:
187
+ wav = torchaudio.functional.resample(
188
+ torch.from_numpy(wav).unsqueeze(0), sr, 16_000
189
+ )[0].numpy()
190
+ prompt = torch.from_numpy(wav).unsqueeze(0)
191
+
192
+ else:
193
+ # legacy raw‑bytes payload
194
+ prompt = _bytes_to_tensor(ir.prompt_audio)
195
+
196
+ speed = getattr(ir, "speed", 1.0)
197
+ mo = self.cosyvoice.inference_instruct2(
198
+ ir.tts_text,
199
+ ir.instruct_text,
200
+ prompt,
201
+ stream=False,
202
+ speed=speed
203
+ )
204
+
205
+ finally:
206
+ if tmp_path and os.path.exists(tmp_path):
207
+ try:
208
+ os.remove(tmp_path)
209
+ except Exception as e:
210
+ logging.warning("Could not remove temp file %s: %s",
211
+ tmp_path, e)
212
+
213
+ # ──────────────────────────────────────────────────────────────────
214
+ # 4‑b) classic instruct (speaker‑ID, no prompt audio)
215
+ # ──────────────────────────────────────────────────────────────────
216
+ else:
217
+ logging.info("Received instruct inference request")
218
+ mo = self.cosyvoice.inference_instruct(
219
+ ir.tts_text,
220
+ ir.spk_id,
221
+ ir.instruct_text
222
+ )
223
+
224
+ yield from _yield_audio(mo)
225
+ return
226
+
227
+ # unknown request type
228
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT,
229
+ "Unsupported request type in oneof field.")
230
+
231
+
232
+ # ────────────────────────────────────────────────────────────────────────────────
233
+ # entry‑point
234
+ # ────────────────────────────────────────────────────────────────────────────────
235
+ def serve(args):
236
+ server = grpc.server(
237
+ futures.ThreadPoolExecutor(max_workers=args.max_conc),
238
+ maximum_concurrent_rpcs=args.max_conc
239
+ )
240
+ cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(
241
+ CosyVoiceServiceImpl(args), server
242
  )
243
+ server.add_insecure_port(f"0.0.0.0:{args.port}")
244
+ server.start()
245
+ logging.info("CosyVoice gRPC server listening on 0.0.0.0:%d", args.port)
246
+ server.wait_for_termination()
247
+
248
+
249
+ if __name__ == "__main__":
250
+ parser = argparse.ArgumentParser()
251
+ parser.add_argument("--port", type=int, default=8000)
252
+ parser.add_argument("--max_conc", type=int, default=4,
253
+ help="maximum concurrent requests / threads")
254
+ parser.add_argument("--model_dir", type=str,
255
+ default="pretrained_models/CosyVoice2-0.5B",
256
+ help="local path or ModelScope repo id")
257
+ serve(parser.parse_args())