Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from torch import einsum | |
from einops import rearrange | |
import torch.distributed as dist | |
from utils.commons.hparams import hparams | |
class ClusteringVectorQuantiser(nn.Module): | |
""" | |
Improved version over vector quantiser, with the dynamic initialisation | |
for these unoptimised "dead" points. | |
num_embed: number of codebook entry | |
embed_dim: dimensionality of codebook entry | |
beta: weight for the commitment loss | |
distance: distance for looking up the closest code | |
anchor: anchor sampled methods | |
first_batch: if true, the offline version of our model | |
contras_loss: if true, use the contras_loss to further improve the performance | |
""" | |
def __init__(self, num_embed=1024, embed_dim=512, beta=0.25, distance='l2', | |
anchor='closest', first_batch=False, contras_loss=True): | |
super().__init__() | |
self.num_embed = num_embed | |
self.embed_dim = embed_dim | |
self.beta = beta | |
self.distance = distance | |
self.anchor = anchor | |
self.first_batch = first_batch | |
self.contras_loss = contras_loss | |
self.decay = 0.99 | |
self.init = False | |
self.pool = FeaturePool(self.num_embed, self.embed_dim) | |
self.embedding = nn.Embedding(self.num_embed, self.embed_dim) | |
self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed) | |
self.register_buffer("embed_prob", torch.zeros(self.num_embed)) | |
def forward(self, z, mask=None, temp=None, rescale_logits=False, return_logits=False): | |
if mask is not None: | |
assert mask.shape[:2] == z.shape[:2], (mask.shape, z.shape) | |
assert mask.shape[-1] == 1, (mask.shape,) | |
z = z * mask | |
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" | |
assert rescale_logits == False, "Only for interface compatible with Gumbel" | |
assert return_logits == False, "Only for interface compatible with Gumbel" | |
# reshape z -> (batch, height, width, channel) and flatten | |
# z = rearrange(z, 'b c h w -> b h w c').contiguous() | |
assert z.shape[-1] == self.embed_dim | |
z_flattened = z.view(-1, self.embed_dim) | |
# clculate the distance | |
if self.distance == 'l2': | |
# l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z | |
d = - torch.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \ | |
torch.sum(self.embedding.weight ** 2, dim=1) + \ | |
2 * torch.einsum('bd, dn-> bn', z_flattened.detach(), rearrange(self.embedding.weight, 'n d-> d n')) | |
elif self.distance == 'cos': | |
# cosine distances from z to embeddings e_j | |
normed_z_flattened = F.normalize(z_flattened, dim=1).detach() | |
normed_codebook = F.normalize(self.embedding.weight, dim=1) | |
d = torch.einsum('bd,dn->bn', normed_z_flattened, rearrange(normed_codebook, 'n d -> d n')) | |
# encoding | |
sort_distance, indices = d.sort(dim=1) | |
# look up the closest point for the indices | |
encoding_indices = indices[:,-1] | |
encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device) | |
encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) | |
# quantise and unflatten | |
z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape) | |
# compute loss for embedding | |
loss = self.beta * (z_q.detach() - z) ** 2 + (z_q - z.detach()) ** 2 | |
if mask is not None: | |
loss = (loss * mask).sum() / mask.sum() / self.embed_dim | |
else: | |
loss = loss.mean() | |
# loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2) | |
# preserve gradients | |
z_q = z + (z_q - z).detach() | |
# reshape back to match original input shape | |
# z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() | |
# count | |
# import pdb | |
# pdb.set_trace() | |
avg_probs = torch.mean(encodings, dim=0) | |
# perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) | |
# min_encodings = encodings | |
# online clustered reinitialisation for unoptimized points | |
if self.training: | |
# calculate the average usage of code entries | |
self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay) | |
# running average updates | |
if self.anchor in ['closest', 'random', 'probrandom'] and (not self.init): | |
# closest sampling | |
if self.anchor == 'closest': | |
sort_distance, indices = d.sort(dim=0) | |
random_feat = z_flattened.detach()[indices[-1,:]] | |
# feature pool based random sampling | |
elif self.anchor == 'random': | |
random_feat = self.pool.query(z_flattened.detach()) | |
# probabilitical based random sampling | |
elif self.anchor == 'probrandom': | |
norm_distance = F.softmax(d.t(), dim=1) | |
prob = torch.multinomial(norm_distance, num_samples=1).view(-1) | |
random_feat = z_flattened.detach()[prob] | |
# decay parameter based on the average usage | |
decay = torch.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim) | |
if hparams.get('reduce_cvq_embed') and dist.is_initialized(): | |
# 确保在所有GPU上同步embedding的权重 | |
dist.all_reduce(random_feat.data, op=dist.ReduceOp.SUM) | |
random_feat.data /= dist.get_world_size() | |
self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay | |
if self.first_batch: | |
self.init = True | |
# contrastive loss | |
if self.contras_loss: | |
sort_distance, indices = d.sort(dim=0) | |
dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True) | |
dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:] | |
dis = torch.cat([dis_pos, dis_neg], dim=0).t() / 0.07 | |
contra_loss = F.cross_entropy(dis, torch.zeros((dis.size(0),), dtype=torch.long, device=dis.device)) | |
loss += contra_loss | |
encoding_indices = encoding_indices.reshape(z.shape[:-1]) | |
return z_q, loss, encoding_indices | |
def get_codebook_entry(self, encoding_indices): | |
# # get quantized latent vectors | |
# print(encoding_indices.shape) | |
# encoding_indices = encoding_indices.view(-1) | |
# encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=encoding_indices.device) | |
# print(encodings.shape) | |
# encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) | |
# print(encodings.shape) | |
# # quantise and unflatten | |
# z_q = torch.matmul(encodings, self.embedding.weight).view(encoding_indices.shape[0], -1) | |
z_q = self.embedding(encoding_indices) | |
return z_q | |
class FeaturePool(): | |
""" | |
This class implements a feature buffer that stores previously encoded features | |
This buffer enables us to initialize the codebook using a history of generated features | |
rather than the ones produced by the latest encoders | |
""" | |
def __init__(self, pool_size, dim=64): | |
""" | |
Initialize the FeaturePool class | |
Parameters: | |
pool_size(int) -- the size of featue buffer | |
""" | |
self.pool_size = pool_size | |
if self.pool_size > 0: | |
self.nums_features = 0 | |
self.features = (torch.rand((pool_size, dim)) * 2 - 1)/ pool_size | |
def query(self, features): | |
""" | |
return features from the pool | |
""" | |
self.features = self.features.to(features.device) | |
if self.nums_features < self.pool_size: | |
if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook | |
random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),)) | |
self.features = features[random_feat_id] | |
self.nums_features = self.pool_size | |
else: | |
# if the mini-batch is not large nuough, just store it for the next update | |
num = self.nums_features + features.size(0) | |
self.features[self.nums_features:num] = features | |
self.nums_features = num | |
else: | |
if features.size(0) > int(self.pool_size): | |
random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),)) | |
self.features = features[random_feat_id] | |
else: | |
random_id = torch.randperm(self.pool_size) | |
self.features[random_id[:features.size(0)]] = features | |
return self.features |