Plachta commited on
Commit
84a7891
Β·
verified Β·
1 Parent(s): 5e78e49

Update modules/v2/vc_wrapper.py

Browse files
Files changed (1) hide show
  1. modules/v2/vc_wrapper.py +21 -80
modules/v2/vc_wrapper.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import librosa
3
  import torchaudio
@@ -52,56 +53,6 @@ class VoiceConversionWrapper(torch.nn.Module):
52
  self.ar_max_content_len = 1500 # in num of narrow tokens
53
  self.compile_len = 87 * self.dit_max_context_len
54
 
55
- def forward_cfm(self, content_indices_wide, content_lens, mels, mel_lens, style_vectors):
56
- device = content_indices_wide.device
57
- B = content_indices_wide.size(0)
58
- cond, _ = self.cfm_length_regulator(content_indices_wide, ylens=mel_lens)
59
-
60
- # randomly set a length as prompt
61
- prompt_len_max = mel_lens - 1
62
- prompt_len = (torch.rand([B], device=device) * prompt_len_max).floor().to(dtype=torch.long)
63
- prompt_len[torch.rand([B], device=device) < 0.1] = 0
64
-
65
- loss = self.cfm(mels, mel_lens, prompt_len, cond, style_vectors)
66
- return loss
67
-
68
- def forward_ar(self, content_indices_narrow, content_indices_wide, content_lens):
69
- device = content_indices_narrow.device
70
- duration_reduced_narrow_tokens = []
71
- duration_reduced_narrow_lens = []
72
- for bib in range(content_indices_narrow.size(0)):
73
- reduced, reduced_len = self.duration_reduction_func(content_indices_narrow[bib])
74
- duration_reduced_narrow_tokens.append(reduced)
75
- duration_reduced_narrow_lens.append(reduced_len)
76
- duration_reduced_narrow_tokens = torch.nn.utils.rnn.pad_sequence(duration_reduced_narrow_tokens,
77
- batch_first=True, padding_value=0).to(device)
78
- duration_reduced_narrow_lens = torch.LongTensor(duration_reduced_narrow_lens).to(device)
79
-
80
- # interpolate speech token to match acoustic feature length
81
- cond, _ = self.ar_length_regulator(duration_reduced_narrow_tokens)
82
- loss = self.ar(cond, duration_reduced_narrow_lens, content_indices_wide, content_lens)
83
- return loss
84
-
85
- def forward(self, waves_16k, mels, wave_lens_16k, mel_lens, forward_ar=False, forward_cfm=True):
86
- """
87
- Forward pass for the model.
88
- """
89
- # extract wide content features as both AR and CFM models use them
90
- with torch.no_grad():
91
- _, content_indices_wide, content_lens = self.content_extractor_wide(waves_16k, wave_lens_16k)
92
- if forward_ar:
93
- # extract narrow content features for AR model
94
- _, content_indices_narrow, _ = self.content_extractor_narrow(waves_16k, wave_lens_16k, ssl_model=self.content_extractor_wide.ssl_model)
95
- loss_ar = self.forward_ar(content_indices_narrow.clone(), content_indices_wide.clone(), content_lens)
96
- else:
97
- loss_ar = torch.tensor(0.0, device=waves_16k.device, dtype=waves_16k.dtype)
98
- if forward_cfm:
99
- style_vectors = self.compute_style(waves_16k, wave_lens_16k)
100
- loss_cfm = self.forward_cfm(content_indices_wide, content_lens, mels, mel_lens, style_vectors)
101
- else:
102
- loss_cfm = torch.tensor(0.0, device=waves_16k.device, dtype=waves_16k.dtype)
103
- return loss_ar, loss_cfm
104
-
105
  def compile_ar(self):
106
  """
107
  Compile the AR model for inference.
@@ -258,28 +209,24 @@ class VoiceConversionWrapper(torch.nn.Module):
258
  repo_id=DEFAULT_REPO_ID,
259
  model_filename=DEFAULT_CFM_CHECKPOINT,
260
  )
261
- else:
262
- print(f"Loading CFM checkpoint from {cfm_checkpoint_path}...")
263
  if ar_checkpoint_path is None:
264
  ar_checkpoint_path = load_custom_model_from_hf(
265
  repo_id=DEFAULT_REPO_ID,
266
  model_filename=DEFAULT_AR_CHECKPOINT,
267
  )
268
- else:
269
- print(f"Loading AR checkpoint from {ar_checkpoint_path}...")
270
  # cfm
271
  cfm_checkpoint = torch.load(cfm_checkpoint_path, map_location="cpu")
272
  cfm_length_regulator_state_dict = self.strip_prefix(cfm_checkpoint["net"]['length_regulator'], "module.")
273
  cfm_state_dict = self.strip_prefix(cfm_checkpoint["net"]['cfm'], "module.")
274
- missing_keys, unexpected_keys = self.cfm.load_state_dict(cfm_state_dict, strict=False)
275
- missing_keys, unexpected_keys = self.cfm_length_regulator.load_state_dict(cfm_length_regulator_state_dict, strict=False)
276
 
277
  # ar
278
  ar_checkpoint = torch.load(ar_checkpoint_path, map_location="cpu")
279
  ar_length_regulator_state_dict = self.strip_prefix(ar_checkpoint["net"]['length_regulator'], "module.")
280
  ar_state_dict = self.strip_prefix(ar_checkpoint["net"]['ar'], "module.")
281
- missing_keys, unexpected_keys = self.ar.load_state_dict(ar_state_dict, strict=False)
282
- missing_keys, unexpected_keys = self.ar_length_regulator.load_state_dict(ar_length_regulator_state_dict, strict=False)
283
 
284
  # content extractor
285
  content_extractor_narrow_checkpoint_path = load_custom_model_from_hf(
@@ -308,26 +255,13 @@ class VoiceConversionWrapper(torch.nn.Module):
308
  def setup_ar_caches(self, max_batch_size=1, max_seq_len=4096, dtype=torch.float32, device=torch.device("cpu")):
309
  self.ar.setup_caches(max_batch_size=max_batch_size, max_seq_len=max_seq_len, dtype=dtype, device=device)
310
 
311
- @torch.no_grad()
312
- def compute_style(self, waves_16k: torch.Tensor, wave_lens_16k: torch.Tensor = None):
313
- if wave_lens_16k is None:
314
- wave_lens_16k = torch.tensor([waves_16k.size(-1)], dtype=torch.int32).to(waves_16k.device)
315
- feat_list = []
316
- for bib in range(waves_16k.size(0)):
317
- feat = torchaudio.compliance.kaldi.fbank(waves_16k[bib:bib + 1, :wave_lens_16k[bib]],
318
- num_mel_bins=80,
319
- dither=0,
320
- sample_frequency=16000)
321
- feat = feat - feat.mean(dim=0, keepdim=True)
322
- feat_list.append(feat)
323
- max_feat_len = max([feat.size(0) for feat in feat_list])
324
- feat_lens = torch.tensor([feat.size(0) for feat in feat_list], dtype=torch.int32).to(waves_16k.device) // 2
325
- feat_list = [
326
- torch.nn.functional.pad(feat, (0, 0, 0, max_feat_len - feat.size(0)), value=float(feat.min().item()))
327
- for feat in feat_list
328
- ]
329
- feat = torch.stack(feat_list, dim=0)
330
- style = self.style_encoder(feat, feat_lens)
331
  return style
332
 
333
  @torch.no_grad()
@@ -490,6 +424,7 @@ class VoiceConversionWrapper(torch.nn.Module):
490
 
491
  return content_indices
492
 
 
493
  @torch.no_grad()
494
  @torch.inference_mode()
495
  def convert_voice_with_streaming(
@@ -623,7 +558,10 @@ class VoiceConversionWrapper(torch.nn.Module):
623
 
624
  if stream_output and mp3_bytes is not None:
625
  yield mp3_bytes, full_audio
 
626
  if should_break:
 
 
627
  break
628
  else:
629
  cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device))
@@ -641,7 +579,7 @@ class VoiceConversionWrapper(torch.nn.Module):
641
  if self.dit_compiled:
642
  cat_condition = torch.nn.functional.pad(cat_condition,
643
  (0, 0, 0, self.compile_len - cat_condition.size(1),), value=0)
644
- with torch.autocast(device_type=device.type, dtype=torch.float32): # force CFM to use float32
645
  # Voice Conversion
646
  vc_mel = self.cfm.inference(
647
  cat_condition,
@@ -660,5 +598,8 @@ class VoiceConversionWrapper(torch.nn.Module):
660
 
661
  if stream_output and mp3_bytes is not None:
662
  yield mp3_bytes, full_audio
 
663
  if should_break:
664
- break
 
 
 
1
+ import spaces
2
  import torch
3
  import librosa
4
  import torchaudio
 
53
  self.ar_max_content_len = 1500 # in num of narrow tokens
54
  self.compile_len = 87 * self.dit_max_context_len
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def compile_ar(self):
57
  """
58
  Compile the AR model for inference.
 
209
  repo_id=DEFAULT_REPO_ID,
210
  model_filename=DEFAULT_CFM_CHECKPOINT,
211
  )
 
 
212
  if ar_checkpoint_path is None:
213
  ar_checkpoint_path = load_custom_model_from_hf(
214
  repo_id=DEFAULT_REPO_ID,
215
  model_filename=DEFAULT_AR_CHECKPOINT,
216
  )
 
 
217
  # cfm
218
  cfm_checkpoint = torch.load(cfm_checkpoint_path, map_location="cpu")
219
  cfm_length_regulator_state_dict = self.strip_prefix(cfm_checkpoint["net"]['length_regulator'], "module.")
220
  cfm_state_dict = self.strip_prefix(cfm_checkpoint["net"]['cfm'], "module.")
221
+ self.cfm.load_state_dict(cfm_state_dict, strict=False)
222
+ self.cfm_length_regulator.load_state_dict(cfm_length_regulator_state_dict, strict=False)
223
 
224
  # ar
225
  ar_checkpoint = torch.load(ar_checkpoint_path, map_location="cpu")
226
  ar_length_regulator_state_dict = self.strip_prefix(ar_checkpoint["net"]['length_regulator'], "module.")
227
  ar_state_dict = self.strip_prefix(ar_checkpoint["net"]['ar'], "module.")
228
+ self.ar.load_state_dict(ar_state_dict, strict=False)
229
+ self.ar_length_regulator.load_state_dict(ar_length_regulator_state_dict, strict=False)
230
 
231
  # content extractor
232
  content_extractor_narrow_checkpoint_path = load_custom_model_from_hf(
 
255
  def setup_ar_caches(self, max_batch_size=1, max_seq_len=4096, dtype=torch.float32, device=torch.device("cpu")):
256
  self.ar.setup_caches(max_batch_size=max_batch_size, max_seq_len=max_seq_len, dtype=dtype, device=device)
257
 
258
+ def compute_style(self, waves_16k: torch.Tensor):
259
+ feat = torchaudio.compliance.kaldi.fbank(waves_16k,
260
+ num_mel_bins=80,
261
+ dither=0,
262
+ sample_frequency=16000)
263
+ feat = feat - feat.mean(dim=0, keepdim=True)
264
+ style = self.style_encoder(feat.unsqueeze(0))
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  return style
266
 
267
  @torch.no_grad()
 
424
 
425
  return content_indices
426
 
427
+ @spaces.GPU
428
  @torch.no_grad()
429
  @torch.inference_mode()
430
  def convert_voice_with_streaming(
 
558
 
559
  if stream_output and mp3_bytes is not None:
560
  yield mp3_bytes, full_audio
561
+
562
  if should_break:
563
+ if not stream_output:
564
+ return full_audio
565
  break
566
  else:
567
  cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device))
 
579
  if self.dit_compiled:
580
  cat_condition = torch.nn.functional.pad(cat_condition,
581
  (0, 0, 0, self.compile_len - cat_condition.size(1),), value=0)
582
+ with torch.autocast(device_type=device.type, dtype=dtype):
583
  # Voice Conversion
584
  vc_mel = self.cfm.inference(
585
  cat_condition,
 
598
 
599
  if stream_output and mp3_bytes is not None:
600
  yield mp3_bytes, full_audio
601
+
602
  if should_break:
603
+ if not stream_output:
604
+ return full_audio
605
+ break