File size: 2,284 Bytes
a180d8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
# SPDX-FileContributor: Karl El Hajal
#
# SPDX-License-Identifier: MIT
import torch
import torchaudio
from knn_tts.text_cleaners import clean_input_text
from knn_tts.tts.models import GlowTTS
from knn_tts.vc.knn import knn_vc, load_target_style_feats
from knn_tts.vocoder.models import HiFiGANWavLM
class Synthesizer:
def __init__(
self,
tts_model_base_path,
tts_model_checkpoint,
vocoder_checkpoint_path,
model_name="glowtts",
):
self.model_name = model_name
self.model = GlowTTS(tts_model_base_path, tts_model_checkpoint)
self.vocoder = HiFiGANWavLM(checkpoint_path=vocoder_checkpoint_path, device=self.model.device)
self.target_style_feats_path = None
self.target_style_feats = None
def __call__(
self,
text_input,
target_style_feats_path,
knnvc_topk=4,
weighted_average=False,
interpolation_rate=1.0,
save_path=None,
timesteps=10,
max_target_num_files=1000,
):
with torch.no_grad():
# Text-to-SSL
text_input = clean_input_text(text_input)
tts_feats = self.model(text_input, timesteps) # timesteps are used for GradTTS only
# kNN-VC
if interpolation_rate != 0.0:
if target_style_feats_path != self.target_style_feats_path:
self.target_style_feats_path = target_style_feats_path
self.target_style_feats = load_target_style_feats(target_style_feats_path, max_target_num_files)
selected_feats = knn_vc(
tts_feats,
self.target_style_feats,
topk=knnvc_topk,
weighted_average=weighted_average,
device=self.model.device,
)
converted_feats = interpolation_rate * selected_feats + (1.0 - interpolation_rate) * tts_feats
else:
converted_feats = tts_feats
# Vocoder
wav = self.vocoder(converted_feats.unsqueeze(0)).unsqueeze(0)
if save_path is not None:
torchaudio.save(save_path, wav, 16000)
return wav
|