"""This file contains the definition of the perceptual loss.""" import torch from torchvision import models from torchvision.models.feature_extraction import create_feature_extractor class PerceptualLoss(torch.nn.Module): def __init__( self, model_name: str = "resnet50", compute_perceptual_loss_on_logits: bool = True, ): """Initialize the perceptual loss. Args: model_name -> str: The name of the model to use. compute_perceptual_loss_on_logits -> bool: Whether to compute the perceptual loss on the logits or the features. """ super().__init__() if model_name == "resnet50": model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) return_nodes = {"layer4": "features", "fc": "logits"} elif model_name == "convnext_s": model = models.convnext_small( weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1 ) return_nodes = {"features": "features", "classifier": "logits"} if compute_perceptual_loss_on_logits: self.model = model else: self.model = create_feature_extractor(model, return_nodes=return_nodes) self.compute_perceptual_loss_on_logits = compute_perceptual_loss_on_logits self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406])[None, :, None, None] ) self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225])[None, :, None, None] ) for param in self.parameters(): param.requires_grad = False def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute the perceptual loss. Args: input -> torch.Tensor: The input tensor. target -> torch.Tensor: The target tensor. Returns: loss -> torch.Tensor: The perceptual loss. """ input = torch.nn.functional.interpolate( input, size=224, mode="bilinear", antialias=True, align_corners=False ) target = torch.nn.functional.interpolate( target, size=224, mode="bilinear", antialias=True, align_corners=False ) input = (input - self.mean) / self.std target = (target - self.mean) / self.std features_input = self.model(input) features_target = self.model(target) if self.compute_perceptual_loss_on_logits: loss = torch.nn.functional.mse_loss( features_input, features_target, reduction="mean" ) else: loss = torch.nn.functional.mse_loss( features_input["features"], features_target["features"], reduction="mean", ) loss += torch.nn.functional.mse_loss( features_input["logits"], features_target["logits"], reduction="mean" ) return loss if __name__ == "__main__": model = PerceptualLoss() input = torch.randn(2, 3, 256, 256).clamp_(0, 1) target = torch.randn(2, 3, 256, 256).clamp_(0, 1) loss = model(input, target) print(loss)