Spaces:
Build error
Build error
Commit
·
b32806b
1
Parent(s):
28e8293
init project
Browse files- Dockerfile +42 -0
- README.md +190 -0
- inference/app_mimictalk.py +365 -0
- inference/app_real3dportrait.py +247 -0
- inference/edit_secc.py +147 -0
- inference/infer_utils.py +154 -0
- inference/mimictalk_infer.py +357 -0
- inference/real3d_infer.py +667 -0
- inference/real3dportrait_demo.ipynb +287 -0
- inference/train_mimictalk_on_a_video.py +608 -0
- install_guide.md +27 -0
- requirements.txt +79 -0
Dockerfile
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
2 |
+
|
3 |
+
# Install system dependencies
|
4 |
+
RUN apt-get update && apt-get install -y \
|
5 |
+
git \
|
6 |
+
wget \
|
7 |
+
ffmpeg \
|
8 |
+
libsndfile1 \
|
9 |
+
python3-pip \
|
10 |
+
&& rm -rf /var/lib/apt/lists/*
|
11 |
+
|
12 |
+
# Install Miniconda
|
13 |
+
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh \
|
14 |
+
&& bash miniconda.sh -b -p /opt/conda \
|
15 |
+
&& rm miniconda.sh
|
16 |
+
|
17 |
+
# Add conda to path
|
18 |
+
ENV PATH=/opt/conda/bin:$PATH
|
19 |
+
|
20 |
+
# Create and activate conda environment
|
21 |
+
RUN conda create -n mimictalk python=3.9 -y
|
22 |
+
SHELL ["conda", "run", "-n", "mimictalk", "/bin/bash", "-c"]
|
23 |
+
|
24 |
+
# Install PyTorch and other dependencies
|
25 |
+
RUN pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 \
|
26 |
+
&& pip install cython openmim==0.3.9 \
|
27 |
+
&& mim install mmcv==2.1.0 \
|
28 |
+
&& pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
29 |
+
|
30 |
+
# Install requirements
|
31 |
+
COPY requirements.txt /tmp/
|
32 |
+
RUN pip install -r /tmp/requirements.txt
|
33 |
+
|
34 |
+
# Download pre-trained models
|
35 |
+
RUN pip install huggingface_hub \
|
36 |
+
&& python -c "from huggingface_hub import snapshot_download; snapshot_download('mrbear1024/demo-model', local_dir='/app/models')"
|
37 |
+
|
38 |
+
# Set working directory
|
39 |
+
WORKDIR /app
|
40 |
+
|
41 |
+
# Set default command to activate conda env
|
42 |
+
CMD ["conda", "run", "--no-capture-output", "-n", "mimictalk", "python", "inference/app_mimictalk.py"]
|
README.md
CHANGED
@@ -5,6 +5,196 @@ colorFrom: blue
|
|
5 |
colorTo: pink
|
6 |
sdk: docker
|
7 |
pinned: false
|
|
|
|
|
|
|
|
|
8 |
---
|
9 |
|
10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
colorTo: pink
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
+
app_file: inference/app_mimictalk.py
|
9 |
+
models:
|
10 |
+
- mrbear1024/demo-model
|
11 |
+
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
15 |
+
|
16 |
+
|
17 |
+
# MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes | NeurIPS 2024
|
18 |
+
[](https://arxiv.org/abs/2401.08503)| [](https://github.com/yerfor/MimicTalk) | [English Readme](./README.md)
|
20 |
+
|
21 |
+
这个仓库是MimicTalk的官方PyTorch实现, 用于实现特定说话人的高表现力的虚拟人视频合成。该仓库代码基于我们先前的工作[Real3D-Portrait](https://github.com/yerfor/Real3DPortrait) (ICLR 2024),即基于NeRF的one-shot说话人合成,这让Mimictalk的训练加速且效果增强。您可以访问我们的[项目页面](https://mimictalk.github.io/)以观看Demo视频, 阅读我们的[论文](https://arxiv.org/abs/2410.06734)以了解技术细节。
|
22 |
+
|
23 |
+
<p align="center">
|
24 |
+
<br>
|
25 |
+
<img src="assets/mimictalk.png" width="100%"/>
|
26 |
+
<br>
|
27 |
+
</p>
|
28 |
+
|
29 |
+
# 快速上手!
|
30 |
+
## 安装环境
|
31 |
+
请参照[环境配置文档](docs/prepare_env/install_guide-zh.md),配置Conda环境`mimictalk`
|
32 |
+
## 下载预训练与第三方模型
|
33 |
+
### 3DMM BFM模型
|
34 |
+
下载3DMM BFM模型:[Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) 提取码: m9q5
|
35 |
+
|
36 |
+
|
37 |
+
下载完成后,放置全部的文件到`deep_3drecon/BFM`里,文件结构如下:
|
38 |
+
```
|
39 |
+
deep_3drecon/BFM/
|
40 |
+
├── 01_MorphableModel.mat
|
41 |
+
├── BFM_exp_idx.mat
|
42 |
+
├── BFM_front_idx.mat
|
43 |
+
├── BFM_model_front.mat
|
44 |
+
├── Exp_Pca.bin
|
45 |
+
├── facemodel_info.mat
|
46 |
+
├── index_mp468_from_mesh35709.npy
|
47 |
+
├── mediapipe_in_bfm53201.npy
|
48 |
+
└── std_exp.txt
|
49 |
+
```
|
50 |
+
|
51 |
+
### 预训练模型
|
52 |
+
下载预训练的MimicTalk相关Checkpoints:[Google Drive](https://drive.google.com/drive/folders/1Kc6ueDO9HFDN3BhtJCEKNCZtyKHSktaA?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1nQKyGV5JB6rJtda7qsThUg?pwd=mimi) 提取码: mimi
|
53 |
+
|
54 |
+
下载完成后,放置全部的文件到`checkpoints`与`checkpoints_mimictalk`里并解压,文件结构如下:
|
55 |
+
```
|
56 |
+
checkpoints/
|
57 |
+
├── mimictalk_orig
|
58 |
+
│ └── os_secc2plane_torso
|
59 |
+
│ ├── config.yaml
|
60 |
+
│ └── model_ckpt_steps_100000.ckpt
|
61 |
+
|-- 240112_icl_audio2secc_vox2_cmlr
|
62 |
+
│ ├── config.yaml
|
63 |
+
│ └── model_ckpt_steps_1856000.ckpt
|
64 |
+
└── pretrained_ckpts
|
65 |
+
└── mit_b0.pth
|
66 |
+
|
67 |
+
checkpoints_mimictalk/
|
68 |
+
└── German_20s
|
69 |
+
├── config.yaml
|
70 |
+
└── model_ckpt_steps_10000.ckpt
|
71 |
+
```
|
72 |
+
|
73 |
+
## MimicTalk训练与推理的最简命令
|
74 |
+
```
|
75 |
+
python inference/train_mimictalk_on_a_video.py # train the model, this may take 10 minutes for 2,000 steps
|
76 |
+
python inference/mimictalk_infer.py # infer the model
|
77 |
+
```
|
78 |
+
|
79 |
+
|
80 |
+
# 训练与推理细节
|
81 |
+
我们目前提供了**命令行(CLI)**与**Gradio WebUI**推理方式。音频驱动推理的人像信息来自于`torso_ckpt`,因此需要至少再提供`driving audio`用于推理。另外,可以提供`style video`让模型能够预测与该视频风格一致的说话人动作。
|
82 |
+
|
83 |
+
首先,切换至项目根目录并启用Conda环境:
|
84 |
+
```bash
|
85 |
+
cd <Real3DPortraitRoot>
|
86 |
+
conda activate mimictalk
|
87 |
+
export PYTHONPATH=./
|
88 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
89 |
+
```
|
90 |
+
|
91 |
+
## Gradio WebUI推理
|
92 |
+
启动Gradio WebUI,按照提示上传素材,点击`Training`按钮进行训练;训练完成后点击`Generate`按钮即可推理:
|
93 |
+
```bash
|
94 |
+
python inference/app_mimictalk.py
|
95 |
+
```
|
96 |
+
|
97 |
+
## 使用Transformers Pipeline(Hugging Face Space)
|
98 |
+
在Hugging Face Space中,我们提供了使用Transformers Pipeline的简化版本,可以直接从Hugging Face Hub下载模型文件:
|
99 |
+
|
100 |
+
```bash
|
101 |
+
python app.py --model_id mrbear1024/demo-model
|
102 |
+
```
|
103 |
+
|
104 |
+
您也可以通过环境变量设置模型来源:
|
105 |
+
```bash
|
106 |
+
export MODEL_ID=mrbear1024/demo-model
|
107 |
+
python app.py
|
108 |
+
```
|
109 |
+
|
110 |
+
### 模型路径说明
|
111 |
+
在Hugging Face Space中,模型文件默认会被下载到`/app/models`目录。您可以通过以下方式自定义模型路径:
|
112 |
+
|
113 |
+
```bash
|
114 |
+
# 使用命令行参数
|
115 |
+
python app.py --model_dir /path/to/models
|
116 |
+
|
117 |
+
# 使用环境变量
|
118 |
+
export MODEL_DIR=/path/to/models
|
119 |
+
python app.py
|
120 |
+
```
|
121 |
+
|
122 |
+
如果您在本地运行,模型文件会被下载到临时目录中。如果您想指定自定义路径,可以使用上述方法。
|
123 |
+
|
124 |
+
## 命令行特定说话人训练
|
125 |
+
|
126 |
+
需要至少提供`source video`,训练指令:
|
127 |
+
```bash
|
128 |
+
python inference/train_mimictalk_on_a_video.py \
|
129 |
+
--video_id <PATH_TO_SOURCE_VIDEO> \
|
130 |
+
--max_updates <UPDATES_NUMBER> \
|
131 |
+
--work_dir <PATH_TO_SAVING_CKPT>
|
132 |
+
```
|
133 |
+
|
134 |
+
一些可选参数注释:
|
135 |
+
|
136 |
+
- `--torso_ckpt` 预训练的Real3D-Portrait模型
|
137 |
+
- `--max_updates` 训练更新次数
|
138 |
+
- `--batch_size` 训练的batch size: `1` 需要约8GB显存; `2`需要约15GB��存
|
139 |
+
- `--lr_triplane` triplane的学习率:对于视频输入, 应为0.1; 对于图片输入,应为0.001
|
140 |
+
- `--work_dir` 未指定时,将默认存储在`checkpoints_mimictalk/`中
|
141 |
+
|
142 |
+
指令示例:
|
143 |
+
```bash
|
144 |
+
python inference/train_mimictalk_on_a_video.py \
|
145 |
+
--video_id data/raw/videos/German_20s.mp4 \
|
146 |
+
--max_updates 2000 \
|
147 |
+
--work_dir checkpoints_mimictalk/German_20s
|
148 |
+
```
|
149 |
+
|
150 |
+
## 命令行推理
|
151 |
+
|
152 |
+
需要至少提供`driving audio`,可选提供`driving style`,推理指令:
|
153 |
+
```bash
|
154 |
+
python inference/mimictalk_infer.py \
|
155 |
+
--drv_aud <PATH_TO_AUDIO> \
|
156 |
+
--drv_style <PATH_TO_STYLE_VIDEO, OPTIONAL> \
|
157 |
+
--drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
|
158 |
+
--bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
|
159 |
+
--out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
|
160 |
+
```
|
161 |
+
|
162 |
+
一些可选参数注释:
|
163 |
+
- `--drv_pose` 指定时提供了运动pose信息,不指定则为静态运动
|
164 |
+
- `--bg_img` 指定时提供了背景信息,不指定则为source image提取的背景
|
165 |
+
- `--mouth_amp` 嘴部张幅参数,值越大张幅越大
|
166 |
+
- `--map_to_init_pose` 值为`True`时,首帧的pose将被映射到source pose,后续帧也作相同变换
|
167 |
+
- `--temperature` 代表audio2motion的采样温度,值越大结果越多样,但同时精确度越低
|
168 |
+
- `--out_name` 不指定时,结果将保存在`infer_out/tmp/`中
|
169 |
+
- `--out_mode` 值为`final`时,只输出说话人视频;值为`concat_debug`时,同时输出一些可视化的中间结果
|
170 |
+
|
171 |
+
推理命令例子:
|
172 |
+
```bash
|
173 |
+
python inference/mimictalk_infer.py \
|
174 |
+
--drv_aud data/raw/examples/Obama_5s.wav \
|
175 |
+
--drv_pose data/raw/examples/German_20s.mp4 \
|
176 |
+
--drv_style data/raw/examples/German_20s.mp4 \
|
177 |
+
--bg_img data/raw/examples/bg.png \
|
178 |
+
--out_name output.mp4 \
|
179 |
+
--out_mode final
|
180 |
+
```
|
181 |
+
|
182 |
+
# 声明
|
183 |
+
任何组织或个人未经本人同意,不得使用本文提及的任何技术生成他人说话的视频,包括但不限于政府领导人、政界人士、社会名流等。如不遵守本条款,则可能违反版权法。
|
184 |
+
|
185 |
+
# 引用我们
|
186 |
+
如果这个仓库对你有帮助,请考虑引用我们的工作:
|
187 |
+
```
|
188 |
+
@inproceedings{ye2024mimicktalk,
|
189 |
+
author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiangwei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and Zhang, Chen and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
|
190 |
+
title = {MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes},
|
191 |
+
journal = {NeurIPS},
|
192 |
+
year = {2024},
|
193 |
+
}
|
194 |
+
@inproceedings{ye2024real3d,
|
195 |
+
title = {Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
|
196 |
+
author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
|
197 |
+
journal = {ICLR},
|
198 |
+
year={2024}
|
199 |
+
}
|
200 |
+
```
|
inference/app_mimictalk.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
sys.path.append('./')
|
3 |
+
import argparse
|
4 |
+
import traceback
|
5 |
+
import gradio as gr
|
6 |
+
# from inference.real3d_infer import GeneFace2Infer
|
7 |
+
from inference.mimictalk_infer import AdaptGeneFace2Infer
|
8 |
+
from inference.train_mimictalk_on_a_video import LoRATrainer
|
9 |
+
from utils.commons.hparams import hparams
|
10 |
+
|
11 |
+
class Trainer():
|
12 |
+
def __init__(self):
|
13 |
+
pass
|
14 |
+
|
15 |
+
def train_once_args(self, *args, **kargs):
|
16 |
+
assert len(kargs) == 0
|
17 |
+
args = [
|
18 |
+
'', # head_ckpt
|
19 |
+
'checkpoints/mimictalk_orig/os_secc2plane_torso/', # torso_ckpt
|
20 |
+
args[0],
|
21 |
+
'',
|
22 |
+
10000,
|
23 |
+
1,
|
24 |
+
False,
|
25 |
+
0.001,
|
26 |
+
0.005,
|
27 |
+
0.2,
|
28 |
+
'secc2plane_sr',
|
29 |
+
2,
|
30 |
+
]
|
31 |
+
keys = [
|
32 |
+
'head_ckpt',
|
33 |
+
'torso_ckpt',
|
34 |
+
'video_id',
|
35 |
+
'work_dir',
|
36 |
+
'max_updates',
|
37 |
+
'batch_size',
|
38 |
+
'test',
|
39 |
+
'lr',
|
40 |
+
'lr_triplane',
|
41 |
+
'lambda_lpips',
|
42 |
+
'lora_mode',
|
43 |
+
'lora_r',
|
44 |
+
]
|
45 |
+
inp = {}
|
46 |
+
info = ""
|
47 |
+
|
48 |
+
try: # try to catch errors and jump to return
|
49 |
+
for key_index in range(len(keys)):
|
50 |
+
key = keys[key_index]
|
51 |
+
inp[key] = args[key_index]
|
52 |
+
if '_name' in key:
|
53 |
+
inp[key] = inp[key] if inp[key] is not None else ''
|
54 |
+
|
55 |
+
if inp['video_id'] == '':
|
56 |
+
info = "Input Error: Source video is REQUIRED!"
|
57 |
+
raise ValueError
|
58 |
+
|
59 |
+
inp['out_name'] = ''
|
60 |
+
inp['seed'] = 42
|
61 |
+
|
62 |
+
if inp['work_dir'] is None or inp['work_dir'] == '':
|
63 |
+
video_id = os.path.basename(inp['video_id'])[:-4] if inp['video_id'].endswith((".mp4", ".png", ".jpg", ".jpeg")) else inp['video_id']
|
64 |
+
inp['work_dir'] = f'checkpoints_mimictalk/{video_id}'
|
65 |
+
try:
|
66 |
+
trainer = LoRATrainer(inp)
|
67 |
+
trainer.training_loop(inp)
|
68 |
+
except Exception as e:
|
69 |
+
content = f"{e}"
|
70 |
+
info = f"Training ERROR: {content}"
|
71 |
+
traceback.print_exc()
|
72 |
+
raise ValueError
|
73 |
+
except Exception as e:
|
74 |
+
if info == "": # unexpected errors
|
75 |
+
content = f"{e}"
|
76 |
+
info = f"WebUI ERROR: {content}"
|
77 |
+
traceback.print_exc()
|
78 |
+
|
79 |
+
|
80 |
+
# output part
|
81 |
+
if len(info) > 0 : # there is errors
|
82 |
+
print(info)
|
83 |
+
info_gr = gr.update(visible=True, value=info)
|
84 |
+
else: # no errors
|
85 |
+
info_gr = gr.update(visible=False, value=info)
|
86 |
+
|
87 |
+
torso_model_dir = gr.FileExplorer(glob="checkpoints_mimictalk/**/*.ckpt", value=inp['work_dir'], file_count='single', label='mimictalk model ckpt path or directory')
|
88 |
+
|
89 |
+
|
90 |
+
return info_gr, torso_model_dir
|
91 |
+
|
92 |
+
class Inferer(AdaptGeneFace2Infer):
|
93 |
+
def infer_once_args(self, *args, **kargs):
|
94 |
+
assert len(kargs) == 0
|
95 |
+
keys = [
|
96 |
+
# 'src_image_name',
|
97 |
+
'drv_audio_name',
|
98 |
+
'drv_pose_name',
|
99 |
+
'drv_talking_style_name',
|
100 |
+
'bg_image_name',
|
101 |
+
'blink_mode',
|
102 |
+
'temperature',
|
103 |
+
'cfg_scale',
|
104 |
+
'out_mode',
|
105 |
+
'map_to_init_pose',
|
106 |
+
'low_memory_usage',
|
107 |
+
'hold_eye_opened',
|
108 |
+
'a2m_ckpt',
|
109 |
+
# 'head_ckpt',
|
110 |
+
'torso_ckpt',
|
111 |
+
'min_face_area_percent',
|
112 |
+
]
|
113 |
+
inp = {}
|
114 |
+
out_name = None
|
115 |
+
info = ""
|
116 |
+
|
117 |
+
try: # try to catch errors and jump to return
|
118 |
+
for key_index in range(len(keys)):
|
119 |
+
key = keys[key_index]
|
120 |
+
inp[key] = args[key_index]
|
121 |
+
if '_name' in key:
|
122 |
+
inp[key] = inp[key] if inp[key] is not None else ''
|
123 |
+
|
124 |
+
inp['head_ckpt'] = ''
|
125 |
+
|
126 |
+
# if inp['src_image_name'] == '':
|
127 |
+
# info = "Input Error: Source image is REQUIRED!"
|
128 |
+
# raise ValueError
|
129 |
+
if inp['drv_audio_name'] == '' and inp['drv_pose_name'] == '':
|
130 |
+
info = "Input Error: At least one of driving audio or video is REQUIRED!"
|
131 |
+
raise ValueError
|
132 |
+
|
133 |
+
|
134 |
+
if inp['drv_audio_name'] == '' and inp['drv_pose_name'] != '':
|
135 |
+
inp['drv_audio_name'] = inp['drv_pose_name']
|
136 |
+
print("No audio input, we use driving pose video for video driving")
|
137 |
+
|
138 |
+
if inp['drv_pose_name'] == '':
|
139 |
+
inp['drv_pose_name'] = 'static'
|
140 |
+
|
141 |
+
reload_flag = False
|
142 |
+
if inp['a2m_ckpt'] != self.audio2secc_dir:
|
143 |
+
print("Changes of a2m_ckpt detected, reloading model")
|
144 |
+
reload_flag = True
|
145 |
+
if inp['head_ckpt'] != self.head_model_dir:
|
146 |
+
print("Changes of head_ckpt detected, reloading model")
|
147 |
+
reload_flag = True
|
148 |
+
if inp['torso_ckpt'] != self.torso_model_dir:
|
149 |
+
print("Changes of torso_ckpt detected, reloading model")
|
150 |
+
reload_flag = True
|
151 |
+
|
152 |
+
inp['out_name'] = ''
|
153 |
+
inp['seed'] = 42
|
154 |
+
inp['denoising_steps'] = 20
|
155 |
+
print(f"infer inputs : {inp}")
|
156 |
+
|
157 |
+
try:
|
158 |
+
if reload_flag:
|
159 |
+
self.__init__(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp, device=self.device)
|
160 |
+
except Exception as e:
|
161 |
+
content = f"{e}"
|
162 |
+
info = f"Reload ERROR: {content}"
|
163 |
+
traceback.print_exc()
|
164 |
+
raise ValueError
|
165 |
+
try:
|
166 |
+
out_name = self.infer_once(inp)
|
167 |
+
except Exception as e:
|
168 |
+
content = f"{e}"
|
169 |
+
info = f"Inference ERROR: {content}"
|
170 |
+
traceback.print_exc()
|
171 |
+
raise ValueError
|
172 |
+
except Exception as e:
|
173 |
+
if info == "": # unexpected errors
|
174 |
+
content = f"{e}"
|
175 |
+
info = f"WebUI ERROR: {content}"
|
176 |
+
traceback.print_exc()
|
177 |
+
|
178 |
+
# output part
|
179 |
+
if len(info) > 0 : # there is errors
|
180 |
+
print(info)
|
181 |
+
info_gr = gr.update(visible=True, value=info)
|
182 |
+
else: # no errors
|
183 |
+
info_gr = gr.update(visible=False, value=info)
|
184 |
+
if out_name is not None and len(out_name) > 0 and os.path.exists(out_name): # good output
|
185 |
+
print(f"Succefully generated in {out_name}")
|
186 |
+
video_gr = gr.update(visible=True, value=out_name)
|
187 |
+
else:
|
188 |
+
print(out_name)
|
189 |
+
print(os.path.exists(out_name))
|
190 |
+
print(f"Failed to generate")
|
191 |
+
video_gr = gr.update(visible=True, value=out_name)
|
192 |
+
|
193 |
+
return video_gr, info_gr
|
194 |
+
|
195 |
+
def toggle_audio_file(choice):
|
196 |
+
if choice == False:
|
197 |
+
return gr.update(visible=True), gr.update(visible=False)
|
198 |
+
else:
|
199 |
+
return gr.update(visible=False), gr.update(visible=True)
|
200 |
+
|
201 |
+
def ref_video_fn(path_of_ref_video):
|
202 |
+
if path_of_ref_video is not None:
|
203 |
+
return gr.update(value=True)
|
204 |
+
else:
|
205 |
+
return gr.update(value=False)
|
206 |
+
|
207 |
+
def mimictalk_demo(
|
208 |
+
audio2secc_dir,
|
209 |
+
head_model_dir,
|
210 |
+
torso_model_dir,
|
211 |
+
device = 'cuda',
|
212 |
+
warpfn = None,
|
213 |
+
):
|
214 |
+
|
215 |
+
sep_line = "-" * 40
|
216 |
+
|
217 |
+
infer_obj = Inferer(
|
218 |
+
audio2secc_dir=audio2secc_dir,
|
219 |
+
head_model_dir=head_model_dir,
|
220 |
+
torso_model_dir=torso_model_dir,
|
221 |
+
device=device,
|
222 |
+
)
|
223 |
+
train_obj = Trainer()
|
224 |
+
|
225 |
+
print(sep_line)
|
226 |
+
print("Model loading is finished.")
|
227 |
+
print(sep_line)
|
228 |
+
with gr.Blocks(analytics_enabled=False) as real3dportrait_interface:
|
229 |
+
# gr.Markdown("\
|
230 |
+
# <div align='center'> <h2> MimicTalk: Mimicking a personalized and expressive 3D talking face in minutes (NIPS 2024) </span> </h2> \
|
231 |
+
# <a style='font-size:18px;color: #a0a0a0' href=''>Arxiv</a> \
|
232 |
+
# <a style='font-size:18px;color: #a0a0a0' href='https://mimictalk.github.io/'>Homepage</a> \
|
233 |
+
# <a style='font-size:18px;color: #a0a0a0' href='https://github.com/yerfor/MimicTalk/'> Github </div>")
|
234 |
+
|
235 |
+
gr.Markdown("\
|
236 |
+
<div align='center'> <h2> MimicTalk: Mimicking a personalized and expressive 3D talking face in minutes (NIPS 2024) </span> </h2> \
|
237 |
+
<a style='font-size:18px;color: #a0a0a0' href='https://mimictalk.github.io/'>Homepage</a> ")
|
238 |
+
|
239 |
+
sources = None
|
240 |
+
with gr.Row():
|
241 |
+
with gr.Column(variant='panel'):
|
242 |
+
with gr.Tabs(elem_id="source_image"):
|
243 |
+
with gr.TabItem('Upload Training Video'):
|
244 |
+
with gr.Row():
|
245 |
+
src_video_name = gr.Video(label="Source video (required for training)", sources=sources, value="data/raw/videos/German_20s.mp4")
|
246 |
+
# src_video_name = gr.Image(label="Source video (required for training)", sources=sources, type="filepath", value="data/raw/videos/German_20s.mp4")
|
247 |
+
with gr.Tabs(elem_id="driven_audio"):
|
248 |
+
with gr.TabItem('Upload Driving Audio'):
|
249 |
+
with gr.Column(variant='panel'):
|
250 |
+
drv_audio_name = gr.Audio(label="Input audio (required for inference)", sources=sources, type="filepath", value="data/raw/examples/80_vs_60_10s.wav")
|
251 |
+
with gr.Tabs(elem_id="driven_style"):
|
252 |
+
with gr.TabItem('Upload Style Prompt'):
|
253 |
+
with gr.Column(variant='panel'):
|
254 |
+
drv_style_name = gr.Video(label="Driven Style (optional for inference)", sources=sources, value="data/raw/videos/German_20s.mp4")
|
255 |
+
with gr.Tabs(elem_id="driven_pose"):
|
256 |
+
with gr.TabItem('Upload Driving Pose'):
|
257 |
+
with gr.Column(variant='panel'):
|
258 |
+
drv_pose_name = gr.Video(label="Driven Pose (optional for inference)", sources=sources, value="data/raw/videos/German_20s.mp4")
|
259 |
+
with gr.Tabs(elem_id="bg_image"):
|
260 |
+
with gr.TabItem('Upload Background Image'):
|
261 |
+
with gr.Row():
|
262 |
+
bg_image_name = gr.Image(label="Background image (optional for inference)", sources=sources, type="filepath", value=None)
|
263 |
+
|
264 |
+
|
265 |
+
with gr.Column(variant='panel'):
|
266 |
+
with gr.Tabs(elem_id="checkbox"):
|
267 |
+
with gr.TabItem('General Settings'):
|
268 |
+
with gr.Column(variant='panel'):
|
269 |
+
|
270 |
+
blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
|
271 |
+
min_face_area_percent = gr.Slider(minimum=0.15, maximum=0.5, step=0.01, label="min_face_area_percent", value=0.2, info='The minimum face area percent in the output frame, to prevent bad cases caused by a too small face.',)
|
272 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
|
273 |
+
cfg_scale = gr.Slider(minimum=1.0, maximum=3.0, step=0.025, label="talking style cfg_scale", value=1.5, info='higher -> encourage the generated motion more coherent to talking style',)
|
274 |
+
out_mode = gr.Radio(['final', 'concat_debug'], value='concat_debug', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
|
275 |
+
low_memory_usage = gr.Checkbox(label="Low Memory Usage Mode: save memory at the expense of lower inference speed. Useful when running a low audio (minutes-long).", value=False)
|
276 |
+
map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose", value=True)
|
277 |
+
hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
|
278 |
+
|
279 |
+
train_submit = gr.Button('Train', elem_id="train", variant='primary')
|
280 |
+
infer_submit = gr.Button('Generate', elem_id="generate", variant='primary')
|
281 |
+
|
282 |
+
with gr.Tabs(elem_id="genearted_video"):
|
283 |
+
info_box = gr.Textbox(label="Error", interactive=False, visible=False)
|
284 |
+
gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
|
285 |
+
with gr.Column(variant='panel'):
|
286 |
+
with gr.Tabs(elem_id="checkbox"):
|
287 |
+
with gr.TabItem('Checkpoints'):
|
288 |
+
with gr.Column(variant='panel'):
|
289 |
+
ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
|
290 |
+
audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
|
291 |
+
# head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
|
292 |
+
torso_model_dir = gr.FileExplorer(glob="checkpoints_mimictalk/**/*.ckpt", value=torso_model_dir, file_count='single', label='mimictalk model ckpt path or directory')
|
293 |
+
# audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
|
294 |
+
# head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
|
295 |
+
# torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
|
296 |
+
|
297 |
+
|
298 |
+
fn = infer_obj.infer_once_args
|
299 |
+
if warpfn:
|
300 |
+
fn = warpfn(fn)
|
301 |
+
infer_submit.click(
|
302 |
+
fn=fn,
|
303 |
+
inputs=[
|
304 |
+
drv_audio_name,
|
305 |
+
drv_pose_name,
|
306 |
+
drv_style_name,
|
307 |
+
bg_image_name,
|
308 |
+
blink_mode,
|
309 |
+
temperature,
|
310 |
+
cfg_scale,
|
311 |
+
out_mode,
|
312 |
+
map_to_init_pose,
|
313 |
+
low_memory_usage,
|
314 |
+
hold_eye_opened,
|
315 |
+
audio2secc_dir,
|
316 |
+
# head_model_dir,
|
317 |
+
torso_model_dir,
|
318 |
+
min_face_area_percent,
|
319 |
+
],
|
320 |
+
outputs=[
|
321 |
+
gen_video,
|
322 |
+
info_box,
|
323 |
+
],
|
324 |
+
)
|
325 |
+
|
326 |
+
fn_train = train_obj.train_once_args
|
327 |
+
|
328 |
+
train_submit.click(
|
329 |
+
fn=fn_train,
|
330 |
+
inputs=[
|
331 |
+
src_video_name,
|
332 |
+
|
333 |
+
],
|
334 |
+
outputs=[
|
335 |
+
# gen_video,
|
336 |
+
info_box,
|
337 |
+
torso_model_dir,
|
338 |
+
],
|
339 |
+
)
|
340 |
+
|
341 |
+
print(sep_line)
|
342 |
+
print("Gradio page is constructed.")
|
343 |
+
print(sep_line)
|
344 |
+
|
345 |
+
return real3dportrait_interface
|
346 |
+
|
347 |
+
if __name__ == "__main__":
|
348 |
+
parser = argparse.ArgumentParser()
|
349 |
+
parser.add_argument("--a2m_ckpt", default='checkpoints/240112_icl_audio2secc_vox2_cmlr') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
|
350 |
+
parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
|
351 |
+
parser.add_argument("--torso_ckpt", default='checkpoints_mimictalk/German_20s/model_ckpt_steps_10000.ckpt')
|
352 |
+
parser.add_argument("--port", type=int, default=None)
|
353 |
+
parser.add_argument("--server", type=str, default='127.0.0.1')
|
354 |
+
parser.add_argument("--share", action='store_true', dest='share', help='share srever to Internet')
|
355 |
+
|
356 |
+
args = parser.parse_args()
|
357 |
+
demo = mimictalk_demo(
|
358 |
+
audio2secc_dir=args.a2m_ckpt,
|
359 |
+
head_model_dir=args.head_ckpt,
|
360 |
+
torso_model_dir=args.torso_ckpt,
|
361 |
+
device='cuda:0',
|
362 |
+
warpfn=None,
|
363 |
+
)
|
364 |
+
demo.queue()
|
365 |
+
demo.launch(share=args.share, server_name=args.server, server_port=args.port)
|
inference/app_real3dportrait.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
sys.path.append('./')
|
3 |
+
import argparse
|
4 |
+
import gradio as gr
|
5 |
+
from inference.real3d_infer import GeneFace2Infer
|
6 |
+
from utils.commons.hparams import hparams
|
7 |
+
|
8 |
+
class Inferer(GeneFace2Infer):
|
9 |
+
def infer_once_args(self, *args, **kargs):
|
10 |
+
assert len(kargs) == 0
|
11 |
+
keys = [
|
12 |
+
'src_image_name',
|
13 |
+
'drv_audio_name',
|
14 |
+
'drv_pose_name',
|
15 |
+
'bg_image_name',
|
16 |
+
'blink_mode',
|
17 |
+
'temperature',
|
18 |
+
'mouth_amp',
|
19 |
+
'out_mode',
|
20 |
+
'map_to_init_pose',
|
21 |
+
'low_memory_usage',
|
22 |
+
'hold_eye_opened',
|
23 |
+
'a2m_ckpt',
|
24 |
+
'head_ckpt',
|
25 |
+
'torso_ckpt',
|
26 |
+
'min_face_area_percent',
|
27 |
+
]
|
28 |
+
inp = {}
|
29 |
+
out_name = None
|
30 |
+
info = ""
|
31 |
+
|
32 |
+
try: # try to catch errors and jump to return
|
33 |
+
for key_index in range(len(keys)):
|
34 |
+
key = keys[key_index]
|
35 |
+
inp[key] = args[key_index]
|
36 |
+
if '_name' in key:
|
37 |
+
inp[key] = inp[key] if inp[key] is not None else ''
|
38 |
+
|
39 |
+
if inp['src_image_name'] == '':
|
40 |
+
info = "Input Error: Source image is REQUIRED!"
|
41 |
+
raise ValueError
|
42 |
+
if inp['drv_audio_name'] == '' and inp['drv_pose_name'] == '':
|
43 |
+
info = "Input Error: At least one of driving audio or video is REQUIRED!"
|
44 |
+
raise ValueError
|
45 |
+
|
46 |
+
|
47 |
+
if inp['drv_audio_name'] == '' and inp['drv_pose_name'] != '':
|
48 |
+
inp['drv_audio_name'] = inp['drv_pose_name']
|
49 |
+
print("No audio input, we use driving pose video for video driving")
|
50 |
+
|
51 |
+
if inp['drv_pose_name'] == '':
|
52 |
+
inp['drv_pose_name'] = 'static'
|
53 |
+
|
54 |
+
reload_flag = False
|
55 |
+
if inp['a2m_ckpt'] != self.audio2secc_dir:
|
56 |
+
print("Changes of a2m_ckpt detected, reloading model")
|
57 |
+
reload_flag = True
|
58 |
+
if inp['head_ckpt'] != self.head_model_dir:
|
59 |
+
print("Changes of head_ckpt detected, reloading model")
|
60 |
+
reload_flag = True
|
61 |
+
if inp['torso_ckpt'] != self.torso_model_dir:
|
62 |
+
print("Changes of torso_ckpt detected, reloading model")
|
63 |
+
reload_flag = True
|
64 |
+
|
65 |
+
inp['out_name'] = ''
|
66 |
+
inp['seed'] = 42
|
67 |
+
|
68 |
+
print(f"infer inputs : {inp}")
|
69 |
+
|
70 |
+
try:
|
71 |
+
if reload_flag:
|
72 |
+
self.__init__(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp, device=self.device)
|
73 |
+
except Exception as e:
|
74 |
+
content = f"{e}"
|
75 |
+
info = f"Reload ERROR: {content}"
|
76 |
+
raise ValueError
|
77 |
+
try:
|
78 |
+
out_name = self.infer_once(inp)
|
79 |
+
except Exception as e:
|
80 |
+
content = f"{e}"
|
81 |
+
info = f"Inference ERROR: {content}"
|
82 |
+
raise ValueError
|
83 |
+
except Exception as e:
|
84 |
+
if info == "": # unexpected errors
|
85 |
+
content = f"{e}"
|
86 |
+
info = f"WebUI ERROR: {content}"
|
87 |
+
|
88 |
+
# output part
|
89 |
+
if len(info) > 0 : # there is errors
|
90 |
+
print(info)
|
91 |
+
info_gr = gr.update(visible=True, value=info)
|
92 |
+
else: # no errors
|
93 |
+
info_gr = gr.update(visible=False, value=info)
|
94 |
+
if out_name is not None and len(out_name) > 0 and os.path.exists(out_name): # good output
|
95 |
+
print(f"Succefully generated in {out_name}")
|
96 |
+
video_gr = gr.update(visible=True, value=out_name)
|
97 |
+
else:
|
98 |
+
print(f"Failed to generate")
|
99 |
+
video_gr = gr.update(visible=True, value=out_name)
|
100 |
+
|
101 |
+
return video_gr, info_gr
|
102 |
+
|
103 |
+
def toggle_audio_file(choice):
|
104 |
+
if choice == False:
|
105 |
+
return gr.update(visible=True), gr.update(visible=False)
|
106 |
+
else:
|
107 |
+
return gr.update(visible=False), gr.update(visible=True)
|
108 |
+
|
109 |
+
def ref_video_fn(path_of_ref_video):
|
110 |
+
if path_of_ref_video is not None:
|
111 |
+
return gr.update(value=True)
|
112 |
+
else:
|
113 |
+
return gr.update(value=False)
|
114 |
+
|
115 |
+
def real3dportrait_demo(
|
116 |
+
audio2secc_dir,
|
117 |
+
head_model_dir,
|
118 |
+
torso_model_dir,
|
119 |
+
device = 'cuda',
|
120 |
+
warpfn = None,
|
121 |
+
):
|
122 |
+
|
123 |
+
sep_line = "-" * 40
|
124 |
+
|
125 |
+
infer_obj = Inferer(
|
126 |
+
audio2secc_dir=audio2secc_dir,
|
127 |
+
head_model_dir=head_model_dir,
|
128 |
+
torso_model_dir=torso_model_dir,
|
129 |
+
device=device,
|
130 |
+
)
|
131 |
+
|
132 |
+
print(sep_line)
|
133 |
+
print("Model loading is finished.")
|
134 |
+
print(sep_line)
|
135 |
+
with gr.Blocks(analytics_enabled=False) as real3dportrait_interface:
|
136 |
+
gr.Markdown("\
|
137 |
+
<div align='center'> <h2> Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis (ICLR 2024 Spotlight) </span> </h2> \
|
138 |
+
<a style='font-size:18px;color: #a0a0a0' href='https://arxiv.org/pdf/2401.08503.pdf'>Arxiv</a> \
|
139 |
+
<a style='font-size:18px;color: #a0a0a0' href='https://real3dportrait.github.io/'>Homepage</a> \
|
140 |
+
<a style='font-size:18px;color: #a0a0a0' href='https://github.com/yerfor/Real3DPortrait/'> Github </div>")
|
141 |
+
|
142 |
+
sources = None
|
143 |
+
with gr.Row():
|
144 |
+
with gr.Column(variant='panel'):
|
145 |
+
with gr.Tabs(elem_id="source_image"):
|
146 |
+
with gr.TabItem('Upload image'):
|
147 |
+
with gr.Row():
|
148 |
+
src_image_name = gr.Image(label="Source image (required)", sources=sources, type="filepath", value="data/raw/examples/Macron.png")
|
149 |
+
with gr.Tabs(elem_id="driven_audio"):
|
150 |
+
with gr.TabItem('Upload audio'):
|
151 |
+
with gr.Column(variant='panel'):
|
152 |
+
drv_audio_name = gr.Audio(label="Input audio (required for audio-driven)", sources=sources, type="filepath", value="data/raw/examples/Obama_5s.wav")
|
153 |
+
with gr.Tabs(elem_id="driven_pose"):
|
154 |
+
with gr.TabItem('Upload video'):
|
155 |
+
with gr.Column(variant='panel'):
|
156 |
+
drv_pose_name = gr.Video(label="Driven Pose (required for video-driven, optional for audio-driven)", sources=sources, value="data/raw/examples/May_5s.mp4")
|
157 |
+
with gr.Tabs(elem_id="bg_image"):
|
158 |
+
with gr.TabItem('Upload image'):
|
159 |
+
with gr.Row():
|
160 |
+
bg_image_name = gr.Image(label="Background image (optional)", sources=sources, type="filepath", value="data/raw/examples/bg.png")
|
161 |
+
|
162 |
+
|
163 |
+
with gr.Column(variant='panel'):
|
164 |
+
with gr.Tabs(elem_id="checkbox"):
|
165 |
+
with gr.TabItem('General Settings'):
|
166 |
+
with gr.Column(variant='panel'):
|
167 |
+
|
168 |
+
blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
|
169 |
+
min_face_area_percent = gr.Slider(minimum=0.15, maximum=0.5, step=0.01, label="min_face_area_percent", value=0.2, info='The minimum face area percent in the output frame, to prevent bad cases caused by a too small face.',)
|
170 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
|
171 |
+
mouth_amp = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="mouth amplitude", value=0.45, info='higher -> mouth will open wider, default to be 0.4',)
|
172 |
+
out_mode = gr.Radio(['final', 'concat_debug'], value='concat_debug', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
|
173 |
+
low_memory_usage = gr.Checkbox(label="Low Memory Usage Mode: save memory at the expense of lower inference speed. Useful when running a low audio (minutes-long).", value=False)
|
174 |
+
map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose", value=True)
|
175 |
+
hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
|
176 |
+
|
177 |
+
submit = gr.Button('Generate', elem_id="generate", variant='primary')
|
178 |
+
|
179 |
+
with gr.Tabs(elem_id="genearted_video"):
|
180 |
+
info_box = gr.Textbox(label="Error", interactive=False, visible=False)
|
181 |
+
gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
|
182 |
+
with gr.Column(variant='panel'):
|
183 |
+
with gr.Tabs(elem_id="checkbox"):
|
184 |
+
with gr.TabItem('Checkpoints'):
|
185 |
+
with gr.Column(variant='panel'):
|
186 |
+
ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
|
187 |
+
audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
|
188 |
+
head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
|
189 |
+
torso_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=torso_model_dir, file_count='single', label='torso model ckpt path or directory')
|
190 |
+
# audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
|
191 |
+
# head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
|
192 |
+
# torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
|
193 |
+
|
194 |
+
|
195 |
+
fn = infer_obj.infer_once_args
|
196 |
+
if warpfn:
|
197 |
+
fn = warpfn(fn)
|
198 |
+
submit.click(
|
199 |
+
fn=fn,
|
200 |
+
inputs=[
|
201 |
+
src_image_name,
|
202 |
+
drv_audio_name,
|
203 |
+
drv_pose_name,
|
204 |
+
bg_image_name,
|
205 |
+
blink_mode,
|
206 |
+
temperature,
|
207 |
+
mouth_amp,
|
208 |
+
out_mode,
|
209 |
+
map_to_init_pose,
|
210 |
+
low_memory_usage,
|
211 |
+
hold_eye_opened,
|
212 |
+
audio2secc_dir,
|
213 |
+
head_model_dir,
|
214 |
+
torso_model_dir,
|
215 |
+
min_face_area_percent,
|
216 |
+
],
|
217 |
+
outputs=[
|
218 |
+
gen_video,
|
219 |
+
info_box,
|
220 |
+
],
|
221 |
+
)
|
222 |
+
|
223 |
+
print(sep_line)
|
224 |
+
print("Gradio page is constructed.")
|
225 |
+
print(sep_line)
|
226 |
+
|
227 |
+
return real3dportrait_interface
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
parser = argparse.ArgumentParser()
|
231 |
+
parser.add_argument("--a2m_ckpt", type=str, default='checkpoints/240210_real3dportrait_orig/audio2secc_vae/model_ckpt_steps_400000.ckpt')
|
232 |
+
parser.add_argument("--head_ckpt", type=str, default='')
|
233 |
+
parser.add_argument("--torso_ckpt", type=str, default='checkpoints/240210_real3dportrait_orig/secc2plane_torso_orig/model_ckpt_steps_100000.ckpt')
|
234 |
+
parser.add_argument("--port", type=int, default=None)
|
235 |
+
parser.add_argument("--server", type=str, default='127.0.0.1')
|
236 |
+
parser.add_argument("--share", action='store_true', dest='share', help='share srever to Internet')
|
237 |
+
|
238 |
+
args = parser.parse_args()
|
239 |
+
demo = real3dportrait_demo(
|
240 |
+
audio2secc_dir=args.a2m_ckpt,
|
241 |
+
head_model_dir=args.head_ckpt,
|
242 |
+
torso_model_dir=args.torso_ckpt,
|
243 |
+
device='cuda:0',
|
244 |
+
warpfn=None,
|
245 |
+
)
|
246 |
+
demo.queue()
|
247 |
+
demo.launch(share=args.share, server_name=args.server, server_port=args.port)
|
inference/edit_secc.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from utils.commons.image_utils import dilate, erode
|
4 |
+
from sklearn.neighbors import NearestNeighbors
|
5 |
+
import copy
|
6 |
+
import numpy as np
|
7 |
+
from utils.commons.meters import Timer
|
8 |
+
|
9 |
+
def hold_eye_opened_for_secc(img):
|
10 |
+
img = img.permute(1,2,0).cpu().numpy()
|
11 |
+
img = ((img +1)/2*255).astype(np.uint)
|
12 |
+
face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
|
13 |
+
face_xys = np.stack(np.nonzero(face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
14 |
+
h,w = face_mask.shape
|
15 |
+
# get face and eye mask
|
16 |
+
left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
17 |
+
right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
18 |
+
left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
|
19 |
+
right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
|
20 |
+
eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
|
21 |
+
coarse_eye_mask = (~ face_mask) & eye_prior_reigon
|
22 |
+
coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
23 |
+
|
24 |
+
opened_eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
|
25 |
+
opened_eye_mask = torch.nn.functional.interpolate(torch.tensor(opened_eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[0], img.shape[1]), mode='nearest')[0].permute(1,2,0).sum(-1).bool().cpu() # [512,512,3]
|
26 |
+
coarse_opened_eye_xys = np.stack(np.nonzero(opened_eye_mask)) # [N_nonbg,2] coordinate of non-face pixels
|
27 |
+
|
28 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
|
29 |
+
dists, _ = nbrs.kneighbors(coarse_opened_eye_xys) # [512*512, 1] distance to nearest non-bg pixel
|
30 |
+
# print(dists.max())
|
31 |
+
non_opened_eye_pixs = dists > max(dists.max()*0.75, 4) # 大于这个距离的opened eye部分会被合上
|
32 |
+
non_opened_eye_pixs = non_opened_eye_pixs.reshape([-1])
|
33 |
+
opened_eye_xys_to_erode = coarse_opened_eye_xys[non_opened_eye_pixs]
|
34 |
+
opened_eye_mask[opened_eye_xys_to_erode[...,0], opened_eye_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
|
35 |
+
|
36 |
+
img[opened_eye_mask] = 0
|
37 |
+
return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
|
38 |
+
|
39 |
+
|
40 |
+
# def hold_eye_opened_for_secc(img):
|
41 |
+
# img = copy.copy(img)
|
42 |
+
# eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
|
43 |
+
# eye_mask = torch.nn.functional.interpolate(torch.tensor(eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[-2], img.shape[-1]), mode='nearest')[0].bool().to(img.device) # [3,512,512]
|
44 |
+
# img[eye_mask] = -1
|
45 |
+
# return img
|
46 |
+
|
47 |
+
def blink_eye_for_secc(img, close_eye_percent=0.5):
|
48 |
+
"""
|
49 |
+
secc_img: [3,h,w], tensor, -1~1
|
50 |
+
"""
|
51 |
+
img = img.permute(1,2,0).cpu().numpy()
|
52 |
+
img = ((img +1)/2*255).astype(np.uint)
|
53 |
+
assert close_eye_percent <= 1.0 and close_eye_percent >= 0.
|
54 |
+
if close_eye_percent == 0: return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
|
55 |
+
img = copy.deepcopy(img)
|
56 |
+
face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
|
57 |
+
h,w = face_mask.shape
|
58 |
+
|
59 |
+
# get face and eye mask
|
60 |
+
left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
61 |
+
right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
62 |
+
left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
|
63 |
+
right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
|
64 |
+
eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
|
65 |
+
coarse_eye_mask = (~ face_mask) & eye_prior_reigon
|
66 |
+
coarse_left_eye_mask = (~ face_mask) & left_eye_prior_reigon
|
67 |
+
coarse_right_eye_mask = (~ face_mask) & right_eye_prior_reigon
|
68 |
+
coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
69 |
+
min_h = coarse_eye_xys[:, 0].min()
|
70 |
+
max_h = coarse_eye_xys[:, 0].max()
|
71 |
+
coarse_left_eye_xys = np.stack(np.nonzero(coarse_left_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
72 |
+
left_min_w = coarse_left_eye_xys[:, 1].min()
|
73 |
+
left_max_w = coarse_left_eye_xys[:, 1].max()
|
74 |
+
coarse_right_eye_xys = np.stack(np.nonzero(coarse_right_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
75 |
+
right_min_w = coarse_right_eye_xys[:, 1].min()
|
76 |
+
right_max_w = coarse_right_eye_xys[:, 1].max()
|
77 |
+
|
78 |
+
# 尽力较少需要考虑的face_xyz,以降低KNN的损耗
|
79 |
+
left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
80 |
+
more_room = 4 # 过小会导致一些问题
|
81 |
+
left_eye_prior_reigon[min_h-more_room:max_h+more_room, left_min_w-more_room:left_max_w+more_room] = True
|
82 |
+
right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
83 |
+
right_eye_prior_reigon[min_h-more_room:max_h+more_room, right_min_w-more_room:right_max_w+more_room] = True
|
84 |
+
eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
|
85 |
+
|
86 |
+
around_eye_face_mask = face_mask & eye_prior_reigon
|
87 |
+
face_mask = around_eye_face_mask
|
88 |
+
face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
89 |
+
|
90 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
|
91 |
+
dists, _ = nbrs.kneighbors(face_xys) # [512*512, 1] distance to nearest non-bg pixel
|
92 |
+
face_pixs = dists > 5 # 只有距离最近的eye pixel大于5的才被认为是face,过小会导致一些问题
|
93 |
+
face_pixs = face_pixs.reshape([-1])
|
94 |
+
face_xys_to_erode = face_xys[~face_pixs]
|
95 |
+
face_mask[face_xys_to_erode[...,0], face_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
|
96 |
+
eye_mask = (~ face_mask) & eye_prior_reigon
|
97 |
+
|
98 |
+
h_grid = np.mgrid[0:h, 0:w][0]
|
99 |
+
eye_num_pixel_along_w_axis = eye_mask.sum(axis=0)
|
100 |
+
eye_mask_along_w_axis = eye_num_pixel_along_w_axis != 0
|
101 |
+
|
102 |
+
tmp_h_grid = h_grid.copy()
|
103 |
+
tmp_h_grid[~eye_mask] = 0
|
104 |
+
eye_mean_h_coord_along_w_axis = tmp_h_grid.sum(axis=0) / np.clip(eye_num_pixel_along_w_axis, a_min=1, a_max=h)
|
105 |
+
tmp_h_grid = h_grid.copy()
|
106 |
+
tmp_h_grid[~eye_mask] = 99999
|
107 |
+
eye_min_h_coord_along_w_axis = tmp_h_grid.min(axis=0)
|
108 |
+
tmp_h_grid = h_grid.copy()
|
109 |
+
tmp_h_grid[~eye_mask] = -99999
|
110 |
+
eye_max_h_coord_along_w_axis = tmp_h_grid.max(axis=0)
|
111 |
+
|
112 |
+
eye_low_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_min_h_coord_along_w_axis # upper eye
|
113 |
+
eye_high_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_max_h_coord_along_w_axis # lower eye
|
114 |
+
|
115 |
+
tmp_h_grid = h_grid.copy()
|
116 |
+
tmp_h_grid[~eye_mask] = 99999
|
117 |
+
upper_eye_blink_mask = tmp_h_grid <= eye_low_h_coord_along_w_axis
|
118 |
+
tmp_h_grid = h_grid.copy()
|
119 |
+
tmp_h_grid[~eye_mask] = -99999
|
120 |
+
lower_eye_blink_mask = tmp_h_grid >= eye_high_h_coord_along_w_axis
|
121 |
+
eye_blink_mask = upper_eye_blink_mask | lower_eye_blink_mask
|
122 |
+
|
123 |
+
face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
124 |
+
eye_blink_xys = np.stack(np.nonzero(eye_blink_mask)).transpose(1, 0) # [N_nonbg,hw] coordinate of non-face pixels
|
125 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(face_xys)
|
126 |
+
distances, indices = nbrs.kneighbors(eye_blink_xys)
|
127 |
+
bg_fg_xys = face_xys[indices[:, 0]]
|
128 |
+
img[eye_blink_xys[:, 0], eye_blink_xys[:, 1], :] = img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
|
129 |
+
return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
import imageio
|
134 |
+
import tqdm
|
135 |
+
img = cv2.imread("assets/cano_secc.png")
|
136 |
+
img = img / 127.5 - 1
|
137 |
+
img = torch.FloatTensor(img).permute(2, 0, 1)
|
138 |
+
fps = 25
|
139 |
+
writer = imageio.get_writer('demo_blink.mp4', fps=fps)
|
140 |
+
|
141 |
+
for i in tqdm.trange(33):
|
142 |
+
blink_percent = 0.03 * i
|
143 |
+
with Timer("Blink", True):
|
144 |
+
out_img = blink_eye_for_secc(img, blink_percent)
|
145 |
+
out_img = ((out_img.permute(1,2,0)+1)*127.5).int().numpy()
|
146 |
+
writer.append_data(out_img)
|
147 |
+
writer.close()
|
inference/infer_utils.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import importlib
|
7 |
+
import tqdm
|
8 |
+
import copy
|
9 |
+
import cv2
|
10 |
+
from scipy.spatial.transform import Rotation
|
11 |
+
|
12 |
+
|
13 |
+
def load_img_to_512_hwc_array(img_name):
|
14 |
+
img = cv2.imread(img_name)
|
15 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
16 |
+
img = cv2.resize(img, (512, 512))
|
17 |
+
return img
|
18 |
+
|
19 |
+
def load_img_to_normalized_512_bchw_tensor(img_name):
|
20 |
+
img = load_img_to_512_hwc_array(img_name)
|
21 |
+
img = ((torch.tensor(img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2) # [b,c,h,w]
|
22 |
+
return img
|
23 |
+
|
24 |
+
def mirror_index(index, len_seq):
|
25 |
+
"""
|
26 |
+
get mirror index when indexing a sequence and the index is larger than len_pose
|
27 |
+
args:
|
28 |
+
index: int
|
29 |
+
len_pose: int
|
30 |
+
return:
|
31 |
+
mirror_index: int
|
32 |
+
"""
|
33 |
+
turn = index // len_seq
|
34 |
+
res = index % len_seq
|
35 |
+
if turn % 2 == 0:
|
36 |
+
return res # forward indexing
|
37 |
+
else:
|
38 |
+
return len_seq - res - 1 # reverse indexing
|
39 |
+
|
40 |
+
def smooth_camera_sequence(camera, kernel_size=7):
|
41 |
+
"""
|
42 |
+
smooth the camera trajectory (i.e., rotation & translation)...
|
43 |
+
args:
|
44 |
+
camera: [N, 25] or [N, 16]. np.ndarray
|
45 |
+
kernel_size: int
|
46 |
+
return:
|
47 |
+
smoothed_camera: [N, 25] or [N, 16]. np.ndarray
|
48 |
+
"""
|
49 |
+
# poses: [N, 25], numpy array
|
50 |
+
N = camera.shape[0]
|
51 |
+
K = kernel_size // 2
|
52 |
+
poses = camera[:, :16].reshape([-1, 4, 4]).copy()
|
53 |
+
trans = poses[:, :3, 3].copy() # [N, 3]
|
54 |
+
rots = poses[:, :3, :3].copy() # [N, 3, 3]
|
55 |
+
|
56 |
+
for i in range(N):
|
57 |
+
start = max(0, i - K)
|
58 |
+
end = min(N, i + K + 1)
|
59 |
+
poses[i, :3, 3] = trans[start:end].mean(0)
|
60 |
+
try:
|
61 |
+
poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix()
|
62 |
+
except:
|
63 |
+
if i == 0:
|
64 |
+
poses[i, :3, :3] = rots[i]
|
65 |
+
else:
|
66 |
+
poses[i, :3, :3] = poses[i-1, :3, :3]
|
67 |
+
poses = poses.reshape([-1, 16])
|
68 |
+
camera[:, :16] = poses
|
69 |
+
return camera
|
70 |
+
|
71 |
+
def smooth_features_xd(in_tensor, kernel_size=7):
|
72 |
+
"""
|
73 |
+
smooth the feature maps
|
74 |
+
args:
|
75 |
+
in_tensor: [T, c,h,w] or [T, c1,c2,h,w]
|
76 |
+
kernel_size: int
|
77 |
+
return:
|
78 |
+
out_tensor: [T, c,h,w] or [T, c1,c2,h,w]
|
79 |
+
"""
|
80 |
+
t = in_tensor.shape[0]
|
81 |
+
ndim = in_tensor.ndim
|
82 |
+
pad = (kernel_size- 1)//2
|
83 |
+
in_tensor = torch.cat([torch.flip(in_tensor[0:pad], dims=[0]), in_tensor, torch.flip(in_tensor[t-pad:t], dims=[0])], dim=0)
|
84 |
+
if ndim == 2: # tc
|
85 |
+
_,c = in_tensor.shape
|
86 |
+
in_tensor = in_tensor.permute(1,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
|
87 |
+
elif ndim == 4: # tchw
|
88 |
+
_,c,h,w = in_tensor.shape
|
89 |
+
in_tensor = in_tensor.permute(1,2,3,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
|
90 |
+
elif ndim == 5: # tcchw, like deformation
|
91 |
+
_,c1,c2, h,w = in_tensor.shape
|
92 |
+
in_tensor = in_tensor.permute(1,2,3,4,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
|
93 |
+
else: raise NotImplementedError()
|
94 |
+
avg_kernel = 1 / kernel_size * torch.Tensor([1.]*kernel_size).reshape([1,1,kernel_size]).float().to(in_tensor.device) # [1, 1, kw]
|
95 |
+
out_tensor = F.conv1d(in_tensor, avg_kernel)
|
96 |
+
if ndim == 2: # tc
|
97 |
+
return out_tensor.reshape([c,t]).permute(1,0)
|
98 |
+
elif ndim == 4: # tchw
|
99 |
+
return out_tensor.reshape([c,h,w,t]).permute(3,0,1,2)
|
100 |
+
elif ndim == 5: # tcchw, like deformation
|
101 |
+
return out_tensor.reshape([c1,c2,h,w,t]).permute(4,0,1,2,3)
|
102 |
+
|
103 |
+
|
104 |
+
def extract_audio_motion_from_ref_video(video_name):
|
105 |
+
def save_wav16k(audio_name):
|
106 |
+
supported_types = ('.wav', '.mp3', '.mp4', '.avi')
|
107 |
+
assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
|
108 |
+
wav16k_name = audio_name[:-4] + '_16k.wav'
|
109 |
+
extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
|
110 |
+
os.system(extract_wav_cmd)
|
111 |
+
print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
|
112 |
+
return wav16k_name
|
113 |
+
|
114 |
+
def get_f0( wav16k_name):
|
115 |
+
from data_gen.utils.process_audio.extract_mel_f0 import extract_mel_from_fname, extract_f0_from_wav_and_mel
|
116 |
+
wav, mel = extract_mel_from_fname(wav16k_name)
|
117 |
+
f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
|
118 |
+
f0 = f0.reshape([-1,1])
|
119 |
+
f0 = torch.tensor(f0)
|
120 |
+
return f0
|
121 |
+
|
122 |
+
def get_hubert(wav16k_name):
|
123 |
+
from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
|
124 |
+
hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
|
125 |
+
len_mel = hubert.shape[0]
|
126 |
+
x_multiply = 8
|
127 |
+
if len_mel % x_multiply == 0:
|
128 |
+
num_to_pad = 0
|
129 |
+
else:
|
130 |
+
num_to_pad = x_multiply - len_mel % x_multiply
|
131 |
+
hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
|
132 |
+
hubert = torch.tensor(hubert)
|
133 |
+
return hubert
|
134 |
+
|
135 |
+
def get_exp(video_name):
|
136 |
+
from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
|
137 |
+
drv_motion_coeff_dict = fit_3dmm_for_a_video(video_name, save=False)
|
138 |
+
exp = torch.tensor(drv_motion_coeff_dict['exp'])
|
139 |
+
return exp
|
140 |
+
|
141 |
+
wav16k_name = save_wav16k(video_name)
|
142 |
+
f0 = get_f0(wav16k_name)
|
143 |
+
hubert = get_hubert(wav16k_name)
|
144 |
+
os.system(f"rm {wav16k_name}")
|
145 |
+
exp = get_exp(video_name)
|
146 |
+
target_length = min(len(exp), len(hubert)//2, len(f0)//2)
|
147 |
+
exp = exp[:target_length]
|
148 |
+
f0 = f0[:target_length*2]
|
149 |
+
hubert = hubert[:target_length*2]
|
150 |
+
return exp.unsqueeze(0), hubert.unsqueeze(0), f0.unsqueeze(0)
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == '__main__':
|
154 |
+
extract_audio_motion_from_ref_video('data/raw/videos/crop_0213.mp4')
|
inference/mimictalk_infer.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
用于推理 inference/train_mimictalk_on_a_video.py 得到的person-specific模型
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
# import librosa
|
9 |
+
import random
|
10 |
+
import time
|
11 |
+
import numpy as np
|
12 |
+
import importlib
|
13 |
+
import tqdm
|
14 |
+
import copy
|
15 |
+
import cv2
|
16 |
+
|
17 |
+
# common utils
|
18 |
+
from utils.commons.hparams import hparams, set_hparams
|
19 |
+
from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
|
20 |
+
from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
|
21 |
+
# 3DMM-related utils
|
22 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
23 |
+
from data_util.face3d_helper import Face3DHelper
|
24 |
+
from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
|
25 |
+
from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
|
26 |
+
from deep_3drecon.secc_renderer import SECC_Renderer
|
27 |
+
from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
|
28 |
+
# Face Parsing
|
29 |
+
from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
|
30 |
+
from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
|
31 |
+
# other inference utils
|
32 |
+
from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
|
33 |
+
from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
|
34 |
+
from inference.edit_secc import blink_eye_for_secc, hold_eye_opened_for_secc
|
35 |
+
from inference.real3d_infer import GeneFace2Infer
|
36 |
+
|
37 |
+
|
38 |
+
class AdaptGeneFace2Infer(GeneFace2Infer):
|
39 |
+
def __init__(self, audio2secc_dir, head_model_dir, torso_model_dir, device=None, **kwargs):
|
40 |
+
if device is None:
|
41 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
42 |
+
self.device = device
|
43 |
+
self.audio2secc_dir = audio2secc_dir
|
44 |
+
self.head_model_dir = head_model_dir
|
45 |
+
self.torso_model_dir = torso_model_dir
|
46 |
+
self.audio2secc_model = self.load_audio2secc(audio2secc_dir)
|
47 |
+
self.secc2video_model = self.load_secc2video(head_model_dir, torso_model_dir)
|
48 |
+
self.audio2secc_model.to(device).eval()
|
49 |
+
self.secc2video_model.to(device).eval()
|
50 |
+
self.seg_model = MediapipeSegmenter()
|
51 |
+
self.secc_renderer = SECC_Renderer(512)
|
52 |
+
self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
|
53 |
+
self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
|
54 |
+
# self.camera_selector = KNearestCameraSelector()
|
55 |
+
|
56 |
+
def load_secc2video(self, head_model_dir, torso_model_dir):
|
57 |
+
if torso_model_dir != '':
|
58 |
+
config_dir = torso_model_dir if os.path.isdir(torso_model_dir) else os.path.dirname(torso_model_dir)
|
59 |
+
set_hparams(f"{config_dir}/config.yaml", print_hparams=False)
|
60 |
+
hparams['htbsr_head_threshold'] = 1.0
|
61 |
+
self.secc2video_hparams = copy.deepcopy(hparams)
|
62 |
+
ckpt = get_last_checkpoint(torso_model_dir)[0]
|
63 |
+
lora_args = ckpt.get("lora_args", None)
|
64 |
+
from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane_Torso
|
65 |
+
model = OSAvatarSECC_Img2plane_Torso(self.secc2video_hparams, lora_args=lora_args)
|
66 |
+
load_ckpt(model, f"{torso_model_dir}", model_name='model', strict=True)
|
67 |
+
self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, model.triplane_hid_dim*model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
|
68 |
+
load_ckpt(self.learnable_triplane, f"{torso_model_dir}", model_name='learnable_triplane', strict=True)
|
69 |
+
model._last_cano_planes = self.learnable_triplane
|
70 |
+
if head_model_dir != '':
|
71 |
+
print("| Warning: Assigned --torso_ckpt which also contains head, but --head_ckpt is also assigned, skipping the --head_ckpt.")
|
72 |
+
else:
|
73 |
+
from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane
|
74 |
+
set_hparams(f"{head_model_dir}/config.yaml", print_hparams=False)
|
75 |
+
ckpt = get_last_checkpoint(head_model_dir)[0]
|
76 |
+
lora_args = ckpt.get("lora_args", None)
|
77 |
+
self.secc2video_hparams = copy.deepcopy(hparams)
|
78 |
+
model = OSAvatarSECC_Img2plane(self.secc2video_hparams, lora_args=lora_args)
|
79 |
+
load_ckpt(model, f"{head_model_dir}", model_name='model', strict=True)
|
80 |
+
self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, model.triplane_hid_dim*model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
|
81 |
+
model._last_cano_planes = self.learnable_triplane
|
82 |
+
load_ckpt(model._last_cano_planes, f"{head_model_dir}", model_name='learnable_triplane', strict=True)
|
83 |
+
self.person_ds = ckpt['person_ds']
|
84 |
+
return model
|
85 |
+
|
86 |
+
def prepare_batch_from_inp(self, inp):
|
87 |
+
"""
|
88 |
+
:param inp: {'audio_source_name': (str)}
|
89 |
+
:return: a dict that contains the condition feature of NeRF
|
90 |
+
"""
|
91 |
+
sample = {}
|
92 |
+
# Process Driving Motion
|
93 |
+
if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
|
94 |
+
self.save_wav16k(inp['drv_audio_name'])
|
95 |
+
if self.audio2secc_hparams['audio_type'] == 'hubert':
|
96 |
+
hubert = self.get_hubert(self.wav16k_name)
|
97 |
+
elif self.audio2secc_hparams['audio_type'] == 'mfcc':
|
98 |
+
hubert = self.get_mfcc(self.wav16k_name) / 100
|
99 |
+
|
100 |
+
f0 = self.get_f0(self.wav16k_name)
|
101 |
+
if f0.shape[0] > len(hubert):
|
102 |
+
f0 = f0[:len(hubert)]
|
103 |
+
else:
|
104 |
+
num_to_pad = len(hubert) - len(f0)
|
105 |
+
f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0)))
|
106 |
+
t_x = hubert.shape[0]
|
107 |
+
x_mask = torch.ones([1, t_x]).float() # mask for audio frames
|
108 |
+
y_mask = torch.ones([1, t_x//2]).float() # mask for motion/image frames
|
109 |
+
sample.update({
|
110 |
+
'hubert': torch.from_numpy(hubert).float().unsqueeze(0).cuda(),
|
111 |
+
'f0': torch.from_numpy(f0).float().reshape([1,-1]).cuda(),
|
112 |
+
'x_mask': x_mask.cuda(),
|
113 |
+
'y_mask': y_mask.cuda(),
|
114 |
+
})
|
115 |
+
sample['blink'] = torch.zeros([1, t_x, 1]).long().cuda()
|
116 |
+
sample['audio'] = sample['hubert']
|
117 |
+
sample['eye_amp'] = torch.ones([1, 1]).cuda() * 1.0
|
118 |
+
elif inp['drv_audio_name'][-4:] in ['.mp4']:
|
119 |
+
drv_motion_coeff_dict = fit_3dmm_for_a_video(inp['drv_audio_name'], save=False)
|
120 |
+
drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
|
121 |
+
t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
|
122 |
+
self.drv_motion_coeff_dict = drv_motion_coeff_dict
|
123 |
+
elif inp['drv_audio_name'][-4:] in ['.npy']:
|
124 |
+
drv_motion_coeff_dict = np.load(inp['drv_audio_name'], allow_pickle=True).tolist()
|
125 |
+
drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
|
126 |
+
t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
|
127 |
+
self.drv_motion_coeff_dict = drv_motion_coeff_dict
|
128 |
+
|
129 |
+
# Face Parsing
|
130 |
+
sample['ref_gt_img'] = self.person_ds['gt_img'].cuda()
|
131 |
+
img = self.person_ds['gt_img'].reshape([3, 512, 512]).permute(1, 2, 0)
|
132 |
+
img = (img + 1) * 127.5
|
133 |
+
img = np.ascontiguousarray(img.int().numpy()).astype(np.uint8)
|
134 |
+
segmap = self.seg_model._cal_seg_map(img)
|
135 |
+
sample['segmap'] = torch.tensor(segmap).float().unsqueeze(0).cuda()
|
136 |
+
head_img = self.seg_model._seg_out_img_with_segmap(img, segmap, mode='head')[0]
|
137 |
+
sample['ref_head_img'] = ((torch.tensor(head_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
|
138 |
+
inpaint_torso_img, _, _, _ = inpaint_torso_job(img, segmap)
|
139 |
+
sample['ref_torso_img'] = ((torch.tensor(inpaint_torso_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
|
140 |
+
|
141 |
+
if inp['bg_image_name'] == '':
|
142 |
+
bg_img = extract_background([img], [segmap], 'knn')
|
143 |
+
else:
|
144 |
+
bg_img = cv2.imread(inp['bg_image_name'])
|
145 |
+
bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
|
146 |
+
bg_img = cv2.resize(bg_img, (512,512))
|
147 |
+
sample['bg_img'] = ((torch.tensor(bg_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
|
148 |
+
|
149 |
+
# 3DMM, get identity code and camera pose
|
150 |
+
image_name = f"data/raw/val_imgs/{self.person_ds['video_id']}_img.png"
|
151 |
+
os.makedirs(os.path.dirname(image_name), exist_ok=True)
|
152 |
+
cv2.imwrite(image_name, img[:,:,::-1])
|
153 |
+
coeff_dict = fit_3dmm_for_a_image(image_name, save=False)
|
154 |
+
coeff_dict['id'] = self.person_ds['id'].reshape([1,80]).numpy()
|
155 |
+
|
156 |
+
assert coeff_dict is not None
|
157 |
+
src_id = torch.tensor(coeff_dict['id']).reshape([1,80]).cuda()
|
158 |
+
src_exp = torch.tensor(coeff_dict['exp']).reshape([1,64]).cuda()
|
159 |
+
src_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda()
|
160 |
+
src_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda()
|
161 |
+
sample['id'] = src_id.repeat([t_x//2,1])
|
162 |
+
|
163 |
+
# get the src_kp for torso model
|
164 |
+
sample['src_kp'] = self.person_ds['src_kp'].cuda().reshape([1, 68, 3]).repeat([t_x//2,1,1])[..., :2] # [B, 68, 2]
|
165 |
+
|
166 |
+
# get camera pose file
|
167 |
+
random.seed(time.time())
|
168 |
+
if inp['drv_pose_name'] in ['nearest', 'topk']:
|
169 |
+
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler': torch.tensor(coeff_dict['euler']).reshape([1,3]), 'trans': torch.tensor(coeff_dict['trans']).reshape([1,3])})
|
170 |
+
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
|
171 |
+
camera = np.concatenate([c2w.reshape([1,16]), intrinsics.reshape([1,9])], axis=-1)
|
172 |
+
coeff_names, distance_matrix = self.camera_selector.find_k_nearest(camera, k=100)
|
173 |
+
coeff_names = coeff_names[0] # squeeze
|
174 |
+
if inp['drv_pose_name'] == 'nearest':
|
175 |
+
inp['drv_pose_name'] = coeff_names[0]
|
176 |
+
else:
|
177 |
+
inp['drv_pose_name'] = random.choice(coeff_names)
|
178 |
+
# inp['drv_pose_name'] = coeff_names[0]
|
179 |
+
elif inp['drv_pose_name'] == 'random':
|
180 |
+
inp['drv_pose_name'] = self.camera_selector.random_select()
|
181 |
+
else:
|
182 |
+
inp['drv_pose_name'] = inp['drv_pose_name']
|
183 |
+
|
184 |
+
print(f"| To extract pose from {inp['drv_pose_name']}")
|
185 |
+
|
186 |
+
# extract camera pose
|
187 |
+
if inp['drv_pose_name'] == 'static':
|
188 |
+
sample['euler'] = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda().repeat([t_x//2,1]) # default static pose
|
189 |
+
sample['trans'] = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda().repeat([t_x//2,1])
|
190 |
+
else: # from file
|
191 |
+
if inp['drv_pose_name'].endswith('.mp4'):
|
192 |
+
# extract coeff from video
|
193 |
+
drv_pose_coeff_dict = fit_3dmm_for_a_video(inp['drv_pose_name'], save=False)
|
194 |
+
else:
|
195 |
+
# load from npy
|
196 |
+
drv_pose_coeff_dict = np.load(inp['drv_pose_name'], allow_pickle=True).tolist()
|
197 |
+
print(f"| Extracted pose from {inp['drv_pose_name']}")
|
198 |
+
eulers = convert_to_tensor(drv_pose_coeff_dict['euler']).reshape([-1,3]).cuda()
|
199 |
+
trans = convert_to_tensor(drv_pose_coeff_dict['trans']).reshape([-1,3]).cuda()
|
200 |
+
len_pose = len(eulers)
|
201 |
+
index_lst = [mirror_index(i, len_pose) for i in range(t_x//2)]
|
202 |
+
sample['euler'] = eulers[index_lst]
|
203 |
+
sample['trans'] = trans[index_lst]
|
204 |
+
|
205 |
+
# fix the z axis
|
206 |
+
sample['trans'][:, -1] = sample['trans'][0:1, -1].repeat([sample['trans'].shape[0]])
|
207 |
+
|
208 |
+
# mapping to the init pose
|
209 |
+
if inp.get("map_to_init_pose", 'False') == 'True':
|
210 |
+
diff_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda() - sample['euler'][0:1]
|
211 |
+
sample['euler'] = sample['euler'] + diff_euler
|
212 |
+
diff_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda() - sample['trans'][0:1]
|
213 |
+
sample['trans'] = sample['trans'] + diff_trans
|
214 |
+
|
215 |
+
# prepare camera
|
216 |
+
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':sample['euler'].cpu(), 'trans':sample['trans'].cpu()})
|
217 |
+
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
|
218 |
+
# smooth camera
|
219 |
+
camera_smo_ksize = 7
|
220 |
+
camera = np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)
|
221 |
+
camera = smooth_camera_sequence(camera, kernel_size=camera_smo_ksize) # [T, 25]
|
222 |
+
camera = torch.tensor(camera).cuda().float()
|
223 |
+
sample['camera'] = camera
|
224 |
+
|
225 |
+
return sample
|
226 |
+
|
227 |
+
@torch.no_grad()
|
228 |
+
def forward_secc2video(self, batch, inp=None):
|
229 |
+
num_frames = len(batch['drv_secc'])
|
230 |
+
camera = batch['camera']
|
231 |
+
src_kps = batch['src_kp']
|
232 |
+
drv_kps = batch['drv_kp']
|
233 |
+
cano_secc_color = batch['cano_secc']
|
234 |
+
src_secc_color = batch['src_secc']
|
235 |
+
drv_secc_colors = batch['drv_secc']
|
236 |
+
ref_img_gt = batch['ref_gt_img']
|
237 |
+
ref_img_head = batch['ref_head_img']
|
238 |
+
ref_torso_img = batch['ref_torso_img']
|
239 |
+
bg_img = batch['bg_img']
|
240 |
+
segmap = batch['segmap']
|
241 |
+
|
242 |
+
# smooth torso drv_kp
|
243 |
+
torso_smo_ksize = 7
|
244 |
+
drv_kps = smooth_features_xd(drv_kps.reshape([-1, 68*2]), kernel_size=torso_smo_ksize).reshape([-1, 68, 2])
|
245 |
+
|
246 |
+
# forward renderer
|
247 |
+
img_raw_lst = []
|
248 |
+
img_lst = []
|
249 |
+
depth_img_lst = []
|
250 |
+
with torch.no_grad():
|
251 |
+
for i in tqdm.trange(num_frames, desc="MimicTalk is rendering frames"):
|
252 |
+
kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
|
253 |
+
kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
|
254 |
+
cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
|
255 |
+
'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
|
256 |
+
'kp_s': kp_src, 'kp_d': kp_drv}
|
257 |
+
|
258 |
+
########################################################################################################
|
259 |
+
### 相比real3d_infer只修改了这行👇,即cano_triplane来自cache里的learnable_triplane,而不是img预测的plane ####
|
260 |
+
########################################################################################################
|
261 |
+
gen_output = self.secc2video_model.forward(img=None, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
|
262 |
+
|
263 |
+
img_lst.append(gen_output['image'])
|
264 |
+
img_raw_lst.append(gen_output['image_raw'])
|
265 |
+
depth_img_lst.append(gen_output['image_depth'])
|
266 |
+
|
267 |
+
# save demo video
|
268 |
+
depth_imgs = torch.cat(depth_img_lst)
|
269 |
+
imgs = torch.cat(img_lst)
|
270 |
+
imgs_raw = torch.cat(img_raw_lst)
|
271 |
+
secc_img = torch.cat([torch.nn.functional.interpolate(drv_secc_colors[i:i+1], (512,512)) for i in range(num_frames)])
|
272 |
+
|
273 |
+
if inp['out_mode'] == 'concat_debug':
|
274 |
+
secc_img = secc_img.cpu()
|
275 |
+
secc_img = ((secc_img + 1) * 127.5).permute(0, 2, 3, 1).int().numpy()
|
276 |
+
|
277 |
+
depth_img = F.interpolate(depth_imgs, (512,512)).cpu()
|
278 |
+
depth_img = depth_img.repeat([1,3,1,1])
|
279 |
+
depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
|
280 |
+
depth_img = depth_img * 2 - 1
|
281 |
+
depth_img = depth_img.clamp(-1,1)
|
282 |
+
|
283 |
+
secc_img = secc_img / 127.5 - 1
|
284 |
+
secc_img = torch.from_numpy(secc_img).permute(0, 3, 1, 2)
|
285 |
+
imgs = torch.cat([ref_img_gt.repeat([imgs.shape[0],1,1,1]).cpu(), secc_img, F.interpolate(imgs_raw, (512,512)).cpu(), depth_img, imgs.cpu()], dim=-1)
|
286 |
+
elif inp['out_mode'] == 'final':
|
287 |
+
imgs = imgs.cpu()
|
288 |
+
elif inp['out_mode'] == 'debug':
|
289 |
+
raise NotImplementedError("to do: save separate videos")
|
290 |
+
imgs = imgs.clamp(-1,1)
|
291 |
+
|
292 |
+
import imageio
|
293 |
+
import uuid
|
294 |
+
debug_name = f'{uuid.uuid1()}.mp4'
|
295 |
+
out_imgs = ((imgs.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
|
296 |
+
writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
|
297 |
+
for i in tqdm.trange(len(out_imgs), desc="Imageio is saving video"):
|
298 |
+
writer.append_data(out_imgs[i])
|
299 |
+
writer.close()
|
300 |
+
|
301 |
+
out_fname = 'infer_out/tmp/' + os.path.basename(inp['drv_pose_name'])[:-4] + '.mp4' if inp['out_name'] == '' else inp['out_name']
|
302 |
+
try:
|
303 |
+
os.makedirs(os.path.dirname(out_fname), exist_ok=True)
|
304 |
+
except: pass
|
305 |
+
if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
|
306 |
+
# os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -y -v quiet -shortest {out_fname}")
|
307 |
+
cmd = f"/usr/bin/ffmpeg -i {debug_name} -i {self.wav16k_name} -y -r 25 -ar 16000 -c:v copy -c:a libmp3lame -pix_fmt yuv420p -b:v 2000k -strict experimental -shortest {out_fname}"
|
308 |
+
os.system(cmd)
|
309 |
+
os.system(f"rm {debug_name}")
|
310 |
+
else:
|
311 |
+
ret = os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -map 0:v -map 1:a -y -v quiet -shortest {out_fname}")
|
312 |
+
if ret != 0: # 没有成功从drv_audio_name里面提取到音频, 则直接输出无音频轨道的纯视频
|
313 |
+
os.system(f"mv {debug_name} {out_fname}")
|
314 |
+
print(f"Saved at {out_fname}")
|
315 |
+
return out_fname
|
316 |
+
|
317 |
+
if __name__ == '__main__':
|
318 |
+
import argparse, glob, tqdm
|
319 |
+
parser = argparse.ArgumentParser()
|
320 |
+
parser.add_argument("--a2m_ckpt", default='checkpoints/240112_icl_audio2secc_vox2_cmlr') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
|
321 |
+
parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
|
322 |
+
parser.add_argument("--torso_ckpt", default='checkpoints_mimictalk/German_20s')
|
323 |
+
parser.add_argument("--bg_img", default='') # data/raw/val_imgs/bg3.png
|
324 |
+
parser.add_argument("--drv_aud", default='data/raw/examples/80_vs_60_10s.wav')
|
325 |
+
parser.add_argument("--drv_pose", default='data/raw/examples/German_20s.mp4') # nearest | topk | random | static | vid_name
|
326 |
+
parser.add_argument("--drv_style", default='data/raw/examples/angry.mp4') # nearest | topk | random | static | vid_name
|
327 |
+
parser.add_argument("--blink_mode", default='period') # none | period
|
328 |
+
parser.add_argument("--temperature", default=0.3, type=float) # nearest | random
|
329 |
+
parser.add_argument("--denoising_steps", default=20, type=int) # nearest | random
|
330 |
+
parser.add_argument("--cfg_scale", default=1.5, type=float) # nearest | random
|
331 |
+
parser.add_argument("--out_name", default='') # nearest | random
|
332 |
+
parser.add_argument("--out_mode", default='concat_debug') # concat_debug | debug | final
|
333 |
+
parser.add_argument("--hold_eye_opened", default='False') # concat_debug | debug | final
|
334 |
+
parser.add_argument("--map_to_init_pose", default='True') # concat_debug | debug | final
|
335 |
+
parser.add_argument("--seed", default=None, type=int) # random seed, default None to use time.time()
|
336 |
+
|
337 |
+
args = parser.parse_args()
|
338 |
+
|
339 |
+
inp = {
|
340 |
+
'a2m_ckpt': args.a2m_ckpt,
|
341 |
+
'head_ckpt': args.head_ckpt,
|
342 |
+
'torso_ckpt': args.torso_ckpt,
|
343 |
+
'bg_image_name': args.bg_img,
|
344 |
+
'drv_audio_name': args.drv_aud,
|
345 |
+
'drv_pose_name': args.drv_pose,
|
346 |
+
'drv_talking_style_name': args.drv_style,
|
347 |
+
'blink_mode': args.blink_mode,
|
348 |
+
'temperature': args.temperature,
|
349 |
+
'denoising_steps': args.denoising_steps,
|
350 |
+
'cfg_scale': args.cfg_scale,
|
351 |
+
'out_name': args.out_name,
|
352 |
+
'out_mode': args.out_mode,
|
353 |
+
'map_to_init_pose': args.map_to_init_pose,
|
354 |
+
'hold_eye_opened': args.hold_eye_opened,
|
355 |
+
'seed': args.seed,
|
356 |
+
}
|
357 |
+
AdaptGeneFace2Infer.example_run(inp)
|
inference/real3d_infer.py
ADDED
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchshow as ts
|
5 |
+
import librosa
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
import importlib
|
10 |
+
import tqdm
|
11 |
+
import copy
|
12 |
+
import cv2
|
13 |
+
import math
|
14 |
+
|
15 |
+
# common utils
|
16 |
+
from utils.commons.hparams import hparams, set_hparams
|
17 |
+
from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
|
18 |
+
from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
|
19 |
+
# 3DMM-related utils
|
20 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
21 |
+
from data_util.face3d_helper import Face3DHelper
|
22 |
+
from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
|
23 |
+
from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
|
24 |
+
from data_gen.utils.process_image.extract_lm2d import extract_lms_mediapipe_job
|
25 |
+
from data_gen.utils.process_image.fit_3dmm_landmark import index_lm68_from_lm468
|
26 |
+
from deep_3drecon.secc_renderer import SECC_Renderer
|
27 |
+
from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
|
28 |
+
# Face Parsing
|
29 |
+
from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
|
30 |
+
from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
|
31 |
+
# other inference utils
|
32 |
+
from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
|
33 |
+
from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
|
34 |
+
from inference.edit_secc import blink_eye_for_secc, hold_eye_opened_for_secc
|
35 |
+
|
36 |
+
|
37 |
+
def read_first_frame_from_a_video(vid_name):
|
38 |
+
frames = []
|
39 |
+
cap = cv2.VideoCapture(vid_name)
|
40 |
+
ret, frame_bgr = cap.read()
|
41 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
42 |
+
return frame_rgb
|
43 |
+
|
44 |
+
def analyze_weights_img(gen_output):
|
45 |
+
img_raw = gen_output['image_raw']
|
46 |
+
mask_005_to_03 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.3).repeat([1,3,1,1])
|
47 |
+
mask_005_to_05 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.5).repeat([1,3,1,1])
|
48 |
+
mask_005_to_07 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.7).repeat([1,3,1,1])
|
49 |
+
mask_005_to_09 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.9).repeat([1,3,1,1])
|
50 |
+
mask_005_to_10 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<1.0).repeat([1,3,1,1])
|
51 |
+
|
52 |
+
img_raw_005_to_03 = img_raw.clone()
|
53 |
+
img_raw_005_to_03[~mask_005_to_03] = -1
|
54 |
+
img_raw_005_to_05 = img_raw.clone()
|
55 |
+
img_raw_005_to_05[~mask_005_to_05] = -1
|
56 |
+
img_raw_005_to_07 = img_raw.clone()
|
57 |
+
img_raw_005_to_07[~mask_005_to_07] = -1
|
58 |
+
img_raw_005_to_09 = img_raw.clone()
|
59 |
+
img_raw_005_to_09[~mask_005_to_09] = -1
|
60 |
+
img_raw_005_to_10 = img_raw.clone()
|
61 |
+
img_raw_005_to_10[~mask_005_to_10] = -1
|
62 |
+
ts.save([img_raw_005_to_03[0], img_raw_005_to_05[0], img_raw_005_to_07[0], img_raw_005_to_09[0], img_raw_005_to_10[0]])
|
63 |
+
|
64 |
+
|
65 |
+
def cal_face_area_percent(img_name):
|
66 |
+
img = cv2.resize(cv2.imread(img_name)[:,:,::-1], (512,512))
|
67 |
+
lm478 = extract_lms_mediapipe_job(img) / 512
|
68 |
+
min_x = lm478[:,0].min()
|
69 |
+
max_x = lm478[:,0].max()
|
70 |
+
min_y = lm478[:,1].min()
|
71 |
+
max_y = lm478[:,1].max()
|
72 |
+
area = (max_x - min_x) * (max_y - min_y)
|
73 |
+
return area
|
74 |
+
|
75 |
+
def crop_img_on_face_area_percent(img_name, out_name='temp/cropped_src_img.png', min_face_area_percent=0.2):
|
76 |
+
try:
|
77 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
78 |
+
except: pass
|
79 |
+
face_area_percent = cal_face_area_percent(img_name)
|
80 |
+
if face_area_percent >= min_face_area_percent:
|
81 |
+
print(f"face area percent {face_area_percent} larger than threshold {min_face_area_percent}, directly use the input image...")
|
82 |
+
cmd = f"cp {img_name} {out_name}"
|
83 |
+
os.system(cmd)
|
84 |
+
return out_name
|
85 |
+
else:
|
86 |
+
print(f"face area percent {face_area_percent} smaller than threshold {min_face_area_percent}, crop the input image...")
|
87 |
+
img = cv2.resize(cv2.imread(img_name)[:,:,::-1], (512,512))
|
88 |
+
lm478 = extract_lms_mediapipe_job(img).astype(int)
|
89 |
+
min_x = lm478[:,0].min()
|
90 |
+
max_x = lm478[:,0].max()
|
91 |
+
min_y = lm478[:,1].min()
|
92 |
+
max_y = lm478[:,1].max()
|
93 |
+
face_area = (max_x - min_x) * (max_y - min_y)
|
94 |
+
target_total_area = face_area / min_face_area_percent
|
95 |
+
target_hw = int(target_total_area**0.5)
|
96 |
+
center_x, center_y = (min_x+max_x)/2, (min_y+max_y)/2
|
97 |
+
shrink_pixels = 2 * max(-(center_x - target_hw/2), center_x + target_hw/2 - 512, -(center_y - target_hw/2), center_y + target_hw/2-512)
|
98 |
+
shrink_pixels = max(0, shrink_pixels)
|
99 |
+
hw = math.floor(target_hw - shrink_pixels)
|
100 |
+
new_min_x = int(center_x - hw/2)
|
101 |
+
new_max_x = int(center_x + hw/2)
|
102 |
+
new_min_y = int(center_y - hw/2)
|
103 |
+
new_max_y = int(center_y + hw/2)
|
104 |
+
|
105 |
+
img = img[new_min_y:new_max_y, new_min_x:new_max_x]
|
106 |
+
img = cv2.resize(img, (512, 512))
|
107 |
+
cv2.imwrite(out_name, img[:,:,::-1])
|
108 |
+
return out_name
|
109 |
+
|
110 |
+
|
111 |
+
class GeneFace2Infer:
|
112 |
+
def __init__(self, audio2secc_dir, head_model_dir, torso_model_dir, device=None, inp=None):
|
113 |
+
if device is None:
|
114 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
115 |
+
self.audio2secc_model = self.load_audio2secc(audio2secc_dir)
|
116 |
+
self.secc2video_model = self.load_secc2video(head_model_dir, torso_model_dir, inp)
|
117 |
+
self.audio2secc_model.to(device).eval()
|
118 |
+
self.secc2video_model.to(device).eval()
|
119 |
+
self.seg_model = MediapipeSegmenter()
|
120 |
+
self.secc_renderer = SECC_Renderer(512)
|
121 |
+
self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
|
122 |
+
self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
|
123 |
+
self.camera_selector = KNearestCameraSelector()
|
124 |
+
|
125 |
+
def load_audio2secc(self, audio2secc_dir):
|
126 |
+
config_name = f"{audio2secc_dir}/config.yaml" if not audio2secc_dir.endswith(".ckpt") else f"{os.path.dirname(audio2secc_dir)}/config.yaml"
|
127 |
+
set_hparams(f"{config_name}", print_hparams=False)
|
128 |
+
self.audio2secc_dir = audio2secc_dir
|
129 |
+
self.audio2secc_hparams = copy.deepcopy(hparams)
|
130 |
+
from modules.audio2motion.vae import VAEModel, PitchContourVAEModel
|
131 |
+
from modules.audio2motion.cfm.icl_audio2motion_model import InContextAudio2MotionModel
|
132 |
+
if self.audio2secc_hparams['audio_type'] == 'hubert':
|
133 |
+
audio_in_dim = 1024
|
134 |
+
elif self.audio2secc_hparams['audio_type'] == 'mfcc':
|
135 |
+
audio_in_dim = 13
|
136 |
+
|
137 |
+
if 'icl' in hparams['task_cls']:
|
138 |
+
self.use_icl_audio2motion = True
|
139 |
+
model = InContextAudio2MotionModel(hparams['icl_model_type'], hparams=self.audio2secc_hparams)
|
140 |
+
else:
|
141 |
+
self.use_icl_audio2motion = False
|
142 |
+
if hparams.get("use_pitch", False) is True:
|
143 |
+
model = PitchContourVAEModel(hparams, in_out_dim=64, audio_in_dim=audio_in_dim)
|
144 |
+
else:
|
145 |
+
model = VAEModel(in_out_dim=64, audio_in_dim=audio_in_dim)
|
146 |
+
load_ckpt(model, f"{audio2secc_dir}", model_name='model', strict=True)
|
147 |
+
return model
|
148 |
+
|
149 |
+
def load_secc2video(self, head_model_dir, torso_model_dir, inp):
|
150 |
+
if inp is None:
|
151 |
+
inp = {}
|
152 |
+
self.head_model_dir = head_model_dir
|
153 |
+
self.torso_model_dir = torso_model_dir
|
154 |
+
if torso_model_dir != '':
|
155 |
+
if torso_model_dir.endswith(".ckpt"):
|
156 |
+
set_hparams(f"{os.path.dirname(torso_model_dir)}/config.yaml", print_hparams=False)
|
157 |
+
else:
|
158 |
+
set_hparams(f"{torso_model_dir}/config.yaml", print_hparams=False)
|
159 |
+
if inp.get('head_torso_threshold', None) is not None:
|
160 |
+
hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
|
161 |
+
self.secc2video_hparams = copy.deepcopy(hparams)
|
162 |
+
from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane_Torso
|
163 |
+
model = OSAvatarSECC_Img2plane_Torso()
|
164 |
+
load_ckpt(model, f"{torso_model_dir}", model_name='model', strict=False)
|
165 |
+
if head_model_dir != '':
|
166 |
+
print("| Warning: Assigned --torso_ckpt which also contains head, but --head_ckpt is also assigned, skipping the --head_ckpt.")
|
167 |
+
else:
|
168 |
+
from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane
|
169 |
+
if head_model_dir.endswith(".ckpt"):
|
170 |
+
set_hparams(f"{os.path.dirname(head_model_dir)}/config.yaml", print_hparams=False)
|
171 |
+
else:
|
172 |
+
set_hparams(f"{head_model_dir}/config.yaml", print_hparams=False)
|
173 |
+
if inp.get('head_torso_threshold', None) is not None:
|
174 |
+
hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
|
175 |
+
self.secc2video_hparams = copy.deepcopy(hparams)
|
176 |
+
model = OSAvatarSECC_Img2plane()
|
177 |
+
load_ckpt(model, f"{head_model_dir}", model_name='model', strict=False)
|
178 |
+
return model
|
179 |
+
|
180 |
+
def infer_once(self, inp):
|
181 |
+
self.inp = inp
|
182 |
+
samples = self.prepare_batch_from_inp(inp)
|
183 |
+
seed = inp['seed'] if inp['seed'] is not None else int(time.time())
|
184 |
+
random.seed(seed)
|
185 |
+
torch.manual_seed(seed)
|
186 |
+
np.random.seed(seed)
|
187 |
+
out_name = self.forward_system(samples, inp)
|
188 |
+
return out_name
|
189 |
+
|
190 |
+
def prepare_batch_from_inp(self, inp):
|
191 |
+
"""
|
192 |
+
:param inp: {'audio_source_name': (str)}
|
193 |
+
:return: a dict that contains the condition feature of NeRF
|
194 |
+
"""
|
195 |
+
cropped_name = 'temp/cropped_src_img_512.png'
|
196 |
+
crop_img_on_face_area_percent(inp['src_image_name'], cropped_name, min_face_area_percent=inp['min_face_area_percent'])
|
197 |
+
inp['src_image_name'] = cropped_name
|
198 |
+
|
199 |
+
sample = {}
|
200 |
+
# Process Driving Motion
|
201 |
+
if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
|
202 |
+
self.save_wav16k(inp['drv_audio_name'])
|
203 |
+
if self.audio2secc_hparams['audio_type'] == 'hubert':
|
204 |
+
hubert = self.get_hubert(self.wav16k_name)
|
205 |
+
elif self.audio2secc_hparams['audio_type'] == 'mfcc':
|
206 |
+
hubert = self.get_mfcc(self.wav16k_name) / 100
|
207 |
+
|
208 |
+
f0 = self.get_f0(self.wav16k_name)
|
209 |
+
if f0.shape[0] > len(hubert):
|
210 |
+
f0 = f0[:len(hubert)]
|
211 |
+
else:
|
212 |
+
num_to_pad = len(hubert) - len(f0)
|
213 |
+
f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0)))
|
214 |
+
t_x = hubert.shape[0]
|
215 |
+
x_mask = torch.ones([1, t_x]).float() # mask for audio frames
|
216 |
+
y_mask = torch.ones([1, t_x//2]).float() # mask for motion/image frames
|
217 |
+
sample.update({
|
218 |
+
'hubert': torch.from_numpy(hubert).float().unsqueeze(0).cuda(),
|
219 |
+
'f0': torch.from_numpy(f0).float().reshape([1,-1]).cuda(),
|
220 |
+
'x_mask': x_mask.cuda(),
|
221 |
+
'y_mask': y_mask.cuda(),
|
222 |
+
})
|
223 |
+
sample['blink'] = torch.zeros([1, t_x, 1]).long().cuda()
|
224 |
+
sample['audio'] = sample['hubert']
|
225 |
+
sample['eye_amp'] = torch.ones([1, 1]).cuda() * 1.0
|
226 |
+
sample['mouth_amp'] = torch.ones([1, 1]).cuda() * inp['mouth_amp']
|
227 |
+
elif inp['drv_audio_name'][-4:] in ['.mp4']:
|
228 |
+
drv_motion_coeff_dict = fit_3dmm_for_a_video(inp['drv_audio_name'], save=False)
|
229 |
+
drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
|
230 |
+
t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
|
231 |
+
self.drv_motion_coeff_dict = drv_motion_coeff_dict
|
232 |
+
elif inp['drv_audio_name'][-4:] in ['.npy']:
|
233 |
+
drv_motion_coeff_dict = np.load(inp['drv_audio_name'], allow_pickle=True).tolist()
|
234 |
+
drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
|
235 |
+
t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
|
236 |
+
self.drv_motion_coeff_dict = drv_motion_coeff_dict
|
237 |
+
else:
|
238 |
+
raise ValueError()
|
239 |
+
|
240 |
+
# Face Parsing
|
241 |
+
image_name = inp['src_image_name']
|
242 |
+
if image_name.endswith(".mp4"):
|
243 |
+
img = read_first_frame_from_a_video(image_name)
|
244 |
+
image_name = inp['src_image_name'] = image_name[:-4] + '.png'
|
245 |
+
cv2.imwrite(image_name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
246 |
+
sample['ref_gt_img'] = load_img_to_normalized_512_bchw_tensor(image_name).cuda()
|
247 |
+
img = load_img_to_512_hwc_array(image_name)
|
248 |
+
segmap = self.seg_model._cal_seg_map(img)
|
249 |
+
sample['segmap'] = torch.tensor(segmap).float().unsqueeze(0).cuda()
|
250 |
+
head_img = self.seg_model._seg_out_img_with_segmap(img, segmap, mode='head')[0]
|
251 |
+
sample['ref_head_img'] = ((torch.tensor(head_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
|
252 |
+
inpaint_torso_img, _, _, _ = inpaint_torso_job(img, segmap)
|
253 |
+
sample['ref_torso_img'] = ((torch.tensor(inpaint_torso_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
|
254 |
+
|
255 |
+
if inp['bg_image_name'] == '':
|
256 |
+
bg_img = extract_background([img], [segmap], 'lama')
|
257 |
+
else:
|
258 |
+
bg_img = cv2.imread(inp['bg_image_name'])
|
259 |
+
bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
|
260 |
+
bg_img = cv2.resize(bg_img, (512,512))
|
261 |
+
sample['bg_img'] = ((torch.tensor(bg_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
|
262 |
+
|
263 |
+
# 3DMM, get identity code and camera pose
|
264 |
+
coeff_dict = fit_3dmm_for_a_image(image_name, save=False)
|
265 |
+
assert coeff_dict is not None
|
266 |
+
src_id = torch.tensor(coeff_dict['id']).reshape([1,80]).cuda()
|
267 |
+
src_exp = torch.tensor(coeff_dict['exp']).reshape([1,64]).cuda()
|
268 |
+
src_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda()
|
269 |
+
src_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda()
|
270 |
+
sample['id'] = src_id.repeat([t_x//2,1])
|
271 |
+
|
272 |
+
# get the src_kp for torso model
|
273 |
+
src_kp = self.face3d_helper.reconstruct_lm2d(src_id, src_exp, src_euler, src_trans) # [1, 68, 2]
|
274 |
+
src_kp = (src_kp-0.5) / 0.5 # rescale to -1~1
|
275 |
+
sample['src_kp'] = torch.clamp(src_kp, -1, 1).repeat([t_x//2,1,1])
|
276 |
+
|
277 |
+
# get camera pose file
|
278 |
+
# random.seed(time.time())
|
279 |
+
if inp['drv_pose_name'] in ['nearest', 'topk']:
|
280 |
+
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler': torch.tensor(coeff_dict['euler']).reshape([1,3]), 'trans': torch.tensor(coeff_dict['trans']).reshape([1,3])})
|
281 |
+
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
|
282 |
+
camera = np.concatenate([c2w.reshape([1,16]), intrinsics.reshape([1,9])], axis=-1)
|
283 |
+
coeff_names, distance_matrix = self.camera_selector.find_k_nearest(camera, k=100)
|
284 |
+
coeff_names = coeff_names[0] # squeeze
|
285 |
+
if inp['drv_pose_name'] == 'nearest':
|
286 |
+
inp['drv_pose_name'] = coeff_names[0]
|
287 |
+
else:
|
288 |
+
inp['drv_pose_name'] = random.choice(coeff_names)
|
289 |
+
# inp['drv_pose_name'] = coeff_names[0]
|
290 |
+
elif inp['drv_pose_name'] == 'random':
|
291 |
+
inp['drv_pose_name'] = self.camera_selector.random_select()
|
292 |
+
else:
|
293 |
+
inp['drv_pose_name'] = inp['drv_pose_name']
|
294 |
+
|
295 |
+
print(f"| To extract pose from {inp['drv_pose_name']}")
|
296 |
+
|
297 |
+
# extract camera pose
|
298 |
+
if inp['drv_pose_name'] == 'static':
|
299 |
+
sample['euler'] = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda().repeat([t_x//2,1]) # default static pose
|
300 |
+
sample['trans'] = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda().repeat([t_x//2,1])
|
301 |
+
else: # from file
|
302 |
+
if inp['drv_pose_name'].endswith('.mp4'):
|
303 |
+
# extract coeff from video
|
304 |
+
drv_pose_coeff_dict = fit_3dmm_for_a_video(inp['drv_pose_name'], save=False)
|
305 |
+
else:
|
306 |
+
# load from npy
|
307 |
+
drv_pose_coeff_dict = np.load(inp['drv_pose_name'], allow_pickle=True).tolist()
|
308 |
+
print(f"| Extracted pose from {inp['drv_pose_name']}")
|
309 |
+
eulers = convert_to_tensor(drv_pose_coeff_dict['euler']).reshape([-1,3]).cuda()
|
310 |
+
trans = convert_to_tensor(drv_pose_coeff_dict['trans']).reshape([-1,3]).cuda()
|
311 |
+
len_pose = len(eulers)
|
312 |
+
index_lst = [mirror_index(i, len_pose) for i in range(t_x//2)]
|
313 |
+
sample['euler'] = eulers[index_lst]
|
314 |
+
sample['trans'] = trans[index_lst]
|
315 |
+
|
316 |
+
# fix the z axis
|
317 |
+
sample['trans'][:, -1] = sample['trans'][0:1, -1].repeat([sample['trans'].shape[0]])
|
318 |
+
|
319 |
+
# mapping to the init pose
|
320 |
+
if inp.get("map_to_init_pose", 'False') == 'True':
|
321 |
+
diff_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda() - sample['euler'][0:1]
|
322 |
+
sample['euler'] = sample['euler'] + diff_euler
|
323 |
+
diff_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda() - sample['trans'][0:1]
|
324 |
+
sample['trans'] = sample['trans'] + diff_trans
|
325 |
+
|
326 |
+
# prepare camera
|
327 |
+
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':sample['euler'].cpu(), 'trans':sample['trans'].cpu()})
|
328 |
+
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
|
329 |
+
# smooth camera
|
330 |
+
camera_smo_ksize = 7
|
331 |
+
camera = np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)
|
332 |
+
camera = smooth_camera_sequence(camera, kernel_size=camera_smo_ksize) # [T, 25]
|
333 |
+
camera = torch.tensor(camera).cuda().float()
|
334 |
+
sample['camera'] = camera
|
335 |
+
|
336 |
+
return sample
|
337 |
+
|
338 |
+
@torch.no_grad()
|
339 |
+
def get_hubert(self, wav16k_name):
|
340 |
+
from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
|
341 |
+
hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
|
342 |
+
len_mel = hubert.shape[0]
|
343 |
+
x_multiply = 8
|
344 |
+
if len_mel % x_multiply == 0:
|
345 |
+
num_to_pad = 0
|
346 |
+
else:
|
347 |
+
num_to_pad = x_multiply - len_mel % x_multiply
|
348 |
+
hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
|
349 |
+
return hubert
|
350 |
+
|
351 |
+
def get_mfcc(self, wav16k_name):
|
352 |
+
from utils.audio import librosa_wav2mfcc
|
353 |
+
hparams['fft_size'] = 1200
|
354 |
+
hparams['win_size'] = 1200
|
355 |
+
hparams['hop_size'] = 480
|
356 |
+
hparams['audio_num_mel_bins'] = 80
|
357 |
+
hparams['fmin'] = 80
|
358 |
+
hparams['fmax'] = 12000
|
359 |
+
hparams['audio_sample_rate'] = 24000
|
360 |
+
mfcc = librosa_wav2mfcc(wav16k_name,
|
361 |
+
fft_size=hparams['fft_size'],
|
362 |
+
hop_size=hparams['hop_size'],
|
363 |
+
win_length=hparams['win_size'],
|
364 |
+
num_mels=hparams['audio_num_mel_bins'],
|
365 |
+
fmin=hparams['fmin'],
|
366 |
+
fmax=hparams['fmax'],
|
367 |
+
sample_rate=hparams['audio_sample_rate'],
|
368 |
+
center=True)
|
369 |
+
mfcc = np.array(mfcc).reshape([-1, 13])
|
370 |
+
len_mel = mfcc.shape[0]
|
371 |
+
x_multiply = 8
|
372 |
+
if len_mel % x_multiply == 0:
|
373 |
+
num_to_pad = 0
|
374 |
+
else:
|
375 |
+
num_to_pad = x_multiply - len_mel % x_multiply
|
376 |
+
mfcc = np.pad(mfcc, pad_width=((0,num_to_pad), (0,0)))
|
377 |
+
return mfcc
|
378 |
+
|
379 |
+
@torch.no_grad()
|
380 |
+
def forward_audio2secc(self, batch, inp=None):
|
381 |
+
if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
|
382 |
+
from inference.infer_utils import extract_audio_motion_from_ref_video
|
383 |
+
if self.use_icl_audio2motion:
|
384 |
+
self.audio2secc_model.empty_context() # make this function reloadable
|
385 |
+
if self.use_icl_audio2motion and inp['drv_talking_style_name'].endswith(".mp4"):
|
386 |
+
ref_exp, ref_hubert, ref_f0 = extract_audio_motion_from_ref_video(inp['drv_talking_style_name'])
|
387 |
+
self.audio2secc_model.add_sample_to_context(ref_exp, ref_hubert, ref_f0)
|
388 |
+
elif self.use_icl_audio2motion and inp['drv_talking_style_name'].endswith((".png",'.jpg')):
|
389 |
+
style_coeff_dict = fit_3dmm_for_a_image(inp['drv_talking_style_name'])
|
390 |
+
ref_exp = torch.tensor(style_coeff_dict['exp']).reshape([1,1,64]).cuda()
|
391 |
+
self.audio2secc_model.add_sample_to_context(ref_exp.repeat([1, 100, 1]), hubert=None, f0=None)
|
392 |
+
else:
|
393 |
+
print("| WARNING: Not assigned reference talking style, passing...")
|
394 |
+
# audio-to-exp
|
395 |
+
ret = {}
|
396 |
+
# pred = self.audio2secc_model.forward(batch, ret=ret,train=False, ,)
|
397 |
+
pred = self.audio2secc_model.forward(batch, ret=ret,train=False, temperature=inp['temperature'], denoising_steps=inp['denoising_steps'], cond_scale=inp['cfg_scale'])
|
398 |
+
|
399 |
+
print("| audio-to-motion finished")
|
400 |
+
if pred.shape[-1] == 144:
|
401 |
+
id = ret['pred'][0][:,:80]
|
402 |
+
exp = ret['pred'][0][:,80:]
|
403 |
+
else:
|
404 |
+
id = batch['id']
|
405 |
+
exp = ret['pred'][0]
|
406 |
+
if len(id) < len(exp): # happens when use ICL
|
407 |
+
id = torch.cat([id, id[0].unsqueeze(0).repeat([len(exp)-len(id),1])])
|
408 |
+
batch['id'] = id
|
409 |
+
batch['exp'] = exp
|
410 |
+
else:
|
411 |
+
drv_motion_coeff_dict = self.drv_motion_coeff_dict
|
412 |
+
batch['exp'] = torch.FloatTensor(drv_motion_coeff_dict['exp']).cuda()
|
413 |
+
|
414 |
+
batch['id'] = batch['id'][:-4]
|
415 |
+
batch['exp'] = batch['exp'][:-4]
|
416 |
+
batch['euler'] = batch['euler'][:-4]
|
417 |
+
batch['trans'] = batch['trans'][:-4]
|
418 |
+
batch = self.get_driving_motion(batch['id'], batch['exp'], batch['euler'], batch['trans'], batch, inp)
|
419 |
+
if self.use_icl_audio2motion:
|
420 |
+
self.audio2secc_model.empty_context()
|
421 |
+
return batch
|
422 |
+
|
423 |
+
@torch.no_grad()
|
424 |
+
def get_driving_motion(self, id, exp, euler, trans, batch, inp):
|
425 |
+
zero_eulers = torch.zeros([id.shape[0], 3]).to(id.device)
|
426 |
+
zero_trans = torch.zeros([id.shape[0], 3]).to(exp.device)
|
427 |
+
# render the secc given the id,exp
|
428 |
+
with torch.no_grad():
|
429 |
+
chunk_size = 50
|
430 |
+
drv_secc_color_lst = []
|
431 |
+
num_iters = len(id)//chunk_size if len(id)%chunk_size == 0 else len(id)//chunk_size+1
|
432 |
+
for i in tqdm.trange(num_iters, desc="rendering drv secc"):
|
433 |
+
torch.cuda.empty_cache()
|
434 |
+
face_mask, drv_secc_color = self.secc_renderer(id[i*chunk_size:(i+1)*chunk_size], exp[i*chunk_size:(i+1)*chunk_size], zero_eulers[i*chunk_size:(i+1)*chunk_size], zero_trans[i*chunk_size:(i+1)*chunk_size])
|
435 |
+
drv_secc_color_lst.append(drv_secc_color.cpu())
|
436 |
+
drv_secc_colors = torch.cat(drv_secc_color_lst, dim=0)
|
437 |
+
_, src_secc_color = self.secc_renderer(id[0:1], exp[0:1], zero_eulers[0:1], zero_trans[0:1])
|
438 |
+
_, cano_secc_color = self.secc_renderer(id[0:1], exp[0:1]*0, zero_eulers[0:1], zero_trans[0:1])
|
439 |
+
batch['drv_secc'] = drv_secc_colors.cuda()
|
440 |
+
batch['src_secc'] = src_secc_color.cuda()
|
441 |
+
batch['cano_secc'] = cano_secc_color.cuda()
|
442 |
+
|
443 |
+
# blinking secc
|
444 |
+
if inp['blink_mode'] == 'period':
|
445 |
+
period = 5 # second
|
446 |
+
|
447 |
+
if inp['hold_eye_opened'] == 'True':
|
448 |
+
for i in tqdm.trange(len(drv_secc_colors),desc="opening eye for secc"):
|
449 |
+
batch['drv_secc'][i] = hold_eye_opened_for_secc(batch['drv_secc'][i])
|
450 |
+
|
451 |
+
for i in tqdm.trange(len(drv_secc_colors),desc="blinking secc"):
|
452 |
+
if i % (25*period) == 0:
|
453 |
+
blink_dur_frames = random.randint(8, 12)
|
454 |
+
for offset in range(blink_dur_frames):
|
455 |
+
j = offset + i
|
456 |
+
if j >= len(drv_secc_colors)-1: break
|
457 |
+
def blink_percent_fn(t, T):
|
458 |
+
return -4/T**2 * t**2 + 4/T * t
|
459 |
+
blink_percent = blink_percent_fn(offset, blink_dur_frames)
|
460 |
+
secc = batch['drv_secc'][j]
|
461 |
+
out_secc = blink_eye_for_secc(secc, blink_percent)
|
462 |
+
out_secc = out_secc.cuda()
|
463 |
+
batch['drv_secc'][j] = out_secc
|
464 |
+
|
465 |
+
# get the drv_kp for torso model, using the transformed trajectory
|
466 |
+
drv_kp = self.face3d_helper.reconstruct_lm2d(id, exp, euler, trans) # [T, 68, 2]
|
467 |
+
|
468 |
+
drv_kp = (drv_kp-0.5) / 0.5 # rescale to -1~1
|
469 |
+
batch['drv_kp'] = torch.clamp(drv_kp, -1, 1)
|
470 |
+
return batch
|
471 |
+
|
472 |
+
@torch.no_grad()
|
473 |
+
def forward_secc2video(self, batch, inp=None):
|
474 |
+
num_frames = len(batch['drv_secc'])
|
475 |
+
camera = batch['camera']
|
476 |
+
src_kps = batch['src_kp']
|
477 |
+
drv_kps = batch['drv_kp']
|
478 |
+
cano_secc_color = batch['cano_secc']
|
479 |
+
src_secc_color = batch['src_secc']
|
480 |
+
drv_secc_colors = batch['drv_secc']
|
481 |
+
ref_img_gt = batch['ref_gt_img']
|
482 |
+
ref_img_head = batch['ref_head_img']
|
483 |
+
ref_torso_img = batch['ref_torso_img']
|
484 |
+
bg_img = batch['bg_img']
|
485 |
+
segmap = batch['segmap']
|
486 |
+
|
487 |
+
# smooth torso drv_kp
|
488 |
+
torso_smo_ksize = 7
|
489 |
+
drv_kps = smooth_features_xd(drv_kps.reshape([-1, 68*2]), kernel_size=torso_smo_ksize).reshape([-1, 68, 2])
|
490 |
+
|
491 |
+
# forward renderer
|
492 |
+
img_raw_lst = []
|
493 |
+
img_lst = []
|
494 |
+
depth_img_lst = []
|
495 |
+
with torch.no_grad():
|
496 |
+
with torch.cuda.amp.autocast(inp['fp16']):
|
497 |
+
for i in tqdm.trange(num_frames, desc="Real3D-Portrait is rendering frames"):
|
498 |
+
kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
|
499 |
+
kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
|
500 |
+
cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
|
501 |
+
'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
|
502 |
+
'kp_s': kp_src, 'kp_d': kp_drv,
|
503 |
+
'ref_cameras': camera[i:i+1],
|
504 |
+
}
|
505 |
+
if i == 0:
|
506 |
+
gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=True, use_cached_backbone=False)
|
507 |
+
else:
|
508 |
+
gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
|
509 |
+
img_lst.append(gen_output['image'])
|
510 |
+
img_raw_lst.append(gen_output['image_raw'])
|
511 |
+
depth_img_lst.append(gen_output['image_depth'])
|
512 |
+
|
513 |
+
# save demo video
|
514 |
+
depth_imgs = torch.cat(depth_img_lst)
|
515 |
+
imgs = torch.cat(img_lst)
|
516 |
+
imgs_raw = torch.cat(img_raw_lst)
|
517 |
+
secc_img = torch.cat([torch.nn.functional.interpolate(drv_secc_colors[i:i+1], (512,512)) for i in range(num_frames)])
|
518 |
+
|
519 |
+
if inp['out_mode'] == 'concat_debug':
|
520 |
+
secc_img = secc_img.cpu()
|
521 |
+
secc_img = ((secc_img + 1) * 127.5).permute(0, 2, 3, 1).int().numpy()
|
522 |
+
|
523 |
+
depth_img = F.interpolate(depth_imgs, (512,512)).cpu()
|
524 |
+
depth_img = depth_img.repeat([1,3,1,1])
|
525 |
+
depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
|
526 |
+
depth_img = depth_img * 2 - 1
|
527 |
+
depth_img = depth_img.clamp(-1,1)
|
528 |
+
|
529 |
+
secc_img = secc_img / 127.5 - 1
|
530 |
+
secc_img = torch.from_numpy(secc_img).permute(0, 3, 1, 2)
|
531 |
+
imgs = torch.cat([ref_img_gt.repeat([imgs.shape[0],1,1,1]).cpu(), secc_img, F.interpolate(imgs_raw, (512,512)).cpu(), depth_img, imgs.cpu()], dim=-1)
|
532 |
+
elif inp['out_mode'] == 'final':
|
533 |
+
imgs = imgs.cpu()
|
534 |
+
elif inp['out_mode'] == 'debug':
|
535 |
+
raise NotImplementedError("to do: save separate videos")
|
536 |
+
imgs = imgs.clamp(-1,1)
|
537 |
+
|
538 |
+
import imageio
|
539 |
+
debug_name = 'demo.mp4'
|
540 |
+
out_imgs = ((imgs.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
|
541 |
+
writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
|
542 |
+
|
543 |
+
for i in tqdm.trange(len(out_imgs), desc="Imageio is saving video"):
|
544 |
+
writer.append_data(out_imgs[i])
|
545 |
+
writer.close()
|
546 |
+
|
547 |
+
out_fname = 'infer_out/tmp/' + os.path.basename(inp['src_image_name'])[:-4] + '_' + os.path.basename(inp['drv_pose_name'])[:-4] + '.mp4' if inp['out_name'] == '' else inp['out_name']
|
548 |
+
try:
|
549 |
+
os.makedirs(os.path.dirname(out_fname), exist_ok=True)
|
550 |
+
except: pass
|
551 |
+
if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
|
552 |
+
# cmd = f"ffmpeg -i {debug_name} -i {self.wav16k_name} -y -shortest {out_fname}"
|
553 |
+
cmd = f"ffmpeg -i {debug_name} -i {self.wav16k_name} -y -v quiet -shortest {out_fname}"
|
554 |
+
print(cmd)
|
555 |
+
os.system(cmd)
|
556 |
+
os.system(f"rm {debug_name}")
|
557 |
+
os.system(f"rm {self.wav16k_name}")
|
558 |
+
else:
|
559 |
+
ret = os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -map 0:v -map 1:a -y -v quiet -shortest {out_fname}")
|
560 |
+
if ret != 0: # 没有成功从drv_audio_name里面提取到音频, 则直接输出无音频轨道的纯视频
|
561 |
+
os.system(f"mv {debug_name} {out_fname}")
|
562 |
+
print(f"Saved at {out_fname}")
|
563 |
+
return out_fname
|
564 |
+
|
565 |
+
@torch.no_grad()
|
566 |
+
def forward_system(self, batch, inp):
|
567 |
+
self.forward_audio2secc(batch, inp)
|
568 |
+
out_fname = self.forward_secc2video(batch, inp)
|
569 |
+
return out_fname
|
570 |
+
|
571 |
+
@classmethod
|
572 |
+
def example_run(cls, inp=None):
|
573 |
+
inp_tmp = {
|
574 |
+
'drv_audio_name': 'data/raw/val_wavs/zozo.wav',
|
575 |
+
'src_image_name': 'data/raw/val_imgs/Macron.png'
|
576 |
+
}
|
577 |
+
if inp is not None:
|
578 |
+
inp_tmp.update(inp)
|
579 |
+
inp = inp_tmp
|
580 |
+
|
581 |
+
infer_instance = cls(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp)
|
582 |
+
infer_instance.infer_once(inp)
|
583 |
+
|
584 |
+
##############
|
585 |
+
# IO-related
|
586 |
+
##############
|
587 |
+
def save_wav16k(self, audio_name):
|
588 |
+
supported_types = ('.wav', '.mp3', '.mp4', '.avi')
|
589 |
+
assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
|
590 |
+
import uuid
|
591 |
+
wav16k_name = audio_name[:-4] + f'{uuid.uuid1()}_16k.wav'
|
592 |
+
self.wav16k_name = wav16k_name
|
593 |
+
extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
|
594 |
+
# extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -y {wav16k_name} -y"
|
595 |
+
print(extract_wav_cmd)
|
596 |
+
os.system(extract_wav_cmd)
|
597 |
+
print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
|
598 |
+
|
599 |
+
def get_f0(self, wav16k_name):
|
600 |
+
from data_gen.utils.process_audio.extract_mel_f0 import extract_mel_from_fname, extract_f0_from_wav_and_mel
|
601 |
+
wav, mel = extract_mel_from_fname(self.wav16k_name)
|
602 |
+
f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
|
603 |
+
f0 = f0.reshape([-1,1])
|
604 |
+
return f0
|
605 |
+
|
606 |
+
if __name__ == '__main__':
|
607 |
+
import argparse, glob, tqdm
|
608 |
+
parser = argparse.ArgumentParser()
|
609 |
+
# parser.add_argument("--a2m_ckpt", default='checkpoints/240112_audio2secc/icl_audio2secc_vox2_cmlr') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
|
610 |
+
parser.add_argument("--a2m_ckpt", default='checkpoints/240126_real3dportrait_orig/audio2secc_vae') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
|
611 |
+
parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
|
612 |
+
# parser.add_argument("--head_ckpt", default='checkpoints/240210_os_secc2plane/secc2plane_trigridv2_blink0.3_pertubeNone') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
|
613 |
+
# parser.add_argument("--torso_ckpt", default='')
|
614 |
+
# parser.add_argument("--torso_ckpt", default='checkpoints/240209_robust_secc2plane_torso/secc2plane_torso_orig_fuseV1_MulMaskFalse')
|
615 |
+
parser.add_argument("--torso_ckpt", default='checkpoints/240211_robust_secc2plane_torso/secc2plane_torso_orig_fuseV3_MulMaskTrue')
|
616 |
+
# parser.add_argument("--torso_ckpt", default='checkpoints/240209_robust_secc2plane_torso/secc2plane_torso_orig_fuseV2_MulMaskTrue')
|
617 |
+
# parser.add_argument("--src_img", default='data/raw/val_imgs/Macron_img.png')
|
618 |
+
# parser.add_argument("--src_img", default='gf2_iclr_test_data/cross_imgs/Trump.png')
|
619 |
+
parser.add_argument("--src_img", default='data/raw/val_imgs/mercy.png')
|
620 |
+
parser.add_argument("--bg_img", default='') # data/raw/val_imgs/bg3.png
|
621 |
+
parser.add_argument("--drv_aud", default='data/raw/val_wavs/yiwise.wav')
|
622 |
+
parser.add_argument("--drv_pose", default='infer_out/240319_tta/trump.mp4') # nearest | topk | random | static | vid_name
|
623 |
+
# parser.add_argument("--drv_pose", default='nearest') # nearest | topk | random | static | vid_name
|
624 |
+
parser.add_argument("--drv_style", default='') # nearest | topk | random | static | vid_name
|
625 |
+
# parser.add_argument("--drv_style", default='infer_out/240319_tta/trump.mp4') # nearest | topk | random | static | vid_name
|
626 |
+
parser.add_argument("--blink_mode", default='period') # none | period
|
627 |
+
parser.add_argument("--temperature", default=0.2, type=float) # nearest | random
|
628 |
+
parser.add_argument("--denoising_steps", default=20, type=int) # nearest | random
|
629 |
+
parser.add_argument("--cfg_scale", default=2.5, type=float) # nearest | random
|
630 |
+
parser.add_argument("--mouth_amp", default=0.4, type=float) # scale of predicted mouth, enabled in audio-driven
|
631 |
+
parser.add_argument("--min_face_area_percent", default=0.2, type=float) # scale of predicted mouth, enabled in audio-driven
|
632 |
+
parser.add_argument("--head_torso_threshold", default=0.5, type=float, help="0.1~1.0, 如果发现头发有半透明的现象,调小该值,以将小weights的头发直接clamp到weights=1.0; 如果发现头外部有荧光色的虚影,调小这个值. 对不同超参的Nerf也是case-to-case")
|
633 |
+
# parser.add_argument("--head_torso_threshold", default=None, type=float, help="0.1~1.0, 如果发现头发有半透明的现象,调小该值,以将小weights的头发直接clamp到weights=1.0; 如果发现头外部有荧光色的虚影,调小这个值. 对不同超参的Nerf也是case-to-case")
|
634 |
+
parser.add_argument("--out_name", default='') # nearest | random
|
635 |
+
parser.add_argument("--out_mode", default='concat_debug') # concat_debug | debug | final
|
636 |
+
parser.add_argument("--hold_eye_opened", default='False') # concat_debug | debug | final
|
637 |
+
parser.add_argument("--map_to_init_pose", default='True') # concat_debug | debug | final
|
638 |
+
parser.add_argument("--seed", default=None, type=int) # random seed, default None to use time.time()
|
639 |
+
parser.add_argument("--fp16", action='store_true')
|
640 |
+
|
641 |
+
args = parser.parse_args()
|
642 |
+
|
643 |
+
inp = {
|
644 |
+
'a2m_ckpt': args.a2m_ckpt,
|
645 |
+
'head_ckpt': args.head_ckpt,
|
646 |
+
'torso_ckpt': args.torso_ckpt,
|
647 |
+
'src_image_name': args.src_img,
|
648 |
+
'bg_image_name': args.bg_img,
|
649 |
+
'drv_audio_name': args.drv_aud,
|
650 |
+
'drv_pose_name': args.drv_pose,
|
651 |
+
'drv_talking_style_name': args.drv_style,
|
652 |
+
'blink_mode': args.blink_mode,
|
653 |
+
'temperature': args.temperature,
|
654 |
+
'mouth_amp': args.mouth_amp,
|
655 |
+
'out_name': args.out_name,
|
656 |
+
'out_mode': args.out_mode,
|
657 |
+
'map_to_init_pose': args.map_to_init_pose,
|
658 |
+
'hold_eye_opened': args.hold_eye_opened,
|
659 |
+
'head_torso_threshold': args.head_torso_threshold,
|
660 |
+
'min_face_area_percent': args.min_face_area_percent,
|
661 |
+
'denoising_steps': args.denoising_steps,
|
662 |
+
'cfg_scale': args.cfg_scale,
|
663 |
+
'seed': args.seed,
|
664 |
+
'fp16': args.fp16, # 目前的ckpt使用fp16会导致nan,发现是因为i2p模型的layernorm产生了单个nan导致的,在训练阶段也采用fp16可能可以解决这个问题
|
665 |
+
}
|
666 |
+
|
667 |
+
GeneFace2Infer.example_run(inp)
|
inference/real3dportrait_demo.ipynb
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"colab_type": "text",
|
7 |
+
"id": "view-in-github"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"<a href=\"https://colab.research.google.com/github/yerfor/Real3DPortrait/blob/main/real3dportrait_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"id": "QS04K9oO21AW"
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"Check GPU"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": null,
|
25 |
+
"metadata": {
|
26 |
+
"id": "1ESQRDb-yVUG"
|
27 |
+
},
|
28 |
+
"outputs": [],
|
29 |
+
"source": [
|
30 |
+
"!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "markdown",
|
35 |
+
"metadata": {
|
36 |
+
"id": "y-ctmIvu3Ei8"
|
37 |
+
},
|
38 |
+
"source": [
|
39 |
+
"Installation"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": null,
|
45 |
+
"metadata": {
|
46 |
+
"id": "gXu76wdDgaxo"
|
47 |
+
},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"# install pytorch3d, about 15s\n",
|
51 |
+
"import os\n",
|
52 |
+
"import sys\n",
|
53 |
+
"import torch\n",
|
54 |
+
"need_pytorch3d=False\n",
|
55 |
+
"try:\n",
|
56 |
+
" import pytorch3d\n",
|
57 |
+
"except ModuleNotFoundError:\n",
|
58 |
+
" need_pytorch3d=True\n",
|
59 |
+
"if need_pytorch3d:\n",
|
60 |
+
" if torch.__version__.startswith(\"2.1.\") and sys.platform.startswith(\"linux\"):\n",
|
61 |
+
" # We try to install PyTorch3D via a released wheel.\n",
|
62 |
+
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
63 |
+
" version_str=\"\".join([\n",
|
64 |
+
" f\"py3{sys.version_info.minor}_cu\",\n",
|
65 |
+
" torch.version.cuda.replace(\".\",\"\"),\n",
|
66 |
+
" f\"_pyt{pyt_version_str}\"\n",
|
67 |
+
" ])\n",
|
68 |
+
" !pip install fvcore iopath\n",
|
69 |
+
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
70 |
+
" else:\n",
|
71 |
+
" # We try to install PyTorch3D from source.\n",
|
72 |
+
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"metadata": {
|
79 |
+
"id": "DuUynxmotG_-"
|
80 |
+
},
|
81 |
+
"outputs": [],
|
82 |
+
"source": [
|
83 |
+
"# install dependencies, about 5~10 min\n",
|
84 |
+
"!pip install tensorboard==2.13.0 tensorboardX==2.6.1\n",
|
85 |
+
"!pip install pyspy==0.1.1\n",
|
86 |
+
"!pip install protobuf==3.20.3\n",
|
87 |
+
"!pip install scipy==1.9.1\n",
|
88 |
+
"!pip install kornia==0.5.0\n",
|
89 |
+
"!pip install trimesh==3.22.0\n",
|
90 |
+
"!pip install einops==0.6.1 torchshow==0.5.1\n",
|
91 |
+
"!pip install imageio==2.31.1 imageio-ffmpeg==0.4.8\n",
|
92 |
+
"!pip install scikit-learn==1.3.0 scikit-image==0.21.0\n",
|
93 |
+
"!pip install av==10.0.0 lpips==0.1.4\n",
|
94 |
+
"!pip install timm==0.9.2 librosa==0.9.2\n",
|
95 |
+
"!pip install openmim==0.3.9\n",
|
96 |
+
"!mim install mmcv==2.1.0 # use mim to speed up installation for mmcv\n",
|
97 |
+
"!pip install transformers==4.33.2\n",
|
98 |
+
"!pip install pretrainedmodels==0.7.4\n",
|
99 |
+
"!pip install ninja==1.11.1\n",
|
100 |
+
"!pip install faiss-cpu==1.7.4\n",
|
101 |
+
"!pip install praat-parselmouth==0.4.3 moviepy==1.0.3\n",
|
102 |
+
"!pip install mediapipe==0.10.7\n",
|
103 |
+
"!pip install --upgrade attr\n",
|
104 |
+
"!pip install beartype==0.16.4 gateloop_transformer==0.4.0\n",
|
105 |
+
"!pip install torchode==0.2.0 torchdiffeq==0.2.3\n",
|
106 |
+
"!pip install hydra-core==1.3.2 pandas==2.1.3\n",
|
107 |
+
"!pip install pytorch_lightning==2.1.2\n",
|
108 |
+
"!pip install httpx==0.23.3\n",
|
109 |
+
"!pip install gradio==4.16.0\n",
|
110 |
+
"!pip install gdown\n",
|
111 |
+
"!pip install pyloudnorm webrtcvad pyworld==0.2.1rc0 pypinyin==0.42.0"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "code",
|
116 |
+
"execution_count": null,
|
117 |
+
"metadata": {
|
118 |
+
"id": "0GLEV0HVu8rj"
|
119 |
+
},
|
120 |
+
"outputs": [],
|
121 |
+
"source": [
|
122 |
+
"# RESTART kernel to make sure runtime is correct if you meet runtime errors\n",
|
123 |
+
"# import os\n",
|
124 |
+
"# os.kill(os.getpid(), 9)"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"cell_type": "markdown",
|
129 |
+
"metadata": {
|
130 |
+
"id": "5UfKHKrH6kcq"
|
131 |
+
},
|
132 |
+
"source": [
|
133 |
+
"Clone code and download checkpoints"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"execution_count": null,
|
139 |
+
"metadata": {
|
140 |
+
"id": "-gfRsd9DwIgl"
|
141 |
+
},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"# clone Real3DPortrait repo from github\n",
|
145 |
+
"!git clone https://github.com/yerfor/Real3DPortrait\n",
|
146 |
+
"%cd Real3DPortrait"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "code",
|
151 |
+
"execution_count": null,
|
152 |
+
"metadata": {
|
153 |
+
"id": "Yju8dQY7x5OS"
|
154 |
+
},
|
155 |
+
"outputs": [],
|
156 |
+
"source": [
|
157 |
+
"# download pretrained ckpts & third-party ckpts from google drive, about 1 min\n",
|
158 |
+
"!pip install --upgrade --no-cache-dir gdown\n",
|
159 |
+
"%cd deep_3drecon/BFM\n",
|
160 |
+
"!gdown https://drive.google.com/uc?id=1SPM3IHsyNAaVMwqZZGV6QVaV7I2Hly0v\n",
|
161 |
+
"!gdown https://drive.google.com/uc?id=1MSldX9UChKEb3AXLVTPzZQcsbGD4VmGF\n",
|
162 |
+
"!gdown https://drive.google.com/uc?id=180ciTvm16peWrcpl4DOekT9eUQ-lJfMU\n",
|
163 |
+
"!gdown https://drive.google.com/uc?id=1KX9MyGueFB3M-X0Ss152x_johyTXHTfU\n",
|
164 |
+
"!gdown https://drive.google.com/uc?id=19-NyZn_I0_mkF-F5GPyFMwQJ_-WecZIL\n",
|
165 |
+
"!gdown https://drive.google.com/uc?id=11ouQ7Wr2I-JKStp2Fd1afedmWeuifhof\n",
|
166 |
+
"!gdown https://drive.google.com/uc?id=18ICIvQoKX-7feYWP61RbpppzDuYTptCq\n",
|
167 |
+
"!gdown https://drive.google.com/uc?id=1VktuY46m0v_n_d4nvOupauJkK4LF6mHE\n",
|
168 |
+
"%cd ../..\n",
|
169 |
+
"\n",
|
170 |
+
"%cd checkpoints\n",
|
171 |
+
"!gdown https://drive.google.com/uc?id=1gz8A6xestHp__GbZT5qozb43YaybRJhZ\n",
|
172 |
+
"!gdown https://drive.google.com/uc?id=1gSUIw2AkkKnlLJnNfS2FCqtaVw9tw3QF\n",
|
173 |
+
"!unzip 240210_real3dportrait_orig.zip\n",
|
174 |
+
"!unzip pretrained_ckpts.zip\n",
|
175 |
+
"!ls\n",
|
176 |
+
"%cd ..\n"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "markdown",
|
181 |
+
"metadata": {
|
182 |
+
"id": "LHzLro206pnA"
|
183 |
+
},
|
184 |
+
"source": [
|
185 |
+
"Inference sample"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "markdown",
|
190 |
+
"metadata": {},
|
191 |
+
"source": [
|
192 |
+
"!python inference/real3d_infer.py -h"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "code",
|
197 |
+
"execution_count": null,
|
198 |
+
"metadata": {
|
199 |
+
"id": "2aCDwxNivQoS"
|
200 |
+
},
|
201 |
+
"outputs": [],
|
202 |
+
"source": [
|
203 |
+
"# sample inference, about 3 min\n",
|
204 |
+
"!python inference/real3d_infer.py \\\n",
|
205 |
+
"--src_img data/raw/examples/Macron.png \\\n",
|
206 |
+
"--drv_aud data/raw/examples/Obama_5s.wav \\\n",
|
207 |
+
"--drv_pose data/raw/examples/May_5s.mp4 \\\n",
|
208 |
+
"--bg_img data/raw/examples/bg.png \\\n",
|
209 |
+
"--out_name output.mp4 \\\n",
|
210 |
+
"--out_mode concat_debug \\\n",
|
211 |
+
"--low_memory_usage"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "markdown",
|
216 |
+
"metadata": {
|
217 |
+
"id": "XL0c54l19mBG"
|
218 |
+
},
|
219 |
+
"source": [
|
220 |
+
"Display output video"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"cell_type": "code",
|
225 |
+
"execution_count": null,
|
226 |
+
"metadata": {
|
227 |
+
"id": "6olmWwZP9Icj"
|
228 |
+
},
|
229 |
+
"outputs": [],
|
230 |
+
"source": [
|
231 |
+
"# borrow code from makeittalk\n",
|
232 |
+
"from IPython.display import HTML\n",
|
233 |
+
"from base64 import b64encode\n",
|
234 |
+
"import os, sys\n",
|
235 |
+
"import glob\n",
|
236 |
+
"\n",
|
237 |
+
"mp4_name = './output.mp4'\n",
|
238 |
+
"\n",
|
239 |
+
"mp4 = open('{}'.format(mp4_name),'rb').read()\n",
|
240 |
+
"data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
241 |
+
"\n",
|
242 |
+
"print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
|
243 |
+
"display(HTML(\"\"\"\n",
|
244 |
+
" <video width=256 controls>\n",
|
245 |
+
" <source src=\"%s\" type=\"video/mp4\">\n",
|
246 |
+
" </video>\n",
|
247 |
+
" \"\"\" % data_url))"
|
248 |
+
]
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"cell_type": "markdown",
|
252 |
+
"metadata": {},
|
253 |
+
"source": [
|
254 |
+
"WebUI"
|
255 |
+
]
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"cell_type": "code",
|
259 |
+
"execution_count": null,
|
260 |
+
"metadata": {},
|
261 |
+
"outputs": [],
|
262 |
+
"source": [
|
263 |
+
"\n",
|
264 |
+
"!python inference/app_real3dportrait.py --share"
|
265 |
+
]
|
266 |
+
}
|
267 |
+
],
|
268 |
+
"metadata": {
|
269 |
+
"accelerator": "GPU",
|
270 |
+
"colab": {
|
271 |
+
"authorship_tag": "ABX9TyPu++zOlOS4yKF4xn4FHGtZ",
|
272 |
+
"gpuType": "T4",
|
273 |
+
"include_colab_link": true,
|
274 |
+
"private_outputs": true,
|
275 |
+
"provenance": []
|
276 |
+
},
|
277 |
+
"kernelspec": {
|
278 |
+
"display_name": "Python 3",
|
279 |
+
"name": "python3"
|
280 |
+
},
|
281 |
+
"language_info": {
|
282 |
+
"name": "python"
|
283 |
+
}
|
284 |
+
},
|
285 |
+
"nbformat": 4,
|
286 |
+
"nbformat_minor": 0
|
287 |
+
}
|
inference/train_mimictalk_on_a_video.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
将One-shot的说话人大模型(os_secc2plane or os_secc2plane_torso)在单一说话人(一张照片或一段视频)上overfit, 实现和GeneFace++类似的效果
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import librosa
|
9 |
+
import random
|
10 |
+
import time
|
11 |
+
import numpy as np
|
12 |
+
import importlib
|
13 |
+
import tqdm
|
14 |
+
import copy
|
15 |
+
import cv2
|
16 |
+
import glob
|
17 |
+
import imageio
|
18 |
+
# common utils
|
19 |
+
from utils.commons.hparams import hparams, set_hparams
|
20 |
+
from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
|
21 |
+
from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
|
22 |
+
# 3DMM-related utils
|
23 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
24 |
+
from data_util.face3d_helper import Face3DHelper
|
25 |
+
from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
|
26 |
+
from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
|
27 |
+
from data_gen.utils.process_video.extract_segment_imgs import decode_segmap_mask_from_image
|
28 |
+
from deep_3drecon.secc_renderer import SECC_Renderer
|
29 |
+
from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
|
30 |
+
from data_gen.runs.binarizer_nerf import get_lip_rect
|
31 |
+
# Face Parsing
|
32 |
+
from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
|
33 |
+
from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
|
34 |
+
# other inference utils
|
35 |
+
from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
|
36 |
+
from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
|
37 |
+
from inference.edit_secc import blink_eye_for_secc, hold_eye_opened_for_secc
|
38 |
+
from modules.commons.loralib.utils import mark_only_lora_as_trainable
|
39 |
+
from utils.nn.model_utils import num_params
|
40 |
+
import lpips
|
41 |
+
from utils.commons.meters import AvgrageMeter
|
42 |
+
meter = AvgrageMeter()
|
43 |
+
from torch.utils.tensorboard import SummaryWriter
|
44 |
+
class LoRATrainer(nn.Module):
|
45 |
+
def __init__(self, inp):
|
46 |
+
super().__init__()
|
47 |
+
self.inp = inp
|
48 |
+
self.lora_args = {'lora_mode': inp['lora_mode'], 'lora_r': inp['lora_r']}
|
49 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
50 |
+
head_model_dir = inp['head_ckpt']
|
51 |
+
torso_model_dir = inp['torso_ckpt']
|
52 |
+
model_dir = torso_model_dir if torso_model_dir != '' else head_model_dir
|
53 |
+
cmd = f"cp {os.path.join(model_dir, 'config.yaml')} {self.inp['work_dir']}"
|
54 |
+
print(cmd)
|
55 |
+
os.system(cmd)
|
56 |
+
with open(os.path.join(self.inp['work_dir'], 'config.yaml'), "a") as f:
|
57 |
+
f.write(f"\nlora_r: {inp['lora_r']}")
|
58 |
+
f.write(f"\nlora_mode: {inp['lora_mode']}")
|
59 |
+
f.write(f"\n")
|
60 |
+
self.secc2video_model = self.load_secc2video(model_dir)
|
61 |
+
self.secc2video_model.to(device).eval()
|
62 |
+
self.seg_model = MediapipeSegmenter()
|
63 |
+
self.secc_renderer = SECC_Renderer(512)
|
64 |
+
self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
|
65 |
+
self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
|
66 |
+
# self.camera_selector = KNearestCameraSelector()
|
67 |
+
self.load_training_data(inp)
|
68 |
+
def load_secc2video(self, model_dir):
|
69 |
+
inp = self.inp
|
70 |
+
from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane, OSAvatarSECC_Img2plane_Torso
|
71 |
+
hp = set_hparams(f"{model_dir}/config.yaml", print_hparams=False, global_hparams=True)
|
72 |
+
hp['htbsr_head_threshold'] = 1.0
|
73 |
+
self.neural_rendering_resolution = hp['neural_rendering_resolution']
|
74 |
+
if 'torso' in hp['task_cls'].lower():
|
75 |
+
self.torso_mode = True
|
76 |
+
model = OSAvatarSECC_Img2plane_Torso(hp=hp, lora_args=self.lora_args)
|
77 |
+
else:
|
78 |
+
self.torso_mode = False
|
79 |
+
model = OSAvatarSECC_Img2plane(hp=hp, lora_args=self.lora_args)
|
80 |
+
mark_only_lora_as_trainable(model, bias='none')
|
81 |
+
lora_ckpt_path = os.path.join(inp['work_dir'], 'checkpoint.ckpt')
|
82 |
+
if os.path.exists(lora_ckpt_path):
|
83 |
+
self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, model.triplane_hid_dim*model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
|
84 |
+
model._last_cano_planes = self.learnable_triplane
|
85 |
+
load_ckpt(model, lora_ckpt_path, model_name='model', strict=False)
|
86 |
+
else:
|
87 |
+
load_ckpt(model, f"{model_dir}", model_name='model', strict=False)
|
88 |
+
|
89 |
+
num_params(model)
|
90 |
+
self.model = model
|
91 |
+
return model
|
92 |
+
def load_training_data(self, inp):
|
93 |
+
video_id = inp['video_id']
|
94 |
+
if video_id.endswith((".mp4", ".png", ".jpg", ".jpeg")):
|
95 |
+
# If input video is not GeneFace training videos, convert it into GeneFace convention
|
96 |
+
video_id_ = video_id
|
97 |
+
video_id = os.path.basename(video_id)[:-4]
|
98 |
+
inp['video_id'] = video_id
|
99 |
+
target_video_path = f'data/raw/videos/{video_id}.mp4'
|
100 |
+
if not os.path.exists(target_video_path):
|
101 |
+
print(f"| Copying video to {target_video_path}")
|
102 |
+
os.makedirs(os.path.dirname(target_video_path), exist_ok=True)
|
103 |
+
cmd = f"ffmpeg -i {video_id_} -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -y {target_video_path}"
|
104 |
+
print(f"| {cmd}")
|
105 |
+
os.system(cmd)
|
106 |
+
target_video_path = f'data/raw/videos/{video_id}.mp4'
|
107 |
+
print(f"| Copy source video into work dir: {self.inp['work_dir']}")
|
108 |
+
os.system(f"cp {target_video_path} {self.inp['work_dir']}")
|
109 |
+
# check head_img path
|
110 |
+
head_img_pattern = f'data/processed/videos/{video_id}/head_imgs/*.png'
|
111 |
+
head_img_names = sorted(glob.glob(head_img_pattern))
|
112 |
+
if len(head_img_names) == 0:
|
113 |
+
# extract head_imgs
|
114 |
+
head_img_dir = os.path.dirname(head_img_pattern)
|
115 |
+
print(f"| Pre-extracted head_imgs not found, try to extract and save to {head_img_dir}, this may take a while...")
|
116 |
+
gt_img_dir = f"data/processed/videos/{video_id}/gt_imgs"
|
117 |
+
os.makedirs(gt_img_dir, exist_ok=True)
|
118 |
+
target_video_path = f'data/raw/videos/{video_id}.mp4'
|
119 |
+
cmd = f"ffmpeg -i {target_video_path} -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -y {gt_img_dir}/%08d.jpg"
|
120 |
+
print(f"| {cmd}")
|
121 |
+
os.system(cmd)
|
122 |
+
# extract image, segmap, and background
|
123 |
+
cmd = f"python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir={target_video_path}"
|
124 |
+
print(f"| {cmd}")
|
125 |
+
os.system(cmd)
|
126 |
+
print("| Head images Extracted!")
|
127 |
+
num_samples = len(head_img_names)
|
128 |
+
npy_name = f"data/processed/videos/{video_id}/coeff_fit_mp_for_lora.npy"
|
129 |
+
if os.path.exists(npy_name):
|
130 |
+
coeff_dict = np.load(npy_name, allow_pickle=True).tolist()
|
131 |
+
else:
|
132 |
+
print(f"| Pre-extracted 3DMM coefficient not found, try to extract and save to {npy_name}, this may take a while...")
|
133 |
+
coeff_dict = fit_3dmm_for_a_video(f'data/raw/videos/{video_id}.mp4', save=False)
|
134 |
+
os.makedirs(os.path.dirname(npy_name), exist_ok=True)
|
135 |
+
np.save(npy_name, coeff_dict)
|
136 |
+
ids = convert_to_tensor(coeff_dict['id']).reshape([-1,80]).cuda()
|
137 |
+
exps = convert_to_tensor(coeff_dict['exp']).reshape([-1,64]).cuda()
|
138 |
+
eulers = convert_to_tensor(coeff_dict['euler']).reshape([-1,3]).cuda()
|
139 |
+
trans = convert_to_tensor(coeff_dict['trans']).reshape([-1,3]).cuda()
|
140 |
+
WH = 512 # now we only support 512x512
|
141 |
+
lm2ds = WH * self.face3d_helper.reconstruct_lm2d(ids, exps, eulers, trans).cpu().numpy()
|
142 |
+
lip_rects = [get_lip_rect(lm2ds[i], WH, WH) for i in range(len(lm2ds))]
|
143 |
+
kps = self.face3d_helper.reconstruct_lm2d(ids, exps, eulers, trans).cuda()
|
144 |
+
kps = (kps-0.5) / 0.5 # rescale to -1~1
|
145 |
+
kps = torch.cat([kps, torch.zeros([*kps.shape[:-1], 1]).cuda()], dim=-1)
|
146 |
+
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler': torch.tensor(coeff_dict['euler']).reshape([-1,3]), 'trans': torch.tensor(coeff_dict['trans']).reshape([-1,3])})
|
147 |
+
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
|
148 |
+
cameras = torch.tensor(np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)).cuda()
|
149 |
+
camera_smo_ksize = 7
|
150 |
+
cameras = smooth_camera_sequence(cameras.cpu().numpy(), kernel_size=camera_smo_ksize) # [T, 25]
|
151 |
+
cameras = torch.tensor(cameras).cuda()
|
152 |
+
zero_eulers = eulers * 0
|
153 |
+
zero_trans = trans * 0
|
154 |
+
_, cano_secc_color = self.secc_renderer(ids[0:1], exps[0:1]*0, zero_eulers[0:1], zero_trans[0:1])
|
155 |
+
src_idx = 0
|
156 |
+
_, src_secc_color = self.secc_renderer(ids[0:1], exps[src_idx:src_idx+1], zero_eulers[0:1], zero_trans[0:1])
|
157 |
+
drv_secc_colors = [None for _ in range(len(exps))]
|
158 |
+
drv_head_imgs = [None for _ in range(len(exps))]
|
159 |
+
drv_torso_imgs = [None for _ in range(len(exps))]
|
160 |
+
drv_com_imgs = [None for _ in range(len(exps))]
|
161 |
+
segmaps = [None for _ in range(len(exps))]
|
162 |
+
img_name = f'data/processed/videos/{video_id}/bg.jpg'
|
163 |
+
bg_img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
|
164 |
+
ds = {
|
165 |
+
'id': ids.cuda().float(),
|
166 |
+
'exps': exps.cuda().float(),
|
167 |
+
'eulers': eulers.cuda().float(),
|
168 |
+
'trans': trans.cuda().float(),
|
169 |
+
'cano_secc_color': cano_secc_color.cuda().float(),
|
170 |
+
'src_secc_color': src_secc_color.cuda().float(),
|
171 |
+
'cameras': cameras.float(),
|
172 |
+
'video_id': video_id,
|
173 |
+
'lip_rects': lip_rects,
|
174 |
+
'head_imgs': drv_head_imgs,
|
175 |
+
'torso_imgs': drv_torso_imgs,
|
176 |
+
'com_imgs': drv_com_imgs,
|
177 |
+
'bg_img': bg_img,
|
178 |
+
'segmaps': segmaps,
|
179 |
+
'kps': kps,
|
180 |
+
}
|
181 |
+
self.ds = ds
|
182 |
+
return ds
|
183 |
+
|
184 |
+
def training_loop(self, inp):
|
185 |
+
trainer = self
|
186 |
+
video_id = self.ds['video_id']
|
187 |
+
lora_params = [p for k, p in self.secc2video_model.named_parameters() if 'lora_' in k]
|
188 |
+
self.criterion_lpips = lpips.LPIPS(net='alex',lpips=True).cuda()
|
189 |
+
self.logger = SummaryWriter(log_dir=inp['work_dir'])
|
190 |
+
if not hasattr(self, 'learnable_triplane'):
|
191 |
+
src_idx = 0 # init triplane from the first frame's prediction
|
192 |
+
self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, self.secc2video_model.triplane_hid_dim*self.secc2video_model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
|
193 |
+
img_name = f'data/processed/videos/{video_id}/head_imgs/{format(src_idx, "08d")}.png'
|
194 |
+
img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float().cuda().float() # [3, H, W]
|
195 |
+
cano_plane = self.secc2video_model.cal_cano_plane(img.unsqueeze(0)) # [1, 3, CD, h, w]
|
196 |
+
self.learnable_triplane.data = cano_plane.data
|
197 |
+
self.secc2video_model._last_cano_planes = self.learnable_triplane
|
198 |
+
if len(lora_params) == 0:
|
199 |
+
self.optimizer = torch.optim.AdamW([self.learnable_triplane], lr=inp['lr_triplane'], weight_decay=0.01, betas=(0.9,0.98))
|
200 |
+
else:
|
201 |
+
self.optimizer = torch.optim.Adam(lora_params, lr=inp['lr'], betas=(0.9,0.98))
|
202 |
+
self.optimizer.add_param_group({
|
203 |
+
'params': [self.learnable_triplane],
|
204 |
+
'lr': inp['lr_triplane'],
|
205 |
+
'betas': (0.9, 0.98)
|
206 |
+
})
|
207 |
+
|
208 |
+
ids = self.ds['id']
|
209 |
+
exps = self.ds['exps']
|
210 |
+
zero_eulers = self.ds['eulers']*0
|
211 |
+
zero_trans = self.ds['trans']*0
|
212 |
+
num_updates = inp['max_updates']
|
213 |
+
batch_size = inp['batch_size'] # 1 for lower gpu mem usage
|
214 |
+
num_samples = len(self.ds['cameras'])
|
215 |
+
init_plane = self.learnable_triplane.detach().clone()
|
216 |
+
if num_samples <= 5:
|
217 |
+
lambda_reg_triplane = 1.0
|
218 |
+
elif num_samples <= 250:
|
219 |
+
lambda_reg_triplane = 0.1
|
220 |
+
else:
|
221 |
+
lambda_reg_triplane = 0.
|
222 |
+
for i_step in tqdm.trange(num_updates+1,desc="training lora..."):
|
223 |
+
milestone_steps = []
|
224 |
+
# milestone_steps = [100, 200, 500]
|
225 |
+
if i_step % 2000 == 0 or i_step in milestone_steps:
|
226 |
+
trainer.test_loop(inp, step=i_step)
|
227 |
+
if i_step != 0:
|
228 |
+
filepath = os.path.join(inp['work_dir'], f"model_ckpt_steps_{i_step}.ckpt")
|
229 |
+
checkpoint = self.dump_checkpoint(inp)
|
230 |
+
tmp_path = str(filepath) + ".part"
|
231 |
+
torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
|
232 |
+
os.replace(tmp_path, filepath)
|
233 |
+
|
234 |
+
drv_idx = [random.randint(0, num_samples-1) for _ in range(batch_size)]
|
235 |
+
drv_secc_colors = []
|
236 |
+
gt_imgs = []
|
237 |
+
head_imgs = []
|
238 |
+
segmaps_0 = []
|
239 |
+
segmaps = []
|
240 |
+
torso_imgs = []
|
241 |
+
drv_lip_rects = []
|
242 |
+
kp_src = []
|
243 |
+
kp_drv = []
|
244 |
+
for di in drv_idx:
|
245 |
+
# 读取target image
|
246 |
+
if self.torso_mode:
|
247 |
+
if self.ds['com_imgs'][di] is None:
|
248 |
+
# img_name = f'data/processed/videos/{video_id}/gt_imgs/{format(di, "08d")}.jpg'
|
249 |
+
img_name = f'data/processed/videos/{video_id}/com_imgs/{format(di, "08d")}.jpg'
|
250 |
+
img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
|
251 |
+
self.ds['com_imgs'][di] = img
|
252 |
+
gt_imgs.append(self.ds['com_imgs'][di])
|
253 |
+
else:
|
254 |
+
if self.ds['head_imgs'][di] is None:
|
255 |
+
img_name = f'data/processed/videos/{video_id}/head_imgs/{format(di, "08d")}.png'
|
256 |
+
img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
|
257 |
+
self.ds['head_imgs'][di] = img
|
258 |
+
gt_imgs.append(self.ds['head_imgs'][di])
|
259 |
+
if self.ds['head_imgs'][di] is None:
|
260 |
+
img_name = f'data/processed/videos/{video_id}/head_imgs/{format(di, "08d")}.png'
|
261 |
+
img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
|
262 |
+
self.ds['head_imgs'][di] = img
|
263 |
+
head_imgs.append(self.ds['head_imgs'][di])
|
264 |
+
# 使用第一帧的torso作为face v2v的输入
|
265 |
+
if self.ds['torso_imgs'][0] is None:
|
266 |
+
img_name = f'data/processed/videos/{video_id}/inpaint_torso_imgs/{format(0, "08d")}.png'
|
267 |
+
img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
|
268 |
+
self.ds['torso_imgs'][0] = img
|
269 |
+
torso_imgs.append(self.ds['torso_imgs'][0])
|
270 |
+
# 所以segmap也用第一帧的了
|
271 |
+
if self.ds['segmaps'][0] is None:
|
272 |
+
img_name = f'data/processed/videos/{video_id}/segmaps/{format(0, "08d")}.png'
|
273 |
+
seg_img = cv2.imread(img_name)[:,:, ::-1]
|
274 |
+
segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
|
275 |
+
self.ds['segmaps'][0] = segmap
|
276 |
+
segmaps_0.append(self.ds['segmaps'][0])
|
277 |
+
if self.ds['segmaps'][di] is None:
|
278 |
+
img_name = f'data/processed/videos/{video_id}/segmaps/{format(di, "08d")}.png'
|
279 |
+
seg_img = cv2.imread(img_name)[:,:, ::-1]
|
280 |
+
segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
|
281 |
+
self.ds['segmaps'][di] = segmap
|
282 |
+
segmaps.append(self.ds['segmaps'][di])
|
283 |
+
_, secc_color = self.secc_renderer(ids[0:1], exps[di:di+1], zero_eulers[0:1], zero_trans[0:1])
|
284 |
+
drv_secc_colors.append(secc_color)
|
285 |
+
drv_lip_rects.append(self.ds['lip_rects'][di])
|
286 |
+
kp_src.append(self.ds['kps'][0])
|
287 |
+
kp_drv.append(self.ds['kps'][di])
|
288 |
+
bg_img = self.ds['bg_img'].unsqueeze(0).repeat([batch_size, 1, 1, 1]).cuda()
|
289 |
+
ref_torso_imgs = torch.stack(torso_imgs).float().cuda()
|
290 |
+
kp_src = torch.stack(kp_src).float().cuda()
|
291 |
+
kp_drv = torch.stack(kp_drv).float().cuda()
|
292 |
+
segmaps = torch.stack(segmaps).float().cuda()
|
293 |
+
segmaps_0 = torch.stack(segmaps_0).float().cuda()
|
294 |
+
tgt_imgs = torch.stack(gt_imgs).float().cuda()
|
295 |
+
head_imgs = torch.stack(head_imgs).float().cuda()
|
296 |
+
drv_secc_color = torch.cat(drv_secc_colors)
|
297 |
+
cano_secc_color = self.ds['cano_secc_color'].repeat([batch_size, 1, 1, 1])
|
298 |
+
src_secc_color = self.ds['src_secc_color'].repeat([batch_size, 1, 1, 1])
|
299 |
+
cond = {'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_color,
|
300 |
+
'ref_torso_img': ref_torso_imgs, 'bg_img': bg_img,
|
301 |
+
'segmap': segmaps_0, # v2v使用第一帧的torso作为source image来warp
|
302 |
+
'kp_s': kp_src, 'kp_d': kp_drv}
|
303 |
+
camera = self.ds['cameras'][drv_idx]
|
304 |
+
gen_output = self.secc2video_model.forward(img=None, camera=camera, cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
|
305 |
+
pred_imgs = gen_output['image']
|
306 |
+
pred_imgs_raw = gen_output['image_raw']
|
307 |
+
|
308 |
+
losses = {}
|
309 |
+
loss_weights = {
|
310 |
+
'v2v_occlusion_reg_l1_loss': 0.001, # loss for face_vid2vid-based torso
|
311 |
+
'v2v_occlusion_2_reg_l1_loss': 0.001, # loss for face_vid2vid-based torso
|
312 |
+
'v2v_occlusion_2_weights_entropy_loss': hparams['lam_occlusion_weights_entropy'], # loss for face_vid2vid-based torso
|
313 |
+
'density_weight_l2_loss': 0.01, # supervised density
|
314 |
+
'density_weight_entropy_loss': 0.001, # keep the density change sharp
|
315 |
+
'mse_loss': 1.,
|
316 |
+
'head_mse_loss': 0.2, # loss on neural rendering low-reso pred_img
|
317 |
+
'lip_mse_loss': 1.0,
|
318 |
+
'lpips_loss': 0.5,
|
319 |
+
'head_lpips_loss': 0.1,
|
320 |
+
'lip_lpips_loss': 1.0, # make the teeth more clear
|
321 |
+
'blink_reg_loss': 0.003, # increase it when you find head shake while blinking; decrease it when you find the eye cannot closed.
|
322 |
+
'triplane_reg_loss': lambda_reg_triplane,
|
323 |
+
'secc_reg_loss': 0.01, # used to reduce flicking
|
324 |
+
}
|
325 |
+
|
326 |
+
occlusion_reg_l1 = gen_output.get("losses", {}).get('facev2v/occlusion_reg_l1', 0.)
|
327 |
+
occlusion_2_reg_l1 = gen_output.get("losses", {}).get('facev2v/occlusion_2_reg_l1', 0.)
|
328 |
+
occlusion_2_weights_entropy = gen_output.get("losses", {}).get('facev2v/occlusion_2_weights_entropy', 0.)
|
329 |
+
losses['v2v_occlusion_reg_l1_loss'] = occlusion_reg_l1
|
330 |
+
losses['v2v_occlusion_2_reg_l1_loss'] = occlusion_2_reg_l1
|
331 |
+
losses['v2v_occlusion_2_weights_entropy_loss'] = occlusion_2_weights_entropy
|
332 |
+
|
333 |
+
# Weights Reg loss in torso
|
334 |
+
neural_rendering_reso = self.neural_rendering_resolution
|
335 |
+
alphas = gen_output['weights_img'].clamp(1e-5, 1 - 1e-5)
|
336 |
+
loss_weights_entropy = torch.mean(- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas))
|
337 |
+
mv_head_masks = segmaps[:, [1,3,5]].sum(dim=1)
|
338 |
+
mv_head_masks_raw = F.interpolate(mv_head_masks.unsqueeze(1), size=(neural_rendering_reso,neural_rendering_reso)).squeeze(1)
|
339 |
+
face_mask = mv_head_masks_raw.bool().unsqueeze(1)
|
340 |
+
nonface_mask = ~ face_mask
|
341 |
+
loss_weights_l2_loss = (alphas[nonface_mask]-0).pow(2).mean() + (alphas[face_mask]-1).pow(2).mean()
|
342 |
+
losses['density_weight_l2_loss'] = loss_weights_l2_loss
|
343 |
+
losses['density_weight_entropy_loss'] = loss_weights_entropy
|
344 |
+
|
345 |
+
mse_loss = (pred_imgs - tgt_imgs).abs().mean()
|
346 |
+
head_mse_loss = (pred_imgs_raw - F.interpolate(head_imgs, size=(neural_rendering_reso,neural_rendering_reso), mode='bilinear', antialias=True)).abs().mean()
|
347 |
+
lpips_loss = self.criterion_lpips(pred_imgs, tgt_imgs).mean()
|
348 |
+
head_lpips_loss = self.criterion_lpips(pred_imgs_raw, F.interpolate(head_imgs, size=(neural_rendering_reso,neural_rendering_reso), mode='bilinear', antialias=True)).mean()
|
349 |
+
lip_mse_loss = 0
|
350 |
+
lip_lpips_loss = 0
|
351 |
+
for i in range(len(drv_idx)):
|
352 |
+
xmin, xmax, ymin, ymax = drv_lip_rects[i]
|
353 |
+
lip_tgt_imgs = tgt_imgs[i:i+1,:, ymin:ymax,xmin:xmax].contiguous()
|
354 |
+
lip_pred_imgs = pred_imgs[i:i+1,:, ymin:ymax,xmin:xmax].contiguous()
|
355 |
+
try:
|
356 |
+
lip_mse_loss = lip_mse_loss + (lip_pred_imgs - lip_tgt_imgs).abs().mean()
|
357 |
+
lip_lpips_loss = lip_lpips_loss + self.criterion_lpips(lip_pred_imgs, lip_tgt_imgs).mean()
|
358 |
+
except: pass
|
359 |
+
losses['mse_loss'] = mse_loss
|
360 |
+
losses['head_mse_loss'] = head_mse_loss
|
361 |
+
losses['lpips_loss'] = lpips_loss
|
362 |
+
losses['head_lpips_loss'] = head_lpips_loss
|
363 |
+
losses['lip_mse_loss'] = lip_mse_loss
|
364 |
+
losses['lip_lpips_loss'] = lip_lpips_loss
|
365 |
+
|
366 |
+
# eye blink reg loss
|
367 |
+
if i_step % 4 == 0:
|
368 |
+
blink_secc_lst1 = []
|
369 |
+
blink_secc_lst2 = []
|
370 |
+
blink_secc_lst3 = []
|
371 |
+
for i in range(len(drv_secc_color)):
|
372 |
+
secc = drv_secc_color[i]
|
373 |
+
blink_percent1 = random.random() * 0.5 # 0~0.5
|
374 |
+
blink_percent3 = 0.5 + random.random() * 0.5 # 0.5~1.0
|
375 |
+
blink_percent2 = (blink_percent1 + blink_percent3)/2
|
376 |
+
try:
|
377 |
+
out_secc1 = blink_eye_for_secc(secc, blink_percent1).to(secc.device)
|
378 |
+
out_secc2 = blink_eye_for_secc(secc, blink_percent2).to(secc.device)
|
379 |
+
out_secc3 = blink_eye_for_secc(secc, blink_percent3).to(secc.device)
|
380 |
+
except:
|
381 |
+
print("blink eye for secc failed, use original secc")
|
382 |
+
out_secc1 = copy.deepcopy(secc)
|
383 |
+
out_secc2 = copy.deepcopy(secc)
|
384 |
+
out_secc3 = copy.deepcopy(secc)
|
385 |
+
blink_secc_lst1.append(out_secc1)
|
386 |
+
blink_secc_lst2.append(out_secc2)
|
387 |
+
blink_secc_lst3.append(out_secc3)
|
388 |
+
src_secc_color1 = torch.stack(blink_secc_lst1)
|
389 |
+
src_secc_color2 = torch.stack(blink_secc_lst2)
|
390 |
+
src_secc_color3 = torch.stack(blink_secc_lst3)
|
391 |
+
blink_cond1 = {'cond_cano': cano_secc_color, 'cond_src': src_secc_color, 'cond_tgt': src_secc_color1}
|
392 |
+
blink_cond2 = {'cond_cano': cano_secc_color, 'cond_src': src_secc_color, 'cond_tgt': src_secc_color2}
|
393 |
+
blink_cond3 = {'cond_cano': cano_secc_color, 'cond_src': src_secc_color, 'cond_tgt': src_secc_color3}
|
394 |
+
blink_secc_plane1 = self.model.cal_secc_plane(blink_cond1)
|
395 |
+
blink_secc_plane2 = self.model.cal_secc_plane(blink_cond2)
|
396 |
+
blink_secc_plane3 = self.model.cal_secc_plane(blink_cond3)
|
397 |
+
interpolate_blink_secc_plane = (blink_secc_plane1 + blink_secc_plane3) / 2
|
398 |
+
blink_reg_loss = torch.nn.functional.l1_loss(blink_secc_plane2, interpolate_blink_secc_plane)
|
399 |
+
losses['blink_reg_loss'] = blink_reg_loss
|
400 |
+
|
401 |
+
# Triplane Reg loss
|
402 |
+
triplane_reg_loss = (self.learnable_triplane - init_plane).abs().mean()
|
403 |
+
losses['triplane_reg_loss'] = triplane_reg_loss
|
404 |
+
|
405 |
+
|
406 |
+
ref_id = self.ds['id'][0:1]
|
407 |
+
secc_pertube_randn_scale = hparams['secc_pertube_randn_scale']
|
408 |
+
perturbed_id = ref_id + torch.randn_like(ref_id) * secc_pertube_randn_scale
|
409 |
+
drv_exp = self.ds['exps'][drv_idx]
|
410 |
+
perturbed_exp = drv_exp + torch.randn_like(drv_exp) * secc_pertube_randn_scale
|
411 |
+
zero_euler = torch.zeros([len(drv_idx), 3], device=ref_id.device, dtype=ref_id.dtype)
|
412 |
+
zero_trans = torch.zeros([len(drv_idx), 3], device=ref_id.device, dtype=ref_id.dtype)
|
413 |
+
perturbed_secc = self.secc_renderer(perturbed_id, perturbed_exp, zero_euler, zero_trans)[1]
|
414 |
+
secc_reg_loss = torch.nn.functional.l1_loss(drv_secc_color, perturbed_secc)
|
415 |
+
losses['secc_reg_loss'] = secc_reg_loss
|
416 |
+
|
417 |
+
|
418 |
+
total_loss = sum([loss_weights[k] * v for k, v in losses.items() if isinstance(v, torch.Tensor) and v.requires_grad])
|
419 |
+
# Update weights
|
420 |
+
self.optimizer.zero_grad()
|
421 |
+
total_loss.backward()
|
422 |
+
self.learnable_triplane.grad.data = self.learnable_triplane.grad.data * self.learnable_triplane.numel()
|
423 |
+
self.optimizer.step()
|
424 |
+
meter.update(total_loss.item())
|
425 |
+
if i_step % 10 == 0:
|
426 |
+
log_line = f"Iter {i_step+1}: total_loss={meter.avg} "
|
427 |
+
for k, v in losses.items():
|
428 |
+
log_line = log_line + f" {k}={v.item()}, "
|
429 |
+
self.logger.add_scalar(f"train/{k}", v.item(), i_step)
|
430 |
+
print(log_line)
|
431 |
+
meter.reset()
|
432 |
+
@torch.no_grad()
|
433 |
+
def test_loop(self, inp, step=''):
|
434 |
+
self.model.eval()
|
435 |
+
# coeff_dict = np.load('data/processed/videos/Lieu/coeff_fit_mp_for_lora.npy', allow_pickle=True).tolist()
|
436 |
+
# drv_exps = torch.tensor(coeff_dict['exp']).cuda().float()
|
437 |
+
drv_exps = self.ds['exps']
|
438 |
+
zero_eulers = self.ds['eulers']*0
|
439 |
+
zero_trans = self.ds['trans']*0
|
440 |
+
batch_size = 1
|
441 |
+
num_samples = len(self.ds['cameras'])
|
442 |
+
video_writer = imageio.get_writer(os.path.join(inp['work_dir'], f'val_step{step}.mp4'), fps=25)
|
443 |
+
total_iters = min(num_samples, 250)
|
444 |
+
video_id = inp['video_id']
|
445 |
+
for i in tqdm.trange(total_iters,desc="testing lora..."):
|
446 |
+
drv_idx = [i]
|
447 |
+
drv_secc_colors = []
|
448 |
+
gt_imgs = []
|
449 |
+
segmaps = []
|
450 |
+
torso_imgs = []
|
451 |
+
drv_lip_rects = []
|
452 |
+
kp_src = []
|
453 |
+
kp_drv = []
|
454 |
+
for di in drv_idx:
|
455 |
+
# 读取target image
|
456 |
+
if self.torso_mode:
|
457 |
+
if self.ds['com_imgs'][di] is None:
|
458 |
+
img_name = f'data/processed/videos/{video_id}/com_imgs/{format(di, "08d")}.jpg'
|
459 |
+
img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
|
460 |
+
self.ds['com_imgs'][di] = img
|
461 |
+
gt_imgs.append(self.ds['com_imgs'][di])
|
462 |
+
else:
|
463 |
+
if self.ds['head_imgs'][di] is None:
|
464 |
+
img_name = f'data/processed/videos/{video_id}/head_imgs/{format(di, "08d")}.png'
|
465 |
+
img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
|
466 |
+
self.ds['head_imgs'][di] = img
|
467 |
+
gt_imgs.append(self.ds['head_imgs'][di])
|
468 |
+
# 使用第一帧的torso作为face v2v的输入
|
469 |
+
if self.ds['torso_imgs'][0] is None:
|
470 |
+
img_name = f'data/processed/videos/{video_id}/inpaint_torso_imgs/{format(0, "08d")}.png'
|
471 |
+
img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
|
472 |
+
self.ds['torso_imgs'][0] = img
|
473 |
+
torso_imgs.append(self.ds['torso_imgs'][0])
|
474 |
+
# 所以segmap也用第一帧的了
|
475 |
+
if self.ds['segmaps'][0] is None:
|
476 |
+
img_name = f'data/processed/videos/{video_id}/segmaps/{format(0, "08d")}.png'
|
477 |
+
seg_img = cv2.imread(img_name)[:,:, ::-1]
|
478 |
+
segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
|
479 |
+
self.ds['segmaps'][0] = segmap
|
480 |
+
segmaps.append(self.ds['segmaps'][0])
|
481 |
+
drv_lip_rects.append(self.ds['lip_rects'][di])
|
482 |
+
kp_src.append(self.ds['kps'][0])
|
483 |
+
kp_drv.append(self.ds['kps'][di])
|
484 |
+
bg_img = self.ds['bg_img'].unsqueeze(0).repeat([batch_size, 1, 1, 1]).cuda()
|
485 |
+
ref_torso_imgs = torch.stack(torso_imgs).float().cuda()
|
486 |
+
kp_src = torch.stack(kp_src).float().cuda()
|
487 |
+
kp_drv = torch.stack(kp_drv).float().cuda()
|
488 |
+
segmaps = torch.stack(segmaps).float().cuda()
|
489 |
+
tgt_imgs = torch.stack(gt_imgs).float().cuda()
|
490 |
+
for di in drv_idx:
|
491 |
+
_, secc_color = self.secc_renderer(self.ds['id'][0:1], drv_exps[di:di+1], zero_eulers[0:1], zero_trans[0:1])
|
492 |
+
drv_secc_colors.append(secc_color)
|
493 |
+
drv_secc_color = torch.cat(drv_secc_colors)
|
494 |
+
cano_secc_color = self.ds['cano_secc_color'].repeat([batch_size, 1, 1, 1])
|
495 |
+
src_secc_color = self.ds['src_secc_color'].repeat([batch_size, 1, 1, 1])
|
496 |
+
cond = {'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_color,
|
497 |
+
'ref_torso_img': ref_torso_imgs, 'bg_img': bg_img, 'segmap': segmaps,
|
498 |
+
'kp_s': kp_src, 'kp_d': kp_drv}
|
499 |
+
camera = self.ds['cameras'][drv_idx]
|
500 |
+
gen_output = self.secc2video_model.forward(img=None, camera=camera, cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
|
501 |
+
pred_img = gen_output['image']
|
502 |
+
pred_img = ((pred_img.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
|
503 |
+
video_writer.append_data(pred_img[0])
|
504 |
+
video_writer.close()
|
505 |
+
self.model.train()
|
506 |
+
|
507 |
+
def masked_error_loss(self, img_pred, img_gt, mask, unmasked_weight=0.1, mode='l1'):
|
508 |
+
# 对raw图像,因为deform的原因背景没法全黑,导致这部分mse过高,我们将其mask掉,只计算人脸部分
|
509 |
+
masked_weight = 1.0
|
510 |
+
weight_mask = mask.float() * masked_weight + (~mask).float() * unmasked_weight
|
511 |
+
if mode == 'l1':
|
512 |
+
error = (img_pred - img_gt).abs().sum(dim=1) * weight_mask
|
513 |
+
else:
|
514 |
+
error = (img_pred - img_gt).pow(2).sum(dim=1) * weight_mask
|
515 |
+
error.clamp_(0, max(0.5, error.quantile(0.8).item())) # clamp掉较高loss的pixel,避免姿态没对齐的pixel导致的异常值占主导影响训练
|
516 |
+
loss = error.mean()
|
517 |
+
return loss
|
518 |
+
|
519 |
+
def dilate(self, bin_img, ksize=5, mode='max_pool'):
|
520 |
+
"""
|
521 |
+
mode: max_pool or avg_pool
|
522 |
+
"""
|
523 |
+
# bin_img, [1, h, w]
|
524 |
+
pad = (ksize-1)//2
|
525 |
+
bin_img = F.pad(bin_img, pad=[pad,pad,pad,pad], mode='reflect')
|
526 |
+
if mode == 'max_pool':
|
527 |
+
out = F.max_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0)
|
528 |
+
else:
|
529 |
+
out = F.avg_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0)
|
530 |
+
return out
|
531 |
+
|
532 |
+
def dilate_mask(self, mask, ksize=21):
|
533 |
+
mask = self.dilate(mask, ksize=ksize, mode='max_pool')
|
534 |
+
return mask
|
535 |
+
|
536 |
+
def set_unmasked_to_black(self, img, mask):
|
537 |
+
out_img = img * mask.float() - (~mask).float() # -1 denotes black
|
538 |
+
return out_img
|
539 |
+
|
540 |
+
def dump_checkpoint(self, inp):
|
541 |
+
checkpoint = {}
|
542 |
+
# save optimizers
|
543 |
+
optimizer_states = []
|
544 |
+
self.optimizers = [self.optimizer]
|
545 |
+
for i, optimizer in enumerate(self.optimizers):
|
546 |
+
if optimizer is not None:
|
547 |
+
state_dict = optimizer.state_dict()
|
548 |
+
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
549 |
+
optimizer_states.append(state_dict)
|
550 |
+
checkpoint['optimizer_states'] = optimizer_states
|
551 |
+
state_dict = {
|
552 |
+
'model': self.model.state_dict(),
|
553 |
+
'learnable_triplane': self.model.state_dict()['_last_cano_planes'],
|
554 |
+
}
|
555 |
+
del state_dict['model']['_last_cano_planes']
|
556 |
+
checkpoint['state_dict'] = state_dict
|
557 |
+
checkpoint['lora_args'] = self.lora_args
|
558 |
+
person_ds = {}
|
559 |
+
video_id = inp['video_id']
|
560 |
+
img_name = f'data/processed/videos/{video_id}/gt_imgs/{format(0, "08d")}.jpg'
|
561 |
+
gt_img = torch.tensor(cv2.resize(cv2.imread(img_name), (512, 512))[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
|
562 |
+
person_ds['gt_img'] = gt_img.reshape([1, 3, 512, 512])
|
563 |
+
person_ds['id'] = self.ds['id'].cpu().reshape([1, 80])
|
564 |
+
person_ds['src_kp'] = self.ds['kps'][0].cpu()
|
565 |
+
person_ds['video_id'] = inp['video_id']
|
566 |
+
checkpoint['person_ds'] = person_ds
|
567 |
+
return checkpoint
|
568 |
+
if __name__ == '__main__':
|
569 |
+
import argparse, glob, tqdm
|
570 |
+
parser = argparse.ArgumentParser()
|
571 |
+
parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
|
572 |
+
# parser.add_argument("--torso_ckpt", default='checkpoints/240210_real3dportrait_orig/secc2plane_torso_orig') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
|
573 |
+
parser.add_argument("--torso_ckpt", default='checkpoints/mimictalk_orig/os_secc2plane_torso') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
|
574 |
+
parser.add_argument("--video_id", default='data/raw/examples/GER.mp4', help="identity source, we support (1) already processed <video_id> of GeneFace, (2) video path, (3) image path")
|
575 |
+
parser.add_argument("--work_dir", default=None)
|
576 |
+
parser.add_argument("--max_updates", default=10000, type=int, help="for video, 2000 is good; for an image, 3~10 is good")
|
577 |
+
parser.add_argument("--test", action='store_true')
|
578 |
+
parser.add_argument("--batch_size", default=1, type=int, help="batch size during training, 1 needs 8GB, 2 needs 15GB")
|
579 |
+
parser.add_argument("--lr", default=0.001)
|
580 |
+
parser.add_argument("--lr_triplane", default=0.005, help="for video, 0.1; for an image, 0.001; for ablation with_triplane, 0.")
|
581 |
+
parser.add_argument("--lora_r", default=2, type=int, help="width of lora unit")
|
582 |
+
parser.add_argument("--lora_mode", default='secc2plane_sr', help='for video, full; for an image, none')
|
583 |
+
|
584 |
+
args = parser.parse_args()
|
585 |
+
inp = {
|
586 |
+
'head_ckpt': args.head_ckpt,
|
587 |
+
'torso_ckpt': args.torso_ckpt,
|
588 |
+
'video_id': args.video_id,
|
589 |
+
'work_dir': args.work_dir,
|
590 |
+
'max_updates': args.max_updates,
|
591 |
+
'batch_size': args.batch_size,
|
592 |
+
'test': args.test,
|
593 |
+
'lr': float(args.lr),
|
594 |
+
'lr_triplane': float(args.lr_triplane),
|
595 |
+
'lora_mode': args.lora_mode,
|
596 |
+
'lora_r': args.lora_r,
|
597 |
+
}
|
598 |
+
if inp['work_dir'] == None:
|
599 |
+
video_id = os.path.basename(inp['video_id'])[:-4] if inp['video_id'].endswith((".mp4", ".png", ".jpg", ".jpeg")) else inp['video_id']
|
600 |
+
inp['work_dir'] = f'checkpoints_mimictalk/{video_id}'
|
601 |
+
os.makedirs(inp['work_dir'], exist_ok=True)
|
602 |
+
trainer = LoRATrainer(inp)
|
603 |
+
if inp['test']:
|
604 |
+
trainer.test_loop(inp)
|
605 |
+
else:
|
606 |
+
trainer.training_loop(inp)
|
607 |
+
trainer.test_loop(inp)
|
608 |
+
print(" ")
|
install_guide.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Prepare the Environment
|
2 |
+
[中文文档](./install_guide-zh.md)
|
3 |
+
|
4 |
+
This guide is about building a python environment for MimicTalk with Conda (the same as `Real3D-Portrait`).
|
5 |
+
|
6 |
+
The following installation process is verified in A100/V100 + CUDA12.1.
|
7 |
+
|
8 |
+
# Install Python Packages & CUDA
|
9 |
+
```bash
|
10 |
+
cd <MimicTalkRoot>
|
11 |
+
source <CondaRoot>/bin/activate
|
12 |
+
conda create -n mimictalk python=3.9
|
13 |
+
conda activate mimictalk
|
14 |
+
|
15 |
+
# MMCV for SegFormer network structure
|
16 |
+
# other dependencies
|
17 |
+
pip install -r docs/prepare_env/requirements.txt -v
|
18 |
+
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
|
19 |
+
pip install cython
|
20 |
+
pip install openmim==0.3.9
|
21 |
+
mim install mmcv==2.1.0 # use mim to speed up installation for mmcv
|
22 |
+
## build pytorch3d from Github's source code.
|
23 |
+
## It may take a long time (maybe tens of minutes), Proxy is recommended if encountering the time-out problem
|
24 |
+
# Before install pytorch3d, you need to install CUDA-12.1 (https://developer.nvidia.com/cuda-toolkit-archive) and make sure /usr/local/cuda points to the `cuda-12.1` directory
|
25 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
26 |
+
|
27 |
+
```
|
requirements.txt
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Cython
|
2 |
+
numpy # ==1.23.0
|
3 |
+
numba==0.56.4
|
4 |
+
pandas
|
5 |
+
transformers
|
6 |
+
scipy==1.11.1 # required by cal_fid. https://github.com/mseitzer/pytorch-fid/issues/103
|
7 |
+
scikit-learn
|
8 |
+
scikit-image
|
9 |
+
# tensorflow # you can flexible it, this is gpu version
|
10 |
+
tensorboard
|
11 |
+
tensorboardX
|
12 |
+
python_speech_features
|
13 |
+
resampy
|
14 |
+
opencv_python
|
15 |
+
face_alignment
|
16 |
+
matplotlib
|
17 |
+
configargparse
|
18 |
+
librosa==0.9.2
|
19 |
+
praat-parselmouth # ==0.4.3
|
20 |
+
trimesh
|
21 |
+
kornia==0.5.0
|
22 |
+
PyMCubes
|
23 |
+
lpips
|
24 |
+
setuptools # ==59.5.0
|
25 |
+
ffmpeg-python
|
26 |
+
moviepy
|
27 |
+
dearpygui
|
28 |
+
ninja
|
29 |
+
pyaudio # for extract esperanto
|
30 |
+
mediapipe
|
31 |
+
protobuf
|
32 |
+
decord
|
33 |
+
soundfile
|
34 |
+
pillow
|
35 |
+
# torch # it's better to install torch with conda
|
36 |
+
av
|
37 |
+
timm
|
38 |
+
pretrainedmodels
|
39 |
+
faiss-cpu # for fast nearest camera pose retriveal
|
40 |
+
einops
|
41 |
+
# mmcv # use mim install is faster
|
42 |
+
|
43 |
+
# conditional flow matching
|
44 |
+
beartype
|
45 |
+
torchode
|
46 |
+
torchdiffeq
|
47 |
+
|
48 |
+
# tts
|
49 |
+
cython
|
50 |
+
textgrid
|
51 |
+
pyloudnorm
|
52 |
+
websocket-client
|
53 |
+
pyworld==0.2.1rc0
|
54 |
+
pypinyin==0.42.0
|
55 |
+
webrtcvad
|
56 |
+
torchshow
|
57 |
+
|
58 |
+
# cal spk sim
|
59 |
+
# s3prl
|
60 |
+
# fire
|
61 |
+
|
62 |
+
# cal LMD
|
63 |
+
# dlib
|
64 |
+
|
65 |
+
# debug
|
66 |
+
# ipykernel
|
67 |
+
|
68 |
+
# lama
|
69 |
+
# hydra-core
|
70 |
+
# pytorch_lightning
|
71 |
+
# setproctitle
|
72 |
+
|
73 |
+
# Gradio GUI
|
74 |
+
# httpx==0.23.3
|
75 |
+
# gradio==4.16.0
|
76 |
+
gradio==4.43.0
|
77 |
+
httpx==0.23.3
|
78 |
+
# gradio_client==0.8.1
|
79 |
+
fastapi==0.112.2
|