Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.utils.spectral_norm as spectral_norm | |
from modules.eg3ds.models.networks_stylegan2 import SynthesisBlock | |
from modules.eg3ds.models.superresolution import SynthesisBlockNoUp | |
from modules.eg3ds.models.superresolution import SuperresolutionHybrid8XDC | |
from modules.real3d.facev2v_warp.model import WarpBasedTorsoModelMediaPipe as torso_model_v1 | |
from modules.real3d.facev2v_warp.model2 import WarpBasedTorsoModelMediaPipe as torso_model_v2 | |
from utils.commons.hparams import hparams | |
from utils.commons.image_utils import dilate, erode | |
class SuperresolutionHybrid8XDC_Warp(SuperresolutionHybrid8XDC): | |
def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, **block_kwargs): | |
super().__init__(channels, img_resolution, sr_num_fp16_res, sr_antialias, **block_kwargs) | |
if hparams.get("torso_model_version", "v1") == 'v1': | |
self.torso_model = torso_model_v1('standard') | |
elif hparams.get("torso_model_version", "v1") == 'v2': | |
self.torso_model = torso_model_v2('standard') | |
else: raise NotImplementedError() | |
# self.torso_model = WarpBasedTorsoModelMediaPipe('small') | |
self.torso_encoder = nn.Sequential(*[ | |
nn.Conv2d(64, 256, 1, 1, padding=0), | |
]) | |
self.bg_encoder = nn.Sequential(*[ | |
nn.Conv2d(3, 64, 3, 1, padding=1), | |
nn.LeakyReLU(), | |
nn.Conv2d(64, 256, 3, 1, padding=1), | |
nn.LeakyReLU(), | |
nn.Conv2d(256, 256, 3, 1, padding=1), | |
]) | |
if hparams.get("weight_fuse", True): | |
if hparams['htbsr_head_weight_fuse_mode'] in ['v1']: | |
fuse_in_dim = 512 | |
# elif hparams['htbsr_head_weight_fuse_mode'] in ['v2']: | |
else: | |
fuse_in_dim = 512 | |
self.head_torso_alpha_predictor = nn.Sequential(*[ | |
nn.Conv2d(3+1+3, 32, 3, 1, padding=1), | |
nn.LeakyReLU(), | |
nn.Conv2d(32, 32, 3, 1, padding=1), | |
nn.LeakyReLU(), | |
nn.Conv2d(32, 1, 3, 1, padding=1), | |
nn.Sigmoid(), | |
]) | |
self.fuse_head_torso_convs = nn.Sequential(*[ | |
nn.Conv2d(256+256, 256, 3, 1, padding=1), | |
nn.LeakyReLU(), | |
nn.Conv2d(256, 256, 3, 1, padding=1), | |
]) | |
self.head_torso_block = SynthesisBlockNoUp(256, 256, w_dim=512, resolution=256, | |
img_channels=3, is_last=False, use_fp16=False, conv_clamp=None, **block_kwargs) | |
else: | |
fuse_in_dim = 768 | |
self.fuse_fg_bg_convs = nn.Sequential(*[ | |
nn.Conv2d(fuse_in_dim, 64, 1, 1, padding=0), | |
nn.LeakyReLU(), | |
nn.Conv2d(64, 256, 3, 1, padding=1), | |
nn.LeakyReLU(), | |
nn.Conv2d(256, 256, 3, 1, padding=1), | |
]) | |
def forward(self, rgb, x, ws, ref_torso_rgb, ref_bg_rgb, weights_img, segmap, kp_s, kp_d, target_torso_mask=None, **block_kwargs): | |
weights_img = weights_img.detach() | |
ws = ws[:, -1:, :].expand([rgb.shape[0], 3, -1]) | |
if x.shape[-1] != self.input_resolution: | |
x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), | |
mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), | |
mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
rgb_256 = torch.nn.functional.interpolate(rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
weights_256 = torch.nn.functional.interpolate(weights_img, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
ref_torso_rgb_256 = torch.nn.functional.interpolate(ref_torso_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
ref_bg_rgb_256 = torch.nn.functional.interpolate(ref_bg_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
x, rgb = self.block0(x, rgb, ws, **block_kwargs) # sr branch, 128x128 head img ==> 256x256 head img | |
if hparams.get("torso_model_version", "v1") == 'v1': | |
rgb_torso, facev2v_ret = self.torso_model.forward(ref_torso_rgb_256, segmap, kp_s, kp_d, rgb_256.detach(), cal_loss=True, target_torso_mask=target_torso_mask) | |
elif hparams.get("torso_model_version", "v1") == 'v2': | |
rgb_torso, facev2v_ret = self.torso_model.forward(ref_torso_rgb_256, segmap, kp_s, kp_d, rgb_256.detach(), weights_256.detach(), cal_loss=True, target_torso_mask=target_torso_mask) | |
x_torso = self.torso_encoder(facev2v_ret['deformed_torso_hid']) | |
x_bg = self.bg_encoder(ref_bg_rgb_256) | |
if hparams.get("weight_fuse", True): | |
if hparams['htbsr_head_weight_fuse_mode'] == 'v1': | |
rgb = rgb * weights_256 + rgb_torso * (1-weights_256) # get person img | |
x = x * weights_256 + x_torso * (1-weights_256) # get person img | |
head_occlusion = weights_256.clone() | |
htbsr_head_threshold = hparams['htbsr_head_threshold'] | |
head_occlusion[head_occlusion > htbsr_head_threshold] = 1. | |
torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1) | |
rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6 | |
x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6 | |
x = self.fuse_fg_bg_convs(x) | |
x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
elif hparams['htbsr_head_weight_fuse_mode'] == 'v2': | |
# 用alpha-cat实现head torso的x的融合;替代了之前的直接alpha相加 | |
head_torso_alpha = weights_256.clone() | |
head_torso_alpha[head_torso_alpha>weights_256] = weights_256[head_torso_alpha>weights_256] | |
rgb = rgb * head_torso_alpha + rgb_torso * (1-head_torso_alpha) # get person img | |
x = torch.cat([x * head_torso_alpha, x_torso * (1-head_torso_alpha)], dim=1) | |
x = self.fuse_head_torso_convs(x) | |
x, rgb = self.head_torso_block(x, rgb, ws, **block_kwargs) | |
head_occlusion = head_torso_alpha.clone() | |
# 鼓励weights与mask逼近后,不再需要手动修改head weights threshold到很小的值了,0.7就行 | |
htbsr_head_threshold = hparams['htbsr_head_threshold'] | |
head_occlusion[head_occlusion > htbsr_head_threshold] = 1. | |
torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1) | |
rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6 | |
x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6 | |
x = self.fuse_fg_bg_convs(x) | |
x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
elif hparams['htbsr_head_weight_fuse_mode'] == 'v3': | |
# v2:用alpha-cat实现head torso的x的融合;替代了之前的直接alpha相加 | |
# v3: 用nn额外后处理head mask | |
head_torso_alpha_inp = torch.cat([rgb.clamp(-1,1)/2+0.5, weights_256, rgb_torso.clamp(-1,1)/2+0.5], dim=1) | |
head_torso_alpha_ = self.head_torso_alpha_predictor(head_torso_alpha_inp) | |
head_torso_alpha = head_torso_alpha_.clone() | |
head_torso_alpha[head_torso_alpha>weights_256] = weights_256[head_torso_alpha>weights_256] | |
rgb = rgb * head_torso_alpha + rgb_torso * (1-head_torso_alpha) # get person img | |
x = torch.cat([x * head_torso_alpha, x_torso * (1-head_torso_alpha)], dim=1) # run6 | |
x = self.fuse_head_torso_convs(x) | |
x, rgb = self.head_torso_block(x, rgb, ws, **block_kwargs) | |
head_occlusion = head_torso_alpha.clone() | |
htbsr_head_threshold = hparams['htbsr_head_threshold'] | |
if not self.training: | |
head_occlusion_ = head_occlusion[head_occlusion>0.05] | |
htbsr_head_threshold = max(head_occlusion_.quantile(0.05), htbsr_head_threshold) # 过滤掉比0.05大的最后5% voxels | |
head_occlusion[head_occlusion > htbsr_head_threshold] = 1. | |
torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1) | |
rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6 | |
# Todo: 修改这里,把cat的occlusion去掉?或者把occlusion截断一下。 | |
x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6 | |
x = self.fuse_fg_bg_convs(x) | |
x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
else: | |
# v4 尝试直接用cat处理head-torso的hid的融合,发现不好 | |
# v5 try1处理x的时候也把cat里的alpha去掉了,但是try1发现导致occlusion直接变1.所以去掉 | |
# v5 try2给torso也加了threshold让他算rgb的时候更加sharp, 会导致torso周围黑边? | |
raise NotImplementedError() | |
else: | |
x = torch.cat([x, x_torso, x_bg], dim=1) # run6 | |
x = self.fuse_fg_bg_convs(x) | |
x, rgb = self.block1(x, None, ws, **block_kwargs) | |
return rgb, facev2v_ret | |
def infer_forward_stage1(self, rgb, x, ws, ref_torso_rgb, ref_bg_rgb, weights_img, segmap, kp_s, kp_d, **block_kwargs): | |
weights_img = weights_img.detach() | |
ws = ws[:, -1:, :].repeat(1, 3, 1) | |
if x.shape[-1] != self.input_resolution: | |
x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), | |
mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), | |
mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
rgb_256 = torch.nn.functional.interpolate(rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
weights_256 = torch.nn.functional.interpolate(weights_img, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
ref_torso_rgb_256 = torch.nn.functional.interpolate(ref_torso_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
ref_bg_rgb_256 = torch.nn.functional.interpolate(ref_bg_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
x, rgb = self.block0(x, rgb, ws, **block_kwargs) | |
facev2v_ret = self.torso_model.infer_forward_stage1(ref_torso_rgb_256, segmap, kp_s, kp_d, rgb_256.detach(), cal_loss=True) | |
facev2v_ret['ref_bg_rgb_256'] = ref_bg_rgb_256 | |
facev2v_ret['weights_256'] = weights_256 | |
facev2v_ret['x'] = x | |
facev2v_ret['ws'] = ws | |
facev2v_ret['rgb'] = rgb | |
return facev2v_ret | |
def infer_forward_stage2(self, facev2v_ret, **block_kwargs): | |
x = facev2v_ret['x'] | |
ws = facev2v_ret['ws'] | |
rgb = facev2v_ret['rgb'] | |
ref_bg_rgb_256 = facev2v_ret['ref_bg_rgb_256'] | |
weights_256 = facev2v_ret['weights_256'] | |
rgb_torso = self.torso_model.infer_forward_stage2(facev2v_ret) | |
x_torso = self.torso_encoder(facev2v_ret['deformed_torso_hid']) | |
x_bg = self.bg_encoder(ref_bg_rgb_256) | |
if hparams.get("weight_fuse", True): | |
rgb = rgb * weights_256 + rgb_torso * (1-weights_256) # get person img | |
x = x * weights_256 + x_torso * (1-weights_256) # get person img | |
head_occlusion = weights_256.clone() | |
head_occlusion[head_occlusion > 0.5] = 1. | |
torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1) | |
rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6 | |
x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6 | |
x = self.fuse_fg_bg_convs(x) | |
x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
else: | |
x = torch.cat([x, x_torso, x_bg], dim=1) # run6 | |
x = self.fuse_fg_bg_convs(x) | |
x, rgb = self.block1(x, None, ws, **block_kwargs) | |
return rgb, facev2v_ret | |
if __name__ == '__main__': | |
model = SuperresolutionHybrid8XDC_Warp(32,512,0, False) | |
model.cuda() | |
rgb = torch.randn([4, 3, 128, 128]).cuda() | |
x = torch.randn([4, 32, 128, 128]).cuda() | |
ws = torch.randn([4, 14, 512]).cuda() | |
ref_rgb = torch.randn([4, 3, 128, 128]).cuda() | |
ref_torso_rgb = torch.randn([4, 3, 128, 128]).cuda() | |
y = model(rgb, x, ws, ref_rgb, ref_torso_rgb) | |
print(" ") |