Spaces:
Build error
Build error
# Copyright 2020 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import numpy as np | |
import torch | |
import math | |
from modules.real3d.segformer import SegFormerImg2PlaneBackbone, SegFormerSECC2PlaneBackbone | |
from modules.real3d.img2plane_baseline import OSAvatar_Img2plane | |
from modules.img2plane.img2plane_model import Img2PlaneModel | |
from utils.commons.hparams import hparams | |
# 换成attention吧?value用plane。 | |
class OSAvatarSECC_Img2plane(OSAvatar_Img2plane): | |
def __init__(self, hp=None, lora_args=None): | |
if lora_args is None or lora_args.get("lora_mode", 'none') == 'none': | |
lora_args = None | |
super().__init__(hp=hp, lora_args=lora_args) | |
hparams = self.hparams | |
# extract canonical triplane from src img | |
self.cano_img2plane_backbone = self.img2plane_backbone # rename | |
del self.img2plane_backbone | |
lora_args_secc2plane = lora_args if (lora_args and lora_args.get("lora_mode", 'none') == 'all' or 'secc2plane' in lora_args.get("lora_mode", 'none')) else None | |
if lora_args_secc2plane: | |
print("lora_args_secc2plane: ", lora_args_secc2plane) | |
self.secc_img2plane_backbone = SegFormerSECC2PlaneBackbone(mode=hparams.get('secc_segformer_scale','b0'), out_channels=3*self.triplane_hid_dim*self.triplane_depth, pncc_cond_mode=hparams['pncc_cond_mode'], lora_args=lora_args_secc2plane) | |
self.lambda_pertube_blink_secc = torch.nn.Parameter(torch.tensor([0.001]), requires_grad=False) | |
self.lambda_pertube_secc = torch.nn.Parameter(torch.tensor([0.001]), requires_grad=False) | |
def on_train_full_model(self): | |
self.requires_grad_(True) | |
def on_train_nerf(self): | |
self.cano_img2plane_backbone.requires_grad_(True) | |
self.secc_img2plane_backbone.requires_grad_(True) | |
self.decoder.requires_grad_(True) | |
self.superresolution.requires_grad_(False) | |
def on_train_superresolution(self): | |
self.cano_img2plane_backbone.requires_grad_(False) | |
self.secc_img2plane_backbone.requires_grad_(False) | |
self.decoder.requires_grad_(False) | |
self.superresolution.requires_grad_(True) | |
def cal_cano_plane(self, img, cond=None, **kwargs): | |
hparams = self.hparams | |
planes = cano_planes = self.cano_img2plane_backbone(img, cond, **kwargs) # [B, 3, C*D, H, W] | |
if hparams.get("triplane_feature_type", "triplane") in ['triplane', 'trigrid']: | |
planes = planes.view(len(planes), 3, self.triplane_hid_dim*self.triplane_depth, planes.shape[-2], planes.shape[-1]) | |
elif hparams.get("triplane_feature_type", "triplane") in ['trigrid_v2']: | |
b, k, cd, h, w = planes.shape # k = 3 | |
planes = planes.reshape([b, k*cd, h, w]) | |
planes = self.plane2grid_module(planes) | |
planes = planes.reshape([b, k, cd, h, w]) | |
else: | |
raise NotImplementedError() | |
return planes | |
def cal_secc_plane(self, cond): | |
cano_pncc, src_pncc, tgt_pncc = cond['cond_cano'], cond['cond_src'], cond['cond_tgt'] | |
if self.hparams.get("pncc_cond_mode", "cano_tgt") == 'cano_src_tgt': | |
inp_pncc = torch.cat([cano_pncc, src_pncc, tgt_pncc], dim=1) | |
else: | |
inp_pncc = torch.cat([cano_pncc, tgt_pncc], dim=1) | |
secc_planes = self.secc_img2plane_backbone(inp_pncc) | |
return secc_planes | |
def cal_plane_given_cano(self, cano_planes, cond=None): | |
# cano_planes: # [B, 3, C*D, H, W] | |
secc_planes = self.cal_secc_plane(cond) # [B, 3, C*D, H, W] | |
if self.hparams.get("phase1_plane_fusion_mode", "add") == 'add': | |
planes = cano_planes + secc_planes | |
elif self.hparams.get("phase1_plane_fusion_mode", "add") == 'mul': | |
planes = cano_planes * secc_planes | |
else: raise NotImplementedError() | |
return planes | |
def cal_plane(self, img, cond, ret=None, **kwargs): | |
cano_planes = self.cal_cano_plane(img, cond, **kwargs) # [B, 3, C*D, H, W] | |
planes = self.cal_plane_given_cano(cano_planes, cond) | |
return planes, cano_planes | |
def sample(self, coordinates, directions, img, cond=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, ref_camera=None, **synthesis_kwargs): | |
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. | |
planes, _ = self.cal_plane(img, cond, ret={}, ref_camera=ref_camera) | |
return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) | |
def synthesis(self, img, camera, cond=None, ret=None, update_emas=False, cache_backbone=True, use_cached_backbone=False, **synthesis_kwargs): | |
if ret is None: ret = {} | |
cam2world_matrix = camera[:, :16].view(-1, 4, 4) | |
intrinsics = camera[:, 16:25].view(-1, 3, 3) | |
neural_rendering_resolution = self.neural_rendering_resolution | |
# Create a batch of rays for volume rendering | |
ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution) | |
# Create triplanes by running StyleGAN backbone | |
N, M, _ = ray_origins.shape | |
if use_cached_backbone: | |
# use the cached cano_planes obtained from a previous forward with flag cache_backbone=True | |
cano_planes = self._last_cano_planes | |
planes = self.cal_plane_given_cano(cano_planes, cond) | |
else: | |
planes, cano_planes = self.cal_plane(img, cond, ret, **synthesis_kwargs) | |
if cache_backbone: | |
self._last_cano_planes = cano_planes | |
# Perform volume rendering | |
feature_samples, depth_samples, weights_samples, is_ray_valid = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) # channels last | |
# Reshape into 'raw' neural-rendered image | |
H = W = self.neural_rendering_resolution | |
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() | |
weights_image = weights_samples.permute(0, 2, 1).reshape(N,1,H,W).contiguous() # [N,1,H,W] | |
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) | |
if self.hparams.get("mask_invalid_rays", False): | |
is_ray_valid_mask = is_ray_valid.reshape([feature_samples.shape[0], 1,self.neural_rendering_resolution,self.neural_rendering_resolution]) # [B, 1, H, W] | |
feature_image[~is_ray_valid_mask.repeat([1,feature_image.shape[1],1,1])] = -1 | |
# feature_image[~is_ray_valid_mask.repeat([1,feature_image.shape[1],1,1])] *= 0 | |
# feature_image[~is_ray_valid_mask.repeat([1,feature_image.shape[1],1,1])] -= 1 | |
depth_image[~is_ray_valid_mask] = depth_image[is_ray_valid_mask].min().item() | |
# Run superresolution to get final image | |
rgb_image = feature_image[:, :3] | |
ret['weights_img'] = weights_image | |
sr_image = self._forward_sr(rgb_image, feature_image, cond, ret, **synthesis_kwargs) | |
rgb_image = rgb_image.clamp(-1,1) | |
sr_image = sr_image.clamp(-1,1) | |
ret.update({'image_raw': rgb_image, 'image_depth': depth_image, 'image': sr_image, 'image_feature': feature_image[:, 3:], 'plane': planes}) | |
return ret | |