mrbear1024 commited on
Commit
b32806b
·
1 Parent(s): 28e8293

init project

Browse files
Dockerfile ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
2
+
3
+ # Install system dependencies
4
+ RUN apt-get update && apt-get install -y \
5
+ git \
6
+ wget \
7
+ ffmpeg \
8
+ libsndfile1 \
9
+ python3-pip \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Install Miniconda
13
+ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh \
14
+ && bash miniconda.sh -b -p /opt/conda \
15
+ && rm miniconda.sh
16
+
17
+ # Add conda to path
18
+ ENV PATH=/opt/conda/bin:$PATH
19
+
20
+ # Create and activate conda environment
21
+ RUN conda create -n mimictalk python=3.9 -y
22
+ SHELL ["conda", "run", "-n", "mimictalk", "/bin/bash", "-c"]
23
+
24
+ # Install PyTorch and other dependencies
25
+ RUN pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 \
26
+ && pip install cython openmim==0.3.9 \
27
+ && mim install mmcv==2.1.0 \
28
+ && pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
29
+
30
+ # Install requirements
31
+ COPY requirements.txt /tmp/
32
+ RUN pip install -r /tmp/requirements.txt
33
+
34
+ # Download pre-trained models
35
+ RUN pip install huggingface_hub \
36
+ && python -c "from huggingface_hub import snapshot_download; snapshot_download('mrbear1024/demo-model', local_dir='/app/models')"
37
+
38
+ # Set working directory
39
+ WORKDIR /app
40
+
41
+ # Set default command to activate conda env
42
+ CMD ["conda", "run", "--no-capture-output", "-n", "mimictalk", "python", "inference/app_mimictalk.py"]
README.md CHANGED
@@ -5,6 +5,196 @@ colorFrom: blue
5
  colorTo: pink
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  colorTo: pink
6
  sdk: docker
7
  pinned: false
8
+ app_file: inference/app_mimictalk.py
9
+ models:
10
+ - mrbear1024/demo-model
11
+
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+
17
+ # MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes | NeurIPS 2024
18
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-%3CCOLOR%3E.svg)](https://arxiv.org/abs/2401.08503)| [![GitHub Stars](https://img.shields.io/github/stars/yerfor/MimicTalk
19
+ )](https://github.com/yerfor/MimicTalk) | [English Readme](./README.md)
20
+
21
+ 这个仓库是MimicTalk的官方PyTorch实现, 用于实现特定说话人的高表现力的虚拟人视频合成。该仓库代码基于我们先前的工作[Real3D-Portrait](https://github.com/yerfor/Real3DPortrait) (ICLR 2024),即基于NeRF的one-shot说话人合成,这让Mimictalk的训练加速且效果增强。您可以访问我们的[项目页面](https://mimictalk.github.io/)以观看Demo视频, 阅读我们的[论文](https://arxiv.org/abs/2410.06734)以了解技术细节。
22
+
23
+ <p align="center">
24
+ <br>
25
+ <img src="assets/mimictalk.png" width="100%"/>
26
+ <br>
27
+ </p>
28
+
29
+ # 快速上手!
30
+ ## 安装环境
31
+ 请参照[环境配置文档](docs/prepare_env/install_guide-zh.md),配置Conda环境`mimictalk`
32
+ ## 下载预训练与第三方模型
33
+ ### 3DMM BFM模型
34
+ 下载3DMM BFM模型:[Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) 提取码: m9q5
35
+
36
+
37
+ 下载完成后,放置全部的文件到`deep_3drecon/BFM`里,文件结构如下:
38
+ ```
39
+ deep_3drecon/BFM/
40
+ ├── 01_MorphableModel.mat
41
+ ├── BFM_exp_idx.mat
42
+ ├── BFM_front_idx.mat
43
+ ├── BFM_model_front.mat
44
+ ├── Exp_Pca.bin
45
+ ├── facemodel_info.mat
46
+ ├── index_mp468_from_mesh35709.npy
47
+ ├── mediapipe_in_bfm53201.npy
48
+ └── std_exp.txt
49
+ ```
50
+
51
+ ### 预训练模型
52
+ 下载预训练的MimicTalk相关Checkpoints:[Google Drive](https://drive.google.com/drive/folders/1Kc6ueDO9HFDN3BhtJCEKNCZtyKHSktaA?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1nQKyGV5JB6rJtda7qsThUg?pwd=mimi) 提取码: mimi
53
+
54
+ 下载完成后,放置全部的文件到`checkpoints`与`checkpoints_mimictalk`里并解压,文件结构如下:
55
+ ```
56
+ checkpoints/
57
+ ├── mimictalk_orig
58
+ │ └── os_secc2plane_torso
59
+ │ ├── config.yaml
60
+ │ └── model_ckpt_steps_100000.ckpt
61
+ |-- 240112_icl_audio2secc_vox2_cmlr
62
+ │ ├── config.yaml
63
+ │ └── model_ckpt_steps_1856000.ckpt
64
+ └── pretrained_ckpts
65
+ └── mit_b0.pth
66
+
67
+ checkpoints_mimictalk/
68
+ └── German_20s
69
+ ├── config.yaml
70
+ └── model_ckpt_steps_10000.ckpt
71
+ ```
72
+
73
+ ## MimicTalk训练与推理的最简命令
74
+ ```
75
+ python inference/train_mimictalk_on_a_video.py # train the model, this may take 10 minutes for 2,000 steps
76
+ python inference/mimictalk_infer.py # infer the model
77
+ ```
78
+
79
+
80
+ # 训练与推理细节
81
+ 我们目前提供了**命令行(CLI)**与**Gradio WebUI**推理方式。音频驱动推理的人像信息来自于`torso_ckpt`,因此需要至少再提供`driving audio`用于推理。另外,可以提供`style video`让模型能够预测与该视频风格一致的说话人动作。
82
+
83
+ 首先,切换至项目根目录并启用Conda环境:
84
+ ```bash
85
+ cd <Real3DPortraitRoot>
86
+ conda activate mimictalk
87
+ export PYTHONPATH=./
88
+ export HF_ENDPOINT=https://hf-mirror.com
89
+ ```
90
+
91
+ ## Gradio WebUI推理
92
+ 启动Gradio WebUI,按照提示上传素材,点击`Training`按钮进行训练;训练完成后点击`Generate`按钮即可推理:
93
+ ```bash
94
+ python inference/app_mimictalk.py
95
+ ```
96
+
97
+ ## 使用Transformers Pipeline(Hugging Face Space)
98
+ 在Hugging Face Space中,我们提供了使用Transformers Pipeline的简化版本,可以直接从Hugging Face Hub下载模型文件:
99
+
100
+ ```bash
101
+ python app.py --model_id mrbear1024/demo-model
102
+ ```
103
+
104
+ 您也可以通过环境变量设置模型来源:
105
+ ```bash
106
+ export MODEL_ID=mrbear1024/demo-model
107
+ python app.py
108
+ ```
109
+
110
+ ### 模型路径说明
111
+ 在Hugging Face Space中,模型文件默认会被下载到`/app/models`目录。您可以通过以下方式自定义模型路径:
112
+
113
+ ```bash
114
+ # 使用命令行参数
115
+ python app.py --model_dir /path/to/models
116
+
117
+ # 使用环境变量
118
+ export MODEL_DIR=/path/to/models
119
+ python app.py
120
+ ```
121
+
122
+ 如果您在本地运行,模型文件会被下载到临时目录中。如果您想指定自定义路径,可以使用上述方法。
123
+
124
+ ## 命令行特定说话人训练
125
+
126
+ 需要至少提供`source video`,训练指令:
127
+ ```bash
128
+ python inference/train_mimictalk_on_a_video.py \
129
+ --video_id <PATH_TO_SOURCE_VIDEO> \
130
+ --max_updates <UPDATES_NUMBER> \
131
+ --work_dir <PATH_TO_SAVING_CKPT>
132
+ ```
133
+
134
+ 一些可选参数注释:
135
+
136
+ - `--torso_ckpt` 预训练的Real3D-Portrait模型
137
+ - `--max_updates` 训练更新次数
138
+ - `--batch_size` 训练的batch size: `1` 需要约8GB显存; `2`需要约15GB��存
139
+ - `--lr_triplane` triplane的学习率:对于视频输入, 应为0.1; 对于图片输入,应为0.001
140
+ - `--work_dir` 未指定时,将默认存储在`checkpoints_mimictalk/`中
141
+
142
+ 指令示例:
143
+ ```bash
144
+ python inference/train_mimictalk_on_a_video.py \
145
+ --video_id data/raw/videos/German_20s.mp4 \
146
+ --max_updates 2000 \
147
+ --work_dir checkpoints_mimictalk/German_20s
148
+ ```
149
+
150
+ ## 命令行推理
151
+
152
+ 需要至少提供`driving audio`,可选提供`driving style`,推理指令:
153
+ ```bash
154
+ python inference/mimictalk_infer.py \
155
+ --drv_aud <PATH_TO_AUDIO> \
156
+ --drv_style <PATH_TO_STYLE_VIDEO, OPTIONAL> \
157
+ --drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
158
+ --bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
159
+ --out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
160
+ ```
161
+
162
+ 一些可选参数注释:
163
+ - `--drv_pose` 指定时提供了运动pose信息,不指定则为静态运动
164
+ - `--bg_img` 指定时提供了背景信息,不指定则为source image提取的背景
165
+ - `--mouth_amp` 嘴部张幅参数,值越大张幅越大
166
+ - `--map_to_init_pose` 值为`True`时,首帧的pose将被映射到source pose,后续帧也作相同变换
167
+ - `--temperature` 代表audio2motion的采样温度,值越大结果越多样,但同时精确度越低
168
+ - `--out_name` 不指定时,结果将保存在`infer_out/tmp/`中
169
+ - `--out_mode` 值为`final`时,只输出说话人视频;值为`concat_debug`时,同时输出一些可视化的中间结果
170
+
171
+ 推理命令例子:
172
+ ```bash
173
+ python inference/mimictalk_infer.py \
174
+ --drv_aud data/raw/examples/Obama_5s.wav \
175
+ --drv_pose data/raw/examples/German_20s.mp4 \
176
+ --drv_style data/raw/examples/German_20s.mp4 \
177
+ --bg_img data/raw/examples/bg.png \
178
+ --out_name output.mp4 \
179
+ --out_mode final
180
+ ```
181
+
182
+ # 声明
183
+ 任何组织或个人未经本人同意,不得使用本文提及的任何技术生成他人说话的视频,包括但不限于政府领导人、政界人士、社会名流等。如不遵守本条款,则可能违反版权法。
184
+
185
+ # 引用我们
186
+ 如果这个仓库对你有帮助,请考虑引用我们的工作:
187
+ ```
188
+ @inproceedings{ye2024mimicktalk,
189
+ author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiangwei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and Zhang, Chen and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
190
+ title = {MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes},
191
+ journal = {NeurIPS},
192
+ year = {2024},
193
+ }
194
+ @inproceedings{ye2024real3d,
195
+ title = {Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
196
+ author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
197
+ journal = {ICLR},
198
+ year={2024}
199
+ }
200
+ ```
inference/app_mimictalk.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append('./')
3
+ import argparse
4
+ import traceback
5
+ import gradio as gr
6
+ # from inference.real3d_infer import GeneFace2Infer
7
+ from inference.mimictalk_infer import AdaptGeneFace2Infer
8
+ from inference.train_mimictalk_on_a_video import LoRATrainer
9
+ from utils.commons.hparams import hparams
10
+
11
+ class Trainer():
12
+ def __init__(self):
13
+ pass
14
+
15
+ def train_once_args(self, *args, **kargs):
16
+ assert len(kargs) == 0
17
+ args = [
18
+ '', # head_ckpt
19
+ 'checkpoints/mimictalk_orig/os_secc2plane_torso/', # torso_ckpt
20
+ args[0],
21
+ '',
22
+ 10000,
23
+ 1,
24
+ False,
25
+ 0.001,
26
+ 0.005,
27
+ 0.2,
28
+ 'secc2plane_sr',
29
+ 2,
30
+ ]
31
+ keys = [
32
+ 'head_ckpt',
33
+ 'torso_ckpt',
34
+ 'video_id',
35
+ 'work_dir',
36
+ 'max_updates',
37
+ 'batch_size',
38
+ 'test',
39
+ 'lr',
40
+ 'lr_triplane',
41
+ 'lambda_lpips',
42
+ 'lora_mode',
43
+ 'lora_r',
44
+ ]
45
+ inp = {}
46
+ info = ""
47
+
48
+ try: # try to catch errors and jump to return
49
+ for key_index in range(len(keys)):
50
+ key = keys[key_index]
51
+ inp[key] = args[key_index]
52
+ if '_name' in key:
53
+ inp[key] = inp[key] if inp[key] is not None else ''
54
+
55
+ if inp['video_id'] == '':
56
+ info = "Input Error: Source video is REQUIRED!"
57
+ raise ValueError
58
+
59
+ inp['out_name'] = ''
60
+ inp['seed'] = 42
61
+
62
+ if inp['work_dir'] is None or inp['work_dir'] == '':
63
+ video_id = os.path.basename(inp['video_id'])[:-4] if inp['video_id'].endswith((".mp4", ".png", ".jpg", ".jpeg")) else inp['video_id']
64
+ inp['work_dir'] = f'checkpoints_mimictalk/{video_id}'
65
+ try:
66
+ trainer = LoRATrainer(inp)
67
+ trainer.training_loop(inp)
68
+ except Exception as e:
69
+ content = f"{e}"
70
+ info = f"Training ERROR: {content}"
71
+ traceback.print_exc()
72
+ raise ValueError
73
+ except Exception as e:
74
+ if info == "": # unexpected errors
75
+ content = f"{e}"
76
+ info = f"WebUI ERROR: {content}"
77
+ traceback.print_exc()
78
+
79
+
80
+ # output part
81
+ if len(info) > 0 : # there is errors
82
+ print(info)
83
+ info_gr = gr.update(visible=True, value=info)
84
+ else: # no errors
85
+ info_gr = gr.update(visible=False, value=info)
86
+
87
+ torso_model_dir = gr.FileExplorer(glob="checkpoints_mimictalk/**/*.ckpt", value=inp['work_dir'], file_count='single', label='mimictalk model ckpt path or directory')
88
+
89
+
90
+ return info_gr, torso_model_dir
91
+
92
+ class Inferer(AdaptGeneFace2Infer):
93
+ def infer_once_args(self, *args, **kargs):
94
+ assert len(kargs) == 0
95
+ keys = [
96
+ # 'src_image_name',
97
+ 'drv_audio_name',
98
+ 'drv_pose_name',
99
+ 'drv_talking_style_name',
100
+ 'bg_image_name',
101
+ 'blink_mode',
102
+ 'temperature',
103
+ 'cfg_scale',
104
+ 'out_mode',
105
+ 'map_to_init_pose',
106
+ 'low_memory_usage',
107
+ 'hold_eye_opened',
108
+ 'a2m_ckpt',
109
+ # 'head_ckpt',
110
+ 'torso_ckpt',
111
+ 'min_face_area_percent',
112
+ ]
113
+ inp = {}
114
+ out_name = None
115
+ info = ""
116
+
117
+ try: # try to catch errors and jump to return
118
+ for key_index in range(len(keys)):
119
+ key = keys[key_index]
120
+ inp[key] = args[key_index]
121
+ if '_name' in key:
122
+ inp[key] = inp[key] if inp[key] is not None else ''
123
+
124
+ inp['head_ckpt'] = ''
125
+
126
+ # if inp['src_image_name'] == '':
127
+ # info = "Input Error: Source image is REQUIRED!"
128
+ # raise ValueError
129
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] == '':
130
+ info = "Input Error: At least one of driving audio or video is REQUIRED!"
131
+ raise ValueError
132
+
133
+
134
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] != '':
135
+ inp['drv_audio_name'] = inp['drv_pose_name']
136
+ print("No audio input, we use driving pose video for video driving")
137
+
138
+ if inp['drv_pose_name'] == '':
139
+ inp['drv_pose_name'] = 'static'
140
+
141
+ reload_flag = False
142
+ if inp['a2m_ckpt'] != self.audio2secc_dir:
143
+ print("Changes of a2m_ckpt detected, reloading model")
144
+ reload_flag = True
145
+ if inp['head_ckpt'] != self.head_model_dir:
146
+ print("Changes of head_ckpt detected, reloading model")
147
+ reload_flag = True
148
+ if inp['torso_ckpt'] != self.torso_model_dir:
149
+ print("Changes of torso_ckpt detected, reloading model")
150
+ reload_flag = True
151
+
152
+ inp['out_name'] = ''
153
+ inp['seed'] = 42
154
+ inp['denoising_steps'] = 20
155
+ print(f"infer inputs : {inp}")
156
+
157
+ try:
158
+ if reload_flag:
159
+ self.__init__(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp, device=self.device)
160
+ except Exception as e:
161
+ content = f"{e}"
162
+ info = f"Reload ERROR: {content}"
163
+ traceback.print_exc()
164
+ raise ValueError
165
+ try:
166
+ out_name = self.infer_once(inp)
167
+ except Exception as e:
168
+ content = f"{e}"
169
+ info = f"Inference ERROR: {content}"
170
+ traceback.print_exc()
171
+ raise ValueError
172
+ except Exception as e:
173
+ if info == "": # unexpected errors
174
+ content = f"{e}"
175
+ info = f"WebUI ERROR: {content}"
176
+ traceback.print_exc()
177
+
178
+ # output part
179
+ if len(info) > 0 : # there is errors
180
+ print(info)
181
+ info_gr = gr.update(visible=True, value=info)
182
+ else: # no errors
183
+ info_gr = gr.update(visible=False, value=info)
184
+ if out_name is not None and len(out_name) > 0 and os.path.exists(out_name): # good output
185
+ print(f"Succefully generated in {out_name}")
186
+ video_gr = gr.update(visible=True, value=out_name)
187
+ else:
188
+ print(out_name)
189
+ print(os.path.exists(out_name))
190
+ print(f"Failed to generate")
191
+ video_gr = gr.update(visible=True, value=out_name)
192
+
193
+ return video_gr, info_gr
194
+
195
+ def toggle_audio_file(choice):
196
+ if choice == False:
197
+ return gr.update(visible=True), gr.update(visible=False)
198
+ else:
199
+ return gr.update(visible=False), gr.update(visible=True)
200
+
201
+ def ref_video_fn(path_of_ref_video):
202
+ if path_of_ref_video is not None:
203
+ return gr.update(value=True)
204
+ else:
205
+ return gr.update(value=False)
206
+
207
+ def mimictalk_demo(
208
+ audio2secc_dir,
209
+ head_model_dir,
210
+ torso_model_dir,
211
+ device = 'cuda',
212
+ warpfn = None,
213
+ ):
214
+
215
+ sep_line = "-" * 40
216
+
217
+ infer_obj = Inferer(
218
+ audio2secc_dir=audio2secc_dir,
219
+ head_model_dir=head_model_dir,
220
+ torso_model_dir=torso_model_dir,
221
+ device=device,
222
+ )
223
+ train_obj = Trainer()
224
+
225
+ print(sep_line)
226
+ print("Model loading is finished.")
227
+ print(sep_line)
228
+ with gr.Blocks(analytics_enabled=False) as real3dportrait_interface:
229
+ # gr.Markdown("\
230
+ # <div align='center'> <h2> MimicTalk: Mimicking a personalized and expressive 3D talking face in minutes (NIPS 2024) </span> </h2> \
231
+ # <a style='font-size:18px;color: #a0a0a0' href=''>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
232
+ # <a style='font-size:18px;color: #a0a0a0' href='https://mimictalk.github.io/'>Homepage</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
233
+ # <a style='font-size:18px;color: #a0a0a0' href='https://github.com/yerfor/MimicTalk/'> Github </div>")
234
+
235
+ gr.Markdown("\
236
+ <div align='center'> <h2> MimicTalk: Mimicking a personalized and expressive 3D talking face in minutes (NIPS 2024) </span> </h2> \
237
+ <a style='font-size:18px;color: #a0a0a0' href='https://mimictalk.github.io/'>Homepage</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;")
238
+
239
+ sources = None
240
+ with gr.Row():
241
+ with gr.Column(variant='panel'):
242
+ with gr.Tabs(elem_id="source_image"):
243
+ with gr.TabItem('Upload Training Video'):
244
+ with gr.Row():
245
+ src_video_name = gr.Video(label="Source video (required for training)", sources=sources, value="data/raw/videos/German_20s.mp4")
246
+ # src_video_name = gr.Image(label="Source video (required for training)", sources=sources, type="filepath", value="data/raw/videos/German_20s.mp4")
247
+ with gr.Tabs(elem_id="driven_audio"):
248
+ with gr.TabItem('Upload Driving Audio'):
249
+ with gr.Column(variant='panel'):
250
+ drv_audio_name = gr.Audio(label="Input audio (required for inference)", sources=sources, type="filepath", value="data/raw/examples/80_vs_60_10s.wav")
251
+ with gr.Tabs(elem_id="driven_style"):
252
+ with gr.TabItem('Upload Style Prompt'):
253
+ with gr.Column(variant='panel'):
254
+ drv_style_name = gr.Video(label="Driven Style (optional for inference)", sources=sources, value="data/raw/videos/German_20s.mp4")
255
+ with gr.Tabs(elem_id="driven_pose"):
256
+ with gr.TabItem('Upload Driving Pose'):
257
+ with gr.Column(variant='panel'):
258
+ drv_pose_name = gr.Video(label="Driven Pose (optional for inference)", sources=sources, value="data/raw/videos/German_20s.mp4")
259
+ with gr.Tabs(elem_id="bg_image"):
260
+ with gr.TabItem('Upload Background Image'):
261
+ with gr.Row():
262
+ bg_image_name = gr.Image(label="Background image (optional for inference)", sources=sources, type="filepath", value=None)
263
+
264
+
265
+ with gr.Column(variant='panel'):
266
+ with gr.Tabs(elem_id="checkbox"):
267
+ with gr.TabItem('General Settings'):
268
+ with gr.Column(variant='panel'):
269
+
270
+ blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
271
+ min_face_area_percent = gr.Slider(minimum=0.15, maximum=0.5, step=0.01, label="min_face_area_percent", value=0.2, info='The minimum face area percent in the output frame, to prevent bad cases caused by a too small face.',)
272
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
273
+ cfg_scale = gr.Slider(minimum=1.0, maximum=3.0, step=0.025, label="talking style cfg_scale", value=1.5, info='higher -> encourage the generated motion more coherent to talking style',)
274
+ out_mode = gr.Radio(['final', 'concat_debug'], value='concat_debug', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
275
+ low_memory_usage = gr.Checkbox(label="Low Memory Usage Mode: save memory at the expense of lower inference speed. Useful when running a low audio (minutes-long).", value=False)
276
+ map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose", value=True)
277
+ hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
278
+
279
+ train_submit = gr.Button('Train', elem_id="train", variant='primary')
280
+ infer_submit = gr.Button('Generate', elem_id="generate", variant='primary')
281
+
282
+ with gr.Tabs(elem_id="genearted_video"):
283
+ info_box = gr.Textbox(label="Error", interactive=False, visible=False)
284
+ gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
285
+ with gr.Column(variant='panel'):
286
+ with gr.Tabs(elem_id="checkbox"):
287
+ with gr.TabItem('Checkpoints'):
288
+ with gr.Column(variant='panel'):
289
+ ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
290
+ audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
291
+ # head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
292
+ torso_model_dir = gr.FileExplorer(glob="checkpoints_mimictalk/**/*.ckpt", value=torso_model_dir, file_count='single', label='mimictalk model ckpt path or directory')
293
+ # audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
294
+ # head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
295
+ # torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
296
+
297
+
298
+ fn = infer_obj.infer_once_args
299
+ if warpfn:
300
+ fn = warpfn(fn)
301
+ infer_submit.click(
302
+ fn=fn,
303
+ inputs=[
304
+ drv_audio_name,
305
+ drv_pose_name,
306
+ drv_style_name,
307
+ bg_image_name,
308
+ blink_mode,
309
+ temperature,
310
+ cfg_scale,
311
+ out_mode,
312
+ map_to_init_pose,
313
+ low_memory_usage,
314
+ hold_eye_opened,
315
+ audio2secc_dir,
316
+ # head_model_dir,
317
+ torso_model_dir,
318
+ min_face_area_percent,
319
+ ],
320
+ outputs=[
321
+ gen_video,
322
+ info_box,
323
+ ],
324
+ )
325
+
326
+ fn_train = train_obj.train_once_args
327
+
328
+ train_submit.click(
329
+ fn=fn_train,
330
+ inputs=[
331
+ src_video_name,
332
+
333
+ ],
334
+ outputs=[
335
+ # gen_video,
336
+ info_box,
337
+ torso_model_dir,
338
+ ],
339
+ )
340
+
341
+ print(sep_line)
342
+ print("Gradio page is constructed.")
343
+ print(sep_line)
344
+
345
+ return real3dportrait_interface
346
+
347
+ if __name__ == "__main__":
348
+ parser = argparse.ArgumentParser()
349
+ parser.add_argument("--a2m_ckpt", default='checkpoints/240112_icl_audio2secc_vox2_cmlr') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
350
+ parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
351
+ parser.add_argument("--torso_ckpt", default='checkpoints_mimictalk/German_20s/model_ckpt_steps_10000.ckpt')
352
+ parser.add_argument("--port", type=int, default=None)
353
+ parser.add_argument("--server", type=str, default='127.0.0.1')
354
+ parser.add_argument("--share", action='store_true', dest='share', help='share srever to Internet')
355
+
356
+ args = parser.parse_args()
357
+ demo = mimictalk_demo(
358
+ audio2secc_dir=args.a2m_ckpt,
359
+ head_model_dir=args.head_ckpt,
360
+ torso_model_dir=args.torso_ckpt,
361
+ device='cuda:0',
362
+ warpfn=None,
363
+ )
364
+ demo.queue()
365
+ demo.launch(share=args.share, server_name=args.server, server_port=args.port)
inference/app_real3dportrait.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append('./')
3
+ import argparse
4
+ import gradio as gr
5
+ from inference.real3d_infer import GeneFace2Infer
6
+ from utils.commons.hparams import hparams
7
+
8
+ class Inferer(GeneFace2Infer):
9
+ def infer_once_args(self, *args, **kargs):
10
+ assert len(kargs) == 0
11
+ keys = [
12
+ 'src_image_name',
13
+ 'drv_audio_name',
14
+ 'drv_pose_name',
15
+ 'bg_image_name',
16
+ 'blink_mode',
17
+ 'temperature',
18
+ 'mouth_amp',
19
+ 'out_mode',
20
+ 'map_to_init_pose',
21
+ 'low_memory_usage',
22
+ 'hold_eye_opened',
23
+ 'a2m_ckpt',
24
+ 'head_ckpt',
25
+ 'torso_ckpt',
26
+ 'min_face_area_percent',
27
+ ]
28
+ inp = {}
29
+ out_name = None
30
+ info = ""
31
+
32
+ try: # try to catch errors and jump to return
33
+ for key_index in range(len(keys)):
34
+ key = keys[key_index]
35
+ inp[key] = args[key_index]
36
+ if '_name' in key:
37
+ inp[key] = inp[key] if inp[key] is not None else ''
38
+
39
+ if inp['src_image_name'] == '':
40
+ info = "Input Error: Source image is REQUIRED!"
41
+ raise ValueError
42
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] == '':
43
+ info = "Input Error: At least one of driving audio or video is REQUIRED!"
44
+ raise ValueError
45
+
46
+
47
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] != '':
48
+ inp['drv_audio_name'] = inp['drv_pose_name']
49
+ print("No audio input, we use driving pose video for video driving")
50
+
51
+ if inp['drv_pose_name'] == '':
52
+ inp['drv_pose_name'] = 'static'
53
+
54
+ reload_flag = False
55
+ if inp['a2m_ckpt'] != self.audio2secc_dir:
56
+ print("Changes of a2m_ckpt detected, reloading model")
57
+ reload_flag = True
58
+ if inp['head_ckpt'] != self.head_model_dir:
59
+ print("Changes of head_ckpt detected, reloading model")
60
+ reload_flag = True
61
+ if inp['torso_ckpt'] != self.torso_model_dir:
62
+ print("Changes of torso_ckpt detected, reloading model")
63
+ reload_flag = True
64
+
65
+ inp['out_name'] = ''
66
+ inp['seed'] = 42
67
+
68
+ print(f"infer inputs : {inp}")
69
+
70
+ try:
71
+ if reload_flag:
72
+ self.__init__(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp, device=self.device)
73
+ except Exception as e:
74
+ content = f"{e}"
75
+ info = f"Reload ERROR: {content}"
76
+ raise ValueError
77
+ try:
78
+ out_name = self.infer_once(inp)
79
+ except Exception as e:
80
+ content = f"{e}"
81
+ info = f"Inference ERROR: {content}"
82
+ raise ValueError
83
+ except Exception as e:
84
+ if info == "": # unexpected errors
85
+ content = f"{e}"
86
+ info = f"WebUI ERROR: {content}"
87
+
88
+ # output part
89
+ if len(info) > 0 : # there is errors
90
+ print(info)
91
+ info_gr = gr.update(visible=True, value=info)
92
+ else: # no errors
93
+ info_gr = gr.update(visible=False, value=info)
94
+ if out_name is not None and len(out_name) > 0 and os.path.exists(out_name): # good output
95
+ print(f"Succefully generated in {out_name}")
96
+ video_gr = gr.update(visible=True, value=out_name)
97
+ else:
98
+ print(f"Failed to generate")
99
+ video_gr = gr.update(visible=True, value=out_name)
100
+
101
+ return video_gr, info_gr
102
+
103
+ def toggle_audio_file(choice):
104
+ if choice == False:
105
+ return gr.update(visible=True), gr.update(visible=False)
106
+ else:
107
+ return gr.update(visible=False), gr.update(visible=True)
108
+
109
+ def ref_video_fn(path_of_ref_video):
110
+ if path_of_ref_video is not None:
111
+ return gr.update(value=True)
112
+ else:
113
+ return gr.update(value=False)
114
+
115
+ def real3dportrait_demo(
116
+ audio2secc_dir,
117
+ head_model_dir,
118
+ torso_model_dir,
119
+ device = 'cuda',
120
+ warpfn = None,
121
+ ):
122
+
123
+ sep_line = "-" * 40
124
+
125
+ infer_obj = Inferer(
126
+ audio2secc_dir=audio2secc_dir,
127
+ head_model_dir=head_model_dir,
128
+ torso_model_dir=torso_model_dir,
129
+ device=device,
130
+ )
131
+
132
+ print(sep_line)
133
+ print("Model loading is finished.")
134
+ print(sep_line)
135
+ with gr.Blocks(analytics_enabled=False) as real3dportrait_interface:
136
+ gr.Markdown("\
137
+ <div align='center'> <h2> Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis (ICLR 2024 Spotlight) </span> </h2> \
138
+ <a style='font-size:18px;color: #a0a0a0' href='https://arxiv.org/pdf/2401.08503.pdf'>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
139
+ <a style='font-size:18px;color: #a0a0a0' href='https://real3dportrait.github.io/'>Homepage</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
140
+ <a style='font-size:18px;color: #a0a0a0' href='https://github.com/yerfor/Real3DPortrait/'> Github </div>")
141
+
142
+ sources = None
143
+ with gr.Row():
144
+ with gr.Column(variant='panel'):
145
+ with gr.Tabs(elem_id="source_image"):
146
+ with gr.TabItem('Upload image'):
147
+ with gr.Row():
148
+ src_image_name = gr.Image(label="Source image (required)", sources=sources, type="filepath", value="data/raw/examples/Macron.png")
149
+ with gr.Tabs(elem_id="driven_audio"):
150
+ with gr.TabItem('Upload audio'):
151
+ with gr.Column(variant='panel'):
152
+ drv_audio_name = gr.Audio(label="Input audio (required for audio-driven)", sources=sources, type="filepath", value="data/raw/examples/Obama_5s.wav")
153
+ with gr.Tabs(elem_id="driven_pose"):
154
+ with gr.TabItem('Upload video'):
155
+ with gr.Column(variant='panel'):
156
+ drv_pose_name = gr.Video(label="Driven Pose (required for video-driven, optional for audio-driven)", sources=sources, value="data/raw/examples/May_5s.mp4")
157
+ with gr.Tabs(elem_id="bg_image"):
158
+ with gr.TabItem('Upload image'):
159
+ with gr.Row():
160
+ bg_image_name = gr.Image(label="Background image (optional)", sources=sources, type="filepath", value="data/raw/examples/bg.png")
161
+
162
+
163
+ with gr.Column(variant='panel'):
164
+ with gr.Tabs(elem_id="checkbox"):
165
+ with gr.TabItem('General Settings'):
166
+ with gr.Column(variant='panel'):
167
+
168
+ blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
169
+ min_face_area_percent = gr.Slider(minimum=0.15, maximum=0.5, step=0.01, label="min_face_area_percent", value=0.2, info='The minimum face area percent in the output frame, to prevent bad cases caused by a too small face.',)
170
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
171
+ mouth_amp = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="mouth amplitude", value=0.45, info='higher -> mouth will open wider, default to be 0.4',)
172
+ out_mode = gr.Radio(['final', 'concat_debug'], value='concat_debug', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
173
+ low_memory_usage = gr.Checkbox(label="Low Memory Usage Mode: save memory at the expense of lower inference speed. Useful when running a low audio (minutes-long).", value=False)
174
+ map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose", value=True)
175
+ hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
176
+
177
+ submit = gr.Button('Generate', elem_id="generate", variant='primary')
178
+
179
+ with gr.Tabs(elem_id="genearted_video"):
180
+ info_box = gr.Textbox(label="Error", interactive=False, visible=False)
181
+ gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
182
+ with gr.Column(variant='panel'):
183
+ with gr.Tabs(elem_id="checkbox"):
184
+ with gr.TabItem('Checkpoints'):
185
+ with gr.Column(variant='panel'):
186
+ ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
187
+ audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
188
+ head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
189
+ torso_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=torso_model_dir, file_count='single', label='torso model ckpt path or directory')
190
+ # audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
191
+ # head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
192
+ # torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
193
+
194
+
195
+ fn = infer_obj.infer_once_args
196
+ if warpfn:
197
+ fn = warpfn(fn)
198
+ submit.click(
199
+ fn=fn,
200
+ inputs=[
201
+ src_image_name,
202
+ drv_audio_name,
203
+ drv_pose_name,
204
+ bg_image_name,
205
+ blink_mode,
206
+ temperature,
207
+ mouth_amp,
208
+ out_mode,
209
+ map_to_init_pose,
210
+ low_memory_usage,
211
+ hold_eye_opened,
212
+ audio2secc_dir,
213
+ head_model_dir,
214
+ torso_model_dir,
215
+ min_face_area_percent,
216
+ ],
217
+ outputs=[
218
+ gen_video,
219
+ info_box,
220
+ ],
221
+ )
222
+
223
+ print(sep_line)
224
+ print("Gradio page is constructed.")
225
+ print(sep_line)
226
+
227
+ return real3dportrait_interface
228
+
229
+ if __name__ == "__main__":
230
+ parser = argparse.ArgumentParser()
231
+ parser.add_argument("--a2m_ckpt", type=str, default='checkpoints/240210_real3dportrait_orig/audio2secc_vae/model_ckpt_steps_400000.ckpt')
232
+ parser.add_argument("--head_ckpt", type=str, default='')
233
+ parser.add_argument("--torso_ckpt", type=str, default='checkpoints/240210_real3dportrait_orig/secc2plane_torso_orig/model_ckpt_steps_100000.ckpt')
234
+ parser.add_argument("--port", type=int, default=None)
235
+ parser.add_argument("--server", type=str, default='127.0.0.1')
236
+ parser.add_argument("--share", action='store_true', dest='share', help='share srever to Internet')
237
+
238
+ args = parser.parse_args()
239
+ demo = real3dportrait_demo(
240
+ audio2secc_dir=args.a2m_ckpt,
241
+ head_model_dir=args.head_ckpt,
242
+ torso_model_dir=args.torso_ckpt,
243
+ device='cuda:0',
244
+ warpfn=None,
245
+ )
246
+ demo.queue()
247
+ demo.launch(share=args.share, server_name=args.server, server_port=args.port)
inference/edit_secc.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from utils.commons.image_utils import dilate, erode
4
+ from sklearn.neighbors import NearestNeighbors
5
+ import copy
6
+ import numpy as np
7
+ from utils.commons.meters import Timer
8
+
9
+ def hold_eye_opened_for_secc(img):
10
+ img = img.permute(1,2,0).cpu().numpy()
11
+ img = ((img +1)/2*255).astype(np.uint)
12
+ face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
13
+ face_xys = np.stack(np.nonzero(face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
14
+ h,w = face_mask.shape
15
+ # get face and eye mask
16
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
17
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
18
+ left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
19
+ right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
20
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
21
+ coarse_eye_mask = (~ face_mask) & eye_prior_reigon
22
+ coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
23
+
24
+ opened_eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
25
+ opened_eye_mask = torch.nn.functional.interpolate(torch.tensor(opened_eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[0], img.shape[1]), mode='nearest')[0].permute(1,2,0).sum(-1).bool().cpu() # [512,512,3]
26
+ coarse_opened_eye_xys = np.stack(np.nonzero(opened_eye_mask)) # [N_nonbg,2] coordinate of non-face pixels
27
+
28
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
29
+ dists, _ = nbrs.kneighbors(coarse_opened_eye_xys) # [512*512, 1] distance to nearest non-bg pixel
30
+ # print(dists.max())
31
+ non_opened_eye_pixs = dists > max(dists.max()*0.75, 4) # 大于这个距离的opened eye部分会被合上
32
+ non_opened_eye_pixs = non_opened_eye_pixs.reshape([-1])
33
+ opened_eye_xys_to_erode = coarse_opened_eye_xys[non_opened_eye_pixs]
34
+ opened_eye_mask[opened_eye_xys_to_erode[...,0], opened_eye_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
35
+
36
+ img[opened_eye_mask] = 0
37
+ return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
38
+
39
+
40
+ # def hold_eye_opened_for_secc(img):
41
+ # img = copy.copy(img)
42
+ # eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
43
+ # eye_mask = torch.nn.functional.interpolate(torch.tensor(eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[-2], img.shape[-1]), mode='nearest')[0].bool().to(img.device) # [3,512,512]
44
+ # img[eye_mask] = -1
45
+ # return img
46
+
47
+ def blink_eye_for_secc(img, close_eye_percent=0.5):
48
+ """
49
+ secc_img: [3,h,w], tensor, -1~1
50
+ """
51
+ img = img.permute(1,2,0).cpu().numpy()
52
+ img = ((img +1)/2*255).astype(np.uint)
53
+ assert close_eye_percent <= 1.0 and close_eye_percent >= 0.
54
+ if close_eye_percent == 0: return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
55
+ img = copy.deepcopy(img)
56
+ face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
57
+ h,w = face_mask.shape
58
+
59
+ # get face and eye mask
60
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
61
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
62
+ left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
63
+ right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
64
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
65
+ coarse_eye_mask = (~ face_mask) & eye_prior_reigon
66
+ coarse_left_eye_mask = (~ face_mask) & left_eye_prior_reigon
67
+ coarse_right_eye_mask = (~ face_mask) & right_eye_prior_reigon
68
+ coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
69
+ min_h = coarse_eye_xys[:, 0].min()
70
+ max_h = coarse_eye_xys[:, 0].max()
71
+ coarse_left_eye_xys = np.stack(np.nonzero(coarse_left_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
72
+ left_min_w = coarse_left_eye_xys[:, 1].min()
73
+ left_max_w = coarse_left_eye_xys[:, 1].max()
74
+ coarse_right_eye_xys = np.stack(np.nonzero(coarse_right_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
75
+ right_min_w = coarse_right_eye_xys[:, 1].min()
76
+ right_max_w = coarse_right_eye_xys[:, 1].max()
77
+
78
+ # 尽力较少需要考虑的face_xyz,以降低KNN的损耗
79
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
80
+ more_room = 4 # 过小会导致一些问题
81
+ left_eye_prior_reigon[min_h-more_room:max_h+more_room, left_min_w-more_room:left_max_w+more_room] = True
82
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
83
+ right_eye_prior_reigon[min_h-more_room:max_h+more_room, right_min_w-more_room:right_max_w+more_room] = True
84
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
85
+
86
+ around_eye_face_mask = face_mask & eye_prior_reigon
87
+ face_mask = around_eye_face_mask
88
+ face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
89
+
90
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
91
+ dists, _ = nbrs.kneighbors(face_xys) # [512*512, 1] distance to nearest non-bg pixel
92
+ face_pixs = dists > 5 # 只有距离最近的eye pixel大于5的才被认为是face,过小会导致一些问题
93
+ face_pixs = face_pixs.reshape([-1])
94
+ face_xys_to_erode = face_xys[~face_pixs]
95
+ face_mask[face_xys_to_erode[...,0], face_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
96
+ eye_mask = (~ face_mask) & eye_prior_reigon
97
+
98
+ h_grid = np.mgrid[0:h, 0:w][0]
99
+ eye_num_pixel_along_w_axis = eye_mask.sum(axis=0)
100
+ eye_mask_along_w_axis = eye_num_pixel_along_w_axis != 0
101
+
102
+ tmp_h_grid = h_grid.copy()
103
+ tmp_h_grid[~eye_mask] = 0
104
+ eye_mean_h_coord_along_w_axis = tmp_h_grid.sum(axis=0) / np.clip(eye_num_pixel_along_w_axis, a_min=1, a_max=h)
105
+ tmp_h_grid = h_grid.copy()
106
+ tmp_h_grid[~eye_mask] = 99999
107
+ eye_min_h_coord_along_w_axis = tmp_h_grid.min(axis=0)
108
+ tmp_h_grid = h_grid.copy()
109
+ tmp_h_grid[~eye_mask] = -99999
110
+ eye_max_h_coord_along_w_axis = tmp_h_grid.max(axis=0)
111
+
112
+ eye_low_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_min_h_coord_along_w_axis # upper eye
113
+ eye_high_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_max_h_coord_along_w_axis # lower eye
114
+
115
+ tmp_h_grid = h_grid.copy()
116
+ tmp_h_grid[~eye_mask] = 99999
117
+ upper_eye_blink_mask = tmp_h_grid <= eye_low_h_coord_along_w_axis
118
+ tmp_h_grid = h_grid.copy()
119
+ tmp_h_grid[~eye_mask] = -99999
120
+ lower_eye_blink_mask = tmp_h_grid >= eye_high_h_coord_along_w_axis
121
+ eye_blink_mask = upper_eye_blink_mask | lower_eye_blink_mask
122
+
123
+ face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
124
+ eye_blink_xys = np.stack(np.nonzero(eye_blink_mask)).transpose(1, 0) # [N_nonbg,hw] coordinate of non-face pixels
125
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(face_xys)
126
+ distances, indices = nbrs.kneighbors(eye_blink_xys)
127
+ bg_fg_xys = face_xys[indices[:, 0]]
128
+ img[eye_blink_xys[:, 0], eye_blink_xys[:, 1], :] = img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
129
+ return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
130
+
131
+
132
+ if __name__ == '__main__':
133
+ import imageio
134
+ import tqdm
135
+ img = cv2.imread("assets/cano_secc.png")
136
+ img = img / 127.5 - 1
137
+ img = torch.FloatTensor(img).permute(2, 0, 1)
138
+ fps = 25
139
+ writer = imageio.get_writer('demo_blink.mp4', fps=fps)
140
+
141
+ for i in tqdm.trange(33):
142
+ blink_percent = 0.03 * i
143
+ with Timer("Blink", True):
144
+ out_img = blink_eye_for_secc(img, blink_percent)
145
+ out_img = ((out_img.permute(1,2,0)+1)*127.5).int().numpy()
146
+ writer.append_data(out_img)
147
+ writer.close()
inference/infer_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import librosa
5
+ import numpy as np
6
+ import importlib
7
+ import tqdm
8
+ import copy
9
+ import cv2
10
+ from scipy.spatial.transform import Rotation
11
+
12
+
13
+ def load_img_to_512_hwc_array(img_name):
14
+ img = cv2.imread(img_name)
15
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
16
+ img = cv2.resize(img, (512, 512))
17
+ return img
18
+
19
+ def load_img_to_normalized_512_bchw_tensor(img_name):
20
+ img = load_img_to_512_hwc_array(img_name)
21
+ img = ((torch.tensor(img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2) # [b,c,h,w]
22
+ return img
23
+
24
+ def mirror_index(index, len_seq):
25
+ """
26
+ get mirror index when indexing a sequence and the index is larger than len_pose
27
+ args:
28
+ index: int
29
+ len_pose: int
30
+ return:
31
+ mirror_index: int
32
+ """
33
+ turn = index // len_seq
34
+ res = index % len_seq
35
+ if turn % 2 == 0:
36
+ return res # forward indexing
37
+ else:
38
+ return len_seq - res - 1 # reverse indexing
39
+
40
+ def smooth_camera_sequence(camera, kernel_size=7):
41
+ """
42
+ smooth the camera trajectory (i.e., rotation & translation)...
43
+ args:
44
+ camera: [N, 25] or [N, 16]. np.ndarray
45
+ kernel_size: int
46
+ return:
47
+ smoothed_camera: [N, 25] or [N, 16]. np.ndarray
48
+ """
49
+ # poses: [N, 25], numpy array
50
+ N = camera.shape[0]
51
+ K = kernel_size // 2
52
+ poses = camera[:, :16].reshape([-1, 4, 4]).copy()
53
+ trans = poses[:, :3, 3].copy() # [N, 3]
54
+ rots = poses[:, :3, :3].copy() # [N, 3, 3]
55
+
56
+ for i in range(N):
57
+ start = max(0, i - K)
58
+ end = min(N, i + K + 1)
59
+ poses[i, :3, 3] = trans[start:end].mean(0)
60
+ try:
61
+ poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix()
62
+ except:
63
+ if i == 0:
64
+ poses[i, :3, :3] = rots[i]
65
+ else:
66
+ poses[i, :3, :3] = poses[i-1, :3, :3]
67
+ poses = poses.reshape([-1, 16])
68
+ camera[:, :16] = poses
69
+ return camera
70
+
71
+ def smooth_features_xd(in_tensor, kernel_size=7):
72
+ """
73
+ smooth the feature maps
74
+ args:
75
+ in_tensor: [T, c,h,w] or [T, c1,c2,h,w]
76
+ kernel_size: int
77
+ return:
78
+ out_tensor: [T, c,h,w] or [T, c1,c2,h,w]
79
+ """
80
+ t = in_tensor.shape[0]
81
+ ndim = in_tensor.ndim
82
+ pad = (kernel_size- 1)//2
83
+ in_tensor = torch.cat([torch.flip(in_tensor[0:pad], dims=[0]), in_tensor, torch.flip(in_tensor[t-pad:t], dims=[0])], dim=0)
84
+ if ndim == 2: # tc
85
+ _,c = in_tensor.shape
86
+ in_tensor = in_tensor.permute(1,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
87
+ elif ndim == 4: # tchw
88
+ _,c,h,w = in_tensor.shape
89
+ in_tensor = in_tensor.permute(1,2,3,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
90
+ elif ndim == 5: # tcchw, like deformation
91
+ _,c1,c2, h,w = in_tensor.shape
92
+ in_tensor = in_tensor.permute(1,2,3,4,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
93
+ else: raise NotImplementedError()
94
+ avg_kernel = 1 / kernel_size * torch.Tensor([1.]*kernel_size).reshape([1,1,kernel_size]).float().to(in_tensor.device) # [1, 1, kw]
95
+ out_tensor = F.conv1d(in_tensor, avg_kernel)
96
+ if ndim == 2: # tc
97
+ return out_tensor.reshape([c,t]).permute(1,0)
98
+ elif ndim == 4: # tchw
99
+ return out_tensor.reshape([c,h,w,t]).permute(3,0,1,2)
100
+ elif ndim == 5: # tcchw, like deformation
101
+ return out_tensor.reshape([c1,c2,h,w,t]).permute(4,0,1,2,3)
102
+
103
+
104
+ def extract_audio_motion_from_ref_video(video_name):
105
+ def save_wav16k(audio_name):
106
+ supported_types = ('.wav', '.mp3', '.mp4', '.avi')
107
+ assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
108
+ wav16k_name = audio_name[:-4] + '_16k.wav'
109
+ extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
110
+ os.system(extract_wav_cmd)
111
+ print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
112
+ return wav16k_name
113
+
114
+ def get_f0( wav16k_name):
115
+ from data_gen.utils.process_audio.extract_mel_f0 import extract_mel_from_fname, extract_f0_from_wav_and_mel
116
+ wav, mel = extract_mel_from_fname(wav16k_name)
117
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
118
+ f0 = f0.reshape([-1,1])
119
+ f0 = torch.tensor(f0)
120
+ return f0
121
+
122
+ def get_hubert(wav16k_name):
123
+ from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
124
+ hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
125
+ len_mel = hubert.shape[0]
126
+ x_multiply = 8
127
+ if len_mel % x_multiply == 0:
128
+ num_to_pad = 0
129
+ else:
130
+ num_to_pad = x_multiply - len_mel % x_multiply
131
+ hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
132
+ hubert = torch.tensor(hubert)
133
+ return hubert
134
+
135
+ def get_exp(video_name):
136
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
137
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(video_name, save=False)
138
+ exp = torch.tensor(drv_motion_coeff_dict['exp'])
139
+ return exp
140
+
141
+ wav16k_name = save_wav16k(video_name)
142
+ f0 = get_f0(wav16k_name)
143
+ hubert = get_hubert(wav16k_name)
144
+ os.system(f"rm {wav16k_name}")
145
+ exp = get_exp(video_name)
146
+ target_length = min(len(exp), len(hubert)//2, len(f0)//2)
147
+ exp = exp[:target_length]
148
+ f0 = f0[:target_length*2]
149
+ hubert = hubert[:target_length*2]
150
+ return exp.unsqueeze(0), hubert.unsqueeze(0), f0.unsqueeze(0)
151
+
152
+
153
+ if __name__ == '__main__':
154
+ extract_audio_motion_from_ref_video('data/raw/videos/crop_0213.mp4')
inference/mimictalk_infer.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 用于推理 inference/train_mimictalk_on_a_video.py 得到的person-specific模型
3
+ """
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ # import librosa
9
+ import random
10
+ import time
11
+ import numpy as np
12
+ import importlib
13
+ import tqdm
14
+ import copy
15
+ import cv2
16
+
17
+ # common utils
18
+ from utils.commons.hparams import hparams, set_hparams
19
+ from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
20
+ from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
21
+ # 3DMM-related utils
22
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
23
+ from data_util.face3d_helper import Face3DHelper
24
+ from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
25
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
26
+ from deep_3drecon.secc_renderer import SECC_Renderer
27
+ from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
28
+ # Face Parsing
29
+ from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
30
+ from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
31
+ # other inference utils
32
+ from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
33
+ from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
34
+ from inference.edit_secc import blink_eye_for_secc, hold_eye_opened_for_secc
35
+ from inference.real3d_infer import GeneFace2Infer
36
+
37
+
38
+ class AdaptGeneFace2Infer(GeneFace2Infer):
39
+ def __init__(self, audio2secc_dir, head_model_dir, torso_model_dir, device=None, **kwargs):
40
+ if device is None:
41
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
42
+ self.device = device
43
+ self.audio2secc_dir = audio2secc_dir
44
+ self.head_model_dir = head_model_dir
45
+ self.torso_model_dir = torso_model_dir
46
+ self.audio2secc_model = self.load_audio2secc(audio2secc_dir)
47
+ self.secc2video_model = self.load_secc2video(head_model_dir, torso_model_dir)
48
+ self.audio2secc_model.to(device).eval()
49
+ self.secc2video_model.to(device).eval()
50
+ self.seg_model = MediapipeSegmenter()
51
+ self.secc_renderer = SECC_Renderer(512)
52
+ self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
53
+ self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
54
+ # self.camera_selector = KNearestCameraSelector()
55
+
56
+ def load_secc2video(self, head_model_dir, torso_model_dir):
57
+ if torso_model_dir != '':
58
+ config_dir = torso_model_dir if os.path.isdir(torso_model_dir) else os.path.dirname(torso_model_dir)
59
+ set_hparams(f"{config_dir}/config.yaml", print_hparams=False)
60
+ hparams['htbsr_head_threshold'] = 1.0
61
+ self.secc2video_hparams = copy.deepcopy(hparams)
62
+ ckpt = get_last_checkpoint(torso_model_dir)[0]
63
+ lora_args = ckpt.get("lora_args", None)
64
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane_Torso
65
+ model = OSAvatarSECC_Img2plane_Torso(self.secc2video_hparams, lora_args=lora_args)
66
+ load_ckpt(model, f"{torso_model_dir}", model_name='model', strict=True)
67
+ self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, model.triplane_hid_dim*model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
68
+ load_ckpt(self.learnable_triplane, f"{torso_model_dir}", model_name='learnable_triplane', strict=True)
69
+ model._last_cano_planes = self.learnable_triplane
70
+ if head_model_dir != '':
71
+ print("| Warning: Assigned --torso_ckpt which also contains head, but --head_ckpt is also assigned, skipping the --head_ckpt.")
72
+ else:
73
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane
74
+ set_hparams(f"{head_model_dir}/config.yaml", print_hparams=False)
75
+ ckpt = get_last_checkpoint(head_model_dir)[0]
76
+ lora_args = ckpt.get("lora_args", None)
77
+ self.secc2video_hparams = copy.deepcopy(hparams)
78
+ model = OSAvatarSECC_Img2plane(self.secc2video_hparams, lora_args=lora_args)
79
+ load_ckpt(model, f"{head_model_dir}", model_name='model', strict=True)
80
+ self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, model.triplane_hid_dim*model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
81
+ model._last_cano_planes = self.learnable_triplane
82
+ load_ckpt(model._last_cano_planes, f"{head_model_dir}", model_name='learnable_triplane', strict=True)
83
+ self.person_ds = ckpt['person_ds']
84
+ return model
85
+
86
+ def prepare_batch_from_inp(self, inp):
87
+ """
88
+ :param inp: {'audio_source_name': (str)}
89
+ :return: a dict that contains the condition feature of NeRF
90
+ """
91
+ sample = {}
92
+ # Process Driving Motion
93
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
94
+ self.save_wav16k(inp['drv_audio_name'])
95
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
96
+ hubert = self.get_hubert(self.wav16k_name)
97
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
98
+ hubert = self.get_mfcc(self.wav16k_name) / 100
99
+
100
+ f0 = self.get_f0(self.wav16k_name)
101
+ if f0.shape[0] > len(hubert):
102
+ f0 = f0[:len(hubert)]
103
+ else:
104
+ num_to_pad = len(hubert) - len(f0)
105
+ f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0)))
106
+ t_x = hubert.shape[0]
107
+ x_mask = torch.ones([1, t_x]).float() # mask for audio frames
108
+ y_mask = torch.ones([1, t_x//2]).float() # mask for motion/image frames
109
+ sample.update({
110
+ 'hubert': torch.from_numpy(hubert).float().unsqueeze(0).cuda(),
111
+ 'f0': torch.from_numpy(f0).float().reshape([1,-1]).cuda(),
112
+ 'x_mask': x_mask.cuda(),
113
+ 'y_mask': y_mask.cuda(),
114
+ })
115
+ sample['blink'] = torch.zeros([1, t_x, 1]).long().cuda()
116
+ sample['audio'] = sample['hubert']
117
+ sample['eye_amp'] = torch.ones([1, 1]).cuda() * 1.0
118
+ elif inp['drv_audio_name'][-4:] in ['.mp4']:
119
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(inp['drv_audio_name'], save=False)
120
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
121
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
122
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
123
+ elif inp['drv_audio_name'][-4:] in ['.npy']:
124
+ drv_motion_coeff_dict = np.load(inp['drv_audio_name'], allow_pickle=True).tolist()
125
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
126
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
127
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
128
+
129
+ # Face Parsing
130
+ sample['ref_gt_img'] = self.person_ds['gt_img'].cuda()
131
+ img = self.person_ds['gt_img'].reshape([3, 512, 512]).permute(1, 2, 0)
132
+ img = (img + 1) * 127.5
133
+ img = np.ascontiguousarray(img.int().numpy()).astype(np.uint8)
134
+ segmap = self.seg_model._cal_seg_map(img)
135
+ sample['segmap'] = torch.tensor(segmap).float().unsqueeze(0).cuda()
136
+ head_img = self.seg_model._seg_out_img_with_segmap(img, segmap, mode='head')[0]
137
+ sample['ref_head_img'] = ((torch.tensor(head_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
138
+ inpaint_torso_img, _, _, _ = inpaint_torso_job(img, segmap)
139
+ sample['ref_torso_img'] = ((torch.tensor(inpaint_torso_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
140
+
141
+ if inp['bg_image_name'] == '':
142
+ bg_img = extract_background([img], [segmap], 'knn')
143
+ else:
144
+ bg_img = cv2.imread(inp['bg_image_name'])
145
+ bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
146
+ bg_img = cv2.resize(bg_img, (512,512))
147
+ sample['bg_img'] = ((torch.tensor(bg_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
148
+
149
+ # 3DMM, get identity code and camera pose
150
+ image_name = f"data/raw/val_imgs/{self.person_ds['video_id']}_img.png"
151
+ os.makedirs(os.path.dirname(image_name), exist_ok=True)
152
+ cv2.imwrite(image_name, img[:,:,::-1])
153
+ coeff_dict = fit_3dmm_for_a_image(image_name, save=False)
154
+ coeff_dict['id'] = self.person_ds['id'].reshape([1,80]).numpy()
155
+
156
+ assert coeff_dict is not None
157
+ src_id = torch.tensor(coeff_dict['id']).reshape([1,80]).cuda()
158
+ src_exp = torch.tensor(coeff_dict['exp']).reshape([1,64]).cuda()
159
+ src_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda()
160
+ src_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda()
161
+ sample['id'] = src_id.repeat([t_x//2,1])
162
+
163
+ # get the src_kp for torso model
164
+ sample['src_kp'] = self.person_ds['src_kp'].cuda().reshape([1, 68, 3]).repeat([t_x//2,1,1])[..., :2] # [B, 68, 2]
165
+
166
+ # get camera pose file
167
+ random.seed(time.time())
168
+ if inp['drv_pose_name'] in ['nearest', 'topk']:
169
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler': torch.tensor(coeff_dict['euler']).reshape([1,3]), 'trans': torch.tensor(coeff_dict['trans']).reshape([1,3])})
170
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
171
+ camera = np.concatenate([c2w.reshape([1,16]), intrinsics.reshape([1,9])], axis=-1)
172
+ coeff_names, distance_matrix = self.camera_selector.find_k_nearest(camera, k=100)
173
+ coeff_names = coeff_names[0] # squeeze
174
+ if inp['drv_pose_name'] == 'nearest':
175
+ inp['drv_pose_name'] = coeff_names[0]
176
+ else:
177
+ inp['drv_pose_name'] = random.choice(coeff_names)
178
+ # inp['drv_pose_name'] = coeff_names[0]
179
+ elif inp['drv_pose_name'] == 'random':
180
+ inp['drv_pose_name'] = self.camera_selector.random_select()
181
+ else:
182
+ inp['drv_pose_name'] = inp['drv_pose_name']
183
+
184
+ print(f"| To extract pose from {inp['drv_pose_name']}")
185
+
186
+ # extract camera pose
187
+ if inp['drv_pose_name'] == 'static':
188
+ sample['euler'] = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda().repeat([t_x//2,1]) # default static pose
189
+ sample['trans'] = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda().repeat([t_x//2,1])
190
+ else: # from file
191
+ if inp['drv_pose_name'].endswith('.mp4'):
192
+ # extract coeff from video
193
+ drv_pose_coeff_dict = fit_3dmm_for_a_video(inp['drv_pose_name'], save=False)
194
+ else:
195
+ # load from npy
196
+ drv_pose_coeff_dict = np.load(inp['drv_pose_name'], allow_pickle=True).tolist()
197
+ print(f"| Extracted pose from {inp['drv_pose_name']}")
198
+ eulers = convert_to_tensor(drv_pose_coeff_dict['euler']).reshape([-1,3]).cuda()
199
+ trans = convert_to_tensor(drv_pose_coeff_dict['trans']).reshape([-1,3]).cuda()
200
+ len_pose = len(eulers)
201
+ index_lst = [mirror_index(i, len_pose) for i in range(t_x//2)]
202
+ sample['euler'] = eulers[index_lst]
203
+ sample['trans'] = trans[index_lst]
204
+
205
+ # fix the z axis
206
+ sample['trans'][:, -1] = sample['trans'][0:1, -1].repeat([sample['trans'].shape[0]])
207
+
208
+ # mapping to the init pose
209
+ if inp.get("map_to_init_pose", 'False') == 'True':
210
+ diff_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda() - sample['euler'][0:1]
211
+ sample['euler'] = sample['euler'] + diff_euler
212
+ diff_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda() - sample['trans'][0:1]
213
+ sample['trans'] = sample['trans'] + diff_trans
214
+
215
+ # prepare camera
216
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':sample['euler'].cpu(), 'trans':sample['trans'].cpu()})
217
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
218
+ # smooth camera
219
+ camera_smo_ksize = 7
220
+ camera = np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)
221
+ camera = smooth_camera_sequence(camera, kernel_size=camera_smo_ksize) # [T, 25]
222
+ camera = torch.tensor(camera).cuda().float()
223
+ sample['camera'] = camera
224
+
225
+ return sample
226
+
227
+ @torch.no_grad()
228
+ def forward_secc2video(self, batch, inp=None):
229
+ num_frames = len(batch['drv_secc'])
230
+ camera = batch['camera']
231
+ src_kps = batch['src_kp']
232
+ drv_kps = batch['drv_kp']
233
+ cano_secc_color = batch['cano_secc']
234
+ src_secc_color = batch['src_secc']
235
+ drv_secc_colors = batch['drv_secc']
236
+ ref_img_gt = batch['ref_gt_img']
237
+ ref_img_head = batch['ref_head_img']
238
+ ref_torso_img = batch['ref_torso_img']
239
+ bg_img = batch['bg_img']
240
+ segmap = batch['segmap']
241
+
242
+ # smooth torso drv_kp
243
+ torso_smo_ksize = 7
244
+ drv_kps = smooth_features_xd(drv_kps.reshape([-1, 68*2]), kernel_size=torso_smo_ksize).reshape([-1, 68, 2])
245
+
246
+ # forward renderer
247
+ img_raw_lst = []
248
+ img_lst = []
249
+ depth_img_lst = []
250
+ with torch.no_grad():
251
+ for i in tqdm.trange(num_frames, desc="MimicTalk is rendering frames"):
252
+ kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
253
+ kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
254
+ cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
255
+ 'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
256
+ 'kp_s': kp_src, 'kp_d': kp_drv}
257
+
258
+ ########################################################################################################
259
+ ### 相比real3d_infer只修改了这行👇,即cano_triplane来自cache里的learnable_triplane,而不是img预测的plane ####
260
+ ########################################################################################################
261
+ gen_output = self.secc2video_model.forward(img=None, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
262
+
263
+ img_lst.append(gen_output['image'])
264
+ img_raw_lst.append(gen_output['image_raw'])
265
+ depth_img_lst.append(gen_output['image_depth'])
266
+
267
+ # save demo video
268
+ depth_imgs = torch.cat(depth_img_lst)
269
+ imgs = torch.cat(img_lst)
270
+ imgs_raw = torch.cat(img_raw_lst)
271
+ secc_img = torch.cat([torch.nn.functional.interpolate(drv_secc_colors[i:i+1], (512,512)) for i in range(num_frames)])
272
+
273
+ if inp['out_mode'] == 'concat_debug':
274
+ secc_img = secc_img.cpu()
275
+ secc_img = ((secc_img + 1) * 127.5).permute(0, 2, 3, 1).int().numpy()
276
+
277
+ depth_img = F.interpolate(depth_imgs, (512,512)).cpu()
278
+ depth_img = depth_img.repeat([1,3,1,1])
279
+ depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
280
+ depth_img = depth_img * 2 - 1
281
+ depth_img = depth_img.clamp(-1,1)
282
+
283
+ secc_img = secc_img / 127.5 - 1
284
+ secc_img = torch.from_numpy(secc_img).permute(0, 3, 1, 2)
285
+ imgs = torch.cat([ref_img_gt.repeat([imgs.shape[0],1,1,1]).cpu(), secc_img, F.interpolate(imgs_raw, (512,512)).cpu(), depth_img, imgs.cpu()], dim=-1)
286
+ elif inp['out_mode'] == 'final':
287
+ imgs = imgs.cpu()
288
+ elif inp['out_mode'] == 'debug':
289
+ raise NotImplementedError("to do: save separate videos")
290
+ imgs = imgs.clamp(-1,1)
291
+
292
+ import imageio
293
+ import uuid
294
+ debug_name = f'{uuid.uuid1()}.mp4'
295
+ out_imgs = ((imgs.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
296
+ writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
297
+ for i in tqdm.trange(len(out_imgs), desc="Imageio is saving video"):
298
+ writer.append_data(out_imgs[i])
299
+ writer.close()
300
+
301
+ out_fname = 'infer_out/tmp/' + os.path.basename(inp['drv_pose_name'])[:-4] + '.mp4' if inp['out_name'] == '' else inp['out_name']
302
+ try:
303
+ os.makedirs(os.path.dirname(out_fname), exist_ok=True)
304
+ except: pass
305
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
306
+ # os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -y -v quiet -shortest {out_fname}")
307
+ cmd = f"/usr/bin/ffmpeg -i {debug_name} -i {self.wav16k_name} -y -r 25 -ar 16000 -c:v copy -c:a libmp3lame -pix_fmt yuv420p -b:v 2000k -strict experimental -shortest {out_fname}"
308
+ os.system(cmd)
309
+ os.system(f"rm {debug_name}")
310
+ else:
311
+ ret = os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -map 0:v -map 1:a -y -v quiet -shortest {out_fname}")
312
+ if ret != 0: # 没有成功从drv_audio_name里面提取到音频, 则直接输出无音频轨道的纯视频
313
+ os.system(f"mv {debug_name} {out_fname}")
314
+ print(f"Saved at {out_fname}")
315
+ return out_fname
316
+
317
+ if __name__ == '__main__':
318
+ import argparse, glob, tqdm
319
+ parser = argparse.ArgumentParser()
320
+ parser.add_argument("--a2m_ckpt", default='checkpoints/240112_icl_audio2secc_vox2_cmlr') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
321
+ parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
322
+ parser.add_argument("--torso_ckpt", default='checkpoints_mimictalk/German_20s')
323
+ parser.add_argument("--bg_img", default='') # data/raw/val_imgs/bg3.png
324
+ parser.add_argument("--drv_aud", default='data/raw/examples/80_vs_60_10s.wav')
325
+ parser.add_argument("--drv_pose", default='data/raw/examples/German_20s.mp4') # nearest | topk | random | static | vid_name
326
+ parser.add_argument("--drv_style", default='data/raw/examples/angry.mp4') # nearest | topk | random | static | vid_name
327
+ parser.add_argument("--blink_mode", default='period') # none | period
328
+ parser.add_argument("--temperature", default=0.3, type=float) # nearest | random
329
+ parser.add_argument("--denoising_steps", default=20, type=int) # nearest | random
330
+ parser.add_argument("--cfg_scale", default=1.5, type=float) # nearest | random
331
+ parser.add_argument("--out_name", default='') # nearest | random
332
+ parser.add_argument("--out_mode", default='concat_debug') # concat_debug | debug | final
333
+ parser.add_argument("--hold_eye_opened", default='False') # concat_debug | debug | final
334
+ parser.add_argument("--map_to_init_pose", default='True') # concat_debug | debug | final
335
+ parser.add_argument("--seed", default=None, type=int) # random seed, default None to use time.time()
336
+
337
+ args = parser.parse_args()
338
+
339
+ inp = {
340
+ 'a2m_ckpt': args.a2m_ckpt,
341
+ 'head_ckpt': args.head_ckpt,
342
+ 'torso_ckpt': args.torso_ckpt,
343
+ 'bg_image_name': args.bg_img,
344
+ 'drv_audio_name': args.drv_aud,
345
+ 'drv_pose_name': args.drv_pose,
346
+ 'drv_talking_style_name': args.drv_style,
347
+ 'blink_mode': args.blink_mode,
348
+ 'temperature': args.temperature,
349
+ 'denoising_steps': args.denoising_steps,
350
+ 'cfg_scale': args.cfg_scale,
351
+ 'out_name': args.out_name,
352
+ 'out_mode': args.out_mode,
353
+ 'map_to_init_pose': args.map_to_init_pose,
354
+ 'hold_eye_opened': args.hold_eye_opened,
355
+ 'seed': args.seed,
356
+ }
357
+ AdaptGeneFace2Infer.example_run(inp)
inference/real3d_infer.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchshow as ts
5
+ import librosa
6
+ import random
7
+ import time
8
+ import numpy as np
9
+ import importlib
10
+ import tqdm
11
+ import copy
12
+ import cv2
13
+ import math
14
+
15
+ # common utils
16
+ from utils.commons.hparams import hparams, set_hparams
17
+ from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
18
+ from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
19
+ # 3DMM-related utils
20
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
21
+ from data_util.face3d_helper import Face3DHelper
22
+ from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
23
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
24
+ from data_gen.utils.process_image.extract_lm2d import extract_lms_mediapipe_job
25
+ from data_gen.utils.process_image.fit_3dmm_landmark import index_lm68_from_lm468
26
+ from deep_3drecon.secc_renderer import SECC_Renderer
27
+ from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
28
+ # Face Parsing
29
+ from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
30
+ from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
31
+ # other inference utils
32
+ from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
33
+ from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
34
+ from inference.edit_secc import blink_eye_for_secc, hold_eye_opened_for_secc
35
+
36
+
37
+ def read_first_frame_from_a_video(vid_name):
38
+ frames = []
39
+ cap = cv2.VideoCapture(vid_name)
40
+ ret, frame_bgr = cap.read()
41
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
42
+ return frame_rgb
43
+
44
+ def analyze_weights_img(gen_output):
45
+ img_raw = gen_output['image_raw']
46
+ mask_005_to_03 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.3).repeat([1,3,1,1])
47
+ mask_005_to_05 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.5).repeat([1,3,1,1])
48
+ mask_005_to_07 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.7).repeat([1,3,1,1])
49
+ mask_005_to_09 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.9).repeat([1,3,1,1])
50
+ mask_005_to_10 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<1.0).repeat([1,3,1,1])
51
+
52
+ img_raw_005_to_03 = img_raw.clone()
53
+ img_raw_005_to_03[~mask_005_to_03] = -1
54
+ img_raw_005_to_05 = img_raw.clone()
55
+ img_raw_005_to_05[~mask_005_to_05] = -1
56
+ img_raw_005_to_07 = img_raw.clone()
57
+ img_raw_005_to_07[~mask_005_to_07] = -1
58
+ img_raw_005_to_09 = img_raw.clone()
59
+ img_raw_005_to_09[~mask_005_to_09] = -1
60
+ img_raw_005_to_10 = img_raw.clone()
61
+ img_raw_005_to_10[~mask_005_to_10] = -1
62
+ ts.save([img_raw_005_to_03[0], img_raw_005_to_05[0], img_raw_005_to_07[0], img_raw_005_to_09[0], img_raw_005_to_10[0]])
63
+
64
+
65
+ def cal_face_area_percent(img_name):
66
+ img = cv2.resize(cv2.imread(img_name)[:,:,::-1], (512,512))
67
+ lm478 = extract_lms_mediapipe_job(img) / 512
68
+ min_x = lm478[:,0].min()
69
+ max_x = lm478[:,0].max()
70
+ min_y = lm478[:,1].min()
71
+ max_y = lm478[:,1].max()
72
+ area = (max_x - min_x) * (max_y - min_y)
73
+ return area
74
+
75
+ def crop_img_on_face_area_percent(img_name, out_name='temp/cropped_src_img.png', min_face_area_percent=0.2):
76
+ try:
77
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
78
+ except: pass
79
+ face_area_percent = cal_face_area_percent(img_name)
80
+ if face_area_percent >= min_face_area_percent:
81
+ print(f"face area percent {face_area_percent} larger than threshold {min_face_area_percent}, directly use the input image...")
82
+ cmd = f"cp {img_name} {out_name}"
83
+ os.system(cmd)
84
+ return out_name
85
+ else:
86
+ print(f"face area percent {face_area_percent} smaller than threshold {min_face_area_percent}, crop the input image...")
87
+ img = cv2.resize(cv2.imread(img_name)[:,:,::-1], (512,512))
88
+ lm478 = extract_lms_mediapipe_job(img).astype(int)
89
+ min_x = lm478[:,0].min()
90
+ max_x = lm478[:,0].max()
91
+ min_y = lm478[:,1].min()
92
+ max_y = lm478[:,1].max()
93
+ face_area = (max_x - min_x) * (max_y - min_y)
94
+ target_total_area = face_area / min_face_area_percent
95
+ target_hw = int(target_total_area**0.5)
96
+ center_x, center_y = (min_x+max_x)/2, (min_y+max_y)/2
97
+ shrink_pixels = 2 * max(-(center_x - target_hw/2), center_x + target_hw/2 - 512, -(center_y - target_hw/2), center_y + target_hw/2-512)
98
+ shrink_pixels = max(0, shrink_pixels)
99
+ hw = math.floor(target_hw - shrink_pixels)
100
+ new_min_x = int(center_x - hw/2)
101
+ new_max_x = int(center_x + hw/2)
102
+ new_min_y = int(center_y - hw/2)
103
+ new_max_y = int(center_y + hw/2)
104
+
105
+ img = img[new_min_y:new_max_y, new_min_x:new_max_x]
106
+ img = cv2.resize(img, (512, 512))
107
+ cv2.imwrite(out_name, img[:,:,::-1])
108
+ return out_name
109
+
110
+
111
+ class GeneFace2Infer:
112
+ def __init__(self, audio2secc_dir, head_model_dir, torso_model_dir, device=None, inp=None):
113
+ if device is None:
114
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
115
+ self.audio2secc_model = self.load_audio2secc(audio2secc_dir)
116
+ self.secc2video_model = self.load_secc2video(head_model_dir, torso_model_dir, inp)
117
+ self.audio2secc_model.to(device).eval()
118
+ self.secc2video_model.to(device).eval()
119
+ self.seg_model = MediapipeSegmenter()
120
+ self.secc_renderer = SECC_Renderer(512)
121
+ self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
122
+ self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
123
+ self.camera_selector = KNearestCameraSelector()
124
+
125
+ def load_audio2secc(self, audio2secc_dir):
126
+ config_name = f"{audio2secc_dir}/config.yaml" if not audio2secc_dir.endswith(".ckpt") else f"{os.path.dirname(audio2secc_dir)}/config.yaml"
127
+ set_hparams(f"{config_name}", print_hparams=False)
128
+ self.audio2secc_dir = audio2secc_dir
129
+ self.audio2secc_hparams = copy.deepcopy(hparams)
130
+ from modules.audio2motion.vae import VAEModel, PitchContourVAEModel
131
+ from modules.audio2motion.cfm.icl_audio2motion_model import InContextAudio2MotionModel
132
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
133
+ audio_in_dim = 1024
134
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
135
+ audio_in_dim = 13
136
+
137
+ if 'icl' in hparams['task_cls']:
138
+ self.use_icl_audio2motion = True
139
+ model = InContextAudio2MotionModel(hparams['icl_model_type'], hparams=self.audio2secc_hparams)
140
+ else:
141
+ self.use_icl_audio2motion = False
142
+ if hparams.get("use_pitch", False) is True:
143
+ model = PitchContourVAEModel(hparams, in_out_dim=64, audio_in_dim=audio_in_dim)
144
+ else:
145
+ model = VAEModel(in_out_dim=64, audio_in_dim=audio_in_dim)
146
+ load_ckpt(model, f"{audio2secc_dir}", model_name='model', strict=True)
147
+ return model
148
+
149
+ def load_secc2video(self, head_model_dir, torso_model_dir, inp):
150
+ if inp is None:
151
+ inp = {}
152
+ self.head_model_dir = head_model_dir
153
+ self.torso_model_dir = torso_model_dir
154
+ if torso_model_dir != '':
155
+ if torso_model_dir.endswith(".ckpt"):
156
+ set_hparams(f"{os.path.dirname(torso_model_dir)}/config.yaml", print_hparams=False)
157
+ else:
158
+ set_hparams(f"{torso_model_dir}/config.yaml", print_hparams=False)
159
+ if inp.get('head_torso_threshold', None) is not None:
160
+ hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
161
+ self.secc2video_hparams = copy.deepcopy(hparams)
162
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane_Torso
163
+ model = OSAvatarSECC_Img2plane_Torso()
164
+ load_ckpt(model, f"{torso_model_dir}", model_name='model', strict=False)
165
+ if head_model_dir != '':
166
+ print("| Warning: Assigned --torso_ckpt which also contains head, but --head_ckpt is also assigned, skipping the --head_ckpt.")
167
+ else:
168
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane
169
+ if head_model_dir.endswith(".ckpt"):
170
+ set_hparams(f"{os.path.dirname(head_model_dir)}/config.yaml", print_hparams=False)
171
+ else:
172
+ set_hparams(f"{head_model_dir}/config.yaml", print_hparams=False)
173
+ if inp.get('head_torso_threshold', None) is not None:
174
+ hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
175
+ self.secc2video_hparams = copy.deepcopy(hparams)
176
+ model = OSAvatarSECC_Img2plane()
177
+ load_ckpt(model, f"{head_model_dir}", model_name='model', strict=False)
178
+ return model
179
+
180
+ def infer_once(self, inp):
181
+ self.inp = inp
182
+ samples = self.prepare_batch_from_inp(inp)
183
+ seed = inp['seed'] if inp['seed'] is not None else int(time.time())
184
+ random.seed(seed)
185
+ torch.manual_seed(seed)
186
+ np.random.seed(seed)
187
+ out_name = self.forward_system(samples, inp)
188
+ return out_name
189
+
190
+ def prepare_batch_from_inp(self, inp):
191
+ """
192
+ :param inp: {'audio_source_name': (str)}
193
+ :return: a dict that contains the condition feature of NeRF
194
+ """
195
+ cropped_name = 'temp/cropped_src_img_512.png'
196
+ crop_img_on_face_area_percent(inp['src_image_name'], cropped_name, min_face_area_percent=inp['min_face_area_percent'])
197
+ inp['src_image_name'] = cropped_name
198
+
199
+ sample = {}
200
+ # Process Driving Motion
201
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
202
+ self.save_wav16k(inp['drv_audio_name'])
203
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
204
+ hubert = self.get_hubert(self.wav16k_name)
205
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
206
+ hubert = self.get_mfcc(self.wav16k_name) / 100
207
+
208
+ f0 = self.get_f0(self.wav16k_name)
209
+ if f0.shape[0] > len(hubert):
210
+ f0 = f0[:len(hubert)]
211
+ else:
212
+ num_to_pad = len(hubert) - len(f0)
213
+ f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0)))
214
+ t_x = hubert.shape[0]
215
+ x_mask = torch.ones([1, t_x]).float() # mask for audio frames
216
+ y_mask = torch.ones([1, t_x//2]).float() # mask for motion/image frames
217
+ sample.update({
218
+ 'hubert': torch.from_numpy(hubert).float().unsqueeze(0).cuda(),
219
+ 'f0': torch.from_numpy(f0).float().reshape([1,-1]).cuda(),
220
+ 'x_mask': x_mask.cuda(),
221
+ 'y_mask': y_mask.cuda(),
222
+ })
223
+ sample['blink'] = torch.zeros([1, t_x, 1]).long().cuda()
224
+ sample['audio'] = sample['hubert']
225
+ sample['eye_amp'] = torch.ones([1, 1]).cuda() * 1.0
226
+ sample['mouth_amp'] = torch.ones([1, 1]).cuda() * inp['mouth_amp']
227
+ elif inp['drv_audio_name'][-4:] in ['.mp4']:
228
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(inp['drv_audio_name'], save=False)
229
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
230
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
231
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
232
+ elif inp['drv_audio_name'][-4:] in ['.npy']:
233
+ drv_motion_coeff_dict = np.load(inp['drv_audio_name'], allow_pickle=True).tolist()
234
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
235
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
236
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
237
+ else:
238
+ raise ValueError()
239
+
240
+ # Face Parsing
241
+ image_name = inp['src_image_name']
242
+ if image_name.endswith(".mp4"):
243
+ img = read_first_frame_from_a_video(image_name)
244
+ image_name = inp['src_image_name'] = image_name[:-4] + '.png'
245
+ cv2.imwrite(image_name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
246
+ sample['ref_gt_img'] = load_img_to_normalized_512_bchw_tensor(image_name).cuda()
247
+ img = load_img_to_512_hwc_array(image_name)
248
+ segmap = self.seg_model._cal_seg_map(img)
249
+ sample['segmap'] = torch.tensor(segmap).float().unsqueeze(0).cuda()
250
+ head_img = self.seg_model._seg_out_img_with_segmap(img, segmap, mode='head')[0]
251
+ sample['ref_head_img'] = ((torch.tensor(head_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
252
+ inpaint_torso_img, _, _, _ = inpaint_torso_job(img, segmap)
253
+ sample['ref_torso_img'] = ((torch.tensor(inpaint_torso_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
254
+
255
+ if inp['bg_image_name'] == '':
256
+ bg_img = extract_background([img], [segmap], 'lama')
257
+ else:
258
+ bg_img = cv2.imread(inp['bg_image_name'])
259
+ bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
260
+ bg_img = cv2.resize(bg_img, (512,512))
261
+ sample['bg_img'] = ((torch.tensor(bg_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
262
+
263
+ # 3DMM, get identity code and camera pose
264
+ coeff_dict = fit_3dmm_for_a_image(image_name, save=False)
265
+ assert coeff_dict is not None
266
+ src_id = torch.tensor(coeff_dict['id']).reshape([1,80]).cuda()
267
+ src_exp = torch.tensor(coeff_dict['exp']).reshape([1,64]).cuda()
268
+ src_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda()
269
+ src_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda()
270
+ sample['id'] = src_id.repeat([t_x//2,1])
271
+
272
+ # get the src_kp for torso model
273
+ src_kp = self.face3d_helper.reconstruct_lm2d(src_id, src_exp, src_euler, src_trans) # [1, 68, 2]
274
+ src_kp = (src_kp-0.5) / 0.5 # rescale to -1~1
275
+ sample['src_kp'] = torch.clamp(src_kp, -1, 1).repeat([t_x//2,1,1])
276
+
277
+ # get camera pose file
278
+ # random.seed(time.time())
279
+ if inp['drv_pose_name'] in ['nearest', 'topk']:
280
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler': torch.tensor(coeff_dict['euler']).reshape([1,3]), 'trans': torch.tensor(coeff_dict['trans']).reshape([1,3])})
281
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
282
+ camera = np.concatenate([c2w.reshape([1,16]), intrinsics.reshape([1,9])], axis=-1)
283
+ coeff_names, distance_matrix = self.camera_selector.find_k_nearest(camera, k=100)
284
+ coeff_names = coeff_names[0] # squeeze
285
+ if inp['drv_pose_name'] == 'nearest':
286
+ inp['drv_pose_name'] = coeff_names[0]
287
+ else:
288
+ inp['drv_pose_name'] = random.choice(coeff_names)
289
+ # inp['drv_pose_name'] = coeff_names[0]
290
+ elif inp['drv_pose_name'] == 'random':
291
+ inp['drv_pose_name'] = self.camera_selector.random_select()
292
+ else:
293
+ inp['drv_pose_name'] = inp['drv_pose_name']
294
+
295
+ print(f"| To extract pose from {inp['drv_pose_name']}")
296
+
297
+ # extract camera pose
298
+ if inp['drv_pose_name'] == 'static':
299
+ sample['euler'] = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda().repeat([t_x//2,1]) # default static pose
300
+ sample['trans'] = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda().repeat([t_x//2,1])
301
+ else: # from file
302
+ if inp['drv_pose_name'].endswith('.mp4'):
303
+ # extract coeff from video
304
+ drv_pose_coeff_dict = fit_3dmm_for_a_video(inp['drv_pose_name'], save=False)
305
+ else:
306
+ # load from npy
307
+ drv_pose_coeff_dict = np.load(inp['drv_pose_name'], allow_pickle=True).tolist()
308
+ print(f"| Extracted pose from {inp['drv_pose_name']}")
309
+ eulers = convert_to_tensor(drv_pose_coeff_dict['euler']).reshape([-1,3]).cuda()
310
+ trans = convert_to_tensor(drv_pose_coeff_dict['trans']).reshape([-1,3]).cuda()
311
+ len_pose = len(eulers)
312
+ index_lst = [mirror_index(i, len_pose) for i in range(t_x//2)]
313
+ sample['euler'] = eulers[index_lst]
314
+ sample['trans'] = trans[index_lst]
315
+
316
+ # fix the z axis
317
+ sample['trans'][:, -1] = sample['trans'][0:1, -1].repeat([sample['trans'].shape[0]])
318
+
319
+ # mapping to the init pose
320
+ if inp.get("map_to_init_pose", 'False') == 'True':
321
+ diff_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda() - sample['euler'][0:1]
322
+ sample['euler'] = sample['euler'] + diff_euler
323
+ diff_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda() - sample['trans'][0:1]
324
+ sample['trans'] = sample['trans'] + diff_trans
325
+
326
+ # prepare camera
327
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':sample['euler'].cpu(), 'trans':sample['trans'].cpu()})
328
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
329
+ # smooth camera
330
+ camera_smo_ksize = 7
331
+ camera = np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)
332
+ camera = smooth_camera_sequence(camera, kernel_size=camera_smo_ksize) # [T, 25]
333
+ camera = torch.tensor(camera).cuda().float()
334
+ sample['camera'] = camera
335
+
336
+ return sample
337
+
338
+ @torch.no_grad()
339
+ def get_hubert(self, wav16k_name):
340
+ from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
341
+ hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
342
+ len_mel = hubert.shape[0]
343
+ x_multiply = 8
344
+ if len_mel % x_multiply == 0:
345
+ num_to_pad = 0
346
+ else:
347
+ num_to_pad = x_multiply - len_mel % x_multiply
348
+ hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
349
+ return hubert
350
+
351
+ def get_mfcc(self, wav16k_name):
352
+ from utils.audio import librosa_wav2mfcc
353
+ hparams['fft_size'] = 1200
354
+ hparams['win_size'] = 1200
355
+ hparams['hop_size'] = 480
356
+ hparams['audio_num_mel_bins'] = 80
357
+ hparams['fmin'] = 80
358
+ hparams['fmax'] = 12000
359
+ hparams['audio_sample_rate'] = 24000
360
+ mfcc = librosa_wav2mfcc(wav16k_name,
361
+ fft_size=hparams['fft_size'],
362
+ hop_size=hparams['hop_size'],
363
+ win_length=hparams['win_size'],
364
+ num_mels=hparams['audio_num_mel_bins'],
365
+ fmin=hparams['fmin'],
366
+ fmax=hparams['fmax'],
367
+ sample_rate=hparams['audio_sample_rate'],
368
+ center=True)
369
+ mfcc = np.array(mfcc).reshape([-1, 13])
370
+ len_mel = mfcc.shape[0]
371
+ x_multiply = 8
372
+ if len_mel % x_multiply == 0:
373
+ num_to_pad = 0
374
+ else:
375
+ num_to_pad = x_multiply - len_mel % x_multiply
376
+ mfcc = np.pad(mfcc, pad_width=((0,num_to_pad), (0,0)))
377
+ return mfcc
378
+
379
+ @torch.no_grad()
380
+ def forward_audio2secc(self, batch, inp=None):
381
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
382
+ from inference.infer_utils import extract_audio_motion_from_ref_video
383
+ if self.use_icl_audio2motion:
384
+ self.audio2secc_model.empty_context() # make this function reloadable
385
+ if self.use_icl_audio2motion and inp['drv_talking_style_name'].endswith(".mp4"):
386
+ ref_exp, ref_hubert, ref_f0 = extract_audio_motion_from_ref_video(inp['drv_talking_style_name'])
387
+ self.audio2secc_model.add_sample_to_context(ref_exp, ref_hubert, ref_f0)
388
+ elif self.use_icl_audio2motion and inp['drv_talking_style_name'].endswith((".png",'.jpg')):
389
+ style_coeff_dict = fit_3dmm_for_a_image(inp['drv_talking_style_name'])
390
+ ref_exp = torch.tensor(style_coeff_dict['exp']).reshape([1,1,64]).cuda()
391
+ self.audio2secc_model.add_sample_to_context(ref_exp.repeat([1, 100, 1]), hubert=None, f0=None)
392
+ else:
393
+ print("| WARNING: Not assigned reference talking style, passing...")
394
+ # audio-to-exp
395
+ ret = {}
396
+ # pred = self.audio2secc_model.forward(batch, ret=ret,train=False, ,)
397
+ pred = self.audio2secc_model.forward(batch, ret=ret,train=False, temperature=inp['temperature'], denoising_steps=inp['denoising_steps'], cond_scale=inp['cfg_scale'])
398
+
399
+ print("| audio-to-motion finished")
400
+ if pred.shape[-1] == 144:
401
+ id = ret['pred'][0][:,:80]
402
+ exp = ret['pred'][0][:,80:]
403
+ else:
404
+ id = batch['id']
405
+ exp = ret['pred'][0]
406
+ if len(id) < len(exp): # happens when use ICL
407
+ id = torch.cat([id, id[0].unsqueeze(0).repeat([len(exp)-len(id),1])])
408
+ batch['id'] = id
409
+ batch['exp'] = exp
410
+ else:
411
+ drv_motion_coeff_dict = self.drv_motion_coeff_dict
412
+ batch['exp'] = torch.FloatTensor(drv_motion_coeff_dict['exp']).cuda()
413
+
414
+ batch['id'] = batch['id'][:-4]
415
+ batch['exp'] = batch['exp'][:-4]
416
+ batch['euler'] = batch['euler'][:-4]
417
+ batch['trans'] = batch['trans'][:-4]
418
+ batch = self.get_driving_motion(batch['id'], batch['exp'], batch['euler'], batch['trans'], batch, inp)
419
+ if self.use_icl_audio2motion:
420
+ self.audio2secc_model.empty_context()
421
+ return batch
422
+
423
+ @torch.no_grad()
424
+ def get_driving_motion(self, id, exp, euler, trans, batch, inp):
425
+ zero_eulers = torch.zeros([id.shape[0], 3]).to(id.device)
426
+ zero_trans = torch.zeros([id.shape[0], 3]).to(exp.device)
427
+ # render the secc given the id,exp
428
+ with torch.no_grad():
429
+ chunk_size = 50
430
+ drv_secc_color_lst = []
431
+ num_iters = len(id)//chunk_size if len(id)%chunk_size == 0 else len(id)//chunk_size+1
432
+ for i in tqdm.trange(num_iters, desc="rendering drv secc"):
433
+ torch.cuda.empty_cache()
434
+ face_mask, drv_secc_color = self.secc_renderer(id[i*chunk_size:(i+1)*chunk_size], exp[i*chunk_size:(i+1)*chunk_size], zero_eulers[i*chunk_size:(i+1)*chunk_size], zero_trans[i*chunk_size:(i+1)*chunk_size])
435
+ drv_secc_color_lst.append(drv_secc_color.cpu())
436
+ drv_secc_colors = torch.cat(drv_secc_color_lst, dim=0)
437
+ _, src_secc_color = self.secc_renderer(id[0:1], exp[0:1], zero_eulers[0:1], zero_trans[0:1])
438
+ _, cano_secc_color = self.secc_renderer(id[0:1], exp[0:1]*0, zero_eulers[0:1], zero_trans[0:1])
439
+ batch['drv_secc'] = drv_secc_colors.cuda()
440
+ batch['src_secc'] = src_secc_color.cuda()
441
+ batch['cano_secc'] = cano_secc_color.cuda()
442
+
443
+ # blinking secc
444
+ if inp['blink_mode'] == 'period':
445
+ period = 5 # second
446
+
447
+ if inp['hold_eye_opened'] == 'True':
448
+ for i in tqdm.trange(len(drv_secc_colors),desc="opening eye for secc"):
449
+ batch['drv_secc'][i] = hold_eye_opened_for_secc(batch['drv_secc'][i])
450
+
451
+ for i in tqdm.trange(len(drv_secc_colors),desc="blinking secc"):
452
+ if i % (25*period) == 0:
453
+ blink_dur_frames = random.randint(8, 12)
454
+ for offset in range(blink_dur_frames):
455
+ j = offset + i
456
+ if j >= len(drv_secc_colors)-1: break
457
+ def blink_percent_fn(t, T):
458
+ return -4/T**2 * t**2 + 4/T * t
459
+ blink_percent = blink_percent_fn(offset, blink_dur_frames)
460
+ secc = batch['drv_secc'][j]
461
+ out_secc = blink_eye_for_secc(secc, blink_percent)
462
+ out_secc = out_secc.cuda()
463
+ batch['drv_secc'][j] = out_secc
464
+
465
+ # get the drv_kp for torso model, using the transformed trajectory
466
+ drv_kp = self.face3d_helper.reconstruct_lm2d(id, exp, euler, trans) # [T, 68, 2]
467
+
468
+ drv_kp = (drv_kp-0.5) / 0.5 # rescale to -1~1
469
+ batch['drv_kp'] = torch.clamp(drv_kp, -1, 1)
470
+ return batch
471
+
472
+ @torch.no_grad()
473
+ def forward_secc2video(self, batch, inp=None):
474
+ num_frames = len(batch['drv_secc'])
475
+ camera = batch['camera']
476
+ src_kps = batch['src_kp']
477
+ drv_kps = batch['drv_kp']
478
+ cano_secc_color = batch['cano_secc']
479
+ src_secc_color = batch['src_secc']
480
+ drv_secc_colors = batch['drv_secc']
481
+ ref_img_gt = batch['ref_gt_img']
482
+ ref_img_head = batch['ref_head_img']
483
+ ref_torso_img = batch['ref_torso_img']
484
+ bg_img = batch['bg_img']
485
+ segmap = batch['segmap']
486
+
487
+ # smooth torso drv_kp
488
+ torso_smo_ksize = 7
489
+ drv_kps = smooth_features_xd(drv_kps.reshape([-1, 68*2]), kernel_size=torso_smo_ksize).reshape([-1, 68, 2])
490
+
491
+ # forward renderer
492
+ img_raw_lst = []
493
+ img_lst = []
494
+ depth_img_lst = []
495
+ with torch.no_grad():
496
+ with torch.cuda.amp.autocast(inp['fp16']):
497
+ for i in tqdm.trange(num_frames, desc="Real3D-Portrait is rendering frames"):
498
+ kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
499
+ kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
500
+ cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
501
+ 'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
502
+ 'kp_s': kp_src, 'kp_d': kp_drv,
503
+ 'ref_cameras': camera[i:i+1],
504
+ }
505
+ if i == 0:
506
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=True, use_cached_backbone=False)
507
+ else:
508
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
509
+ img_lst.append(gen_output['image'])
510
+ img_raw_lst.append(gen_output['image_raw'])
511
+ depth_img_lst.append(gen_output['image_depth'])
512
+
513
+ # save demo video
514
+ depth_imgs = torch.cat(depth_img_lst)
515
+ imgs = torch.cat(img_lst)
516
+ imgs_raw = torch.cat(img_raw_lst)
517
+ secc_img = torch.cat([torch.nn.functional.interpolate(drv_secc_colors[i:i+1], (512,512)) for i in range(num_frames)])
518
+
519
+ if inp['out_mode'] == 'concat_debug':
520
+ secc_img = secc_img.cpu()
521
+ secc_img = ((secc_img + 1) * 127.5).permute(0, 2, 3, 1).int().numpy()
522
+
523
+ depth_img = F.interpolate(depth_imgs, (512,512)).cpu()
524
+ depth_img = depth_img.repeat([1,3,1,1])
525
+ depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
526
+ depth_img = depth_img * 2 - 1
527
+ depth_img = depth_img.clamp(-1,1)
528
+
529
+ secc_img = secc_img / 127.5 - 1
530
+ secc_img = torch.from_numpy(secc_img).permute(0, 3, 1, 2)
531
+ imgs = torch.cat([ref_img_gt.repeat([imgs.shape[0],1,1,1]).cpu(), secc_img, F.interpolate(imgs_raw, (512,512)).cpu(), depth_img, imgs.cpu()], dim=-1)
532
+ elif inp['out_mode'] == 'final':
533
+ imgs = imgs.cpu()
534
+ elif inp['out_mode'] == 'debug':
535
+ raise NotImplementedError("to do: save separate videos")
536
+ imgs = imgs.clamp(-1,1)
537
+
538
+ import imageio
539
+ debug_name = 'demo.mp4'
540
+ out_imgs = ((imgs.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
541
+ writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
542
+
543
+ for i in tqdm.trange(len(out_imgs), desc="Imageio is saving video"):
544
+ writer.append_data(out_imgs[i])
545
+ writer.close()
546
+
547
+ out_fname = 'infer_out/tmp/' + os.path.basename(inp['src_image_name'])[:-4] + '_' + os.path.basename(inp['drv_pose_name'])[:-4] + '.mp4' if inp['out_name'] == '' else inp['out_name']
548
+ try:
549
+ os.makedirs(os.path.dirname(out_fname), exist_ok=True)
550
+ except: pass
551
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
552
+ # cmd = f"ffmpeg -i {debug_name} -i {self.wav16k_name} -y -shortest {out_fname}"
553
+ cmd = f"ffmpeg -i {debug_name} -i {self.wav16k_name} -y -v quiet -shortest {out_fname}"
554
+ print(cmd)
555
+ os.system(cmd)
556
+ os.system(f"rm {debug_name}")
557
+ os.system(f"rm {self.wav16k_name}")
558
+ else:
559
+ ret = os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -map 0:v -map 1:a -y -v quiet -shortest {out_fname}")
560
+ if ret != 0: # 没有成功从drv_audio_name里面提取到音频, 则直接输出无音频轨道的纯视频
561
+ os.system(f"mv {debug_name} {out_fname}")
562
+ print(f"Saved at {out_fname}")
563
+ return out_fname
564
+
565
+ @torch.no_grad()
566
+ def forward_system(self, batch, inp):
567
+ self.forward_audio2secc(batch, inp)
568
+ out_fname = self.forward_secc2video(batch, inp)
569
+ return out_fname
570
+
571
+ @classmethod
572
+ def example_run(cls, inp=None):
573
+ inp_tmp = {
574
+ 'drv_audio_name': 'data/raw/val_wavs/zozo.wav',
575
+ 'src_image_name': 'data/raw/val_imgs/Macron.png'
576
+ }
577
+ if inp is not None:
578
+ inp_tmp.update(inp)
579
+ inp = inp_tmp
580
+
581
+ infer_instance = cls(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp)
582
+ infer_instance.infer_once(inp)
583
+
584
+ ##############
585
+ # IO-related
586
+ ##############
587
+ def save_wav16k(self, audio_name):
588
+ supported_types = ('.wav', '.mp3', '.mp4', '.avi')
589
+ assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
590
+ import uuid
591
+ wav16k_name = audio_name[:-4] + f'{uuid.uuid1()}_16k.wav'
592
+ self.wav16k_name = wav16k_name
593
+ extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
594
+ # extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -y {wav16k_name} -y"
595
+ print(extract_wav_cmd)
596
+ os.system(extract_wav_cmd)
597
+ print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
598
+
599
+ def get_f0(self, wav16k_name):
600
+ from data_gen.utils.process_audio.extract_mel_f0 import extract_mel_from_fname, extract_f0_from_wav_and_mel
601
+ wav, mel = extract_mel_from_fname(self.wav16k_name)
602
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
603
+ f0 = f0.reshape([-1,1])
604
+ return f0
605
+
606
+ if __name__ == '__main__':
607
+ import argparse, glob, tqdm
608
+ parser = argparse.ArgumentParser()
609
+ # parser.add_argument("--a2m_ckpt", default='checkpoints/240112_audio2secc/icl_audio2secc_vox2_cmlr') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
610
+ parser.add_argument("--a2m_ckpt", default='checkpoints/240126_real3dportrait_orig/audio2secc_vae') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
611
+ parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
612
+ # parser.add_argument("--head_ckpt", default='checkpoints/240210_os_secc2plane/secc2plane_trigridv2_blink0.3_pertubeNone') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
613
+ # parser.add_argument("--torso_ckpt", default='')
614
+ # parser.add_argument("--torso_ckpt", default='checkpoints/240209_robust_secc2plane_torso/secc2plane_torso_orig_fuseV1_MulMaskFalse')
615
+ parser.add_argument("--torso_ckpt", default='checkpoints/240211_robust_secc2plane_torso/secc2plane_torso_orig_fuseV3_MulMaskTrue')
616
+ # parser.add_argument("--torso_ckpt", default='checkpoints/240209_robust_secc2plane_torso/secc2plane_torso_orig_fuseV2_MulMaskTrue')
617
+ # parser.add_argument("--src_img", default='data/raw/val_imgs/Macron_img.png')
618
+ # parser.add_argument("--src_img", default='gf2_iclr_test_data/cross_imgs/Trump.png')
619
+ parser.add_argument("--src_img", default='data/raw/val_imgs/mercy.png')
620
+ parser.add_argument("--bg_img", default='') # data/raw/val_imgs/bg3.png
621
+ parser.add_argument("--drv_aud", default='data/raw/val_wavs/yiwise.wav')
622
+ parser.add_argument("--drv_pose", default='infer_out/240319_tta/trump.mp4') # nearest | topk | random | static | vid_name
623
+ # parser.add_argument("--drv_pose", default='nearest') # nearest | topk | random | static | vid_name
624
+ parser.add_argument("--drv_style", default='') # nearest | topk | random | static | vid_name
625
+ # parser.add_argument("--drv_style", default='infer_out/240319_tta/trump.mp4') # nearest | topk | random | static | vid_name
626
+ parser.add_argument("--blink_mode", default='period') # none | period
627
+ parser.add_argument("--temperature", default=0.2, type=float) # nearest | random
628
+ parser.add_argument("--denoising_steps", default=20, type=int) # nearest | random
629
+ parser.add_argument("--cfg_scale", default=2.5, type=float) # nearest | random
630
+ parser.add_argument("--mouth_amp", default=0.4, type=float) # scale of predicted mouth, enabled in audio-driven
631
+ parser.add_argument("--min_face_area_percent", default=0.2, type=float) # scale of predicted mouth, enabled in audio-driven
632
+ parser.add_argument("--head_torso_threshold", default=0.5, type=float, help="0.1~1.0, 如果发现头发有半透明的现象,调小该值,以将小weights的头发直接clamp到weights=1.0; 如果发现头外部有荧光色的虚影,调小这个值. 对不同超参的Nerf也是case-to-case")
633
+ # parser.add_argument("--head_torso_threshold", default=None, type=float, help="0.1~1.0, 如果发现头发有半透明的现象,调小该值,以将小weights的头发直接clamp到weights=1.0; 如果发现头外部有荧光色的虚影,调小这个值. 对不同超参的Nerf也是case-to-case")
634
+ parser.add_argument("--out_name", default='') # nearest | random
635
+ parser.add_argument("--out_mode", default='concat_debug') # concat_debug | debug | final
636
+ parser.add_argument("--hold_eye_opened", default='False') # concat_debug | debug | final
637
+ parser.add_argument("--map_to_init_pose", default='True') # concat_debug | debug | final
638
+ parser.add_argument("--seed", default=None, type=int) # random seed, default None to use time.time()
639
+ parser.add_argument("--fp16", action='store_true')
640
+
641
+ args = parser.parse_args()
642
+
643
+ inp = {
644
+ 'a2m_ckpt': args.a2m_ckpt,
645
+ 'head_ckpt': args.head_ckpt,
646
+ 'torso_ckpt': args.torso_ckpt,
647
+ 'src_image_name': args.src_img,
648
+ 'bg_image_name': args.bg_img,
649
+ 'drv_audio_name': args.drv_aud,
650
+ 'drv_pose_name': args.drv_pose,
651
+ 'drv_talking_style_name': args.drv_style,
652
+ 'blink_mode': args.blink_mode,
653
+ 'temperature': args.temperature,
654
+ 'mouth_amp': args.mouth_amp,
655
+ 'out_name': args.out_name,
656
+ 'out_mode': args.out_mode,
657
+ 'map_to_init_pose': args.map_to_init_pose,
658
+ 'hold_eye_opened': args.hold_eye_opened,
659
+ 'head_torso_threshold': args.head_torso_threshold,
660
+ 'min_face_area_percent': args.min_face_area_percent,
661
+ 'denoising_steps': args.denoising_steps,
662
+ 'cfg_scale': args.cfg_scale,
663
+ 'seed': args.seed,
664
+ 'fp16': args.fp16, # 目前的ckpt使用fp16会导致nan,发现是因为i2p模型的layernorm产生了单个nan导致的,在训练阶段也采用fp16可能可以解决这个问题
665
+ }
666
+
667
+ GeneFace2Infer.example_run(inp)
inference/real3dportrait_demo.ipynb ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/yerfor/Real3DPortrait/blob/main/real3dportrait_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "QS04K9oO21AW"
17
+ },
18
+ "source": [
19
+ "Check GPU"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {
26
+ "id": "1ESQRDb-yVUG"
27
+ },
28
+ "outputs": [],
29
+ "source": [
30
+ "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "metadata": {
36
+ "id": "y-ctmIvu3Ei8"
37
+ },
38
+ "source": [
39
+ "Installation"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {
46
+ "id": "gXu76wdDgaxo"
47
+ },
48
+ "outputs": [],
49
+ "source": [
50
+ "# install pytorch3d, about 15s\n",
51
+ "import os\n",
52
+ "import sys\n",
53
+ "import torch\n",
54
+ "need_pytorch3d=False\n",
55
+ "try:\n",
56
+ " import pytorch3d\n",
57
+ "except ModuleNotFoundError:\n",
58
+ " need_pytorch3d=True\n",
59
+ "if need_pytorch3d:\n",
60
+ " if torch.__version__.startswith(\"2.1.\") and sys.platform.startswith(\"linux\"):\n",
61
+ " # We try to install PyTorch3D via a released wheel.\n",
62
+ " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
63
+ " version_str=\"\".join([\n",
64
+ " f\"py3{sys.version_info.minor}_cu\",\n",
65
+ " torch.version.cuda.replace(\".\",\"\"),\n",
66
+ " f\"_pyt{pyt_version_str}\"\n",
67
+ " ])\n",
68
+ " !pip install fvcore iopath\n",
69
+ " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
70
+ " else:\n",
71
+ " # We try to install PyTorch3D from source.\n",
72
+ " !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {
79
+ "id": "DuUynxmotG_-"
80
+ },
81
+ "outputs": [],
82
+ "source": [
83
+ "# install dependencies, about 5~10 min\n",
84
+ "!pip install tensorboard==2.13.0 tensorboardX==2.6.1\n",
85
+ "!pip install pyspy==0.1.1\n",
86
+ "!pip install protobuf==3.20.3\n",
87
+ "!pip install scipy==1.9.1\n",
88
+ "!pip install kornia==0.5.0\n",
89
+ "!pip install trimesh==3.22.0\n",
90
+ "!pip install einops==0.6.1 torchshow==0.5.1\n",
91
+ "!pip install imageio==2.31.1 imageio-ffmpeg==0.4.8\n",
92
+ "!pip install scikit-learn==1.3.0 scikit-image==0.21.0\n",
93
+ "!pip install av==10.0.0 lpips==0.1.4\n",
94
+ "!pip install timm==0.9.2 librosa==0.9.2\n",
95
+ "!pip install openmim==0.3.9\n",
96
+ "!mim install mmcv==2.1.0 # use mim to speed up installation for mmcv\n",
97
+ "!pip install transformers==4.33.2\n",
98
+ "!pip install pretrainedmodels==0.7.4\n",
99
+ "!pip install ninja==1.11.1\n",
100
+ "!pip install faiss-cpu==1.7.4\n",
101
+ "!pip install praat-parselmouth==0.4.3 moviepy==1.0.3\n",
102
+ "!pip install mediapipe==0.10.7\n",
103
+ "!pip install --upgrade attr\n",
104
+ "!pip install beartype==0.16.4 gateloop_transformer==0.4.0\n",
105
+ "!pip install torchode==0.2.0 torchdiffeq==0.2.3\n",
106
+ "!pip install hydra-core==1.3.2 pandas==2.1.3\n",
107
+ "!pip install pytorch_lightning==2.1.2\n",
108
+ "!pip install httpx==0.23.3\n",
109
+ "!pip install gradio==4.16.0\n",
110
+ "!pip install gdown\n",
111
+ "!pip install pyloudnorm webrtcvad pyworld==0.2.1rc0 pypinyin==0.42.0"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {
118
+ "id": "0GLEV0HVu8rj"
119
+ },
120
+ "outputs": [],
121
+ "source": [
122
+ "# RESTART kernel to make sure runtime is correct if you meet runtime errors\n",
123
+ "# import os\n",
124
+ "# os.kill(os.getpid(), 9)"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "markdown",
129
+ "metadata": {
130
+ "id": "5UfKHKrH6kcq"
131
+ },
132
+ "source": [
133
+ "Clone code and download checkpoints"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {
140
+ "id": "-gfRsd9DwIgl"
141
+ },
142
+ "outputs": [],
143
+ "source": [
144
+ "# clone Real3DPortrait repo from github\n",
145
+ "!git clone https://github.com/yerfor/Real3DPortrait\n",
146
+ "%cd Real3DPortrait"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "metadata": {
153
+ "id": "Yju8dQY7x5OS"
154
+ },
155
+ "outputs": [],
156
+ "source": [
157
+ "# download pretrained ckpts & third-party ckpts from google drive, about 1 min\n",
158
+ "!pip install --upgrade --no-cache-dir gdown\n",
159
+ "%cd deep_3drecon/BFM\n",
160
+ "!gdown https://drive.google.com/uc?id=1SPM3IHsyNAaVMwqZZGV6QVaV7I2Hly0v\n",
161
+ "!gdown https://drive.google.com/uc?id=1MSldX9UChKEb3AXLVTPzZQcsbGD4VmGF\n",
162
+ "!gdown https://drive.google.com/uc?id=180ciTvm16peWrcpl4DOekT9eUQ-lJfMU\n",
163
+ "!gdown https://drive.google.com/uc?id=1KX9MyGueFB3M-X0Ss152x_johyTXHTfU\n",
164
+ "!gdown https://drive.google.com/uc?id=19-NyZn_I0_mkF-F5GPyFMwQJ_-WecZIL\n",
165
+ "!gdown https://drive.google.com/uc?id=11ouQ7Wr2I-JKStp2Fd1afedmWeuifhof\n",
166
+ "!gdown https://drive.google.com/uc?id=18ICIvQoKX-7feYWP61RbpppzDuYTptCq\n",
167
+ "!gdown https://drive.google.com/uc?id=1VktuY46m0v_n_d4nvOupauJkK4LF6mHE\n",
168
+ "%cd ../..\n",
169
+ "\n",
170
+ "%cd checkpoints\n",
171
+ "!gdown https://drive.google.com/uc?id=1gz8A6xestHp__GbZT5qozb43YaybRJhZ\n",
172
+ "!gdown https://drive.google.com/uc?id=1gSUIw2AkkKnlLJnNfS2FCqtaVw9tw3QF\n",
173
+ "!unzip 240210_real3dportrait_orig.zip\n",
174
+ "!unzip pretrained_ckpts.zip\n",
175
+ "!ls\n",
176
+ "%cd ..\n"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "metadata": {
182
+ "id": "LHzLro206pnA"
183
+ },
184
+ "source": [
185
+ "Inference sample"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "markdown",
190
+ "metadata": {},
191
+ "source": [
192
+ "!python inference/real3d_infer.py -h"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {
199
+ "id": "2aCDwxNivQoS"
200
+ },
201
+ "outputs": [],
202
+ "source": [
203
+ "# sample inference, about 3 min\n",
204
+ "!python inference/real3d_infer.py \\\n",
205
+ "--src_img data/raw/examples/Macron.png \\\n",
206
+ "--drv_aud data/raw/examples/Obama_5s.wav \\\n",
207
+ "--drv_pose data/raw/examples/May_5s.mp4 \\\n",
208
+ "--bg_img data/raw/examples/bg.png \\\n",
209
+ "--out_name output.mp4 \\\n",
210
+ "--out_mode concat_debug \\\n",
211
+ "--low_memory_usage"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "markdown",
216
+ "metadata": {
217
+ "id": "XL0c54l19mBG"
218
+ },
219
+ "source": [
220
+ "Display output video"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "metadata": {
227
+ "id": "6olmWwZP9Icj"
228
+ },
229
+ "outputs": [],
230
+ "source": [
231
+ "# borrow code from makeittalk\n",
232
+ "from IPython.display import HTML\n",
233
+ "from base64 import b64encode\n",
234
+ "import os, sys\n",
235
+ "import glob\n",
236
+ "\n",
237
+ "mp4_name = './output.mp4'\n",
238
+ "\n",
239
+ "mp4 = open('{}'.format(mp4_name),'rb').read()\n",
240
+ "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
241
+ "\n",
242
+ "print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
243
+ "display(HTML(\"\"\"\n",
244
+ " <video width=256 controls>\n",
245
+ " <source src=\"%s\" type=\"video/mp4\">\n",
246
+ " </video>\n",
247
+ " \"\"\" % data_url))"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {},
253
+ "source": [
254
+ "WebUI"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "\n",
264
+ "!python inference/app_real3dportrait.py --share"
265
+ ]
266
+ }
267
+ ],
268
+ "metadata": {
269
+ "accelerator": "GPU",
270
+ "colab": {
271
+ "authorship_tag": "ABX9TyPu++zOlOS4yKF4xn4FHGtZ",
272
+ "gpuType": "T4",
273
+ "include_colab_link": true,
274
+ "private_outputs": true,
275
+ "provenance": []
276
+ },
277
+ "kernelspec": {
278
+ "display_name": "Python 3",
279
+ "name": "python3"
280
+ },
281
+ "language_info": {
282
+ "name": "python"
283
+ }
284
+ },
285
+ "nbformat": 4,
286
+ "nbformat_minor": 0
287
+ }
inference/train_mimictalk_on_a_video.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 将One-shot的说话人大模型(os_secc2plane or os_secc2plane_torso)在单一说话人(一张照片或一段视频)上overfit, 实现和GeneFace++类似的效果
3
+ """
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import librosa
9
+ import random
10
+ import time
11
+ import numpy as np
12
+ import importlib
13
+ import tqdm
14
+ import copy
15
+ import cv2
16
+ import glob
17
+ import imageio
18
+ # common utils
19
+ from utils.commons.hparams import hparams, set_hparams
20
+ from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
21
+ from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
22
+ # 3DMM-related utils
23
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
24
+ from data_util.face3d_helper import Face3DHelper
25
+ from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
26
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
27
+ from data_gen.utils.process_video.extract_segment_imgs import decode_segmap_mask_from_image
28
+ from deep_3drecon.secc_renderer import SECC_Renderer
29
+ from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
30
+ from data_gen.runs.binarizer_nerf import get_lip_rect
31
+ # Face Parsing
32
+ from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
33
+ from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
34
+ # other inference utils
35
+ from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
36
+ from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
37
+ from inference.edit_secc import blink_eye_for_secc, hold_eye_opened_for_secc
38
+ from modules.commons.loralib.utils import mark_only_lora_as_trainable
39
+ from utils.nn.model_utils import num_params
40
+ import lpips
41
+ from utils.commons.meters import AvgrageMeter
42
+ meter = AvgrageMeter()
43
+ from torch.utils.tensorboard import SummaryWriter
44
+ class LoRATrainer(nn.Module):
45
+ def __init__(self, inp):
46
+ super().__init__()
47
+ self.inp = inp
48
+ self.lora_args = {'lora_mode': inp['lora_mode'], 'lora_r': inp['lora_r']}
49
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ head_model_dir = inp['head_ckpt']
51
+ torso_model_dir = inp['torso_ckpt']
52
+ model_dir = torso_model_dir if torso_model_dir != '' else head_model_dir
53
+ cmd = f"cp {os.path.join(model_dir, 'config.yaml')} {self.inp['work_dir']}"
54
+ print(cmd)
55
+ os.system(cmd)
56
+ with open(os.path.join(self.inp['work_dir'], 'config.yaml'), "a") as f:
57
+ f.write(f"\nlora_r: {inp['lora_r']}")
58
+ f.write(f"\nlora_mode: {inp['lora_mode']}")
59
+ f.write(f"\n")
60
+ self.secc2video_model = self.load_secc2video(model_dir)
61
+ self.secc2video_model.to(device).eval()
62
+ self.seg_model = MediapipeSegmenter()
63
+ self.secc_renderer = SECC_Renderer(512)
64
+ self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
65
+ self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
66
+ # self.camera_selector = KNearestCameraSelector()
67
+ self.load_training_data(inp)
68
+ def load_secc2video(self, model_dir):
69
+ inp = self.inp
70
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane, OSAvatarSECC_Img2plane_Torso
71
+ hp = set_hparams(f"{model_dir}/config.yaml", print_hparams=False, global_hparams=True)
72
+ hp['htbsr_head_threshold'] = 1.0
73
+ self.neural_rendering_resolution = hp['neural_rendering_resolution']
74
+ if 'torso' in hp['task_cls'].lower():
75
+ self.torso_mode = True
76
+ model = OSAvatarSECC_Img2plane_Torso(hp=hp, lora_args=self.lora_args)
77
+ else:
78
+ self.torso_mode = False
79
+ model = OSAvatarSECC_Img2plane(hp=hp, lora_args=self.lora_args)
80
+ mark_only_lora_as_trainable(model, bias='none')
81
+ lora_ckpt_path = os.path.join(inp['work_dir'], 'checkpoint.ckpt')
82
+ if os.path.exists(lora_ckpt_path):
83
+ self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, model.triplane_hid_dim*model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
84
+ model._last_cano_planes = self.learnable_triplane
85
+ load_ckpt(model, lora_ckpt_path, model_name='model', strict=False)
86
+ else:
87
+ load_ckpt(model, f"{model_dir}", model_name='model', strict=False)
88
+
89
+ num_params(model)
90
+ self.model = model
91
+ return model
92
+ def load_training_data(self, inp):
93
+ video_id = inp['video_id']
94
+ if video_id.endswith((".mp4", ".png", ".jpg", ".jpeg")):
95
+ # If input video is not GeneFace training videos, convert it into GeneFace convention
96
+ video_id_ = video_id
97
+ video_id = os.path.basename(video_id)[:-4]
98
+ inp['video_id'] = video_id
99
+ target_video_path = f'data/raw/videos/{video_id}.mp4'
100
+ if not os.path.exists(target_video_path):
101
+ print(f"| Copying video to {target_video_path}")
102
+ os.makedirs(os.path.dirname(target_video_path), exist_ok=True)
103
+ cmd = f"ffmpeg -i {video_id_} -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -y {target_video_path}"
104
+ print(f"| {cmd}")
105
+ os.system(cmd)
106
+ target_video_path = f'data/raw/videos/{video_id}.mp4'
107
+ print(f"| Copy source video into work dir: {self.inp['work_dir']}")
108
+ os.system(f"cp {target_video_path} {self.inp['work_dir']}")
109
+ # check head_img path
110
+ head_img_pattern = f'data/processed/videos/{video_id}/head_imgs/*.png'
111
+ head_img_names = sorted(glob.glob(head_img_pattern))
112
+ if len(head_img_names) == 0:
113
+ # extract head_imgs
114
+ head_img_dir = os.path.dirname(head_img_pattern)
115
+ print(f"| Pre-extracted head_imgs not found, try to extract and save to {head_img_dir}, this may take a while...")
116
+ gt_img_dir = f"data/processed/videos/{video_id}/gt_imgs"
117
+ os.makedirs(gt_img_dir, exist_ok=True)
118
+ target_video_path = f'data/raw/videos/{video_id}.mp4'
119
+ cmd = f"ffmpeg -i {target_video_path} -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -y {gt_img_dir}/%08d.jpg"
120
+ print(f"| {cmd}")
121
+ os.system(cmd)
122
+ # extract image, segmap, and background
123
+ cmd = f"python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir={target_video_path}"
124
+ print(f"| {cmd}")
125
+ os.system(cmd)
126
+ print("| Head images Extracted!")
127
+ num_samples = len(head_img_names)
128
+ npy_name = f"data/processed/videos/{video_id}/coeff_fit_mp_for_lora.npy"
129
+ if os.path.exists(npy_name):
130
+ coeff_dict = np.load(npy_name, allow_pickle=True).tolist()
131
+ else:
132
+ print(f"| Pre-extracted 3DMM coefficient not found, try to extract and save to {npy_name}, this may take a while...")
133
+ coeff_dict = fit_3dmm_for_a_video(f'data/raw/videos/{video_id}.mp4', save=False)
134
+ os.makedirs(os.path.dirname(npy_name), exist_ok=True)
135
+ np.save(npy_name, coeff_dict)
136
+ ids = convert_to_tensor(coeff_dict['id']).reshape([-1,80]).cuda()
137
+ exps = convert_to_tensor(coeff_dict['exp']).reshape([-1,64]).cuda()
138
+ eulers = convert_to_tensor(coeff_dict['euler']).reshape([-1,3]).cuda()
139
+ trans = convert_to_tensor(coeff_dict['trans']).reshape([-1,3]).cuda()
140
+ WH = 512 # now we only support 512x512
141
+ lm2ds = WH * self.face3d_helper.reconstruct_lm2d(ids, exps, eulers, trans).cpu().numpy()
142
+ lip_rects = [get_lip_rect(lm2ds[i], WH, WH) for i in range(len(lm2ds))]
143
+ kps = self.face3d_helper.reconstruct_lm2d(ids, exps, eulers, trans).cuda()
144
+ kps = (kps-0.5) / 0.5 # rescale to -1~1
145
+ kps = torch.cat([kps, torch.zeros([*kps.shape[:-1], 1]).cuda()], dim=-1)
146
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler': torch.tensor(coeff_dict['euler']).reshape([-1,3]), 'trans': torch.tensor(coeff_dict['trans']).reshape([-1,3])})
147
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
148
+ cameras = torch.tensor(np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)).cuda()
149
+ camera_smo_ksize = 7
150
+ cameras = smooth_camera_sequence(cameras.cpu().numpy(), kernel_size=camera_smo_ksize) # [T, 25]
151
+ cameras = torch.tensor(cameras).cuda()
152
+ zero_eulers = eulers * 0
153
+ zero_trans = trans * 0
154
+ _, cano_secc_color = self.secc_renderer(ids[0:1], exps[0:1]*0, zero_eulers[0:1], zero_trans[0:1])
155
+ src_idx = 0
156
+ _, src_secc_color = self.secc_renderer(ids[0:1], exps[src_idx:src_idx+1], zero_eulers[0:1], zero_trans[0:1])
157
+ drv_secc_colors = [None for _ in range(len(exps))]
158
+ drv_head_imgs = [None for _ in range(len(exps))]
159
+ drv_torso_imgs = [None for _ in range(len(exps))]
160
+ drv_com_imgs = [None for _ in range(len(exps))]
161
+ segmaps = [None for _ in range(len(exps))]
162
+ img_name = f'data/processed/videos/{video_id}/bg.jpg'
163
+ bg_img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
164
+ ds = {
165
+ 'id': ids.cuda().float(),
166
+ 'exps': exps.cuda().float(),
167
+ 'eulers': eulers.cuda().float(),
168
+ 'trans': trans.cuda().float(),
169
+ 'cano_secc_color': cano_secc_color.cuda().float(),
170
+ 'src_secc_color': src_secc_color.cuda().float(),
171
+ 'cameras': cameras.float(),
172
+ 'video_id': video_id,
173
+ 'lip_rects': lip_rects,
174
+ 'head_imgs': drv_head_imgs,
175
+ 'torso_imgs': drv_torso_imgs,
176
+ 'com_imgs': drv_com_imgs,
177
+ 'bg_img': bg_img,
178
+ 'segmaps': segmaps,
179
+ 'kps': kps,
180
+ }
181
+ self.ds = ds
182
+ return ds
183
+
184
+ def training_loop(self, inp):
185
+ trainer = self
186
+ video_id = self.ds['video_id']
187
+ lora_params = [p for k, p in self.secc2video_model.named_parameters() if 'lora_' in k]
188
+ self.criterion_lpips = lpips.LPIPS(net='alex',lpips=True).cuda()
189
+ self.logger = SummaryWriter(log_dir=inp['work_dir'])
190
+ if not hasattr(self, 'learnable_triplane'):
191
+ src_idx = 0 # init triplane from the first frame's prediction
192
+ self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, self.secc2video_model.triplane_hid_dim*self.secc2video_model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
193
+ img_name = f'data/processed/videos/{video_id}/head_imgs/{format(src_idx, "08d")}.png'
194
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float().cuda().float() # [3, H, W]
195
+ cano_plane = self.secc2video_model.cal_cano_plane(img.unsqueeze(0)) # [1, 3, CD, h, w]
196
+ self.learnable_triplane.data = cano_plane.data
197
+ self.secc2video_model._last_cano_planes = self.learnable_triplane
198
+ if len(lora_params) == 0:
199
+ self.optimizer = torch.optim.AdamW([self.learnable_triplane], lr=inp['lr_triplane'], weight_decay=0.01, betas=(0.9,0.98))
200
+ else:
201
+ self.optimizer = torch.optim.Adam(lora_params, lr=inp['lr'], betas=(0.9,0.98))
202
+ self.optimizer.add_param_group({
203
+ 'params': [self.learnable_triplane],
204
+ 'lr': inp['lr_triplane'],
205
+ 'betas': (0.9, 0.98)
206
+ })
207
+
208
+ ids = self.ds['id']
209
+ exps = self.ds['exps']
210
+ zero_eulers = self.ds['eulers']*0
211
+ zero_trans = self.ds['trans']*0
212
+ num_updates = inp['max_updates']
213
+ batch_size = inp['batch_size'] # 1 for lower gpu mem usage
214
+ num_samples = len(self.ds['cameras'])
215
+ init_plane = self.learnable_triplane.detach().clone()
216
+ if num_samples <= 5:
217
+ lambda_reg_triplane = 1.0
218
+ elif num_samples <= 250:
219
+ lambda_reg_triplane = 0.1
220
+ else:
221
+ lambda_reg_triplane = 0.
222
+ for i_step in tqdm.trange(num_updates+1,desc="training lora..."):
223
+ milestone_steps = []
224
+ # milestone_steps = [100, 200, 500]
225
+ if i_step % 2000 == 0 or i_step in milestone_steps:
226
+ trainer.test_loop(inp, step=i_step)
227
+ if i_step != 0:
228
+ filepath = os.path.join(inp['work_dir'], f"model_ckpt_steps_{i_step}.ckpt")
229
+ checkpoint = self.dump_checkpoint(inp)
230
+ tmp_path = str(filepath) + ".part"
231
+ torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
232
+ os.replace(tmp_path, filepath)
233
+
234
+ drv_idx = [random.randint(0, num_samples-1) for _ in range(batch_size)]
235
+ drv_secc_colors = []
236
+ gt_imgs = []
237
+ head_imgs = []
238
+ segmaps_0 = []
239
+ segmaps = []
240
+ torso_imgs = []
241
+ drv_lip_rects = []
242
+ kp_src = []
243
+ kp_drv = []
244
+ for di in drv_idx:
245
+ # 读取target image
246
+ if self.torso_mode:
247
+ if self.ds['com_imgs'][di] is None:
248
+ # img_name = f'data/processed/videos/{video_id}/gt_imgs/{format(di, "08d")}.jpg'
249
+ img_name = f'data/processed/videos/{video_id}/com_imgs/{format(di, "08d")}.jpg'
250
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
251
+ self.ds['com_imgs'][di] = img
252
+ gt_imgs.append(self.ds['com_imgs'][di])
253
+ else:
254
+ if self.ds['head_imgs'][di] is None:
255
+ img_name = f'data/processed/videos/{video_id}/head_imgs/{format(di, "08d")}.png'
256
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
257
+ self.ds['head_imgs'][di] = img
258
+ gt_imgs.append(self.ds['head_imgs'][di])
259
+ if self.ds['head_imgs'][di] is None:
260
+ img_name = f'data/processed/videos/{video_id}/head_imgs/{format(di, "08d")}.png'
261
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
262
+ self.ds['head_imgs'][di] = img
263
+ head_imgs.append(self.ds['head_imgs'][di])
264
+ # 使用第一帧的torso作为face v2v的输入
265
+ if self.ds['torso_imgs'][0] is None:
266
+ img_name = f'data/processed/videos/{video_id}/inpaint_torso_imgs/{format(0, "08d")}.png'
267
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
268
+ self.ds['torso_imgs'][0] = img
269
+ torso_imgs.append(self.ds['torso_imgs'][0])
270
+ # 所以segmap也用第一帧的了
271
+ if self.ds['segmaps'][0] is None:
272
+ img_name = f'data/processed/videos/{video_id}/segmaps/{format(0, "08d")}.png'
273
+ seg_img = cv2.imread(img_name)[:,:, ::-1]
274
+ segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
275
+ self.ds['segmaps'][0] = segmap
276
+ segmaps_0.append(self.ds['segmaps'][0])
277
+ if self.ds['segmaps'][di] is None:
278
+ img_name = f'data/processed/videos/{video_id}/segmaps/{format(di, "08d")}.png'
279
+ seg_img = cv2.imread(img_name)[:,:, ::-1]
280
+ segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
281
+ self.ds['segmaps'][di] = segmap
282
+ segmaps.append(self.ds['segmaps'][di])
283
+ _, secc_color = self.secc_renderer(ids[0:1], exps[di:di+1], zero_eulers[0:1], zero_trans[0:1])
284
+ drv_secc_colors.append(secc_color)
285
+ drv_lip_rects.append(self.ds['lip_rects'][di])
286
+ kp_src.append(self.ds['kps'][0])
287
+ kp_drv.append(self.ds['kps'][di])
288
+ bg_img = self.ds['bg_img'].unsqueeze(0).repeat([batch_size, 1, 1, 1]).cuda()
289
+ ref_torso_imgs = torch.stack(torso_imgs).float().cuda()
290
+ kp_src = torch.stack(kp_src).float().cuda()
291
+ kp_drv = torch.stack(kp_drv).float().cuda()
292
+ segmaps = torch.stack(segmaps).float().cuda()
293
+ segmaps_0 = torch.stack(segmaps_0).float().cuda()
294
+ tgt_imgs = torch.stack(gt_imgs).float().cuda()
295
+ head_imgs = torch.stack(head_imgs).float().cuda()
296
+ drv_secc_color = torch.cat(drv_secc_colors)
297
+ cano_secc_color = self.ds['cano_secc_color'].repeat([batch_size, 1, 1, 1])
298
+ src_secc_color = self.ds['src_secc_color'].repeat([batch_size, 1, 1, 1])
299
+ cond = {'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_color,
300
+ 'ref_torso_img': ref_torso_imgs, 'bg_img': bg_img,
301
+ 'segmap': segmaps_0, # v2v使用第一帧的torso作为source image来warp
302
+ 'kp_s': kp_src, 'kp_d': kp_drv}
303
+ camera = self.ds['cameras'][drv_idx]
304
+ gen_output = self.secc2video_model.forward(img=None, camera=camera, cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
305
+ pred_imgs = gen_output['image']
306
+ pred_imgs_raw = gen_output['image_raw']
307
+
308
+ losses = {}
309
+ loss_weights = {
310
+ 'v2v_occlusion_reg_l1_loss': 0.001, # loss for face_vid2vid-based torso
311
+ 'v2v_occlusion_2_reg_l1_loss': 0.001, # loss for face_vid2vid-based torso
312
+ 'v2v_occlusion_2_weights_entropy_loss': hparams['lam_occlusion_weights_entropy'], # loss for face_vid2vid-based torso
313
+ 'density_weight_l2_loss': 0.01, # supervised density
314
+ 'density_weight_entropy_loss': 0.001, # keep the density change sharp
315
+ 'mse_loss': 1.,
316
+ 'head_mse_loss': 0.2, # loss on neural rendering low-reso pred_img
317
+ 'lip_mse_loss': 1.0,
318
+ 'lpips_loss': 0.5,
319
+ 'head_lpips_loss': 0.1,
320
+ 'lip_lpips_loss': 1.0, # make the teeth more clear
321
+ 'blink_reg_loss': 0.003, # increase it when you find head shake while blinking; decrease it when you find the eye cannot closed.
322
+ 'triplane_reg_loss': lambda_reg_triplane,
323
+ 'secc_reg_loss': 0.01, # used to reduce flicking
324
+ }
325
+
326
+ occlusion_reg_l1 = gen_output.get("losses", {}).get('facev2v/occlusion_reg_l1', 0.)
327
+ occlusion_2_reg_l1 = gen_output.get("losses", {}).get('facev2v/occlusion_2_reg_l1', 0.)
328
+ occlusion_2_weights_entropy = gen_output.get("losses", {}).get('facev2v/occlusion_2_weights_entropy', 0.)
329
+ losses['v2v_occlusion_reg_l1_loss'] = occlusion_reg_l1
330
+ losses['v2v_occlusion_2_reg_l1_loss'] = occlusion_2_reg_l1
331
+ losses['v2v_occlusion_2_weights_entropy_loss'] = occlusion_2_weights_entropy
332
+
333
+ # Weights Reg loss in torso
334
+ neural_rendering_reso = self.neural_rendering_resolution
335
+ alphas = gen_output['weights_img'].clamp(1e-5, 1 - 1e-5)
336
+ loss_weights_entropy = torch.mean(- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas))
337
+ mv_head_masks = segmaps[:, [1,3,5]].sum(dim=1)
338
+ mv_head_masks_raw = F.interpolate(mv_head_masks.unsqueeze(1), size=(neural_rendering_reso,neural_rendering_reso)).squeeze(1)
339
+ face_mask = mv_head_masks_raw.bool().unsqueeze(1)
340
+ nonface_mask = ~ face_mask
341
+ loss_weights_l2_loss = (alphas[nonface_mask]-0).pow(2).mean() + (alphas[face_mask]-1).pow(2).mean()
342
+ losses['density_weight_l2_loss'] = loss_weights_l2_loss
343
+ losses['density_weight_entropy_loss'] = loss_weights_entropy
344
+
345
+ mse_loss = (pred_imgs - tgt_imgs).abs().mean()
346
+ head_mse_loss = (pred_imgs_raw - F.interpolate(head_imgs, size=(neural_rendering_reso,neural_rendering_reso), mode='bilinear', antialias=True)).abs().mean()
347
+ lpips_loss = self.criterion_lpips(pred_imgs, tgt_imgs).mean()
348
+ head_lpips_loss = self.criterion_lpips(pred_imgs_raw, F.interpolate(head_imgs, size=(neural_rendering_reso,neural_rendering_reso), mode='bilinear', antialias=True)).mean()
349
+ lip_mse_loss = 0
350
+ lip_lpips_loss = 0
351
+ for i in range(len(drv_idx)):
352
+ xmin, xmax, ymin, ymax = drv_lip_rects[i]
353
+ lip_tgt_imgs = tgt_imgs[i:i+1,:, ymin:ymax,xmin:xmax].contiguous()
354
+ lip_pred_imgs = pred_imgs[i:i+1,:, ymin:ymax,xmin:xmax].contiguous()
355
+ try:
356
+ lip_mse_loss = lip_mse_loss + (lip_pred_imgs - lip_tgt_imgs).abs().mean()
357
+ lip_lpips_loss = lip_lpips_loss + self.criterion_lpips(lip_pred_imgs, lip_tgt_imgs).mean()
358
+ except: pass
359
+ losses['mse_loss'] = mse_loss
360
+ losses['head_mse_loss'] = head_mse_loss
361
+ losses['lpips_loss'] = lpips_loss
362
+ losses['head_lpips_loss'] = head_lpips_loss
363
+ losses['lip_mse_loss'] = lip_mse_loss
364
+ losses['lip_lpips_loss'] = lip_lpips_loss
365
+
366
+ # eye blink reg loss
367
+ if i_step % 4 == 0:
368
+ blink_secc_lst1 = []
369
+ blink_secc_lst2 = []
370
+ blink_secc_lst3 = []
371
+ for i in range(len(drv_secc_color)):
372
+ secc = drv_secc_color[i]
373
+ blink_percent1 = random.random() * 0.5 # 0~0.5
374
+ blink_percent3 = 0.5 + random.random() * 0.5 # 0.5~1.0
375
+ blink_percent2 = (blink_percent1 + blink_percent3)/2
376
+ try:
377
+ out_secc1 = blink_eye_for_secc(secc, blink_percent1).to(secc.device)
378
+ out_secc2 = blink_eye_for_secc(secc, blink_percent2).to(secc.device)
379
+ out_secc3 = blink_eye_for_secc(secc, blink_percent3).to(secc.device)
380
+ except:
381
+ print("blink eye for secc failed, use original secc")
382
+ out_secc1 = copy.deepcopy(secc)
383
+ out_secc2 = copy.deepcopy(secc)
384
+ out_secc3 = copy.deepcopy(secc)
385
+ blink_secc_lst1.append(out_secc1)
386
+ blink_secc_lst2.append(out_secc2)
387
+ blink_secc_lst3.append(out_secc3)
388
+ src_secc_color1 = torch.stack(blink_secc_lst1)
389
+ src_secc_color2 = torch.stack(blink_secc_lst2)
390
+ src_secc_color3 = torch.stack(blink_secc_lst3)
391
+ blink_cond1 = {'cond_cano': cano_secc_color, 'cond_src': src_secc_color, 'cond_tgt': src_secc_color1}
392
+ blink_cond2 = {'cond_cano': cano_secc_color, 'cond_src': src_secc_color, 'cond_tgt': src_secc_color2}
393
+ blink_cond3 = {'cond_cano': cano_secc_color, 'cond_src': src_secc_color, 'cond_tgt': src_secc_color3}
394
+ blink_secc_plane1 = self.model.cal_secc_plane(blink_cond1)
395
+ blink_secc_plane2 = self.model.cal_secc_plane(blink_cond2)
396
+ blink_secc_plane3 = self.model.cal_secc_plane(blink_cond3)
397
+ interpolate_blink_secc_plane = (blink_secc_plane1 + blink_secc_plane3) / 2
398
+ blink_reg_loss = torch.nn.functional.l1_loss(blink_secc_plane2, interpolate_blink_secc_plane)
399
+ losses['blink_reg_loss'] = blink_reg_loss
400
+
401
+ # Triplane Reg loss
402
+ triplane_reg_loss = (self.learnable_triplane - init_plane).abs().mean()
403
+ losses['triplane_reg_loss'] = triplane_reg_loss
404
+
405
+
406
+ ref_id = self.ds['id'][0:1]
407
+ secc_pertube_randn_scale = hparams['secc_pertube_randn_scale']
408
+ perturbed_id = ref_id + torch.randn_like(ref_id) * secc_pertube_randn_scale
409
+ drv_exp = self.ds['exps'][drv_idx]
410
+ perturbed_exp = drv_exp + torch.randn_like(drv_exp) * secc_pertube_randn_scale
411
+ zero_euler = torch.zeros([len(drv_idx), 3], device=ref_id.device, dtype=ref_id.dtype)
412
+ zero_trans = torch.zeros([len(drv_idx), 3], device=ref_id.device, dtype=ref_id.dtype)
413
+ perturbed_secc = self.secc_renderer(perturbed_id, perturbed_exp, zero_euler, zero_trans)[1]
414
+ secc_reg_loss = torch.nn.functional.l1_loss(drv_secc_color, perturbed_secc)
415
+ losses['secc_reg_loss'] = secc_reg_loss
416
+
417
+
418
+ total_loss = sum([loss_weights[k] * v for k, v in losses.items() if isinstance(v, torch.Tensor) and v.requires_grad])
419
+ # Update weights
420
+ self.optimizer.zero_grad()
421
+ total_loss.backward()
422
+ self.learnable_triplane.grad.data = self.learnable_triplane.grad.data * self.learnable_triplane.numel()
423
+ self.optimizer.step()
424
+ meter.update(total_loss.item())
425
+ if i_step % 10 == 0:
426
+ log_line = f"Iter {i_step+1}: total_loss={meter.avg} "
427
+ for k, v in losses.items():
428
+ log_line = log_line + f" {k}={v.item()}, "
429
+ self.logger.add_scalar(f"train/{k}", v.item(), i_step)
430
+ print(log_line)
431
+ meter.reset()
432
+ @torch.no_grad()
433
+ def test_loop(self, inp, step=''):
434
+ self.model.eval()
435
+ # coeff_dict = np.load('data/processed/videos/Lieu/coeff_fit_mp_for_lora.npy', allow_pickle=True).tolist()
436
+ # drv_exps = torch.tensor(coeff_dict['exp']).cuda().float()
437
+ drv_exps = self.ds['exps']
438
+ zero_eulers = self.ds['eulers']*0
439
+ zero_trans = self.ds['trans']*0
440
+ batch_size = 1
441
+ num_samples = len(self.ds['cameras'])
442
+ video_writer = imageio.get_writer(os.path.join(inp['work_dir'], f'val_step{step}.mp4'), fps=25)
443
+ total_iters = min(num_samples, 250)
444
+ video_id = inp['video_id']
445
+ for i in tqdm.trange(total_iters,desc="testing lora..."):
446
+ drv_idx = [i]
447
+ drv_secc_colors = []
448
+ gt_imgs = []
449
+ segmaps = []
450
+ torso_imgs = []
451
+ drv_lip_rects = []
452
+ kp_src = []
453
+ kp_drv = []
454
+ for di in drv_idx:
455
+ # 读取target image
456
+ if self.torso_mode:
457
+ if self.ds['com_imgs'][di] is None:
458
+ img_name = f'data/processed/videos/{video_id}/com_imgs/{format(di, "08d")}.jpg'
459
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
460
+ self.ds['com_imgs'][di] = img
461
+ gt_imgs.append(self.ds['com_imgs'][di])
462
+ else:
463
+ if self.ds['head_imgs'][di] is None:
464
+ img_name = f'data/processed/videos/{video_id}/head_imgs/{format(di, "08d")}.png'
465
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
466
+ self.ds['head_imgs'][di] = img
467
+ gt_imgs.append(self.ds['head_imgs'][di])
468
+ # 使用第一帧的torso作为face v2v的输入
469
+ if self.ds['torso_imgs'][0] is None:
470
+ img_name = f'data/processed/videos/{video_id}/inpaint_torso_imgs/{format(0, "08d")}.png'
471
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
472
+ self.ds['torso_imgs'][0] = img
473
+ torso_imgs.append(self.ds['torso_imgs'][0])
474
+ # 所以segmap也用第一帧的了
475
+ if self.ds['segmaps'][0] is None:
476
+ img_name = f'data/processed/videos/{video_id}/segmaps/{format(0, "08d")}.png'
477
+ seg_img = cv2.imread(img_name)[:,:, ::-1]
478
+ segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
479
+ self.ds['segmaps'][0] = segmap
480
+ segmaps.append(self.ds['segmaps'][0])
481
+ drv_lip_rects.append(self.ds['lip_rects'][di])
482
+ kp_src.append(self.ds['kps'][0])
483
+ kp_drv.append(self.ds['kps'][di])
484
+ bg_img = self.ds['bg_img'].unsqueeze(0).repeat([batch_size, 1, 1, 1]).cuda()
485
+ ref_torso_imgs = torch.stack(torso_imgs).float().cuda()
486
+ kp_src = torch.stack(kp_src).float().cuda()
487
+ kp_drv = torch.stack(kp_drv).float().cuda()
488
+ segmaps = torch.stack(segmaps).float().cuda()
489
+ tgt_imgs = torch.stack(gt_imgs).float().cuda()
490
+ for di in drv_idx:
491
+ _, secc_color = self.secc_renderer(self.ds['id'][0:1], drv_exps[di:di+1], zero_eulers[0:1], zero_trans[0:1])
492
+ drv_secc_colors.append(secc_color)
493
+ drv_secc_color = torch.cat(drv_secc_colors)
494
+ cano_secc_color = self.ds['cano_secc_color'].repeat([batch_size, 1, 1, 1])
495
+ src_secc_color = self.ds['src_secc_color'].repeat([batch_size, 1, 1, 1])
496
+ cond = {'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_color,
497
+ 'ref_torso_img': ref_torso_imgs, 'bg_img': bg_img, 'segmap': segmaps,
498
+ 'kp_s': kp_src, 'kp_d': kp_drv}
499
+ camera = self.ds['cameras'][drv_idx]
500
+ gen_output = self.secc2video_model.forward(img=None, camera=camera, cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
501
+ pred_img = gen_output['image']
502
+ pred_img = ((pred_img.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
503
+ video_writer.append_data(pred_img[0])
504
+ video_writer.close()
505
+ self.model.train()
506
+
507
+ def masked_error_loss(self, img_pred, img_gt, mask, unmasked_weight=0.1, mode='l1'):
508
+ # 对raw图像,因为deform的原因背景没法全黑,导致这部分mse过高,我们将其mask掉,只计算人脸部分
509
+ masked_weight = 1.0
510
+ weight_mask = mask.float() * masked_weight + (~mask).float() * unmasked_weight
511
+ if mode == 'l1':
512
+ error = (img_pred - img_gt).abs().sum(dim=1) * weight_mask
513
+ else:
514
+ error = (img_pred - img_gt).pow(2).sum(dim=1) * weight_mask
515
+ error.clamp_(0, max(0.5, error.quantile(0.8).item())) # clamp掉较高loss的pixel,避免姿态没对齐的pixel导致的异常值占主导影响训练
516
+ loss = error.mean()
517
+ return loss
518
+
519
+ def dilate(self, bin_img, ksize=5, mode='max_pool'):
520
+ """
521
+ mode: max_pool or avg_pool
522
+ """
523
+ # bin_img, [1, h, w]
524
+ pad = (ksize-1)//2
525
+ bin_img = F.pad(bin_img, pad=[pad,pad,pad,pad], mode='reflect')
526
+ if mode == 'max_pool':
527
+ out = F.max_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0)
528
+ else:
529
+ out = F.avg_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0)
530
+ return out
531
+
532
+ def dilate_mask(self, mask, ksize=21):
533
+ mask = self.dilate(mask, ksize=ksize, mode='max_pool')
534
+ return mask
535
+
536
+ def set_unmasked_to_black(self, img, mask):
537
+ out_img = img * mask.float() - (~mask).float() # -1 denotes black
538
+ return out_img
539
+
540
+ def dump_checkpoint(self, inp):
541
+ checkpoint = {}
542
+ # save optimizers
543
+ optimizer_states = []
544
+ self.optimizers = [self.optimizer]
545
+ for i, optimizer in enumerate(self.optimizers):
546
+ if optimizer is not None:
547
+ state_dict = optimizer.state_dict()
548
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
549
+ optimizer_states.append(state_dict)
550
+ checkpoint['optimizer_states'] = optimizer_states
551
+ state_dict = {
552
+ 'model': self.model.state_dict(),
553
+ 'learnable_triplane': self.model.state_dict()['_last_cano_planes'],
554
+ }
555
+ del state_dict['model']['_last_cano_planes']
556
+ checkpoint['state_dict'] = state_dict
557
+ checkpoint['lora_args'] = self.lora_args
558
+ person_ds = {}
559
+ video_id = inp['video_id']
560
+ img_name = f'data/processed/videos/{video_id}/gt_imgs/{format(0, "08d")}.jpg'
561
+ gt_img = torch.tensor(cv2.resize(cv2.imread(img_name), (512, 512))[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
562
+ person_ds['gt_img'] = gt_img.reshape([1, 3, 512, 512])
563
+ person_ds['id'] = self.ds['id'].cpu().reshape([1, 80])
564
+ person_ds['src_kp'] = self.ds['kps'][0].cpu()
565
+ person_ds['video_id'] = inp['video_id']
566
+ checkpoint['person_ds'] = person_ds
567
+ return checkpoint
568
+ if __name__ == '__main__':
569
+ import argparse, glob, tqdm
570
+ parser = argparse.ArgumentParser()
571
+ parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
572
+ # parser.add_argument("--torso_ckpt", default='checkpoints/240210_real3dportrait_orig/secc2plane_torso_orig') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
573
+ parser.add_argument("--torso_ckpt", default='checkpoints/mimictalk_orig/os_secc2plane_torso') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
574
+ parser.add_argument("--video_id", default='data/raw/examples/GER.mp4', help="identity source, we support (1) already processed <video_id> of GeneFace, (2) video path, (3) image path")
575
+ parser.add_argument("--work_dir", default=None)
576
+ parser.add_argument("--max_updates", default=10000, type=int, help="for video, 2000 is good; for an image, 3~10 is good")
577
+ parser.add_argument("--test", action='store_true')
578
+ parser.add_argument("--batch_size", default=1, type=int, help="batch size during training, 1 needs 8GB, 2 needs 15GB")
579
+ parser.add_argument("--lr", default=0.001)
580
+ parser.add_argument("--lr_triplane", default=0.005, help="for video, 0.1; for an image, 0.001; for ablation with_triplane, 0.")
581
+ parser.add_argument("--lora_r", default=2, type=int, help="width of lora unit")
582
+ parser.add_argument("--lora_mode", default='secc2plane_sr', help='for video, full; for an image, none')
583
+
584
+ args = parser.parse_args()
585
+ inp = {
586
+ 'head_ckpt': args.head_ckpt,
587
+ 'torso_ckpt': args.torso_ckpt,
588
+ 'video_id': args.video_id,
589
+ 'work_dir': args.work_dir,
590
+ 'max_updates': args.max_updates,
591
+ 'batch_size': args.batch_size,
592
+ 'test': args.test,
593
+ 'lr': float(args.lr),
594
+ 'lr_triplane': float(args.lr_triplane),
595
+ 'lora_mode': args.lora_mode,
596
+ 'lora_r': args.lora_r,
597
+ }
598
+ if inp['work_dir'] == None:
599
+ video_id = os.path.basename(inp['video_id'])[:-4] if inp['video_id'].endswith((".mp4", ".png", ".jpg", ".jpeg")) else inp['video_id']
600
+ inp['work_dir'] = f'checkpoints_mimictalk/{video_id}'
601
+ os.makedirs(inp['work_dir'], exist_ok=True)
602
+ trainer = LoRATrainer(inp)
603
+ if inp['test']:
604
+ trainer.test_loop(inp)
605
+ else:
606
+ trainer.training_loop(inp)
607
+ trainer.test_loop(inp)
608
+ print(" ")
install_guide.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prepare the Environment
2
+ [中文文档](./install_guide-zh.md)
3
+
4
+ This guide is about building a python environment for MimicTalk with Conda (the same as `Real3D-Portrait`).
5
+
6
+ The following installation process is verified in A100/V100 + CUDA12.1.
7
+
8
+ # Install Python Packages & CUDA
9
+ ```bash
10
+ cd <MimicTalkRoot>
11
+ source <CondaRoot>/bin/activate
12
+ conda create -n mimictalk python=3.9
13
+ conda activate mimictalk
14
+
15
+ # MMCV for SegFormer network structure
16
+ # other dependencies
17
+ pip install -r docs/prepare_env/requirements.txt -v
18
+ pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
19
+ pip install cython
20
+ pip install openmim==0.3.9
21
+ mim install mmcv==2.1.0 # use mim to speed up installation for mmcv
22
+ ## build pytorch3d from Github's source code.
23
+ ## It may take a long time (maybe tens of minutes), Proxy is recommended if encountering the time-out problem
24
+ # Before install pytorch3d, you need to install CUDA-12.1 (https://developer.nvidia.com/cuda-toolkit-archive) and make sure /usr/local/cuda points to the `cuda-12.1` directory
25
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
26
+
27
+ ```
requirements.txt ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Cython
2
+ numpy # ==1.23.0
3
+ numba==0.56.4
4
+ pandas
5
+ transformers
6
+ scipy==1.11.1 # required by cal_fid. https://github.com/mseitzer/pytorch-fid/issues/103
7
+ scikit-learn
8
+ scikit-image
9
+ # tensorflow # you can flexible it, this is gpu version
10
+ tensorboard
11
+ tensorboardX
12
+ python_speech_features
13
+ resampy
14
+ opencv_python
15
+ face_alignment
16
+ matplotlib
17
+ configargparse
18
+ librosa==0.9.2
19
+ praat-parselmouth # ==0.4.3
20
+ trimesh
21
+ kornia==0.5.0
22
+ PyMCubes
23
+ lpips
24
+ setuptools # ==59.5.0
25
+ ffmpeg-python
26
+ moviepy
27
+ dearpygui
28
+ ninja
29
+ pyaudio # for extract esperanto
30
+ mediapipe
31
+ protobuf
32
+ decord
33
+ soundfile
34
+ pillow
35
+ # torch # it's better to install torch with conda
36
+ av
37
+ timm
38
+ pretrainedmodels
39
+ faiss-cpu # for fast nearest camera pose retriveal
40
+ einops
41
+ # mmcv # use mim install is faster
42
+
43
+ # conditional flow matching
44
+ beartype
45
+ torchode
46
+ torchdiffeq
47
+
48
+ # tts
49
+ cython
50
+ textgrid
51
+ pyloudnorm
52
+ websocket-client
53
+ pyworld==0.2.1rc0
54
+ pypinyin==0.42.0
55
+ webrtcvad
56
+ torchshow
57
+
58
+ # cal spk sim
59
+ # s3prl
60
+ # fire
61
+
62
+ # cal LMD
63
+ # dlib
64
+
65
+ # debug
66
+ # ipykernel
67
+
68
+ # lama
69
+ # hydra-core
70
+ # pytorch_lightning
71
+ # setproctitle
72
+
73
+ # Gradio GUI
74
+ # httpx==0.23.3
75
+ # gradio==4.16.0
76
+ gradio==4.43.0
77
+ httpx==0.23.3
78
+ # gradio_client==0.8.1
79
+ fastapi==0.112.2