kNN-TTS / knn_tts /ssl /models.py
Karl El Hajal
Add kNN-TTS code and demo interface
a180d8c
raw
history blame contribute delete
1.92 kB
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
# SPDX-FileContributor: Karl El Hajal
#
# SPDX-License-Identifier: MIT
import torch
import torchaudio
from knn_tts.ssl.WavLM.WavLM import WavLMConfig, WavLMModel
class WavLM:
def __init__(self, device=None):
"""
Load the WavLM-Large model (see: https://github.com/microsoft/unilm/tree/master/wavlm).
Torch hub checkpoint from https://github.com/bshall/knn-vc
"""
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
checkpoint_url = "https://github.com/bshall/knn-vc/releases/download/v0.1/WavLM-Large.pt"
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=self.device, progress=True)
cfg = WavLMConfig(checkpoint["cfg"])
self.model = WavLMModel(cfg)
self.model.load_state_dict(checkpoint["model"])
self.model = self.model.to(self.device)
self.model.eval()
self.model.sample_rate = 16000
self.SPEAKER_INFORMATION_LAYER = 6
num_parameters = sum(p.numel() for p in self.model.parameters())
print(f"WavLM-Large loaded with {num_parameters:,d} parameters.")
@torch.no_grad()
def extract_framewise_features(self, wav_path, output_layer=None):
audio, orig_sr = torchaudio.load(wav_path)
audio = audio.to(self.device)
if orig_sr != self.model.sample_rate:
audio = torchaudio.functional.resample(audio, orig_freq=orig_sr, new_freq=self.model.sample_rate)
if output_layer is None:
output_layer = self.SPEAKER_INFORMATION_LAYER
embeddings = self.model.extract_features(audio, output_layer=output_layer, ret_layer_results=False)[0]
embeddings = embeddings.squeeze(0)
return embeddings # [frame_len, 1024]