kNN-TTS / knn_tts /vc /knn.py
Karl El Hajal
Add kNN-TTS code and demo interface
a180d8c
raw
history blame contribute delete
1.99 kB
# SPDX-FileCopyrightText: 2023 MediaLab, Department of Electrical & Electronic Engineering, Stellenbosch University
# SPDX-FileContributor: Karl El Hajal
#
# SPDX-License-Identifier: MIT
import os
import torch
def load_target_style_feats(feats_base_path, max_num_files=1000):
feats = []
for filepath in os.listdir(feats_base_path)[:max_num_files]:
if ".pt" in filepath:
filepath = os.path.join(feats_base_path, filepath)
feats.append(torch.load(filepath, weights_only=False))
feats = torch.concat(feats, dim=0).cpu()
return feats
def fast_cosine_dist(source_feats, matching_pool, device):
"""Like torch.cdist, but fixed dim=-1 and for cosine distance."""
source_norms = torch.norm(source_feats, p=2, dim=-1).to(device)
matching_norms = torch.norm(matching_pool, p=2, dim=-1)
dotprod = -(torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0] ** 2) + source_norms[:, None] ** 2 + matching_norms[None] ** 2
dotprod /= 2
dists = 1 - (dotprod / (source_norms[:, None] * matching_norms[None]))
return dists
@torch.inference_mode()
def knn_vc(source_frames, target_style_set, topk=4, weighted_average=False, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device)
target_style_set = target_style_set.to(device)
source_frames = source_frames.to(device)
dists = fast_cosine_dist(source_frames, target_style_set, device=device)
best = dists.topk(k=topk, largest=False, dim=-1)
if weighted_average:
weights = 1 / (best.values + 1e-8) # Adding a small value to avoid division by zero
weights /= weights.sum(dim=-1, keepdim=True) # Normalize weights
selected_frames = (target_style_set[best.indices] * weights[..., None]).sum(dim=1)
else:
selected_frames = target_style_set[best.indices].mean(dim=1)
return selected_frames