Spaces:
Build error
Build error
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import torchvision | |
from torch import nn | |
from modules.real3d.facev2v_warp.func_utils import apply_imagenet_normalization, apply_vggface_normalization | |
def fuse_math_min_mean_pos(x): | |
r"""Fuse operation min mean for hinge loss computation of positive | |
samples""" | |
minval = torch.min(x - 1, x * 0) | |
loss = -torch.mean(minval) | |
return loss | |
def fuse_math_min_mean_neg(x): | |
r"""Fuse operation min mean for hinge loss computation of negative | |
samples""" | |
minval = torch.min(-x - 1, x * 0) | |
loss = -torch.mean(minval) | |
return loss | |
class _PerceptualNetwork(nn.Module): | |
def __init__(self, network, layer_name_mapping, layers): | |
super().__init__() | |
self.network = network.cuda() | |
self.layer_name_mapping = layer_name_mapping | |
self.layers = layers | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, x): | |
output = {} | |
for i, layer in enumerate(self.network): | |
x = layer(x) | |
layer_name = self.layer_name_mapping.get(i, None) | |
if layer_name in self.layers: | |
output[layer_name] = x | |
return output | |
def _vgg19(layers): | |
network = torchvision.models.vgg19() | |
state_dict = torch.utils.model_zoo.load_url( | |
"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", map_location=torch.device("cpu"), progress=True | |
) | |
network.load_state_dict(state_dict) | |
network = network.features | |
layer_name_mapping = { | |
1: "relu_1_1", | |
3: "relu_1_2", | |
6: "relu_2_1", | |
8: "relu_2_2", | |
11: "relu_3_1", | |
13: "relu_3_2", | |
15: "relu_3_3", | |
17: "relu_3_4", | |
20: "relu_4_1", | |
22: "relu_4_2", | |
24: "relu_4_3", | |
26: "relu_4_4", | |
29: "relu_5_1", | |
} | |
return _PerceptualNetwork(network, layer_name_mapping, layers) | |
def _vgg_face(layers): | |
network = torchvision.models.vgg16(num_classes=2622) | |
state_dict = torch.utils.model_zoo.load_url( | |
"http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/" "vgg_face_dag.pth", map_location=torch.device("cpu"), progress=True | |
) | |
feature_layer_name_mapping = { | |
0: "conv1_1", | |
2: "conv1_2", | |
5: "conv2_1", | |
7: "conv2_2", | |
10: "conv3_1", | |
12: "conv3_2", | |
14: "conv3_3", | |
17: "conv4_1", | |
19: "conv4_2", | |
21: "conv4_3", | |
24: "conv5_1", | |
26: "conv5_2", | |
28: "conv5_3", | |
} | |
new_state_dict = {} | |
for k, v in feature_layer_name_mapping.items(): | |
new_state_dict["features." + str(k) + ".weight"] = state_dict[v + ".weight"] | |
new_state_dict["features." + str(k) + ".bias"] = state_dict[v + ".bias"] | |
classifier_layer_name_mapping = {0: "fc6", 3: "fc7", 6: "fc8"} | |
for k, v in classifier_layer_name_mapping.items(): | |
new_state_dict["classifier." + str(k) + ".weight"] = state_dict[v + ".weight"] | |
new_state_dict["classifier." + str(k) + ".bias"] = state_dict[v + ".bias"] | |
network.load_state_dict(new_state_dict) | |
layer_name_mapping = { | |
1: "relu_1_1", | |
3: "relu_1_2", | |
6: "relu_2_1", | |
8: "relu_2_2", | |
11: "relu_3_1", | |
13: "relu_3_2", | |
15: "relu_3_3", | |
18: "relu_4_1", | |
20: "relu_4_2", | |
22: "relu_4_3", | |
25: "relu_5_1", | |
} | |
return _PerceptualNetwork(network.features, layer_name_mapping, layers) | |
class PerceptualLoss(nn.Module): | |
def __init__( | |
self, | |
layers_weight={"relu_1_1": 0.03125, "relu_2_1": 0.0625, "relu_3_1": 0.125, "relu_4_1": 0.25, "relu_5_1": 1.0}, | |
n_scale=3, | |
vgg19_loss_weight=1.0, | |
vggface_loss_weight=1.0, | |
): | |
super().__init__() | |
self.vgg19 = _vgg19(layers_weight.keys()) | |
self.vggface = _vgg_face(layers_weight.keys()) | |
self.mse_criterion = nn.MSELoss() | |
self.criterion = nn.L1Loss() | |
self.layers_weight, self.n_scale = layers_weight, n_scale | |
self.vgg19_loss_weight = vgg19_loss_weight | |
self.vggface_loss_weight = vggface_loss_weight | |
self.vgg19.eval() | |
self.vggface.eval() | |
def forward(self, input, target): | |
""" | |
input: [B, 3, H, W] in 0.~1. scale | |
""" | |
if input.shape[-1] != 512: | |
assert input.ndim == 4 | |
input = F.interpolate(input, mode="bilinear", size=(512,512), antialias=True, align_corners=False) | |
target = F.interpolate(target, mode="bilinear", size=(512,512), antialias=True, align_corners=False) | |
self.vgg19.eval() | |
self.vggface.eval() | |
loss = 0 | |
features_vggface_input = self.vggface(apply_vggface_normalization(input)) | |
features_vggface_target = self.vggface(apply_vggface_normalization(target)) | |
input = apply_imagenet_normalization(input) | |
target = apply_imagenet_normalization(target) | |
features_vgg19_input = self.vgg19(input) | |
features_vgg19_target = self.vgg19(target) | |
for layer, weight in self.layers_weight.items(): | |
tmp = self.vggface_loss_weight * weight * self.criterion(features_vggface_input[layer], features_vggface_target[layer].detach()) / 255 | |
if not torch.any(torch.isnan(tmp)): | |
loss += tmp | |
else: | |
loss += torch.zeros_like(tmp) | |
tmp = self.vgg19_loss_weight * weight * self.criterion(features_vgg19_input[layer], features_vgg19_target[layer].detach()) | |
if not torch.any(torch.isnan(tmp)): | |
loss += tmp | |
else: | |
loss += torch.zeros_like(tmp) | |
for i in range(self.n_scale): | |
input = F.interpolate(input, mode="bilinear", scale_factor=0.5, align_corners=False, recompute_scale_factor=True) | |
target = F.interpolate(target, mode="bilinear", scale_factor=0.5, align_corners=False, recompute_scale_factor=True) | |
features_vgg19_input = self.vgg19(input) | |
features_vgg19_target = self.vgg19(target) | |
tmp = weight * self.criterion(features_vgg19_input[layer], features_vgg19_target[layer].detach()) | |
if not torch.any(torch.isnan(tmp)): | |
loss += tmp | |
else: | |
loss += torch.zeros_like(tmp) | |
return loss | |
class GANLoss(nn.Module): | |
# Update generator: gan_loss(fake_output, True, False) + other losses | |
# Update discriminator: gan_loss(fake_output(detached), False, True) + gan_loss(real_output, True, True) | |
def __init__(self): | |
super().__init__() | |
def forward(self, dis_output, t_real, dis_update=True): | |
r"""GAN loss computation. | |
Args: | |
dis_output (tensor or list of tensors): Discriminator outputs. | |
t_real (bool): If ``True``, uses the real label as target, otherwise | |
uses the fake label as target. | |
dis_update (bool): If ``True``, the loss will be used to update the | |
discriminator, otherwise the generator. | |
Returns: | |
loss (tensor): Loss value. | |
""" | |
if dis_update: | |
if t_real: | |
loss = fuse_math_min_mean_pos(dis_output) | |
else: | |
loss = fuse_math_min_mean_neg(dis_output) | |
else: | |
loss = -torch.mean(dis_output) | |
return loss | |
class FeatureMatchingLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.criterion = nn.L1Loss() | |
def forward(self, fake_features, real_features): | |
num_d = len(fake_features) | |
dis_weight = 1.0 / num_d | |
loss = fake_features[0][0].new_tensor(0) | |
for i in range(num_d): | |
for j in range(len(fake_features[i])): | |
tmp_loss = self.criterion(fake_features[i][j], real_features[i][j].detach()) | |
loss += dis_weight * tmp_loss | |
return loss | |
class EquivarianceLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.criterion = nn.L1Loss() | |
def forward(self, kp_d, reverse_kp): | |
loss = self.criterion(kp_d[:, :, :2], reverse_kp) | |
return loss | |
class KeypointPriorLoss(nn.Module): | |
def __init__(self, Dt=0.1, zt=0.33): | |
super().__init__() | |
self.Dt, self.zt = Dt, zt | |
def forward(self, kp_d): | |
# use distance matrix to avoid loop | |
dist_mat = torch.cdist(kp_d, kp_d).square() | |
loss = ( | |
torch.max(0 * dist_mat, self.Dt - dist_mat).sum((1, 2)).mean() | |
+ torch.abs(kp_d[:, :, 2].mean(1) - self.zt).mean() | |
- kp_d.shape[1] * self.Dt | |
) | |
return loss | |
class HeadPoseLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.criterion = nn.L1Loss() | |
def forward(self, yaw, pitch, roll, real_yaw, real_pitch, real_roll): | |
loss = (self.criterion(yaw, real_yaw.detach()) + self.criterion(pitch, real_pitch.detach()) + self.criterion(roll, real_roll.detach())) / 3 | |
return loss / np.pi * 180 | |
class DeformationPriorLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, delta_d): | |
loss = delta_d.abs().mean() | |
return loss | |
if __name__ == '__main__': | |
loss_fn = PerceptualLoss() | |
x1 = torch.randn([4, 3, 512, 512]).cuda() | |
x2 = torch.randn([4, 3, 512, 512]).cuda() | |
loss = loss_fn(x1, x2) |