Spaces:
Build error
Build error
""" | |
Lookup Free Quantization | |
Proposed in https://arxiv.org/abs/2310.05737 | |
basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss | |
https://arxiv.org/abs/2309.15505 | |
""" | |
import torch | |
from einops import rearrange | |
from torch.nn import Module | |
# entropy | |
def binary_entropy(prob): | |
return -prob * log(prob) - (1 - prob) * log(1 - prob) | |
# tensor helpers | |
def log(t, eps=1e-20): | |
return t.clamp(min=eps).log() | |
# convert to bit representations and back | |
def decimal_to_bits(x: torch.LongTensor, bits: int) -> torch.FloatTensor: | |
# [b, ...] {0, 1, ..., max - 1} -> [b, ..., d] {-1, 1} | |
mask = 2 ** torch.arange(bits).to(x) # [d] | |
bits = ((x.unsqueeze(-1) & mask) != 0).float() # [b, n, d] {0, 1} | |
return bits * 2 - 1 # {0, 1} -> {-1, 1} | |
def bits_to_decimal(x: torch.FloatTensor) -> torch.LongTensor: | |
# [b, ..., d] {-1, 1} -> [b, ...] {0, 1, ..., max - 1} | |
x = (x > 0).long() # {-1, 1} -> {0, 1}, [b, ..., d] | |
mask = 2 ** torch.arange(x.size(-1)).to(x) # [d] | |
dec = (x * mask).sum(-1) # [b, ...] | |
return dec | |
# class | |
class LFQY(Module): | |
def __init__(self, dim, entropy_loss_weight=0.1, diversity_gamma=1.0): | |
super().__init__() | |
self.dim = dim | |
self.diversity_gamma = diversity_gamma | |
self.entropy_loss_weight = entropy_loss_weight | |
def indices_to_codes(self, indices): | |
codes = decimal_to_bits(indices, self.dim) | |
# codes = rearrange(codes, 'b ... d -> b d ...') | |
return codes | |
def forward(self, x, mask=None, inv_temperature=1.): | |
""" | |
einstein notation | |
b - batch | |
n - sequence (or flattened spatial dimensions) | |
d - feature dimension, which is also log2(codebook size) | |
""" | |
# x = rearrange(x, 'b d ... -> b ... d') | |
assert x.shape[-1] == self.dim | |
z = torch.tanh(x / inv_temperature) # (-1, 1) | |
# quantize by eq 3. | |
quantized = torch.sign(x) # {-1, 1} | |
z = z + (quantized - z).detach() | |
# calculate indices | |
indices = bits_to_decimal(z) | |
# entropy aux loss | |
if self.training: | |
prob = torch.sigmoid(x / inv_temperature) # [b, ..., d] | |
bit_entropy = binary_entropy(prob).sum(-1).mean() | |
# E[H(q)] = avg(sum(H(q_i))) | |
avg_prob = prob.flatten(0, -2).mean(0) # [b, ..., d] -> [n, d] -> [d] | |
codebook_entropy = binary_entropy(avg_prob).sum() | |
# H(E[q]) = sum(H(avg(q_i))) | |
""" | |
1. entropy will be nudged to be low for each bit, | |
so each scalar commits to one latent binary bit or the other. | |
2. codebook entropy will be nudged to be high, | |
to encourage all codes to be uniformly used. | |
""" | |
entropy_aux_loss = bit_entropy - self.diversity_gamma * codebook_entropy | |
else: | |
# if not training, just return dummy 0 | |
entropy_aux_loss = torch.zeros(1).to(z) | |
entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight | |
# reconstitute image or video dimensions | |
# z = rearrange(z, 'b ... d -> b d ...') | |
# bits to decimal for the codebook indices | |
return z, entropy_aux_loss, indices | |
def get_codebook_entry(self, encoding_indices): | |
return self.indices_to_codes(encoding_indices) | |