|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import torch |
|
|
|
|
|
def load_target_style_feats(feats_base_path, max_num_files=1000): |
|
feats = [] |
|
for filepath in os.listdir(feats_base_path)[:max_num_files]: |
|
if ".pt" in filepath: |
|
filepath = os.path.join(feats_base_path, filepath) |
|
feats.append(torch.load(filepath, weights_only=False)) |
|
feats = torch.concat(feats, dim=0).cpu() |
|
return feats |
|
|
|
|
|
def fast_cosine_dist(source_feats, matching_pool, device): |
|
"""Like torch.cdist, but fixed dim=-1 and for cosine distance.""" |
|
source_norms = torch.norm(source_feats, p=2, dim=-1).to(device) |
|
matching_norms = torch.norm(matching_pool, p=2, dim=-1) |
|
dotprod = -(torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0] ** 2) + source_norms[:, None] ** 2 + matching_norms[None] ** 2 |
|
dotprod /= 2 |
|
|
|
dists = 1 - (dotprod / (source_norms[:, None] * matching_norms[None])) |
|
return dists |
|
|
|
|
|
@torch.inference_mode() |
|
def knn_vc(source_frames, target_style_set, topk=4, weighted_average=False, device=None): |
|
if device is None: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
else: |
|
device = torch.device(device) |
|
target_style_set = target_style_set.to(device) |
|
source_frames = source_frames.to(device) |
|
|
|
dists = fast_cosine_dist(source_frames, target_style_set, device=device) |
|
best = dists.topk(k=topk, largest=False, dim=-1) |
|
|
|
if weighted_average: |
|
weights = 1 / (best.values + 1e-8) |
|
weights /= weights.sum(dim=-1, keepdim=True) |
|
selected_frames = (target_style_set[best.indices] * weights[..., None]).sum(dim=1) |
|
else: |
|
selected_frames = target_style_set[best.indices].mean(dim=1) |
|
|
|
return selected_frames |
|
|
|
|