from typing import Mapping, Text, Tuple import torch import torch.nn.functional as F from .lpips import LPIPS from .perceptual_loss import PerceptualLoss from . import gan_utils def create_perception_loss( perception_loss: str, compute_on_logits: bool = True ) -> torch.nn.Module: """Creates the perception loss. Args: perception_loss -> str: The name of the perception loss. compute_on_logits -> bool: Whether to compute the loss on logits or on multiple features. Returns: perception_loss -> torch.nn.Module: The perception loss. """ if perception_loss == "lpips": return LPIPS().eval() elif perception_loss in ("resnet50", "convnext_s"): return PerceptualLoss( model_name=perception_loss, compute_perceptual_loss_on_logits=compute_on_logits, ).eval() else: raise ValueError(f"Perception loss {perception_loss} is not supported.") class VQGANLoss(torch.nn.Module): def __init__( self, discriminator_config, loss_config, ): """Initializes the VQGAN loss. Args: discriminator_config: The configuration of the discriminator. loss_config: The configuration of the loss. """ super().__init__() assert loss_config.discriminator_loss in ("hinge", "vanilla", "non-saturating") assert loss_config.reconstruction_loss in ("l2", "l1") assert loss_config.discriminator_gradient_penalty in ("none", "adopt_weight") self.discriminator = gan_utils.create_discriminator(discriminator_config) self.reconstruction_loss = loss_config.reconstruction_loss self.reconstruction_weight = loss_config.get("reconstruction_weight", 1.0) self.quantizer_weight = loss_config.quantizer_weight self.perceptual_loss = create_perception_loss( loss_config.perceptual_loss, loss_config.get("perceptual_loss_on_logits", True), ) self.perceptual_weight = loss_config.perceptual_weight self.lecam_regularization_weight = loss_config.lecam_regularization_weight self.ema_decay = loss_config.get("ema_decay", 0.999) self.entropy_annealing_steps = loss_config.get("entropy_annealing_steps", 2000) self.entropy_annealing_factor = loss_config.get("entropy_annealing_factor", 0.0) self.discriminator_iter_start = loss_config.discriminator_start if loss_config.discriminator_loss == "hinge": self.discriminator_loss = gan_utils.hinge_d_loss elif loss_config.discriminator_loss == "vanilla": self.discriminator_loss = gan_utils.vanilla_d_loss elif loss_config.discriminator_loss == "non-saturating": self.discriminator_loss = gan_utils.non_saturating_d_loss else: raise ValueError(f"Unknown GAN loss '{loss_config.discriminator_loss}'.") if loss_config.discriminator_loss == "hinge": self.generator_loss = gan_utils.hinge_g_loss elif loss_config.discriminator_loss == "vanilla": self.generator_loss = gan_utils.hinge_g_loss elif loss_config.discriminator_loss == "non-saturating": self.generator_loss = gan_utils.non_saturating_g_loss else: raise ValueError(f"Unknown GAN loss '{loss_config.discriminator_loss}'.") self.discriminator_factor = loss_config.discriminator_factor self.discriminator_weight = loss_config.discriminator_weight self.discriminator_gradient_penalty = ( "" if loss_config.discriminator_gradient_penalty == "none" else loss_config.discriminator_gradient_penalty ) self.discriminator_penalty_cost = loss_config.discriminator_penalty_cost if self.lecam_regularization_weight > 0.0: self.register_buffer("ema_real_logits_mean", torch.zeros((1))) self.register_buffer("ema_fake_logits_mean", torch.zeros((1))) def calculate_adaptive_weight( self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer ) -> torch.Tensor: """Calculates the adaptive weight for the discriminator loss. Args: nll_loss -> torch.Tensor: The NLL loss. g_loss -> torch.Tensor: The generator loss. last_layer: The last layer of the model. Returns: d_weight -> torch.Tensor: The adaptive weight for the discriminator loss. """ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() return d_weight def forward( self, inputs: torch.Tensor, reconstructions: torch.Tensor, extra_result_dict: Mapping[Text, torch.Tensor], global_step: int, last_layer, mode: str = "gen", ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: """Computes the VQGAN loss for the generator or discriminator. Args: inputs -> torch.Tensor: The input images. reconstructions -> torch.Tensor: The reconstructed images. extra_result_dict -> Mapping[Text, torch.Tensor]: The extra result dictionary. global_step -> int: The global step. last_layer: The last layer of the model. mode -> str: The mode. Must be either "gen" or "disc". Returns: loss -> torch.Tensor: The loss. loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses. """ assert mode in ("gen", "disc") if mode == "gen": return self._forward_generator( inputs, reconstructions, extra_result_dict, global_step, last_layer ) elif mode == "disc": return self._forward_discriminator( inputs, reconstructions, extra_result_dict, global_step ) def should_discriminator_be_trained(self, global_step: int): """Returns if the discriminator should be trained at given step.""" return global_step >= self.discriminator_iter_start def _forward_generator( self, inputs: torch.Tensor, reconstructions: torch.Tensor, extra_result_dict: Mapping[Text, torch.Tensor], global_step: int, last_layer, ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: """Computes the VQGAN loss for the generator. Args: inputs -> torch.Tensor: The input images. reconstructions -> torch.Tensor: The reconstructed images. extra_result_dict -> Mapping[Text, torch.Tensor]: The extra result dictionary. global_step -> int: The global step. last_layer: The last layer of the model. Returns: loss -> torch.Tensor: The loss. loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses. """ inputs = inputs.contiguous() reconstructions = reconstructions.contiguous() if self.reconstruction_loss == "l1": reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean") else: reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean") reconstruction_loss *= self.reconstruction_weight perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean() generator_loss = torch.zeros((), device=inputs.device) extra_generator_loss = torch.zeros((), device=inputs.device) discriminator_factor = gan_utils.adopt_weight( self.discriminator_factor, global_step, threshold=self.discriminator_iter_start, ) d_weight = 1.0 if discriminator_factor > 0.0: # Disable discriminator gradients gan_utils.toggle_off_gradients(self.discriminator) logits_fake = self.discriminator(reconstructions) generator_loss = self.generator_loss(logits_fake) if self.discriminator_gradient_penalty == "adopt_weight": d_weight *= self.calculate_adaptive_weight( reconstruction_loss + self.perceptual_weight * perceptual_loss, generator_loss, last_layer=last_layer, ) d_weight *= self.discriminator_weight quantizer_loss = extra_result_dict["quantizer_loss"] if self.entropy_annealing_factor > 0.0: quantizer_loss += ( max(0.0, 1 - global_step / self.entropy_annealing_steps) * self.entropy_annealing_factor * extra_result_dict["entropy_loss"] ) total_loss = ( reconstruction_loss + self.perceptual_weight * perceptual_loss + self.quantizer_weight * quantizer_loss + d_weight * discriminator_factor * (generator_loss + extra_generator_loss) ) loss_dict = dict( total_loss=total_loss.clone().detach(), reconstruction_loss=reconstruction_loss.detach(), perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(), weighted_gan_loss=( d_weight * discriminator_factor * (generator_loss + extra_generator_loss) ).detach(), discriminator_factor=torch.tensor(discriminator_factor), commitment_loss=extra_result_dict["commitment_loss"].detach(), entropy_loss=extra_result_dict["entropy_loss"].detach(), per_sample_entropy=extra_result_dict["per_sample_entropy"], avg_entropy=extra_result_dict["avg_entropy"], d_weight=d_weight, gan_loss=generator_loss.detach(), ) if "codebook_loss" in extra_result_dict: loss_dict["codebook_loss"] = extra_result_dict["codebook_loss"].detach() return total_loss, loss_dict def _forward_discriminator( self, inputs: torch.Tensor, reconstructions: torch.Tensor, extra_result_dict: Mapping[Text, torch.Tensor], global_step: int, ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: """Computes the VQGAN loss for the discriminator. Args: inputs -> torch.Tensor: The input images. reconstructions -> torch.Tensor: The reconstructed images. extra_result_dict -> Mapping[Text, torch.Tensor]: The extra result dictionary. global_step -> int: The global step. Returns: loss -> torch.Tensor: The loss. loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses. """ discriminator_factor = gan_utils.adopt_weight( self.discriminator_factor, global_step, threshold=self.discriminator_iter_start, ) loss_dict = {} # Turn on gradients on gan_utils.toggle_on_gradients(self.discriminator) real_images = inputs.detach().requires_grad_(True) logits_real = self.discriminator(real_images) logits_fake = self.discriminator(reconstructions.detach()) discriminator_loss = discriminator_factor * self.discriminator_loss( logits_real=logits_real, logits_fake=logits_fake ) lecam_loss = torch.zeros((), device=inputs.device) if self.lecam_regularization_weight > 0.0: lecam_loss = ( gan_utils.compute_lecam_loss( torch.mean(logits_real), torch.mean(logits_fake), self.ema_real_logits_mean, self.ema_fake_logits_mean, ) * self.lecam_regularization_weight ) self.ema_real_logits_mean = ( self.ema_real_logits_mean * self.ema_decay + torch.mean(logits_real).detach() * (1 - self.ema_decay) ) self.ema_fake_logits_mean = ( self.ema_fake_logits_mean * self.ema_decay + torch.mean(logits_fake).detach() * (1 - self.ema_decay) ) discriminator_loss += lecam_loss loss_dict = dict( discriminator_loss=discriminator_loss.detach(), logits_real=logits_real.detach().mean(), logits_fake=logits_fake.detach().mean(), lecam_loss=lecam_loss.detach(), ) return discriminator_loss, loss_dict class MLMLoss(torch.nn.Module): def __init__(self, label_smoothing: float = 0.1, sum_splits: bool = False): """Initializes the MLM loss, which is essentially a CrossEntropy loss with label smoothing. Args: label_smoothing -> float: The label smoothing factor. sum_splits -> bool: Whether to sum the loss over the splits. """ super().__init__() self.label_smoothing = label_smoothing self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=self.label_smoothing) self.sum_splits = sum_splits def forward( self, inputs: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: """Computes the MLM loss. Args: inputs -> torch.Tensor: The input logits. targets -> torch.Tensor: The target tokens. masks -> torch.Tensor: The mask for the tokens. Returns: loss -> torch.Tensor: The loss. loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses. """ b, n, m, codebook_size = inputs.shape loss = self.criterion(inputs.reshape(-1, codebook_size), targets.view(-1)) correct_tokens = ( torch.argmax(inputs.detach(), dim=-1) == targets ).float().mean() ** m masked_input = inputs[masks, :].detach() masked_loss = self.criterion(masked_input, targets[masks]) masked_correct_tokens = ( torch.argmax(masked_input, dim=-1) == targets[masks] ).float().mean() ** m if self.sum_splits: loss *= m masked_loss *= m loss_dict = { "mlm_loss": loss, "correct_tokens": correct_tokens, "masked_token_loss": masked_loss, "masked_correct_tokens": masked_correct_tokens, } return loss, loss_dict if __name__ == "__main__": loss_module = MLMLoss() batchsize = 2 codebook_dim = 4 num_codebooks = 1 logits = torch.rand((batchsize, 3, num_codebooks, codebook_dim)) targets = torch.randint(0, codebook_dim, (batchsize, 3, num_codebooks)) masks = torch.randint(0, 2, (batchsize, 3, num_codebooks), dtype=bool) loss, loss_dict = loss_module(logits, targets, masks) print(logits) print(targets) print(masks) print(loss, loss_dict)