Spaces:
Running
on
Zero
Running
on
Zero
Update modules/v2/vc_wrapper.py
Browse files- modules/v2/vc_wrapper.py +21 -80
modules/v2/vc_wrapper.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
import librosa
|
3 |
import torchaudio
|
@@ -52,56 +53,6 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
52 |
self.ar_max_content_len = 1500 # in num of narrow tokens
|
53 |
self.compile_len = 87 * self.dit_max_context_len
|
54 |
|
55 |
-
def forward_cfm(self, content_indices_wide, content_lens, mels, mel_lens, style_vectors):
|
56 |
-
device = content_indices_wide.device
|
57 |
-
B = content_indices_wide.size(0)
|
58 |
-
cond, _ = self.cfm_length_regulator(content_indices_wide, ylens=mel_lens)
|
59 |
-
|
60 |
-
# randomly set a length as prompt
|
61 |
-
prompt_len_max = mel_lens - 1
|
62 |
-
prompt_len = (torch.rand([B], device=device) * prompt_len_max).floor().to(dtype=torch.long)
|
63 |
-
prompt_len[torch.rand([B], device=device) < 0.1] = 0
|
64 |
-
|
65 |
-
loss = self.cfm(mels, mel_lens, prompt_len, cond, style_vectors)
|
66 |
-
return loss
|
67 |
-
|
68 |
-
def forward_ar(self, content_indices_narrow, content_indices_wide, content_lens):
|
69 |
-
device = content_indices_narrow.device
|
70 |
-
duration_reduced_narrow_tokens = []
|
71 |
-
duration_reduced_narrow_lens = []
|
72 |
-
for bib in range(content_indices_narrow.size(0)):
|
73 |
-
reduced, reduced_len = self.duration_reduction_func(content_indices_narrow[bib])
|
74 |
-
duration_reduced_narrow_tokens.append(reduced)
|
75 |
-
duration_reduced_narrow_lens.append(reduced_len)
|
76 |
-
duration_reduced_narrow_tokens = torch.nn.utils.rnn.pad_sequence(duration_reduced_narrow_tokens,
|
77 |
-
batch_first=True, padding_value=0).to(device)
|
78 |
-
duration_reduced_narrow_lens = torch.LongTensor(duration_reduced_narrow_lens).to(device)
|
79 |
-
|
80 |
-
# interpolate speech token to match acoustic feature length
|
81 |
-
cond, _ = self.ar_length_regulator(duration_reduced_narrow_tokens)
|
82 |
-
loss = self.ar(cond, duration_reduced_narrow_lens, content_indices_wide, content_lens)
|
83 |
-
return loss
|
84 |
-
|
85 |
-
def forward(self, waves_16k, mels, wave_lens_16k, mel_lens, forward_ar=False, forward_cfm=True):
|
86 |
-
"""
|
87 |
-
Forward pass for the model.
|
88 |
-
"""
|
89 |
-
# extract wide content features as both AR and CFM models use them
|
90 |
-
with torch.no_grad():
|
91 |
-
_, content_indices_wide, content_lens = self.content_extractor_wide(waves_16k, wave_lens_16k)
|
92 |
-
if forward_ar:
|
93 |
-
# extract narrow content features for AR model
|
94 |
-
_, content_indices_narrow, _ = self.content_extractor_narrow(waves_16k, wave_lens_16k, ssl_model=self.content_extractor_wide.ssl_model)
|
95 |
-
loss_ar = self.forward_ar(content_indices_narrow.clone(), content_indices_wide.clone(), content_lens)
|
96 |
-
else:
|
97 |
-
loss_ar = torch.tensor(0.0, device=waves_16k.device, dtype=waves_16k.dtype)
|
98 |
-
if forward_cfm:
|
99 |
-
style_vectors = self.compute_style(waves_16k, wave_lens_16k)
|
100 |
-
loss_cfm = self.forward_cfm(content_indices_wide, content_lens, mels, mel_lens, style_vectors)
|
101 |
-
else:
|
102 |
-
loss_cfm = torch.tensor(0.0, device=waves_16k.device, dtype=waves_16k.dtype)
|
103 |
-
return loss_ar, loss_cfm
|
104 |
-
|
105 |
def compile_ar(self):
|
106 |
"""
|
107 |
Compile the AR model for inference.
|
@@ -258,28 +209,24 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
258 |
repo_id=DEFAULT_REPO_ID,
|
259 |
model_filename=DEFAULT_CFM_CHECKPOINT,
|
260 |
)
|
261 |
-
else:
|
262 |
-
print(f"Loading CFM checkpoint from {cfm_checkpoint_path}...")
|
263 |
if ar_checkpoint_path is None:
|
264 |
ar_checkpoint_path = load_custom_model_from_hf(
|
265 |
repo_id=DEFAULT_REPO_ID,
|
266 |
model_filename=DEFAULT_AR_CHECKPOINT,
|
267 |
)
|
268 |
-
else:
|
269 |
-
print(f"Loading AR checkpoint from {ar_checkpoint_path}...")
|
270 |
# cfm
|
271 |
cfm_checkpoint = torch.load(cfm_checkpoint_path, map_location="cpu")
|
272 |
cfm_length_regulator_state_dict = self.strip_prefix(cfm_checkpoint["net"]['length_regulator'], "module.")
|
273 |
cfm_state_dict = self.strip_prefix(cfm_checkpoint["net"]['cfm'], "module.")
|
274 |
-
|
275 |
-
|
276 |
|
277 |
# ar
|
278 |
ar_checkpoint = torch.load(ar_checkpoint_path, map_location="cpu")
|
279 |
ar_length_regulator_state_dict = self.strip_prefix(ar_checkpoint["net"]['length_regulator'], "module.")
|
280 |
ar_state_dict = self.strip_prefix(ar_checkpoint["net"]['ar'], "module.")
|
281 |
-
|
282 |
-
|
283 |
|
284 |
# content extractor
|
285 |
content_extractor_narrow_checkpoint_path = load_custom_model_from_hf(
|
@@ -308,26 +255,13 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
308 |
def setup_ar_caches(self, max_batch_size=1, max_seq_len=4096, dtype=torch.float32, device=torch.device("cpu")):
|
309 |
self.ar.setup_caches(max_batch_size=max_batch_size, max_seq_len=max_seq_len, dtype=dtype, device=device)
|
310 |
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
num_mel_bins=80,
|
319 |
-
dither=0,
|
320 |
-
sample_frequency=16000)
|
321 |
-
feat = feat - feat.mean(dim=0, keepdim=True)
|
322 |
-
feat_list.append(feat)
|
323 |
-
max_feat_len = max([feat.size(0) for feat in feat_list])
|
324 |
-
feat_lens = torch.tensor([feat.size(0) for feat in feat_list], dtype=torch.int32).to(waves_16k.device) // 2
|
325 |
-
feat_list = [
|
326 |
-
torch.nn.functional.pad(feat, (0, 0, 0, max_feat_len - feat.size(0)), value=float(feat.min().item()))
|
327 |
-
for feat in feat_list
|
328 |
-
]
|
329 |
-
feat = torch.stack(feat_list, dim=0)
|
330 |
-
style = self.style_encoder(feat, feat_lens)
|
331 |
return style
|
332 |
|
333 |
@torch.no_grad()
|
@@ -490,6 +424,7 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
490 |
|
491 |
return content_indices
|
492 |
|
|
|
493 |
@torch.no_grad()
|
494 |
@torch.inference_mode()
|
495 |
def convert_voice_with_streaming(
|
@@ -623,7 +558,10 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
623 |
|
624 |
if stream_output and mp3_bytes is not None:
|
625 |
yield mp3_bytes, full_audio
|
|
|
626 |
if should_break:
|
|
|
|
|
627 |
break
|
628 |
else:
|
629 |
cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device))
|
@@ -641,7 +579,7 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
641 |
if self.dit_compiled:
|
642 |
cat_condition = torch.nn.functional.pad(cat_condition,
|
643 |
(0, 0, 0, self.compile_len - cat_condition.size(1),), value=0)
|
644 |
-
with torch.autocast(device_type=device.type, dtype=
|
645 |
# Voice Conversion
|
646 |
vc_mel = self.cfm.inference(
|
647 |
cat_condition,
|
@@ -660,5 +598,8 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
660 |
|
661 |
if stream_output and mp3_bytes is not None:
|
662 |
yield mp3_bytes, full_audio
|
|
|
663 |
if should_break:
|
664 |
-
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
import torch
|
3 |
import librosa
|
4 |
import torchaudio
|
|
|
53 |
self.ar_max_content_len = 1500 # in num of narrow tokens
|
54 |
self.compile_len = 87 * self.dit_max_context_len
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def compile_ar(self):
|
57 |
"""
|
58 |
Compile the AR model for inference.
|
|
|
209 |
repo_id=DEFAULT_REPO_ID,
|
210 |
model_filename=DEFAULT_CFM_CHECKPOINT,
|
211 |
)
|
|
|
|
|
212 |
if ar_checkpoint_path is None:
|
213 |
ar_checkpoint_path = load_custom_model_from_hf(
|
214 |
repo_id=DEFAULT_REPO_ID,
|
215 |
model_filename=DEFAULT_AR_CHECKPOINT,
|
216 |
)
|
|
|
|
|
217 |
# cfm
|
218 |
cfm_checkpoint = torch.load(cfm_checkpoint_path, map_location="cpu")
|
219 |
cfm_length_regulator_state_dict = self.strip_prefix(cfm_checkpoint["net"]['length_regulator'], "module.")
|
220 |
cfm_state_dict = self.strip_prefix(cfm_checkpoint["net"]['cfm'], "module.")
|
221 |
+
self.cfm.load_state_dict(cfm_state_dict, strict=False)
|
222 |
+
self.cfm_length_regulator.load_state_dict(cfm_length_regulator_state_dict, strict=False)
|
223 |
|
224 |
# ar
|
225 |
ar_checkpoint = torch.load(ar_checkpoint_path, map_location="cpu")
|
226 |
ar_length_regulator_state_dict = self.strip_prefix(ar_checkpoint["net"]['length_regulator'], "module.")
|
227 |
ar_state_dict = self.strip_prefix(ar_checkpoint["net"]['ar'], "module.")
|
228 |
+
self.ar.load_state_dict(ar_state_dict, strict=False)
|
229 |
+
self.ar_length_regulator.load_state_dict(ar_length_regulator_state_dict, strict=False)
|
230 |
|
231 |
# content extractor
|
232 |
content_extractor_narrow_checkpoint_path = load_custom_model_from_hf(
|
|
|
255 |
def setup_ar_caches(self, max_batch_size=1, max_seq_len=4096, dtype=torch.float32, device=torch.device("cpu")):
|
256 |
self.ar.setup_caches(max_batch_size=max_batch_size, max_seq_len=max_seq_len, dtype=dtype, device=device)
|
257 |
|
258 |
+
def compute_style(self, waves_16k: torch.Tensor):
|
259 |
+
feat = torchaudio.compliance.kaldi.fbank(waves_16k,
|
260 |
+
num_mel_bins=80,
|
261 |
+
dither=0,
|
262 |
+
sample_frequency=16000)
|
263 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
264 |
+
style = self.style_encoder(feat.unsqueeze(0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
return style
|
266 |
|
267 |
@torch.no_grad()
|
|
|
424 |
|
425 |
return content_indices
|
426 |
|
427 |
+
@spaces.GPU
|
428 |
@torch.no_grad()
|
429 |
@torch.inference_mode()
|
430 |
def convert_voice_with_streaming(
|
|
|
558 |
|
559 |
if stream_output and mp3_bytes is not None:
|
560 |
yield mp3_bytes, full_audio
|
561 |
+
|
562 |
if should_break:
|
563 |
+
if not stream_output:
|
564 |
+
return full_audio
|
565 |
break
|
566 |
else:
|
567 |
cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device))
|
|
|
579 |
if self.dit_compiled:
|
580 |
cat_condition = torch.nn.functional.pad(cat_condition,
|
581 |
(0, 0, 0, self.compile_len - cat_condition.size(1),), value=0)
|
582 |
+
with torch.autocast(device_type=device.type, dtype=dtype):
|
583 |
# Voice Conversion
|
584 |
vc_mel = self.cfm.inference(
|
585 |
cat_condition,
|
|
|
598 |
|
599 |
if stream_output and mp3_bytes is not None:
|
600 |
yield mp3_bytes, full_audio
|
601 |
+
|
602 |
if should_break:
|
603 |
+
if not stream_output:
|
604 |
+
return full_audio
|
605 |
+
break
|