Spaces:
Sleeping
Sleeping
junseok
commited on
Commit
·
08cc398
1
Parent(s):
ce904ba
new commit
Browse files- app.py +13 -28
- predict.py +37 -69
- score.py +44 -58
- ssl_ecapa_model.py +65 -8
- wavlm_large_cfg.pt +0 -0
app.py
CHANGED
@@ -1,44 +1,29 @@
|
|
1 |
-
|
2 |
-
from predict import loadWav
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
|
|
|
|
|
|
5 |
import gradio as gr
|
6 |
-
import time
|
7 |
|
8 |
-
model =
|
9 |
model.eval()
|
10 |
|
11 |
def calc_voxsim(inp_path, ref_path):
|
12 |
-
|
13 |
-
|
14 |
-
ref_wavs, ref_wav = loadWav(ref_path)
|
15 |
-
print("loadWav time: ", time.time() - start)
|
16 |
-
|
17 |
-
inp_wavs = torch.FloatTensor(inp_wavs)
|
18 |
-
inp_wav = torch.FloatTensor(inp_wav)
|
19 |
-
ref_wavs = torch.FloatTensor(ref_wavs)
|
20 |
-
ref_wav = torch.FloatTensor(ref_wav)
|
21 |
-
print("torch.FloatTensor time: ", time.time() - start)
|
22 |
|
23 |
with torch.no_grad():
|
24 |
-
|
25 |
-
|
26 |
-
input_emb_2 = F.normalize(model.forward(inp_wav), p=2, dim=1)
|
27 |
-
print("input_emb_2 time: ", time.time() - start)
|
28 |
-
ref_emb_1 = F.normalize(model.forward(ref_wavs), p=2, dim=1)
|
29 |
-
print("ref_emb_1 time: ", time.time() - start)
|
30 |
-
ref_emb_2 = F.normalize(model.forward(ref_wav), p=2, dim=1)
|
31 |
-
print("ref_emb_2 time: ", time.time() - start)
|
32 |
|
33 |
-
|
34 |
-
score_2 = torch.mean(torch.matmul(input_emb_2, ref_emb_2.T))
|
35 |
-
score = (score_1 + score_2) / 2
|
36 |
-
print("score time: ", time.time() - start)
|
37 |
return score.detach().cpu().numpy()
|
38 |
|
39 |
description = """
|
40 |
Voice similarity demo using wavlm-ecapa model, which is trained on Voxsim dataset.
|
41 |
This demo only accepts .wav format. Best at 16 kHz sampling rate.
|
|
|
42 |
|
43 |
Paper is available [here](https://arxiv.org/abs/2407.18505)
|
44 |
"""
|
@@ -46,8 +31,8 @@ Paper is available [here](https://arxiv.org/abs/2407.18505)
|
|
46 |
iface = gr.Interface(
|
47 |
fn=calc_voxsim,
|
48 |
inputs=(
|
49 |
-
gr.Audio(label="Input Audio"),
|
50 |
-
gr.Audio(label="Reference Audio")
|
51 |
),
|
52 |
outputs="text",
|
53 |
title="voice similarity with VoxSim",
|
|
|
1 |
+
import os
|
|
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
+
from ssl_ecapa_model import SSL_ECAPA_TDNN
|
5 |
+
from score import loadModel
|
6 |
+
from predict import loadWav
|
7 |
import gradio as gr
|
|
|
8 |
|
9 |
+
model = loadModel('voxsim_wavlm_ecapa.model')
|
10 |
model.eval()
|
11 |
|
12 |
def calc_voxsim(inp_path, ref_path):
|
13 |
+
inp_wav = loadWav(inp_path, max_frames=0)
|
14 |
+
ref_wav = loadWav(ref_path, max_frames=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
with torch.no_grad():
|
17 |
+
input_emb = F.normalize(model.forward(inp_wav), p=2, dim=1)
|
18 |
+
ref_emb = F.normalize(model.forward(ref_wav), p=2, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
score = torch.matmul(input_emb, ref_emb.T)
|
|
|
|
|
|
|
21 |
return score.detach().cpu().numpy()
|
22 |
|
23 |
description = """
|
24 |
Voice similarity demo using wavlm-ecapa model, which is trained on Voxsim dataset.
|
25 |
This demo only accepts .wav format. Best at 16 kHz sampling rate.
|
26 |
+
The inference process of this Spaces demo is suboptimal due to the limitations of a basic CPU. To obtain an accurate score, refer to the "[voxsim_trainer](https://github.com/kaistmm/voxsim_trainer)" repository and run the code via the CLI.
|
27 |
|
28 |
Paper is available [here](https://arxiv.org/abs/2407.18505)
|
29 |
"""
|
|
|
31 |
iface = gr.Interface(
|
32 |
fn=calc_voxsim,
|
33 |
inputs=(
|
34 |
+
gr.Audio(label="Input Audio", type='filepath'),
|
35 |
+
gr.Audio(label="Reference Audio", type='filepath')
|
36 |
),
|
37 |
outputs="text",
|
38 |
title="voice similarity with VoxSim",
|
predict.py
CHANGED
@@ -2,10 +2,9 @@ import argparse
|
|
2 |
import pathlib
|
3 |
import tqdm
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
-
import
|
6 |
-
import numpy
|
7 |
-
from score import Score
|
8 |
import torch
|
|
|
9 |
|
10 |
import warnings
|
11 |
warnings.filterwarnings("ignore")
|
@@ -13,93 +12,61 @@ warnings.filterwarnings("ignore")
|
|
13 |
|
14 |
def get_arg():
|
15 |
parser = argparse.ArgumentParser()
|
16 |
-
parser.add_argument("--
|
17 |
-
parser.add_argument("--
|
18 |
-
parser.add_argument("--
|
19 |
-
parser.add_argument("--
|
20 |
-
parser.add_argument("--
|
21 |
-
parser.add_argument("--
|
22 |
-
parser.add_argument("--
|
23 |
-
parser.add_argument("--
|
24 |
-
parser.add_argument("--num_workers", required=False, default=0, type=int)
|
25 |
return parser.parse_args()
|
26 |
|
27 |
|
28 |
-
def loadWav(filename, max_frames: int = 400):
|
29 |
-
|
30 |
-
# Maximum audio length
|
31 |
-
max_audio = max_frames * 160 + 240
|
32 |
-
|
33 |
-
# Read wav file and convert to torch tensor
|
34 |
-
if type(filename) == tuple:
|
35 |
-
sr, audio = filename
|
36 |
-
audio = librosa.util.normalize(audio)
|
37 |
-
print(numpy.linalg.norm(audio))
|
38 |
-
else:
|
39 |
-
audio, sr = librosa.load(filename, sr=16000)
|
40 |
-
audio_org = audio.copy()
|
41 |
-
|
42 |
-
audiosize = audio.shape[0]
|
43 |
-
|
44 |
-
if audiosize <= max_audio:
|
45 |
-
shortage = max_audio - audiosize + 1
|
46 |
-
audio = numpy.pad(audio, (0, shortage), 'wrap')
|
47 |
-
audiosize = audio.shape[0]
|
48 |
-
|
49 |
-
startframe = numpy.linspace(0,audiosize-max_audio,num=10)
|
50 |
-
|
51 |
-
feats = []
|
52 |
-
for asf in startframe:
|
53 |
-
feats.append(audio[int(asf):int(asf)+max_audio])
|
54 |
-
|
55 |
-
feat = numpy.stack(feats,axis=0).astype(numpy.float32)
|
56 |
-
|
57 |
-
return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32))
|
58 |
-
|
59 |
-
|
60 |
class AudioDataset(Dataset):
|
61 |
def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400):
|
62 |
-
self.
|
63 |
-
self.
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
self.inp_wavlist.sort()
|
66 |
-
|
67 |
-
_, self.sr = librosa.load(self.inp_wavlist[0], sr=None)
|
68 |
self.max_audio = max_frames * 160 + 240
|
69 |
|
70 |
def __len__(self):
|
71 |
return len(self.inp_wavlist)
|
72 |
|
73 |
def __getitem__(self, idx):
|
74 |
-
inp_wavs, inp_wav = loadWav(self.inp_wavlist[idx])
|
75 |
-
ref_wavs, ref_wav = loadWav(self.
|
76 |
return inp_wavs, inp_wav, ref_wavs, ref_wav
|
77 |
|
78 |
def main():
|
79 |
args = get_arg()
|
80 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
81 |
if args.mode == "predict_file":
|
82 |
-
assert args.inp_path is not None
|
83 |
-
assert args.ref_path is not None
|
84 |
-
assert args.inp_dir is None
|
85 |
-
assert args.ref_dir is None
|
86 |
assert args.inp_path.exists()
|
87 |
-
assert args.inp_path.is_file()
|
88 |
assert args.ref_path.exists()
|
|
|
89 |
assert args.ref_path.is_file()
|
90 |
inp_wavs, inp_wav = loadWav(args.inp_path)
|
91 |
ref_wavs, ref_wav = loadWav(args.ref_path)
|
92 |
scorer = Score(ckpt_path=args.ckpt_path, device=device)
|
93 |
score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
|
94 |
-
print("
|
95 |
with open(args.out_path, "w") as fw:
|
96 |
-
fw.write(str(score
|
97 |
else:
|
98 |
assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
|
99 |
assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir."
|
100 |
-
assert args.bs is not None, "bs is required when mode is predict_dir."
|
101 |
-
assert args.inp_path is None, "inp_path should be None"
|
102 |
-
assert args.ref_path is None, "ref_path should be None"
|
103 |
assert args.inp_dir.exists()
|
104 |
assert args.ref_dir.exists()
|
105 |
assert args.inp_dir.is_dir()
|
@@ -107,17 +74,18 @@ def main():
|
|
107 |
dataset = AudioDataset(args.inp_dir, args.ref_dir)
|
108 |
loader = DataLoader(
|
109 |
dataset,
|
110 |
-
batch_size=
|
111 |
shuffle=False,
|
112 |
num_workers=args.num_workers)
|
113 |
scorer = Score(ckpt_path=args.ckpt_path, device=device)
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
121 |
print("save to ", args.out_path)
|
122 |
|
123 |
if __name__ == "__main__":
|
|
|
2 |
import pathlib
|
3 |
import tqdm
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from score import loadWav, Score
|
|
|
|
|
6 |
import torch
|
7 |
+
import os
|
8 |
|
9 |
import warnings
|
10 |
warnings.filterwarnings("ignore")
|
|
|
12 |
|
13 |
def get_arg():
|
14 |
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str, help="predict mode")
|
16 |
+
parser.add_argument("--ckpt_path", required=False, default="voxsim_wavlm_ecapa.model", type=pathlib.Path, help="path to the model checkpoint")
|
17 |
+
parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path, help="input directory when predict_dir mode")
|
18 |
+
parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path, help="reference directory when predict_dir mode")
|
19 |
+
parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path, help="input file when predict_file mode")
|
20 |
+
parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path, help="reference file when predict_file mode")
|
21 |
+
parser.add_argument("--out_path", required=True, type=pathlib.Path, help="output path")
|
22 |
+
parser.add_argument("--num_workers", required=False, default=4, type=int, help="number of workers for dataloader")
|
|
|
23 |
return parser.parse_args()
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
class AudioDataset(Dataset):
|
27 |
def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400):
|
28 |
+
self.inp_dir_path = inp_dir_path
|
29 |
+
self.ref_dir_path = ref_dir_path
|
30 |
+
self.inp_wavlist = [file for file in os.listdir(inp_dir_path) if file.endswith(".wav")]
|
31 |
+
inp_wavset = set(self.inp_wavlist)
|
32 |
+
ref_wavset = set([file for file in os.listdir(ref_dir_path) if file.endswith(".wav")])
|
33 |
+
diff = inp_wavset - ref_wavset
|
34 |
+
if diff:
|
35 |
+
diff = list(diff)
|
36 |
+
diff.sort()
|
37 |
+
raise ValueError(f"Files {diff} are in inp_dir but not in ref_dir.")
|
38 |
self.inp_wavlist.sort()
|
39 |
+
|
|
|
40 |
self.max_audio = max_frames * 160 + 240
|
41 |
|
42 |
def __len__(self):
|
43 |
return len(self.inp_wavlist)
|
44 |
|
45 |
def __getitem__(self, idx):
|
46 |
+
inp_wavs, inp_wav = loadWav(os.path.join(self.inp_dir_path, self.inp_wavlist[idx]))
|
47 |
+
ref_wavs, ref_wav = loadWav(os.path.join(self.ref_dir_path, self.inp_wavlist[idx]))
|
48 |
return inp_wavs, inp_wav, ref_wavs, ref_wav
|
49 |
|
50 |
def main():
|
51 |
args = get_arg()
|
52 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
53 |
if args.mode == "predict_file":
|
54 |
+
assert args.inp_path is not None, "inp_path is required when mode is predict_file."
|
55 |
+
assert args.ref_path is not None, "ref_path is required when mode is predict_file."
|
|
|
|
|
56 |
assert args.inp_path.exists()
|
|
|
57 |
assert args.ref_path.exists()
|
58 |
+
assert args.inp_path.is_file()
|
59 |
assert args.ref_path.is_file()
|
60 |
inp_wavs, inp_wav = loadWav(args.inp_path)
|
61 |
ref_wavs, ref_wav = loadWav(args.ref_path)
|
62 |
scorer = Score(ckpt_path=args.ckpt_path, device=device)
|
63 |
score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
|
64 |
+
print("VoxSIM score: ", score)
|
65 |
with open(args.out_path, "w") as fw:
|
66 |
+
fw.write(str(score))
|
67 |
else:
|
68 |
assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
|
69 |
assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir."
|
|
|
|
|
|
|
70 |
assert args.inp_dir.exists()
|
71 |
assert args.ref_dir.exists()
|
72 |
assert args.inp_dir.is_dir()
|
|
|
74 |
dataset = AudioDataset(args.inp_dir, args.ref_dir)
|
75 |
loader = DataLoader(
|
76 |
dataset,
|
77 |
+
batch_size=1,
|
78 |
shuffle=False,
|
79 |
num_workers=args.num_workers)
|
80 |
scorer = Score(ckpt_path=args.ckpt_path, device=device)
|
81 |
+
avg_score = []
|
82 |
+
with open(args.out_path, 'w') as fw:
|
83 |
+
for batch in tqdm.tqdm(loader):
|
84 |
+
inp_wavs, inp_wav, ref_wavs, ref_wav = batch
|
85 |
+
score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
|
86 |
+
avg_score.append(score)
|
87 |
+
fw.write(str(score) + "\n")
|
88 |
+
print("Average VoxSIM score: ", sum(avg_score)/len(avg_score))
|
89 |
print("save to ", args.out_path)
|
90 |
|
91 |
if __name__ == "__main__":
|
score.py
CHANGED
@@ -1,35 +1,49 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
from ssl_ecapa_model import SSL_ECAPA_TDNN
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
|
7 |
|
8 |
-
def
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
-
def
|
15 |
-
|
16 |
if not os.path.isfile(ckpt_path):
|
17 |
print("Downloading model from Hugging Face Hub...")
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
for name, param in loaded_state.items():
|
23 |
-
if name.startswith('__S__.'):
|
24 |
-
if name[6:] in model_state:
|
25 |
-
model_state[name[6:]].copy_(param)
|
26 |
-
else:
|
27 |
-
print("{} is not in the model.".format(name[6:]))
|
28 |
-
else:
|
29 |
-
if name in model_state:
|
30 |
-
model_state[name].copy_(param)
|
31 |
-
else:
|
32 |
-
print("{} is not in the model.".format(name))
|
33 |
|
34 |
|
35 |
class Score:
|
@@ -37,7 +51,7 @@ class Score:
|
|
37 |
|
38 |
def __init__(
|
39 |
self,
|
40 |
-
ckpt_path: str = "
|
41 |
device: str = "gpu"):
|
42 |
"""
|
43 |
Args:
|
@@ -47,43 +61,15 @@ class Score:
|
|
47 |
"""
|
48 |
print(f"Using device: {device}")
|
49 |
self.device = device
|
50 |
-
self.model =
|
51 |
self.model.eval()
|
52 |
|
53 |
def score(self, inp_wavs: torch.tensor, inp_wav: torch.tensor, ref_wavs: torch.tensor, ref_wav: torch.tensor) -> torch.tensor:
|
54 |
-
"""
|
55 |
-
Args:
|
56 |
-
wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
|
57 |
-
the model processes the input as a single audio clip. The model
|
58 |
-
performs batch processing when len(wavs) == 3.
|
59 |
-
"""
|
60 |
-
# if len(wavs.shape) == 1:
|
61 |
-
# out_wavs = wavs.unsqueeze(0).unsqueeze(0)
|
62 |
-
# elif len(wavs.shape) == 2:
|
63 |
-
# out_wavs = wavs.unsqueeze(0)
|
64 |
-
# elif len(wavs.shape) == 3:
|
65 |
-
# out_wavs = wavs
|
66 |
-
# else:
|
67 |
-
# raise ValueError('Dimension of input tensor needs to be <= 3.')
|
68 |
-
|
69 |
-
if len(inp_wavs.shape) == 2:
|
70 |
-
bs = 1
|
71 |
-
elif len(inp_wavs.shape) == 3:
|
72 |
-
bs = inp_wavs.shape[0]
|
73 |
-
else:
|
74 |
-
raise ValueError('Dimension of input tensor needs to be <= 3.')
|
75 |
|
76 |
inp_wavs = inp_wavs.reshape(-1, inp_wavs.shape[-1]).to(self.device)
|
77 |
inp_wav = inp_wav.reshape(-1, inp_wav.shape[-1]).to(self.device)
|
78 |
ref_wavs = ref_wavs.reshape(-1, ref_wavs.shape[-1]).to(self.device)
|
79 |
ref_wav = ref_wav.reshape(-1, ref_wav.shape[-1]).to(self.device)
|
80 |
-
|
81 |
-
# assert inp_wavs.shape[1] == 10
|
82 |
-
# assert ref_wavs.shape[1] == 10
|
83 |
-
# assert inp_wav.shape[1] == 1
|
84 |
-
# assert ref_wav.shape[1] == 1
|
85 |
-
|
86 |
-
# import pdb; pdb.set_trace()
|
87 |
|
88 |
with torch.no_grad():
|
89 |
input_emb_1 = F.normalize(self.model.forward(inp_wavs), p=2, dim=1).detach()
|
@@ -92,15 +78,15 @@ class Score:
|
|
92 |
ref_emb_2 = F.normalize(self.model.forward(ref_wav), p=2, dim=1).detach()
|
93 |
|
94 |
emb_size = input_emb_1.shape[-1]
|
95 |
-
input_emb_1 = input_emb_1.reshape(
|
96 |
-
input_emb_2 = input_emb_2.reshape(
|
97 |
-
ref_emb_1 = ref_emb_1.reshape(
|
98 |
-
ref_emb_2 = ref_emb_2.reshape(
|
99 |
|
100 |
-
score_1 = torch.mean(torch.
|
101 |
-
score_2 = torch.mean(torch.
|
102 |
score = (score_1 + score_2) / 2
|
103 |
-
score = score.detach().cpu().
|
104 |
|
105 |
return score
|
106 |
|
|
|
1 |
import os
|
2 |
+
import numpy
|
3 |
+
import librosa
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from ssl_ecapa_model import SSL_ECAPA_TDNN
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
|
9 |
|
10 |
+
def loadWav(filename, max_frames: int = 400, num_eval: int = 10):
|
11 |
+
|
12 |
+
# Maximum audio length
|
13 |
+
max_audio = max_frames * 160 + 240
|
14 |
+
|
15 |
+
# Read wav file and convert to torch tensor
|
16 |
+
audio, sr = librosa.load(filename, sr=16000)
|
17 |
+
audio_org = audio.copy()
|
18 |
+
|
19 |
+
audiosize = audio.shape[0]
|
20 |
+
|
21 |
+
if audiosize <= max_audio:
|
22 |
+
shortage = max_audio - audiosize + 1
|
23 |
+
audio = numpy.pad(audio, (0, shortage), 'wrap')
|
24 |
+
audiosize = audio.shape[0]
|
25 |
+
|
26 |
+
startframe = numpy.linspace(0,audiosize-max_audio, num=num_eval)
|
27 |
+
|
28 |
+
feats = []
|
29 |
+
if max_frames == 0:
|
30 |
+
feats.append(audio)
|
31 |
+
feat = numpy.stack(feats,axis=0).astype(numpy.float32)
|
32 |
+
return torch.FloatTensor(feat)
|
33 |
+
else:
|
34 |
+
for asf in startframe:
|
35 |
+
feats.append(audio[int(asf):int(asf)+max_audio])
|
36 |
+
feat = numpy.stack(feats,axis=0).astype(numpy.float32)
|
37 |
+
return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32))
|
38 |
|
39 |
|
40 |
+
def loadModel(ckpt_path):
|
41 |
+
model = SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large')
|
42 |
if not os.path.isfile(ckpt_path):
|
43 |
print("Downloading model from Hugging Face Hub...")
|
44 |
+
ckpt_path = hf_hub_download(repo_id="junseok520/voxsim-models", filename=ckpt_path, local_dir="./")
|
45 |
+
model.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
|
46 |
+
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
class Score:
|
|
|
51 |
|
52 |
def __init__(
|
53 |
self,
|
54 |
+
ckpt_path: str = "voxsim_wavlm_ecapa.model",
|
55 |
device: str = "gpu"):
|
56 |
"""
|
57 |
Args:
|
|
|
61 |
"""
|
62 |
print(f"Using device: {device}")
|
63 |
self.device = device
|
64 |
+
self.model = loadModel(ckpt_path).to(self.device)
|
65 |
self.model.eval()
|
66 |
|
67 |
def score(self, inp_wavs: torch.tensor, inp_wav: torch.tensor, ref_wavs: torch.tensor, ref_wav: torch.tensor) -> torch.tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
inp_wavs = inp_wavs.reshape(-1, inp_wavs.shape[-1]).to(self.device)
|
70 |
inp_wav = inp_wav.reshape(-1, inp_wav.shape[-1]).to(self.device)
|
71 |
ref_wavs = ref_wavs.reshape(-1, ref_wavs.shape[-1]).to(self.device)
|
72 |
ref_wav = ref_wav.reshape(-1, ref_wav.shape[-1]).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
with torch.no_grad():
|
75 |
input_emb_1 = F.normalize(self.model.forward(inp_wavs), p=2, dim=1).detach()
|
|
|
78 |
ref_emb_2 = F.normalize(self.model.forward(ref_wav), p=2, dim=1).detach()
|
79 |
|
80 |
emb_size = input_emb_1.shape[-1]
|
81 |
+
input_emb_1 = input_emb_1.reshape(-1, emb_size)
|
82 |
+
input_emb_2 = input_emb_2.reshape(-1, emb_size)
|
83 |
+
ref_emb_1 = ref_emb_1.reshape(-1, emb_size)
|
84 |
+
ref_emb_2 = ref_emb_2.reshape(-1, emb_size)
|
85 |
|
86 |
+
score_1 = torch.mean(torch.matmul(input_emb_1, ref_emb_1.T))
|
87 |
+
score_2 = torch.mean(torch.matmul(input_emb_2, ref_emb_2.T))
|
88 |
score = (score_1 + score_2) / 2
|
89 |
+
score = score.detach().cpu().item()
|
90 |
|
91 |
return score
|
92 |
|
ssl_ecapa_model.py
CHANGED
@@ -4,14 +4,68 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
import torchaudio.transforms as trans
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
''' Res2Conv1d + BatchNorm1d + ReLU
|
@@ -199,7 +253,10 @@ class SSL_ECAPA_TDNN(nn.Module):
|
|
199 |
self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
|
200 |
melkwargs=melkwargs)
|
201 |
else:
|
202 |
-
|
|
|
|
|
|
|
203 |
|
204 |
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
|
205 |
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
import torchaudio.transforms as trans
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
from s3prl.upstream.interfaces import UpstreamBase
|
9 |
+
from s3prl.upstream.wavlm.WavLM import WavLM, WavLMConfig
|
10 |
|
11 |
+
|
12 |
+
''' WavLM UpstreamExpert without loading pretrained checkpoint
|
13 |
+
'''
|
14 |
+
|
15 |
+
|
16 |
+
class UpstreamExpert(UpstreamBase):
|
17 |
+
def __init__(self, cfg, **kwargs):
|
18 |
+
super().__init__(**kwargs)
|
19 |
+
|
20 |
+
self.cfg = WavLMConfig(torch.load(cfg))
|
21 |
+
self.model = WavLM(self.cfg)
|
22 |
+
|
23 |
+
self.model.feature_grad_mult = 0.0
|
24 |
+
self.model.encoder.layerdrop = 0.0
|
25 |
+
|
26 |
+
if len(self.hooks) == 0:
|
27 |
+
module_name = "self.model.encoder.layers"
|
28 |
+
for module_id in range(len(eval(module_name))):
|
29 |
+
self.add_hook(
|
30 |
+
f"{module_name}[{module_id}]",
|
31 |
+
lambda input, output: input[0].transpose(0, 1),
|
32 |
+
)
|
33 |
+
self.add_hook("self.model.encoder", lambda input, output: output[0])
|
34 |
+
|
35 |
+
self._init_layerdrop = self.model.encoder.layerdrop
|
36 |
+
|
37 |
+
@property
|
38 |
+
def layer_drop(self):
|
39 |
+
return self.model.encoder.layerdrop
|
40 |
+
|
41 |
+
def set_layer_drop(self, layerdrop: float = None):
|
42 |
+
if isinstance(layerdrop, float):
|
43 |
+
self.model.encoder.layerdrop = layerdrop
|
44 |
+
elif layerdrop is None:
|
45 |
+
self.model.encoder.layerdrop = self._init_layerdrop
|
46 |
+
else:
|
47 |
+
raise ValueError("layerdrop can only be float or None")
|
48 |
+
|
49 |
+
def get_downsample_rates(self, key: str) -> int:
|
50 |
+
return 320
|
51 |
+
|
52 |
+
def forward(self, wavs):
|
53 |
+
if self.cfg.normalize:
|
54 |
+
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
|
55 |
+
|
56 |
+
device = wavs[0].device
|
57 |
+
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
|
58 |
+
wav_padding_mask = ~torch.lt(
|
59 |
+
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
|
60 |
+
wav_lengths.unsqueeze(1),
|
61 |
+
)
|
62 |
+
padded_wav = pad_sequence(wavs, batch_first=True)
|
63 |
+
|
64 |
+
features, feat_padding_mask = self.model.extract_features(
|
65 |
+
padded_wav,
|
66 |
+
padding_mask=wav_padding_mask,
|
67 |
+
mask=False,
|
68 |
+
)
|
69 |
|
70 |
|
71 |
''' Res2Conv1d + BatchNorm1d + ReLU
|
|
|
253 |
self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
|
254 |
melkwargs=melkwargs)
|
255 |
else:
|
256 |
+
if feat_type == "wavlm_large":
|
257 |
+
self.feature_extract = UpstreamExpert(cfg="wavlm_large_cfg.pt")
|
258 |
+
else:
|
259 |
+
raise NotImplementedError
|
260 |
|
261 |
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
|
262 |
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
wavlm_large_cfg.pt
ADDED
Binary file (1.92 kB). View file
|
|