junseok commited on
Commit
08cc398
·
1 Parent(s): ce904ba

new commit

Browse files
Files changed (5) hide show
  1. app.py +13 -28
  2. predict.py +37 -69
  3. score.py +44 -58
  4. ssl_ecapa_model.py +65 -8
  5. wavlm_large_cfg.pt +0 -0
app.py CHANGED
@@ -1,44 +1,29 @@
1
- from score import load_model
2
- from predict import loadWav
3
  import torch
4
  import torch.nn.functional as F
 
 
 
5
  import gradio as gr
6
- import time
7
 
8
- model = load_model("wavlm_ecapa.model")
9
  model.eval()
10
 
11
  def calc_voxsim(inp_path, ref_path):
12
- start = time.time()
13
- inp_wavs, inp_wav = loadWav(inp_path)
14
- ref_wavs, ref_wav = loadWav(ref_path)
15
- print("loadWav time: ", time.time() - start)
16
-
17
- inp_wavs = torch.FloatTensor(inp_wavs)
18
- inp_wav = torch.FloatTensor(inp_wav)
19
- ref_wavs = torch.FloatTensor(ref_wavs)
20
- ref_wav = torch.FloatTensor(ref_wav)
21
- print("torch.FloatTensor time: ", time.time() - start)
22
 
23
  with torch.no_grad():
24
- input_emb_1 = F.normalize(model.forward(inp_wavs), p=2, dim=1)
25
- print("input_emb_1 time: ", time.time() - start)
26
- input_emb_2 = F.normalize(model.forward(inp_wav), p=2, dim=1)
27
- print("input_emb_2 time: ", time.time() - start)
28
- ref_emb_1 = F.normalize(model.forward(ref_wavs), p=2, dim=1)
29
- print("ref_emb_1 time: ", time.time() - start)
30
- ref_emb_2 = F.normalize(model.forward(ref_wav), p=2, dim=1)
31
- print("ref_emb_2 time: ", time.time() - start)
32
 
33
- score_1 = torch.mean(torch.matmul(input_emb_1, ref_emb_1.T))
34
- score_2 = torch.mean(torch.matmul(input_emb_2, ref_emb_2.T))
35
- score = (score_1 + score_2) / 2
36
- print("score time: ", time.time() - start)
37
  return score.detach().cpu().numpy()
38
 
39
  description = """
40
  Voice similarity demo using wavlm-ecapa model, which is trained on Voxsim dataset.
41
  This demo only accepts .wav format. Best at 16 kHz sampling rate.
 
42
 
43
  Paper is available [here](https://arxiv.org/abs/2407.18505)
44
  """
@@ -46,8 +31,8 @@ Paper is available [here](https://arxiv.org/abs/2407.18505)
46
  iface = gr.Interface(
47
  fn=calc_voxsim,
48
  inputs=(
49
- gr.Audio(label="Input Audio"),
50
- gr.Audio(label="Reference Audio")
51
  ),
52
  outputs="text",
53
  title="voice similarity with VoxSim",
 
1
+ import os
 
2
  import torch
3
  import torch.nn.functional as F
4
+ from ssl_ecapa_model import SSL_ECAPA_TDNN
5
+ from score import loadModel
6
+ from predict import loadWav
7
  import gradio as gr
 
8
 
9
+ model = loadModel('voxsim_wavlm_ecapa.model')
10
  model.eval()
11
 
12
  def calc_voxsim(inp_path, ref_path):
13
+ inp_wav = loadWav(inp_path, max_frames=0)
14
+ ref_wav = loadWav(ref_path, max_frames=0)
 
 
 
 
 
 
 
 
15
 
16
  with torch.no_grad():
17
+ input_emb = F.normalize(model.forward(inp_wav), p=2, dim=1)
18
+ ref_emb = F.normalize(model.forward(ref_wav), p=2, dim=1)
 
 
 
 
 
 
19
 
20
+ score = torch.matmul(input_emb, ref_emb.T)
 
 
 
21
  return score.detach().cpu().numpy()
22
 
23
  description = """
24
  Voice similarity demo using wavlm-ecapa model, which is trained on Voxsim dataset.
25
  This demo only accepts .wav format. Best at 16 kHz sampling rate.
26
+ The inference process of this Spaces demo is suboptimal due to the limitations of a basic CPU. To obtain an accurate score, refer to the "[voxsim_trainer](https://github.com/kaistmm/voxsim_trainer)" repository and run the code via the CLI.
27
 
28
  Paper is available [here](https://arxiv.org/abs/2407.18505)
29
  """
 
31
  iface = gr.Interface(
32
  fn=calc_voxsim,
33
  inputs=(
34
+ gr.Audio(label="Input Audio", type='filepath'),
35
+ gr.Audio(label="Reference Audio", type='filepath')
36
  ),
37
  outputs="text",
38
  title="voice similarity with VoxSim",
predict.py CHANGED
@@ -2,10 +2,9 @@ import argparse
2
  import pathlib
3
  import tqdm
4
  from torch.utils.data import Dataset, DataLoader
5
- import librosa
6
- import numpy
7
- from score import Score
8
  import torch
 
9
 
10
  import warnings
11
  warnings.filterwarnings("ignore")
@@ -13,93 +12,61 @@ warnings.filterwarnings("ignore")
13
 
14
  def get_arg():
15
  parser = argparse.ArgumentParser()
16
- parser.add_argument("--bs", required=False, default=None, type=int)
17
- parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str)
18
- parser.add_argument("--ckpt_path", required=False, default="wavlm_ecapa.model", type=pathlib.Path)
19
- parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path)
20
- parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path)
21
- parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path)
22
- parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path)
23
- parser.add_argument("--out_path", required=True, type=pathlib.Path)
24
- parser.add_argument("--num_workers", required=False, default=0, type=int)
25
  return parser.parse_args()
26
 
27
 
28
- def loadWav(filename, max_frames: int = 400):
29
-
30
- # Maximum audio length
31
- max_audio = max_frames * 160 + 240
32
-
33
- # Read wav file and convert to torch tensor
34
- if type(filename) == tuple:
35
- sr, audio = filename
36
- audio = librosa.util.normalize(audio)
37
- print(numpy.linalg.norm(audio))
38
- else:
39
- audio, sr = librosa.load(filename, sr=16000)
40
- audio_org = audio.copy()
41
-
42
- audiosize = audio.shape[0]
43
-
44
- if audiosize <= max_audio:
45
- shortage = max_audio - audiosize + 1
46
- audio = numpy.pad(audio, (0, shortage), 'wrap')
47
- audiosize = audio.shape[0]
48
-
49
- startframe = numpy.linspace(0,audiosize-max_audio,num=10)
50
-
51
- feats = []
52
- for asf in startframe:
53
- feats.append(audio[int(asf):int(asf)+max_audio])
54
-
55
- feat = numpy.stack(feats,axis=0).astype(numpy.float32)
56
-
57
- return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32))
58
-
59
-
60
  class AudioDataset(Dataset):
61
  def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400):
62
- self.inp_wavlist = list(inp_dir_path.glob("*.wav"))
63
- self.ref_wavlist = list(ref_dir_path.glob("*.wav"))
64
- assert len(self.inp_wavlist) == len(self.ref_wavlist)
 
 
 
 
 
 
 
65
  self.inp_wavlist.sort()
66
- self.ref_wavlist.sort()
67
- _, self.sr = librosa.load(self.inp_wavlist[0], sr=None)
68
  self.max_audio = max_frames * 160 + 240
69
 
70
  def __len__(self):
71
  return len(self.inp_wavlist)
72
 
73
  def __getitem__(self, idx):
74
- inp_wavs, inp_wav = loadWav(self.inp_wavlist[idx])
75
- ref_wavs, ref_wav = loadWav(self.ref_wavlist[idx])
76
  return inp_wavs, inp_wav, ref_wavs, ref_wav
77
 
78
  def main():
79
  args = get_arg()
80
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
  if args.mode == "predict_file":
82
- assert args.inp_path is not None
83
- assert args.ref_path is not None
84
- assert args.inp_dir is None
85
- assert args.ref_dir is None
86
  assert args.inp_path.exists()
87
- assert args.inp_path.is_file()
88
  assert args.ref_path.exists()
 
89
  assert args.ref_path.is_file()
90
  inp_wavs, inp_wav = loadWav(args.inp_path)
91
  ref_wavs, ref_wav = loadWav(args.ref_path)
92
  scorer = Score(ckpt_path=args.ckpt_path, device=device)
93
  score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
94
- print("Voxsim score: ", score[0])
95
  with open(args.out_path, "w") as fw:
96
- fw.write(str(score[0]))
97
  else:
98
  assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
99
  assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir."
100
- assert args.bs is not None, "bs is required when mode is predict_dir."
101
- assert args.inp_path is None, "inp_path should be None"
102
- assert args.ref_path is None, "ref_path should be None"
103
  assert args.inp_dir.exists()
104
  assert args.ref_dir.exists()
105
  assert args.inp_dir.is_dir()
@@ -107,17 +74,18 @@ def main():
107
  dataset = AudioDataset(args.inp_dir, args.ref_dir)
108
  loader = DataLoader(
109
  dataset,
110
- batch_size=args.bs,
111
  shuffle=False,
112
  num_workers=args.num_workers)
113
  scorer = Score(ckpt_path=args.ckpt_path, device=device)
114
- with open(args.out_path, 'w'):
115
- pass
116
- for batch in tqdm.tqdm(loader):
117
- scores = score.score(batch.to(device))
118
- with open(args.out_path, 'a') as fw:
119
- for s in scores:
120
- fw.write(str(s) + "\n")
 
121
  print("save to ", args.out_path)
122
 
123
  if __name__ == "__main__":
 
2
  import pathlib
3
  import tqdm
4
  from torch.utils.data import Dataset, DataLoader
5
+ from score import loadWav, Score
 
 
6
  import torch
7
+ import os
8
 
9
  import warnings
10
  warnings.filterwarnings("ignore")
 
12
 
13
  def get_arg():
14
  parser = argparse.ArgumentParser()
15
+ parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str, help="predict mode")
16
+ parser.add_argument("--ckpt_path", required=False, default="voxsim_wavlm_ecapa.model", type=pathlib.Path, help="path to the model checkpoint")
17
+ parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path, help="input directory when predict_dir mode")
18
+ parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path, help="reference directory when predict_dir mode")
19
+ parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path, help="input file when predict_file mode")
20
+ parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path, help="reference file when predict_file mode")
21
+ parser.add_argument("--out_path", required=True, type=pathlib.Path, help="output path")
22
+ parser.add_argument("--num_workers", required=False, default=4, type=int, help="number of workers for dataloader")
 
23
  return parser.parse_args()
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class AudioDataset(Dataset):
27
  def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400):
28
+ self.inp_dir_path = inp_dir_path
29
+ self.ref_dir_path = ref_dir_path
30
+ self.inp_wavlist = [file for file in os.listdir(inp_dir_path) if file.endswith(".wav")]
31
+ inp_wavset = set(self.inp_wavlist)
32
+ ref_wavset = set([file for file in os.listdir(ref_dir_path) if file.endswith(".wav")])
33
+ diff = inp_wavset - ref_wavset
34
+ if diff:
35
+ diff = list(diff)
36
+ diff.sort()
37
+ raise ValueError(f"Files {diff} are in inp_dir but not in ref_dir.")
38
  self.inp_wavlist.sort()
39
+
 
40
  self.max_audio = max_frames * 160 + 240
41
 
42
  def __len__(self):
43
  return len(self.inp_wavlist)
44
 
45
  def __getitem__(self, idx):
46
+ inp_wavs, inp_wav = loadWav(os.path.join(self.inp_dir_path, self.inp_wavlist[idx]))
47
+ ref_wavs, ref_wav = loadWav(os.path.join(self.ref_dir_path, self.inp_wavlist[idx]))
48
  return inp_wavs, inp_wav, ref_wavs, ref_wav
49
 
50
  def main():
51
  args = get_arg()
52
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
  if args.mode == "predict_file":
54
+ assert args.inp_path is not None, "inp_path is required when mode is predict_file."
55
+ assert args.ref_path is not None, "ref_path is required when mode is predict_file."
 
 
56
  assert args.inp_path.exists()
 
57
  assert args.ref_path.exists()
58
+ assert args.inp_path.is_file()
59
  assert args.ref_path.is_file()
60
  inp_wavs, inp_wav = loadWav(args.inp_path)
61
  ref_wavs, ref_wav = loadWav(args.ref_path)
62
  scorer = Score(ckpt_path=args.ckpt_path, device=device)
63
  score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
64
+ print("VoxSIM score: ", score)
65
  with open(args.out_path, "w") as fw:
66
+ fw.write(str(score))
67
  else:
68
  assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
69
  assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir."
 
 
 
70
  assert args.inp_dir.exists()
71
  assert args.ref_dir.exists()
72
  assert args.inp_dir.is_dir()
 
74
  dataset = AudioDataset(args.inp_dir, args.ref_dir)
75
  loader = DataLoader(
76
  dataset,
77
+ batch_size=1,
78
  shuffle=False,
79
  num_workers=args.num_workers)
80
  scorer = Score(ckpt_path=args.ckpt_path, device=device)
81
+ avg_score = []
82
+ with open(args.out_path, 'w') as fw:
83
+ for batch in tqdm.tqdm(loader):
84
+ inp_wavs, inp_wav, ref_wavs, ref_wav = batch
85
+ score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
86
+ avg_score.append(score)
87
+ fw.write(str(score) + "\n")
88
+ print("Average VoxSIM score: ", sum(avg_score)/len(avg_score))
89
  print("save to ", args.out_path)
90
 
91
  if __name__ == "__main__":
score.py CHANGED
@@ -1,35 +1,49 @@
1
  import os
 
 
2
  import torch
3
  import torch.nn.functional as F
4
  from ssl_ecapa_model import SSL_ECAPA_TDNN
5
  from huggingface_hub import hf_hub_download
6
 
7
 
8
- def load_model(ckpt_path):
9
- model = SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large')
10
- load_parameters(model, ckpt_path)
11
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
- def load_parameters(model, ckpt_path):
15
- model_state = model.state_dict()
16
  if not os.path.isfile(ckpt_path):
17
  print("Downloading model from Hugging Face Hub...")
18
- new_ckpt_path = hf_hub_download(repo_id="junseok520/voxsim-models", filename=ckpt_path, local_dir="./")
19
- ckpt_path = new_ckpt_path
20
- loaded_state = torch.load(ckpt_path, map_location='cpu', weights_only=True)
21
-
22
- for name, param in loaded_state.items():
23
- if name.startswith('__S__.'):
24
- if name[6:] in model_state:
25
- model_state[name[6:]].copy_(param)
26
- else:
27
- print("{} is not in the model.".format(name[6:]))
28
- else:
29
- if name in model_state:
30
- model_state[name].copy_(param)
31
- else:
32
- print("{} is not in the model.".format(name))
33
 
34
 
35
  class Score:
@@ -37,7 +51,7 @@ class Score:
37
 
38
  def __init__(
39
  self,
40
- ckpt_path: str = "wavlm_ecapa.pt",
41
  device: str = "gpu"):
42
  """
43
  Args:
@@ -47,43 +61,15 @@ class Score:
47
  """
48
  print(f"Using device: {device}")
49
  self.device = device
50
- self.model = load_model(ckpt_path).to(self.device)
51
  self.model.eval()
52
 
53
  def score(self, inp_wavs: torch.tensor, inp_wav: torch.tensor, ref_wavs: torch.tensor, ref_wav: torch.tensor) -> torch.tensor:
54
- """
55
- Args:
56
- wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
57
- the model processes the input as a single audio clip. The model
58
- performs batch processing when len(wavs) == 3.
59
- """
60
- # if len(wavs.shape) == 1:
61
- # out_wavs = wavs.unsqueeze(0).unsqueeze(0)
62
- # elif len(wavs.shape) == 2:
63
- # out_wavs = wavs.unsqueeze(0)
64
- # elif len(wavs.shape) == 3:
65
- # out_wavs = wavs
66
- # else:
67
- # raise ValueError('Dimension of input tensor needs to be <= 3.')
68
-
69
- if len(inp_wavs.shape) == 2:
70
- bs = 1
71
- elif len(inp_wavs.shape) == 3:
72
- bs = inp_wavs.shape[0]
73
- else:
74
- raise ValueError('Dimension of input tensor needs to be <= 3.')
75
 
76
  inp_wavs = inp_wavs.reshape(-1, inp_wavs.shape[-1]).to(self.device)
77
  inp_wav = inp_wav.reshape(-1, inp_wav.shape[-1]).to(self.device)
78
  ref_wavs = ref_wavs.reshape(-1, ref_wavs.shape[-1]).to(self.device)
79
  ref_wav = ref_wav.reshape(-1, ref_wav.shape[-1]).to(self.device)
80
-
81
- # assert inp_wavs.shape[1] == 10
82
- # assert ref_wavs.shape[1] == 10
83
- # assert inp_wav.shape[1] == 1
84
- # assert ref_wav.shape[1] == 1
85
-
86
- # import pdb; pdb.set_trace()
87
 
88
  with torch.no_grad():
89
  input_emb_1 = F.normalize(self.model.forward(inp_wavs), p=2, dim=1).detach()
@@ -92,15 +78,15 @@ class Score:
92
  ref_emb_2 = F.normalize(self.model.forward(ref_wav), p=2, dim=1).detach()
93
 
94
  emb_size = input_emb_1.shape[-1]
95
- input_emb_1 = input_emb_1.reshape(bs, -1, emb_size)
96
- input_emb_2 = input_emb_2.reshape(bs, -1, emb_size)
97
- ref_emb_1 = ref_emb_1.reshape(bs, -1, emb_size)
98
- ref_emb_2 = ref_emb_2.reshape(bs, -1, emb_size)
99
 
100
- score_1 = torch.mean(torch.bmm(input_emb_1, ref_emb_1.transpose(1,2)), dim=(1,2))
101
- score_2 = torch.mean(torch.bmm(input_emb_2, ref_emb_2.transpose(1,2)), dim=(1,2))
102
  score = (score_1 + score_2) / 2
103
- score = score.detach().cpu().numpy()
104
 
105
  return score
106
 
 
1
  import os
2
+ import numpy
3
+ import librosa
4
  import torch
5
  import torch.nn.functional as F
6
  from ssl_ecapa_model import SSL_ECAPA_TDNN
7
  from huggingface_hub import hf_hub_download
8
 
9
 
10
+ def loadWav(filename, max_frames: int = 400, num_eval: int = 10):
11
+
12
+ # Maximum audio length
13
+ max_audio = max_frames * 160 + 240
14
+
15
+ # Read wav file and convert to torch tensor
16
+ audio, sr = librosa.load(filename, sr=16000)
17
+ audio_org = audio.copy()
18
+
19
+ audiosize = audio.shape[0]
20
+
21
+ if audiosize <= max_audio:
22
+ shortage = max_audio - audiosize + 1
23
+ audio = numpy.pad(audio, (0, shortage), 'wrap')
24
+ audiosize = audio.shape[0]
25
+
26
+ startframe = numpy.linspace(0,audiosize-max_audio, num=num_eval)
27
+
28
+ feats = []
29
+ if max_frames == 0:
30
+ feats.append(audio)
31
+ feat = numpy.stack(feats,axis=0).astype(numpy.float32)
32
+ return torch.FloatTensor(feat)
33
+ else:
34
+ for asf in startframe:
35
+ feats.append(audio[int(asf):int(asf)+max_audio])
36
+ feat = numpy.stack(feats,axis=0).astype(numpy.float32)
37
+ return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32))
38
 
39
 
40
+ def loadModel(ckpt_path):
41
+ model = SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large')
42
  if not os.path.isfile(ckpt_path):
43
  print("Downloading model from Hugging Face Hub...")
44
+ ckpt_path = hf_hub_download(repo_id="junseok520/voxsim-models", filename=ckpt_path, local_dir="./")
45
+ model.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
46
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  class Score:
 
51
 
52
  def __init__(
53
  self,
54
+ ckpt_path: str = "voxsim_wavlm_ecapa.model",
55
  device: str = "gpu"):
56
  """
57
  Args:
 
61
  """
62
  print(f"Using device: {device}")
63
  self.device = device
64
+ self.model = loadModel(ckpt_path).to(self.device)
65
  self.model.eval()
66
 
67
  def score(self, inp_wavs: torch.tensor, inp_wav: torch.tensor, ref_wavs: torch.tensor, ref_wav: torch.tensor) -> torch.tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  inp_wavs = inp_wavs.reshape(-1, inp_wavs.shape[-1]).to(self.device)
70
  inp_wav = inp_wav.reshape(-1, inp_wav.shape[-1]).to(self.device)
71
  ref_wavs = ref_wavs.reshape(-1, ref_wavs.shape[-1]).to(self.device)
72
  ref_wav = ref_wav.reshape(-1, ref_wav.shape[-1]).to(self.device)
 
 
 
 
 
 
 
73
 
74
  with torch.no_grad():
75
  input_emb_1 = F.normalize(self.model.forward(inp_wavs), p=2, dim=1).detach()
 
78
  ref_emb_2 = F.normalize(self.model.forward(ref_wav), p=2, dim=1).detach()
79
 
80
  emb_size = input_emb_1.shape[-1]
81
+ input_emb_1 = input_emb_1.reshape(-1, emb_size)
82
+ input_emb_2 = input_emb_2.reshape(-1, emb_size)
83
+ ref_emb_1 = ref_emb_1.reshape(-1, emb_size)
84
+ ref_emb_2 = ref_emb_2.reshape(-1, emb_size)
85
 
86
+ score_1 = torch.mean(torch.matmul(input_emb_1, ref_emb_1.T))
87
+ score_2 = torch.mean(torch.matmul(input_emb_2, ref_emb_2.T))
88
  score = (score_1 + score_2) / 2
89
+ score = score.detach().cpu().item()
90
 
91
  return score
92
 
ssl_ecapa_model.py CHANGED
@@ -4,14 +4,68 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import torchaudio.transforms as trans
 
 
 
7
 
8
- urls = {
9
- 'hubert_large_ll60k': "https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt",
10
- 'xls_r_300m': "https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr2_300m.pt",
11
- 'unispeech_sat': "https://huggingface.co/s3prl/converted_ckpts/resolve/main/unispeech_sat_large.pt",
12
- 'wavlm_base_plus': "https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_base_plus.pt",
13
- 'wavlm_large': "https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt",
14
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  ''' Res2Conv1d + BatchNorm1d + ReLU
@@ -199,7 +253,10 @@ class SSL_ECAPA_TDNN(nn.Module):
199
  self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
200
  melkwargs=melkwargs)
201
  else:
202
- self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
 
 
 
203
 
204
  if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
205
  self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import torchaudio.transforms as trans
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ from s3prl.upstream.interfaces import UpstreamBase
9
+ from s3prl.upstream.wavlm.WavLM import WavLM, WavLMConfig
10
 
11
+
12
+ ''' WavLM UpstreamExpert without loading pretrained checkpoint
13
+ '''
14
+
15
+
16
+ class UpstreamExpert(UpstreamBase):
17
+ def __init__(self, cfg, **kwargs):
18
+ super().__init__(**kwargs)
19
+
20
+ self.cfg = WavLMConfig(torch.load(cfg))
21
+ self.model = WavLM(self.cfg)
22
+
23
+ self.model.feature_grad_mult = 0.0
24
+ self.model.encoder.layerdrop = 0.0
25
+
26
+ if len(self.hooks) == 0:
27
+ module_name = "self.model.encoder.layers"
28
+ for module_id in range(len(eval(module_name))):
29
+ self.add_hook(
30
+ f"{module_name}[{module_id}]",
31
+ lambda input, output: input[0].transpose(0, 1),
32
+ )
33
+ self.add_hook("self.model.encoder", lambda input, output: output[0])
34
+
35
+ self._init_layerdrop = self.model.encoder.layerdrop
36
+
37
+ @property
38
+ def layer_drop(self):
39
+ return self.model.encoder.layerdrop
40
+
41
+ def set_layer_drop(self, layerdrop: float = None):
42
+ if isinstance(layerdrop, float):
43
+ self.model.encoder.layerdrop = layerdrop
44
+ elif layerdrop is None:
45
+ self.model.encoder.layerdrop = self._init_layerdrop
46
+ else:
47
+ raise ValueError("layerdrop can only be float or None")
48
+
49
+ def get_downsample_rates(self, key: str) -> int:
50
+ return 320
51
+
52
+ def forward(self, wavs):
53
+ if self.cfg.normalize:
54
+ wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
55
+
56
+ device = wavs[0].device
57
+ wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
58
+ wav_padding_mask = ~torch.lt(
59
+ torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
60
+ wav_lengths.unsqueeze(1),
61
+ )
62
+ padded_wav = pad_sequence(wavs, batch_first=True)
63
+
64
+ features, feat_padding_mask = self.model.extract_features(
65
+ padded_wav,
66
+ padding_mask=wav_padding_mask,
67
+ mask=False,
68
+ )
69
 
70
 
71
  ''' Res2Conv1d + BatchNorm1d + ReLU
 
253
  self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
254
  melkwargs=melkwargs)
255
  else:
256
+ if feat_type == "wavlm_large":
257
+ self.feature_extract = UpstreamExpert(cfg="wavlm_large_cfg.pt")
258
+ else:
259
+ raise NotImplementedError
260
 
261
  if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
262
  self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
wavlm_large_cfg.pt ADDED
Binary file (1.92 kB). View file