mrbear1024 commited on
Commit
8eb4303
·
0 Parent(s):

init project

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +12 -0
  2. .gitignore +201 -0
  3. LICENSE +21 -0
  4. README-zh.md +157 -0
  5. README.md +155 -0
  6. checkpoints/.gitkeep +0 -0
  7. data_gen/eg3d/convert_to_eg3d_convention.py +146 -0
  8. data_gen/runs/binarizer_nerf.py +335 -0
  9. data_gen/runs/binarizer_th1kh.py +100 -0
  10. data_gen/runs/nerf/process_guide.md +49 -0
  11. data_gen/runs/nerf/run.sh +51 -0
  12. data_gen/utils/mp_feature_extractors/face_landmarker.py +130 -0
  13. data_gen/utils/mp_feature_extractors/face_landmarker.task +3 -0
  14. data_gen/utils/mp_feature_extractors/mp_segmenter.py +303 -0
  15. data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite +3 -0
  16. data_gen/utils/path_converter.py +24 -0
  17. data_gen/utils/process_audio/extract_hubert.py +92 -0
  18. data_gen/utils/process_audio/extract_mel_f0.py +148 -0
  19. data_gen/utils/process_audio/resample_audio_to_16k.py +49 -0
  20. data_gen/utils/process_image/extract_lm2d.py +197 -0
  21. data_gen/utils/process_image/extract_segment_imgs.py +114 -0
  22. data_gen/utils/process_image/fit_3dmm_landmark.py +369 -0
  23. data_gen/utils/process_video/euler2quaterion.py +35 -0
  24. data_gen/utils/process_video/extract_blink.py +50 -0
  25. data_gen/utils/process_video/extract_lm2d.py +164 -0
  26. data_gen/utils/process_video/extract_segment_imgs.py +494 -0
  27. data_gen/utils/process_video/fit_3dmm_landmark.py +565 -0
  28. data_gen/utils/process_video/inpaint_torso_imgs.py +193 -0
  29. data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py +87 -0
  30. data_gen/utils/process_video/split_video_to_imgs.py +53 -0
  31. data_util/face3d_helper.py +309 -0
  32. deep_3drecon/BFM/.gitkeep +0 -0
  33. deep_3drecon/BFM/basel_53201.txt +0 -0
  34. deep_3drecon/BFM/index_mp468_from_mesh35709_v1.npy +3 -0
  35. deep_3drecon/BFM/index_mp468_from_mesh35709_v2.npy +3 -0
  36. deep_3drecon/BFM/index_mp468_from_mesh35709_v3.1.npy +3 -0
  37. deep_3drecon/BFM/index_mp468_from_mesh35709_v3.npy +3 -0
  38. deep_3drecon/BFM/select_vertex_id.mat +0 -0
  39. deep_3drecon/BFM/similarity_Lm3D_all.mat +0 -0
  40. deep_3drecon/__init__.py +1 -0
  41. deep_3drecon/bfm_left_eye_faces.npy +3 -0
  42. deep_3drecon/bfm_right_eye_faces.npy +3 -0
  43. deep_3drecon/data_preparation.py +45 -0
  44. deep_3drecon/deep_3drecon_models/__init__.py +67 -0
  45. deep_3drecon/deep_3drecon_models/arcface_torch/README.md +218 -0
  46. deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__init__.py +85 -0
  47. deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet.py +194 -0
  48. deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet2060.py +176 -0
  49. deep_3drecon/deep_3drecon_models/arcface_torch/backbones/mobilefacenet.py +147 -0
  50. deep_3drecon/deep_3drecon_models/arcface_torch/backbones/vit.py +280 -0
.gitattributes ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite filter=lfs diff=lfs merge=lfs -text
2
+ data_gen/utils/mp_feature_extractors/face_landmarker.task filter=lfs diff=lfs merge=lfs -text
3
+ utils/audio/pitch/bin/ReaperF0 filter=lfs diff=lfs merge=lfs -text
4
+ utils/audio/pitch/bin/ExtractF0ByStraight filter=lfs diff=lfs merge=lfs -text
5
+ utils/audio/pitch/bin/InterpF0 filter=lfs diff=lfs merge=lfs -text
6
+ deep_3drecon/bfm_left_eye_faces.npy filter=lfs diff=lfs merge=lfs -text
7
+ deep_3drecon/bfm_right_eye_faces.npy filter=lfs diff=lfs merge=lfs -text
8
+ deep_3drecon/ncc_code.npy filter=lfs diff=lfs merge=lfs -text
9
+ deep_3drecon/BFM/index_mp468_from_mesh35709_v1.npy filter=lfs diff=lfs merge=lfs -text
10
+ deep_3drecon/BFM/index_mp468_from_mesh35709_v2.npy filter=lfs diff=lfs merge=lfs -text
11
+ deep_3drecon/BFM/index_mp468_from_mesh35709_v3.1.npy filter=lfs diff=lfs merge=lfs -text
12
+ deep_3drecon/BFM/index_mp468_from_mesh35709_v3.npy filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # big files
2
+ data_util/face_tracking/3DMM/01_MorphableModel.mat
3
+ data_util/face_tracking/3DMM/3DMM_info.npy
4
+
5
+ !/deep_3drecon/BFM/.gitkeep
6
+ deep_3drecon/BFM/Exp_Pca.bin
7
+ deep_3drecon/BFM/01_MorphableModel.mat
8
+ deep_3drecon/BFM/BFM_model_front.mat
9
+ deep_3drecon/network/FaceReconModel.pb
10
+ deep_3drecon/checkpoints/*
11
+
12
+ .vscode
13
+ ### Project ignore
14
+ ./checkpoints/*
15
+ /checkpoints/*
16
+ !/checkpoints/.gitkeep
17
+ /data/*
18
+ !/data/.gitkeep
19
+ infer_out
20
+ rsync
21
+ .idea
22
+ .DS_Store
23
+ bak
24
+ tmp
25
+ *.tar.gz
26
+ mos
27
+ nbs
28
+ /configs_usr/*
29
+ !/configs_usr/.gitkeep
30
+ /egs_usr/*
31
+ !/egs_usr/.gitkeep
32
+ /rnnoise
33
+ #/usr/*
34
+ #!/usr/.gitkeep
35
+ scripts_usr
36
+
37
+ # Created by .ignore support plugin (hsz.mobi)
38
+ ### Python template
39
+ # Byte-compiled / optimized / DLL files
40
+ __pycache__/
41
+ *.py[cod]
42
+ *$py.class
43
+
44
+ # C extensions
45
+ *.so
46
+
47
+ # Distribution / packaging
48
+ .Python
49
+ build/
50
+ develop-eggs/
51
+ dist/
52
+ downloads/
53
+ eggs/
54
+ .eggs/
55
+ lib/
56
+ lib64/
57
+ parts/
58
+ sdist/
59
+ var/
60
+ wheels/
61
+ pip-wheel-metadata/
62
+ share/python-wheels/
63
+ *.egg-info/
64
+ .installed.cfg
65
+ *.egg
66
+ MANIFEST
67
+
68
+ # PyInstaller
69
+ # Usually these files are written by a python script from a template
70
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
71
+ *.manifest
72
+ *.spec
73
+
74
+ # Installer logs
75
+ pip-log.txt
76
+ pip-delete-this-directory.txt
77
+
78
+ # Unit test / coverage reports
79
+ htmlcov/
80
+ .tox/
81
+ .nox/
82
+ .coverage
83
+ .coverage.*
84
+ .cache
85
+ nosetests.xml
86
+ coverage.xml
87
+ *.cover
88
+ .hypothesis/
89
+ .pytest_cache/
90
+
91
+ # Translations
92
+ *.mo
93
+ *.pot
94
+
95
+ # Django stuff:
96
+ *.log
97
+ local_settings.py
98
+ db.sqlite3
99
+ db.sqlite3-journal
100
+
101
+ # Flask stuff:
102
+ instance/
103
+ .webassets-cache
104
+
105
+ # Scrapy stuff:
106
+ .scrapy
107
+
108
+ # Sphinx documentation
109
+ docs/_build/
110
+
111
+ # PyBuilder
112
+ target/
113
+
114
+ # Jupyter Notebook
115
+ .ipynb_checkpoints
116
+
117
+ # IPython
118
+ profile_default/
119
+ ipython_config.py
120
+
121
+ # pyenv
122
+ .python-version
123
+
124
+ # pipenv
125
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
126
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
127
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
128
+ # install all needed dependencies.
129
+ #Pipfile.lock
130
+
131
+ # celery beat schedule file
132
+ celerybeat-schedule
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .venv
140
+ env/
141
+ venv/
142
+ ENV/
143
+ env.bak/
144
+ venv.bak/
145
+
146
+ # Spyder project settings
147
+ .spyderproject
148
+ .spyproject
149
+
150
+ # Rope project settings
151
+ .ropeproject
152
+
153
+ # mkdocs documentation
154
+ /site
155
+
156
+ # mypy
157
+ .mypy_cache/
158
+ .dmypy.json
159
+ dmypy.json
160
+
161
+ # Pyre type checker
162
+ .pyre/
163
+ data_util/deepspeech_features/deepspeech-0.9.2-models.pbmm
164
+ deep_3drecon/mesh_renderer/bazel-bin
165
+ deep_3drecon/mesh_renderer/bazel-mesh_renderer
166
+ deep_3drecon/mesh_renderer/bazel-out
167
+ deep_3drecon/mesh_renderer/bazel-testlogs
168
+
169
+ .nfs*
170
+ infer_outs/*
171
+
172
+ *.pth
173
+ venv_113/*
174
+ *.pt
175
+ experiments/trials
176
+ flame_3drecon/*
177
+
178
+ temp/
179
+ /kill.sh
180
+ /datasets
181
+ data_util/imagenet_classes.txt
182
+ process_data_May.sh
183
+ /env_prepare_reproduce.md
184
+ /my_debug.py
185
+
186
+ utils/metrics/shape_predictor_68_face_landmarks.dat
187
+ *.mp4
188
+ _torchshow/
189
+ *.png
190
+ *.jpg
191
+
192
+ *.mrc
193
+
194
+ deep_3drecon/BFM/BFM_exp_idx.mat
195
+ deep_3drecon/BFM/BFM_front_idx.mat
196
+ deep_3drecon/BFM/facemodel_info.mat
197
+ deep_3drecon/BFM/index_mp468_from_mesh35709.npy
198
+ deep_3drecon/BFM/mediapipe_in_bfm53201.npy
199
+ deep_3drecon/BFM/std_exp.txt
200
+ !data/raw/examples/*
201
+ /heckpoints_mimictalk
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 ZhenhuiYe
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README-zh.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes | NeurIPS 2024
2
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-%3CCOLOR%3E.svg)](https://arxiv.org/abs/2401.08503)| [![GitHub Stars](https://img.shields.io/github/stars/yerfor/MimicTalk
3
+ )](https://github.com/yerfor/MimicTalk) | [English Readme](./README.md)
4
+
5
+ 这个仓库是MimicTalk的官方PyTorch实现, 用于实现特定说话人的高表现力的虚拟人视频合成。该仓库代码基于我们先前的工作[Real3D-Portrait](https://github.com/yerfor/Real3DPortrait) (ICLR 2024),即基于NeRF的one-shot说话人合成,这让Mimictalk的训练加速且效果增强。您可以访问我们的[项目页面](https://mimictalk.github.io/)以观看Demo视频, 阅读我们的[论文](https://arxiv.org/abs/2410.06734)以了解技术细节。
6
+
7
+ <p align="center">
8
+ <br>
9
+ <img src="assets/mimictalk.png" width="100%"/>
10
+ <br>
11
+ </p>
12
+
13
+ # 快速上手!
14
+ ## 安装环境
15
+ 请参照[环境配置文档](docs/prepare_env/install_guide-zh.md),配置Conda环境`mimictalk`
16
+ ## 下载预训练与第三方模型
17
+ ### 3DMM BFM模型
18
+ 下载3DMM BFM模型:[Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) 提取码: m9q5
19
+
20
+
21
+ 下载完成后,放置全部的文件到`deep_3drecon/BFM`里,文件结构如下:
22
+ ```
23
+ deep_3drecon/BFM/
24
+ ├── 01_MorphableModel.mat
25
+ ├── BFM_exp_idx.mat
26
+ ├── BFM_front_idx.mat
27
+ ├── BFM_model_front.mat
28
+ ├── Exp_Pca.bin
29
+ ├── facemodel_info.mat
30
+ ├── index_mp468_from_mesh35709.npy
31
+ ├── mediapipe_in_bfm53201.npy
32
+ └── std_exp.txt
33
+ ```
34
+
35
+ ### 预训练模型
36
+ 下载预训练的MimicTalk相关Checkpoints:[Google Drive](https://drive.google.com/drive/folders/1Kc6ueDO9HFDN3BhtJCEKNCZtyKHSktaA?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1nQKyGV5JB6rJtda7qsThUg?pwd=mimi) 提取码: mimi
37
+
38
+ 下载完成后,放置全部的文件到`checkpoints`与`checkpoints_mimictalk`里并解压,文件结构如下:
39
+ ```
40
+ checkpoints/
41
+ ├── mimictalk_orig
42
+ │ └── os_secc2plane_torso
43
+ │ ├── config.yaml
44
+ │ └── model_ckpt_steps_100000.ckpt
45
+ |-- 240112_icl_audio2secc_vox2_cmlr
46
+ │ ├── config.yaml
47
+ │ └── model_ckpt_steps_1856000.ckpt
48
+ └── pretrained_ckpts
49
+ └── mit_b0.pth
50
+
51
+ checkpoints_mimictalk/
52
+ └── German_20s
53
+ ├── config.yaml
54
+ └── model_ckpt_steps_10000.ckpt
55
+ ```
56
+
57
+ ## MimicTalk训练与推理的最简命令
58
+ ```
59
+ python inference/train_mimictalk_on_a_video.py # train the model, this may take 10 minutes for 2,000 steps
60
+ python inference/mimictalk_infer.py # infer the model
61
+ ```
62
+
63
+
64
+ # 训练与推理细节
65
+ 我们目前提供了**命令行(CLI)**与**Gradio WebUI**推理方式。音频驱动推理的人像信息来自于`torso_ckpt`,因此需要至少再提供`driving audio`用于推理。另外,可以提供`style video`让模型能够预测与该视频风格一致的说话人动作。
66
+
67
+ 首先,切换至项目根目录并启用Conda环境:
68
+ ```bash
69
+ cd <Real3DPortraitRoot>
70
+ conda activate mimictalk
71
+ export PYTHONPATH=./
72
+ export HF_ENDPOINT=https://hf-mirror.com
73
+ ```
74
+
75
+ ## Gradio WebUI推理
76
+ 启动Gradio WebUI,按照提示上传素材,点击`Training`按钮进行训练;训练完成后点击`Generate`按钮即可推理:
77
+ ```bash
78
+ python inference/app_mimictalk.py
79
+ ```
80
+
81
+ ## 命令行特定说话人训练
82
+
83
+ 需要至少提供`source video`,训练指令:
84
+ ```bash
85
+ python inference/train_mimictalk_on_a_video.py \
86
+ --video_id <PATH_TO_SOURCE_VIDEO> \
87
+ --max_updates <UPDATES_NUMBER> \
88
+ --work_dir <PATH_TO_SAVING_CKPT>
89
+ ```
90
+
91
+ 一些可选参数注释:
92
+
93
+ - `--torso_ckpt` 预训练的Real3D-Portrait模型
94
+ - `--max_updates` 训练更新次数
95
+ - `--batch_size` 训练的batch size: `1` 需要约8GB显存; `2`需要约15GB显存
96
+ - `--lr_triplane` triplane的学习率:对于视频输入, 应为0.1; 对于图片输入,应为0.001
97
+ - `--work_dir` 未指定时,将默认存储在`checkpoints_mimictalk/`中
98
+
99
+ 指令示例:
100
+ ```bash
101
+ python inference/train_mimictalk_on_a_video.py \
102
+ --video_id data/raw/videos/German_20s.mp4 \
103
+ --max_updates 2000 \
104
+ --work_dir checkpoints_mimictalk/German_20s
105
+ ```
106
+
107
+ ## 命令行推理
108
+
109
+ 需要至少提供`driving audio`,可选提供`driving style`,推理指令:
110
+ ```bash
111
+ python inference/mimictalk_infer.py \
112
+ --drv_aud <PATH_TO_AUDIO> \
113
+ --drv_style <PATH_TO_STYLE_VIDEO, OPTIONAL> \
114
+ --drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
115
+ --bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
116
+ --out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
117
+ ```
118
+
119
+ 一些可选参数注释:
120
+ - `--drv_pose` 指定时提供了运动pose信息,不指定则为静态运动
121
+ - `--bg_img` 指定时提供了背景信息,不指定则为source image提取的背景
122
+ - `--mouth_amp` 嘴部张幅参数,值越大张幅越大
123
+ - `--map_to_init_pose` 值为`True`时,首帧的pose将被映��到source pose,后续帧也作相同变换
124
+ - `--temperature` 代表audio2motion的采样温度,值越大结果越多样,但同时精确度越低
125
+ - `--out_name` 不指定时,结果将保存在`infer_out/tmp/`中
126
+ - `--out_mode` 值为`final`时,只输出说话人视频;值为`concat_debug`时,同时输出一些可视化的中间结果
127
+
128
+ 推理命令例子:
129
+ ```bash
130
+ python inference/mimictalk_infer.py \
131
+ --drv_aud data/raw/examples/Obama_5s.wav \
132
+ --drv_pose data/raw/examples/German_20s.mp4 \
133
+ --drv_style data/raw/examples/German_20s.mp4 \
134
+ --bg_img data/raw/examples/bg.png \
135
+ --out_name output.mp4 \
136
+ --out_mode final
137
+ ```
138
+
139
+ # 声明
140
+ 任何组织或个人未经本人同意,不得使用本文提及的任何技术生成他人说话的视频,包括但不限于政府领导人、政界人士、社会名流等。如不遵守本条款,则可能违反版权法。
141
+
142
+ # 引用我们
143
+ 如果这个仓库对你有帮助,请考虑引用我们的工作:
144
+ ```
145
+ @inproceedings{ye2024mimicktalk,
146
+ author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiangwei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and Zhang, Chen and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
147
+ title = {MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes},
148
+ journal = {NeurIPS},
149
+ year = {2024},
150
+ }
151
+ @inproceedings{ye2024real3d,
152
+ title = {Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
153
+ author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
154
+ journal = {ICLR},
155
+ year={2024}
156
+ }
157
+ ```
README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes | NeurIPS 2024
2
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-%3CCOLOR%3E.svg)](https://arxiv.org/abs/2401.08503)| [![GitHub Stars](https://img.shields.io/github/stars/yerfor/MimicTalk
3
+ )](https://github.com/yerfor/MimicTalk) | [中文文档](./README-zh.md)
4
+
5
+ This is the official repo of MimicTalk with Pytorch implementation, for training a personalized and expressive talking avatar in minutes. The code is built upon our previous work, [Real3D-Portrait](https://github.com/yerfor/Real3DPortrait) (ICLR 2024), which is a one-shot NeRF-based talking avatar system and enables the fast training and good quality of our MimicTalk. You can visit our [Demo Page](https://mimictalk.github.io/) for watching demo videos, and read our [Paper](https://arxiv.org/abs/2410.06734) for technical details.
6
+
7
+ <p align="center">
8
+ <br>
9
+ <img src="assets/mimictalk.png" width="100%"/>
10
+ <br>
11
+ </p>
12
+
13
+
14
+
15
+ # Quick Start!
16
+ ## Environment Installation
17
+ Please refer to [Installation Guide](docs/prepare_env/install_guide.md), prepare a Conda environment `mimictalk`.
18
+ ## Download Pre-trained & Third-Party Models
19
+ ### 3DMM BFM Model
20
+ Download 3DMM BFM Model from [Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) with Password m9q5.
21
+
22
+
23
+ Put all the files in `deep_3drecon/BFM`, the file structure will be like this:
24
+ ```
25
+ deep_3drecon/BFM/
26
+ ├── 01_MorphableModel.mat
27
+ ├── BFM_exp_idx.mat
28
+ ├── BFM_front_idx.mat
29
+ ├── BFM_model_front.mat
30
+ ├── Exp_Pca.bin
31
+ ├── facemodel_info.mat
32
+ ├── index_mp468_from_mesh35709.npy
33
+ └── std_exp.txt
34
+ ```
35
+
36
+ ### Pre-trained Real3D-Portrait & MimicTalk
37
+ Download Pre-trained MimicTalk Checkpoints:[Google Drive](https://drive.google.com/drive/folders/1Kc6ueDO9HFDN3BhtJCEKNCZtyKHSktaA?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1nQKyGV5JB6rJtda7qsThUg?pwd=mimi) with Password `mimi`
38
+
39
+
40
+ Put the zip files in `checkpoints` & `checkpoints_mimictalk` and unzip them, the file structure will be like this:
41
+ ```
42
+ checkpoints/
43
+ ├── mimictalk_orig
44
+ │ └── os_secc2plane_torso
45
+ │ ├── config.yaml
46
+ │ └── model_ckpt_steps_100000.ckpt
47
+ |-- 240112_icl_audio2secc_vox2_cmlr
48
+ │ ├── config.yaml
49
+ │ └── model_ckpt_steps_1856000.ckpt
50
+ └── pretrained_ckpts
51
+ └── mit_b0.pth
52
+
53
+ checkpoints_mimictalk/
54
+ └── German_20s
55
+ ├── config.yaml
56
+ └── model_ckpt_steps_10000.ckpt
57
+ ```
58
+
59
+ ## Train & Infer MimicTalk in two lines
60
+ ```
61
+ python inference/train_mimictalk_on_a_video.py # train the model, this may take 10 minutes for 2,000 steps
62
+ python inference/mimictalk_infer.py # infer the model
63
+ ```
64
+
65
+ # Detailed options for train & infer
66
+ Currently, we provide **CLI**, **Gradio WebUI** for inference. We support Audio-Driven talking head generation for specific-person (which is from `torso_ckpt`), and at least prepare `driving audio` for inference. Optionly, providing `style video` enables model to predict corressponding talking style with it.
67
+
68
+ Firstly, switch to project folder and activate conda environment:
69
+ ```bash
70
+ cd <mimictalkRoot>
71
+ conda activate mimictalk
72
+ export PYTHONPATH=./
73
+ export HF_ENDPOINT=https://hf-mirror.com
74
+ ```
75
+
76
+ ## Gradio WebUI
77
+ Run Gradio WebUI demo, upload resouces in webpage,click `Training` button to train a person-specific MimicTalk model, and then click `Generate` button to inference with arbitary audio and style:
78
+ ```bash
79
+ python inference/app_mimictalk.py
80
+ ```
81
+
82
+ ## CLI Training for specific-person video
83
+ Provide `source video` for specific-person:
84
+ ```bash
85
+ python inference/train_mimictalk_on_a_video.py \
86
+ --video_id <PATH_TO_SOURCE_VIDEO> \
87
+ --max_updates <UPDATES_NUMBER> \
88
+ --work_dir <PATH_TO_SAVING_CKPT>
89
+ ```
90
+
91
+ Some training optional parameters:
92
+ - `--torso_ckpt` Pre-trained Real3d-Portrait checkpoints path
93
+ - `--max_updates` The number of training updates.
94
+ - `--batch_size` Batch size during training: `1` needs about 8GB VRAM; `2` needs about 15GB
95
+ - `--lr_triplane` Learning rate of triplane: for video, 0.1; for an image, 0.001
96
+ - `--work_dir` When not assigned, the results will be stored at `checkpoints_mimictalk/`.
97
+
98
+ Commandline example:
99
+ ```bash
100
+ python inference/train_mimictalk_on_a_video.py \
101
+ --video_id data/raw/videos/German_20s.mp4 \
102
+ --max_updates 2000 \
103
+ --work_dir checkpoints_mimictalk/German_20s
104
+ ```
105
+
106
+ ## CLI Inference
107
+
108
+ Provide `driving audio` and `driving style` (Optionly):
109
+ ```bash
110
+ python inference/mimictalk_infer.py \
111
+ --drv_aud <PATH_TO_AUDIO> \
112
+ --drv_style <PATH_TO_STYLE_VIDEO, OPTIONAL> \
113
+ --drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
114
+ --bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
115
+ --out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
116
+ ```
117
+
118
+ Some inference optional parameters:
119
+ - `--drv_pose` provide motion pose information, default to be static poses
120
+ - `--bg_img` provide background information, default to be image extracted from source
121
+ - `--map_to_init_pose` when set to `True`, the initial pose will be mapped to source pose, and other poses will be equally transformed
122
+ - `--temperature` stands for the sampling temperature of audio2motion, higher for more diverse results at the expense of lower accuracy
123
+ - `--out_name` When not assigned, the results will be stored at `infer_out/tmp/`.
124
+ - `--out_mode` When `final`, only outputs the final result; when `concat_debug`, also outputs visualization of several intermediate process.
125
+
126
+ Commandline example:
127
+ ```bash
128
+ python inference/mimictalk_infer.py \
129
+ --drv_aud data/raw/examples/Obama_5s.wav \
130
+ --drv_pose data/raw/examples/German_20s.mp4 \
131
+ --drv_style data/raw/examples/German_20s.mp4 \
132
+ --bg_img data/raw/examples/bg.png \
133
+ --out_name output.mp4 \
134
+ --out_mode final
135
+ ```
136
+
137
+ # Disclaimer
138
+ Any organization or individual is prohibited from using any technology mentioned in this paper to generate someone's talking video without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws.
139
+
140
+ # Citation
141
+ If you found this repo helpful to your work, please consider cite us:
142
+ ```
143
+ @inproceedings{ye2024mimicktalk,
144
+ author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiangwei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and Zhang, Chen and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
145
+ title = {MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes},
146
+ journal = {NeurIPS},
147
+ year = {2024},
148
+ }
149
+ @inproceedings{ye2024real3d,
150
+ title = {Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
151
+ author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
152
+ journal = {ICLR},
153
+ year={2024}
154
+ }
155
+ ```
checkpoints/.gitkeep ADDED
File without changes
data_gen/eg3d/convert_to_eg3d_convention.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import copy
4
+ from utils.commons.tensor_utils import convert_to_tensor, convert_to_np
5
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
6
+
7
+
8
+ def _fix_intrinsics(intrinsics):
9
+ """
10
+ intrinsics: [3,3], not batch-wise
11
+ """
12
+ # unnormalized normalized
13
+
14
+ # [[ f_x, s=0, x_0] [[ f_x/size_x, s=0, x_0/size_x=0.5]
15
+ # [ 0, f_y, y_0] -> [ 0, f_y/size_y, y_0/size_y=0.5]
16
+ # [ 0, 0, 1 ]] [ 0, 0, 1 ]]
17
+ intrinsics = np.array(intrinsics).copy()
18
+ assert intrinsics.shape == (3, 3), intrinsics
19
+ intrinsics[0,0] = 2985.29/700
20
+ intrinsics[1,1] = 2985.29/700
21
+ intrinsics[0,2] = 1/2
22
+ intrinsics[1,2] = 1/2
23
+ assert intrinsics[0,1] == 0
24
+ assert intrinsics[2,2] == 1
25
+ assert intrinsics[1,0] == 0
26
+ assert intrinsics[2,0] == 0
27
+ assert intrinsics[2,1] == 0
28
+ return intrinsics
29
+
30
+ # Used in original submission
31
+ def _fix_pose_orig(pose):
32
+ """
33
+ pose: [4,4], not batch-wise
34
+ """
35
+ pose = np.array(pose).copy()
36
+ location = pose[:3, 3]
37
+ radius = np.linalg.norm(location)
38
+ pose[:3, 3] = pose[:3, 3]/radius * 2.7
39
+ return pose
40
+
41
+
42
+ def get_eg3d_convention_camera_pose_intrinsic(item):
43
+ """
44
+ item: a dict during binarize
45
+
46
+ """
47
+ if item['euler'].ndim == 1:
48
+ angle = convert_to_tensor(copy.copy(item['euler']))
49
+ trans = copy.deepcopy(item['trans'])
50
+
51
+ # handle the difference of euler axis between eg3d and ours
52
+ # see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
53
+ # angle += torch.tensor([0, 3.1415926535, 3.1415926535], device=angle.device)
54
+ R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
55
+ trans[2] += -10
56
+ c = -np.dot(R, trans)
57
+ pose = np.eye(4)
58
+ pose[:3,:3] = R
59
+ c *= 0.27 # normalize camera radius
60
+ c[1] += 0.006 # additional offset used in submission
61
+ c[2] += 0.161 # additional offset used in submission
62
+ pose[0,3] = c[0]
63
+ pose[1,3] = c[1]
64
+ pose[2,3] = c[2]
65
+
66
+ focal = 2985.29 # = 1015*1024/224*(300/466.285),
67
+ # todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
68
+ pp = 512#112
69
+ w = 1024#224
70
+ h = 1024#224
71
+
72
+ K = np.eye(3)
73
+ K[0][0] = focal
74
+ K[1][1] = focal
75
+ K[0][2] = w/2.0
76
+ K[1][2] = h/2.0
77
+ convention_K = _fix_intrinsics(K)
78
+
79
+ Rot = np.eye(3)
80
+ Rot[0, 0] = 1
81
+ Rot[1, 1] = -1
82
+ Rot[2, 2] = -1
83
+ pose[:3, :3] = np.dot(pose[:3, :3], Rot) # permute axes
84
+ convention_pose = _fix_pose_orig(pose)
85
+
86
+ item['c2w'] = pose
87
+ item['convention_c2w'] = convention_pose
88
+ item['intrinsics'] = convention_K
89
+ return item
90
+ else:
91
+ num_samples = len(item['euler'])
92
+ eulers_all = convert_to_tensor(copy.deepcopy(item['euler'])) # [B, 3]
93
+ trans_all = copy.deepcopy(item['trans']) # [B, 3]
94
+
95
+ # handle the difference of euler axis between eg3d and ours
96
+ # see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
97
+ # eulers_all += torch.tensor([0, 3.1415926535, 3.1415926535], device=eulers_all.device).unsqueeze(0).repeat([eulers_all.shape[0],1])
98
+
99
+ intrinsics = []
100
+ poses = []
101
+ convention_poses = []
102
+ for i in range(num_samples):
103
+ angle = eulers_all[i]
104
+ trans = trans_all[i]
105
+ R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
106
+ trans[2] += -10
107
+ c = -np.dot(R, trans)
108
+ pose = np.eye(4)
109
+ pose[:3,:3] = R
110
+ c *= 0.27 # normalize camera radius
111
+ c[1] += 0.006 # additional offset used in submission
112
+ c[2] += 0.161 # additional offset used in submission
113
+ pose[0,3] = c[0]
114
+ pose[1,3] = c[1]
115
+ pose[2,3] = c[2]
116
+
117
+ focal = 2985.29 # = 1015*1024/224*(300/466.285),
118
+ # todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
119
+ pp = 512#112
120
+ w = 1024#224
121
+ h = 1024#224
122
+
123
+ K = np.eye(3)
124
+ K[0][0] = focal
125
+ K[1][1] = focal
126
+ K[0][2] = w/2.0
127
+ K[1][2] = h/2.0
128
+ convention_K = _fix_intrinsics(K)
129
+ intrinsics.append(convention_K)
130
+
131
+ Rot = np.eye(3)
132
+ Rot[0, 0] = 1
133
+ Rot[1, 1] = -1
134
+ Rot[2, 2] = -1
135
+ pose[:3, :3] = np.dot(pose[:3, :3], Rot)
136
+ convention_pose = _fix_pose_orig(pose)
137
+ convention_poses.append(convention_pose)
138
+ poses.append(pose)
139
+
140
+ intrinsics = np.stack(intrinsics) # [B, 3, 3]
141
+ poses = np.stack(poses) # [B, 4, 4]
142
+ convention_poses = np.stack(convention_poses) # [B, 4, 4]
143
+ item['intrinsics'] = intrinsics
144
+ item['c2w'] = poses
145
+ item['convention_c2w'] = convention_poses
146
+ return item
data_gen/runs/binarizer_nerf.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import math
4
+ import json
5
+ import imageio
6
+ import torch
7
+ import tqdm
8
+ import cv2
9
+
10
+ from data_util.face3d_helper import Face3DHelper
11
+ from utils.commons.euler2rot import euler_trans_2_c2w, c2w_to_euler_trans
12
+ from data_gen.utils.process_video.euler2quaterion import euler2quaterion, quaterion2euler
13
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
14
+
15
+
16
+ def euler2rot(euler_angle):
17
+ batch_size = euler_angle.shape[0]
18
+ theta = euler_angle[:, 0].reshape(-1, 1, 1)
19
+ phi = euler_angle[:, 1].reshape(-1, 1, 1)
20
+ psi = euler_angle[:, 2].reshape(-1, 1, 1)
21
+ one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
22
+ zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
23
+ rot_x = torch.cat((
24
+ torch.cat((one, zero, zero), 1),
25
+ torch.cat((zero, theta.cos(), theta.sin()), 1),
26
+ torch.cat((zero, -theta.sin(), theta.cos()), 1),
27
+ ), 2)
28
+ rot_y = torch.cat((
29
+ torch.cat((phi.cos(), zero, -phi.sin()), 1),
30
+ torch.cat((zero, one, zero), 1),
31
+ torch.cat((phi.sin(), zero, phi.cos()), 1),
32
+ ), 2)
33
+ rot_z = torch.cat((
34
+ torch.cat((psi.cos(), -psi.sin(), zero), 1),
35
+ torch.cat((psi.sin(), psi.cos(), zero), 1),
36
+ torch.cat((zero, zero, one), 1)
37
+ ), 2)
38
+ return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
39
+
40
+
41
+ def rot2euler(rot_mat):
42
+ batch_size = len(rot_mat)
43
+ # we assert that y in in [-0.5pi, 0.5pi]
44
+ cos_y = torch.sqrt(rot_mat[:, 1, 2] * rot_mat[:, 1, 2] + rot_mat[:, 2, 2] * rot_mat[:, 2, 2])
45
+ theta_x = torch.atan2(-rot_mat[:, 1, 2], rot_mat[:, 2, 2])
46
+ theta_y = torch.atan2(rot_mat[:, 2, 0], cos_y)
47
+ theta_z = torch.atan2(rot_mat[:, 0, 1], rot_mat[:, 0, 0])
48
+ euler_angles = torch.zeros([batch_size, 3])
49
+ euler_angles[:, 0] = theta_x
50
+ euler_angles[:, 1] = theta_y
51
+ euler_angles[:, 2] = theta_z
52
+ return euler_angles
53
+
54
+ index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
55
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
56
+
57
+ def plot_lm2d(lm2d):
58
+ WH = 512
59
+ img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
60
+
61
+ for i in range(len(lm2d)):
62
+ x, y = lm2d[i]
63
+ color = (255,0,0)
64
+ img = cv2.circle(img, center=(int(x),int(y)), radius=3, color=color, thickness=-1)
65
+ font = cv2.FONT_HERSHEY_SIMPLEX
66
+ for i in range(len(lm2d)):
67
+ x, y = lm2d[i]
68
+ img = cv2.putText(img, f"{i}", org=(int(x),int(y)), fontFace=font, fontScale=0.3, color=(255,0,0))
69
+ return img
70
+
71
+ def get_face_rect(lms, h, w):
72
+ """
73
+ lms: [68, 2]
74
+ h, w: int
75
+ return: [4,]
76
+ """
77
+ assert len(lms) == 68
78
+ # min_x, max_x = np.min(lms, 0)[0], np.max(lms, 0)[0]
79
+ min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
80
+ cx = int((min_x+max_x)/2.0)
81
+ cy = int(lms[27, 1])
82
+ h_w = int((max_x-cx)*1.5)
83
+ h_h = int((lms[8, 1]-cy)*1.15)
84
+ rect_x = cx - h_w
85
+ rect_y = cy - h_h
86
+ if rect_x < 0:
87
+ rect_x = 0
88
+ if rect_y < 0:
89
+ rect_y = 0
90
+ rect_w = min(w-1-rect_x, 2*h_w)
91
+ rect_h = min(h-1-rect_y, 2*h_h)
92
+ # rect = np.array((rect_x, rect_y, rect_w, rect_h), dtype=np.int32)
93
+ # rect = [rect_x, rect_y, rect_w, rect_h]
94
+ rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
95
+ return rect # this x is width, y is height
96
+
97
+ def get_lip_rect(lms, h, w):
98
+ """
99
+ lms: [68, 2]
100
+ h, w: int
101
+ return: [4,]
102
+ """
103
+ # this x is width, y is height
104
+ # for lms, lms[:, 0] is width, lms[:, 1] is height
105
+ assert len(lms) == 68
106
+ lips = slice(48, 60)
107
+ lms = lms[lips]
108
+ min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
109
+ min_y, max_y = np.min(lms[:, 1]), np.max(lms[:, 1])
110
+ cx = int((min_x+max_x)/2.0)
111
+ cy = int((min_y+max_y)/2.0)
112
+ h_w = int((max_x-cx)*1.2)
113
+ h_h = int((max_y-cy)*1.2)
114
+
115
+ h_w = max(h_w, h_h)
116
+ h_h = h_w
117
+
118
+ rect_x = cx - h_w
119
+ rect_y = cy - h_h
120
+ rect_w = 2*h_w
121
+ rect_h = 2*h_h
122
+ if rect_x < 0:
123
+ rect_x = 0
124
+ if rect_y < 0:
125
+ rect_y = 0
126
+
127
+ if rect_x + rect_w > w:
128
+ rect_x = w - rect_w
129
+ if rect_y + rect_h > h:
130
+ rect_y = h - rect_h
131
+
132
+ rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
133
+ return rect # this x is width, y is height
134
+
135
+
136
+ # def get_lip_rect(lms, h, w):
137
+ # """
138
+ # lms: [68, 2]
139
+ # h, w: int
140
+ # return: [4,]
141
+ # """
142
+ # assert len(lms) == 68
143
+ # lips = slice(48, 60)
144
+ # # this x is width, y is height
145
+ # xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max())
146
+ # ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max())
147
+ # # padding to H == W
148
+ # cx = (xmin + xmax) // 2
149
+ # cy = (ymin + ymax) // 2
150
+ # l = max(xmax - xmin, ymax - ymin) // 2
151
+ # xmin = max(0, cx - l)
152
+ # xmax = min(h, cx + l)
153
+ # ymin = max(0, cy - l)
154
+ # ymax = min(w, cy + l)
155
+ # lip_rect = [xmin, xmax, ymin, ymax]
156
+ # return lip_rect
157
+
158
+ def get_win_conds(conds, idx, smo_win_size=8, pad_option='zero'):
159
+ """
160
+ conds: [b, t=16, h=29]
161
+ idx: long, time index of the selected frame
162
+ """
163
+ idx = max(0, idx)
164
+ idx = min(idx, conds.shape[0]-1)
165
+ smo_half_win_size = smo_win_size//2
166
+ left_i = idx - smo_half_win_size
167
+ right_i = idx + (smo_win_size - smo_half_win_size)
168
+ pad_left, pad_right = 0, 0
169
+ if left_i < 0:
170
+ pad_left = -left_i
171
+ left_i = 0
172
+ if right_i > conds.shape[0]:
173
+ pad_right = right_i - conds.shape[0]
174
+ right_i = conds.shape[0]
175
+ conds_win = conds[left_i:right_i]
176
+ if pad_left > 0:
177
+ if pad_option == 'zero':
178
+ conds_win = np.concatenate([np.zeros_like(conds_win)[:pad_left], conds_win], axis=0)
179
+ elif pad_option == 'edge':
180
+ edge_value = conds[0][np.newaxis, ...]
181
+ conds_win = np.concatenate([edge_value] * pad_left + [conds_win], axis=0)
182
+ else:
183
+ raise NotImplementedError
184
+ if pad_right > 0:
185
+ if pad_option == 'zero':
186
+ conds_win = np.concatenate([conds_win, np.zeros_like(conds_win)[:pad_right]], axis=0)
187
+ elif pad_option == 'edge':
188
+ edge_value = conds[-1][np.newaxis, ...]
189
+ conds_win = np.concatenate([conds_win] + [edge_value] * pad_right , axis=0)
190
+ else:
191
+ raise NotImplementedError
192
+ assert conds_win.shape[0] == smo_win_size
193
+ return conds_win
194
+
195
+
196
+ def load_processed_data(processed_dir):
197
+ # load necessary files
198
+ background_img_name = os.path.join(processed_dir, "bg.jpg")
199
+ assert os.path.exists(background_img_name)
200
+ head_img_dir = os.path.join(processed_dir, "head_imgs")
201
+ torso_img_dir = os.path.join(processed_dir, "inpaint_torso_imgs")
202
+ gt_img_dir = os.path.join(processed_dir, "gt_imgs")
203
+
204
+ hubert_npy_name = os.path.join(processed_dir, "aud_hubert.npy")
205
+ mel_f0_npy_name = os.path.join(processed_dir, "aud_mel_f0.npy")
206
+ coeff_npy_name = os.path.join(processed_dir, "coeff_fit_mp.npy")
207
+ lm2d_npy_name = os.path.join(processed_dir, "lms_2d.npy")
208
+
209
+ ret_dict = {}
210
+
211
+ ret_dict['bg_img'] = imageio.imread(background_img_name)
212
+ ret_dict['H'], ret_dict['W'] = ret_dict['bg_img'].shape[:2]
213
+ ret_dict['focal'], ret_dict['cx'], ret_dict['cy'] = face_model.focal, face_model.center, face_model.center
214
+
215
+ print("loading lm2d coeff ...")
216
+ lm2d_arr = np.load(lm2d_npy_name)
217
+ face_rect_lst = []
218
+ lip_rect_lst = []
219
+ for lm2d in lm2d_arr:
220
+ if len(lm2d) in [468, 478]:
221
+ lm2d = lm2d[index_lm68_from_lm468]
222
+ face_rect = get_face_rect(lm2d, ret_dict['H'], ret_dict['W'])
223
+ lip_rect = get_lip_rect(lm2d, ret_dict['H'], ret_dict['W'])
224
+ face_rect_lst.append(face_rect)
225
+ lip_rect_lst.append(lip_rect)
226
+ face_rects = np.stack(face_rect_lst, axis=0) # [T, 4]
227
+
228
+ print("loading fitted 3dmm coeff ...")
229
+ coeff_dict = np.load(coeff_npy_name, allow_pickle=True).tolist()
230
+ identity_arr = coeff_dict['id']
231
+ exp_arr = coeff_dict['exp']
232
+ ret_dict['id'] = identity_arr
233
+ ret_dict['exp'] = exp_arr
234
+ euler_arr = ret_dict['euler'] = coeff_dict['euler']
235
+ trans_arr = ret_dict['trans'] = coeff_dict['trans']
236
+ print("calculating lm3d ...")
237
+ idexp_lm3d_arr = face3d_helper.reconstruct_idexp_lm3d(torch.from_numpy(identity_arr), torch.from_numpy(exp_arr)).cpu().numpy().reshape([-1, 68*3])
238
+ len_motion = len(idexp_lm3d_arr)
239
+ video_idexp_lm3d_mean = idexp_lm3d_arr.mean(axis=0)
240
+ video_idexp_lm3d_std = idexp_lm3d_arr.std(axis=0)
241
+ ret_dict['idexp_lm3d'] = idexp_lm3d_arr
242
+ ret_dict['idexp_lm3d_mean'] = video_idexp_lm3d_mean
243
+ ret_dict['idexp_lm3d_std'] = video_idexp_lm3d_std
244
+
245
+ # now we convert the euler_trans from deep3d convention to adnerf convention
246
+ eulers = torch.FloatTensor(euler_arr)
247
+ trans = torch.FloatTensor(trans_arr)
248
+ rots = face_model.compute_rotation(eulers) # rotation matrix is a better intermediate for convention-transplan than euler
249
+
250
+ # handle the camera pose to geneface's convention
251
+ trans[:, 2] = 10 - trans[:, 2] # 抵消fit阶段的to_camera操作,即trans[...,2] = 10 - trans[...,2]
252
+ rots = rots.permute(0, 2, 1)
253
+ trans[:, 2] = - trans[:,2] # 因为intrinsic proj不同
254
+ # below is the NeRF camera preprocessing strategy, see `save_transforms` in data_util/process.py
255
+ trans = trans / 10.0
256
+ rots_inv = rots.permute(0, 2, 1)
257
+ trans_inv = - torch.bmm(rots_inv, trans.unsqueeze(2))
258
+
259
+ pose = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat([len_motion, 1, 1]) # [T, 4, 4]
260
+ pose[:, :3, :3] = rots_inv
261
+ pose[:, :3, 3] = trans_inv[:, :, 0]
262
+ c2w_transform_matrices = pose.numpy()
263
+
264
+ # process the audio features used for postnet training
265
+ print("loading hubert ...")
266
+ hubert_features = np.load(hubert_npy_name)
267
+ print("loading Mel and F0 ...")
268
+ mel_f0_features = np.load(mel_f0_npy_name, allow_pickle=True).tolist()
269
+
270
+ ret_dict['hubert'] = hubert_features
271
+ ret_dict['mel'] = mel_f0_features['mel']
272
+ ret_dict['f0'] = mel_f0_features['f0']
273
+
274
+ # obtaining train samples
275
+ frame_indices = list(range(len_motion))
276
+ num_train = len_motion // 11 * 10
277
+ train_indices = frame_indices[:num_train]
278
+ val_indices = frame_indices[num_train:]
279
+
280
+ for split in ['train', 'val']:
281
+ if split == 'train':
282
+ indices = train_indices
283
+ samples = []
284
+ ret_dict['train_samples'] = samples
285
+ elif split == 'val':
286
+ indices = val_indices
287
+ samples = []
288
+ ret_dict['val_samples'] = samples
289
+
290
+ for idx in indices:
291
+ sample = {}
292
+ sample['idx'] = idx
293
+ sample['head_img_fname'] = os.path.join(head_img_dir,f"{idx:08d}.png")
294
+ sample['torso_img_fname'] = os.path.join(torso_img_dir,f"{idx:08d}.png")
295
+ sample['gt_img_fname'] = os.path.join(gt_img_dir,f"{idx:08d}.jpg")
296
+ # assert os.path.exists(sample['head_img_fname']) and os.path.exists(sample['torso_img_fname']) and os.path.exists(sample['gt_img_fname'])
297
+ sample['face_rect'] = face_rects[idx]
298
+ sample['lip_rect'] = lip_rect_lst[idx]
299
+ sample['c2w'] = c2w_transform_matrices[idx]
300
+ samples.append(sample)
301
+ return ret_dict
302
+
303
+
304
+ class Binarizer:
305
+ def __init__(self):
306
+ self.data_dir = 'data/'
307
+
308
+ def parse(self, video_id):
309
+ processed_dir = os.path.join(self.data_dir, 'processed/videos', video_id)
310
+ binary_dir = os.path.join(self.data_dir, 'binary/videos', video_id)
311
+ out_fname = os.path.join(binary_dir, "trainval_dataset.npy")
312
+ os.makedirs(binary_dir, exist_ok=True)
313
+ ret = load_processed_data(processed_dir)
314
+ mel_name = os.path.join(processed_dir, 'aud_mel_f0.npy')
315
+ mel_f0_dict = np.load(mel_name, allow_pickle=True).tolist()
316
+ ret.update(mel_f0_dict)
317
+ np.save(out_fname, ret, allow_pickle=True)
318
+
319
+
320
+
321
+ if __name__ == '__main__':
322
+ from argparse import ArgumentParser
323
+ parser = ArgumentParser()
324
+ parser.add_argument('--video_id', type=str, default='May', help='')
325
+ args = parser.parse_args()
326
+ ### Process Single Long Audio for NeRF dataset
327
+ video_id = args.video_id
328
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
329
+ camera_distance=10, focal=1015)
330
+ face_model.to("cpu")
331
+ face3d_helper = Face3DHelper()
332
+
333
+ binarizer = Binarizer()
334
+ binarizer.parse(video_id)
335
+ print(f"Binarization for {video_id} Done!")
data_gen/runs/binarizer_th1kh.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from scipy.misc import face
4
+ import torch
5
+ from tqdm import trange
6
+ import pickle
7
+ from copy import deepcopy
8
+
9
+ from data_util.face3d_helper import Face3DHelper
10
+ from utils.commons.indexed_datasets import IndexedDataset, IndexedDatasetBuilder
11
+
12
+
13
+ def load_video_npy(fn):
14
+ assert fn.endswith("_coeff_fit_mp.npy")
15
+ ret_dict = np.load(fn,allow_pickle=True).item()
16
+ video_dict = {
17
+ 'euler': ret_dict['euler'], # [T, 3]
18
+ 'trans': ret_dict['trans'], # [T, 3]
19
+ 'id': ret_dict['id'], # [T, 80]
20
+ 'exp': ret_dict['exp'], # [T, 64]
21
+ }
22
+ return video_dict
23
+
24
+ def cal_lm3d_in_video_dict(video_dict, face3d_helper):
25
+ identity = video_dict['id']
26
+ exp = video_dict['exp']
27
+ idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d(identity, exp).cpu().numpy()
28
+ video_dict['idexp_lm3d'] = idexp_lm3d
29
+
30
+
31
+ def load_audio_npy(fn):
32
+ assert fn.endswith(".npy")
33
+ ret_dict = np.load(fn,allow_pickle=True).item()
34
+ audio_dict = {
35
+ "mel": ret_dict['mel'], # [T, 80]
36
+ "f0": ret_dict['f0'], # [T,1]
37
+ }
38
+ return audio_dict
39
+
40
+
41
+ if __name__ == '__main__':
42
+ face3d_helper = Face3DHelper(use_gpu=False)
43
+
44
+ import glob,tqdm
45
+ prefixs = ['val', 'train']
46
+ binarized_ds_path = "data/binary/th1kh"
47
+ os.makedirs(binarized_ds_path, exist_ok=True)
48
+ for prefix in prefixs:
49
+ databuilder = IndexedDatasetBuilder(os.path.join(binarized_ds_path, prefix), gzip=False, default_idx_size=1024*1024*1024*2)
50
+ raw_base_dir = '/mnt/bn/ailabrenyi/entries/yezhenhui/datasets/raw/TH1KH_512/video'
51
+ mp4_names = glob.glob(os.path.join(raw_base_dir, '*.mp4'))
52
+ mp4_names = mp4_names[:1000]
53
+ cnt = 0
54
+ scnt = 0
55
+ pbar = tqdm.tqdm(enumerate(mp4_names), total=len(mp4_names))
56
+ for i, mp4_name in pbar:
57
+ cnt += 1
58
+ if prefix == 'train':
59
+ if i % 100 == 0:
60
+ continue
61
+ else:
62
+ if i % 100 != 0:
63
+ continue
64
+ hubert_npy_name = mp4_name.replace("/video/", "/hubert/").replace(".mp4", "_hubert.npy")
65
+ audio_npy_name = mp4_name.replace("/video/", "/mel_f0/").replace(".mp4", "_mel_f0.npy")
66
+ video_npy_name = mp4_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4", "_coeff_fit_mp.npy")
67
+ if not os.path.exists(audio_npy_name):
68
+ print(f"Skip item for audio npy not found.")
69
+ continue
70
+ if not os.path.exists(video_npy_name):
71
+ print(f"Skip item for video npy not found.")
72
+ continue
73
+ if (not os.path.exists(hubert_npy_name)):
74
+ print(f"Skip item for hubert_npy not found.")
75
+ continue
76
+ audio_dict = load_audio_npy(audio_npy_name)
77
+ hubert = np.load(hubert_npy_name)
78
+ video_dict = load_video_npy(video_npy_name)
79
+ com_img_dir = mp4_name.replace("/video/", "/com_imgs/").replace(".mp4", "")
80
+ num_com_imgs = len(glob.glob(os.path.join(com_img_dir, '*')))
81
+ num_frames = len(video_dict['exp'])
82
+ if num_com_imgs != num_frames:
83
+ print(f"Skip item for length mismatch.")
84
+ continue
85
+ mel = audio_dict['mel']
86
+ if mel.shape[0] < 32: # the video is shorter than 0.6s
87
+ print(f"Skip item for too short.")
88
+ continue
89
+
90
+ audio_dict.update(video_dict)
91
+ audio_dict['item_id'] = os.path.basename(mp4_name)[:-4]
92
+ audio_dict['hubert'] = hubert # [T_x, hid=1024]
93
+ audio_dict['img_dir'] = com_img_dir
94
+
95
+
96
+ databuilder.add_item(audio_dict)
97
+ scnt += 1
98
+ pbar.set_postfix({'success': scnt, 'success rate': scnt / cnt})
99
+ databuilder.finalize()
100
+ print(f"{prefix} set has {cnt} samples!")
data_gen/runs/nerf/process_guide.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 温馨提示:第一次执行可以先一步步跑完下面的命令行,把环境跑通后,之后可以直接运行同目录的run.sh,一键完成下面的所有步骤。
2
+
3
+ # Step0. 将视频Crop到512x512分辨率,25FPS,确保每一帧都有目标人脸
4
+ ```
5
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 data/raw/videos/${VIDEO_ID}_512.mp4
6
+ mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
7
+ mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
8
+ ```
9
+ # step1: 提取音频特征, 如mel, f0, hubuert, esperanto
10
+ ```
11
+ export CUDA_VISIBLE_DEVICES=0
12
+ export VIDEO_ID=May
13
+ mkdir -p data/processed/videos/${VIDEO_ID}
14
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 data/processed/videos/${VIDEO_ID}/aud.wav
15
+ python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
16
+ python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
17
+ ```
18
+
19
+ # Step2. 提取图片
20
+ ```
21
+ export VIDEO_ID=May
22
+ export CUDA_VISIBLE_DEVICES=0
23
+ mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
24
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
25
+ python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
26
+ ```
27
+
28
+ # Step3. 提取lm2d_mediapipe
29
+ ### 提取2D landmark用于之后Fit 3DMM
30
+ ### num_workers是本机上的CPU worker数量;total_process是使用的机器数;process_id是本机的编号
31
+
32
+ ```
33
+ export VIDEO_ID=May
34
+ python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
35
+ ```
36
+
37
+ # Step3. fit 3dmm
38
+ ```
39
+ export VIDEO_ID=May
40
+ export CUDA_VISIBLE_DEVICES=0
41
+ python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
42
+ ```
43
+
44
+ # Step4. Binarize
45
+ ```
46
+ export VIDEO_ID=May
47
+ python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
48
+ ```
49
+ 可以看到在`data/binary/videos/Mayssss`目录下得到了数据集。
data_gen/runs/nerf/run.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # usage: CUDA_VISIBLE_DEVICES=0 bash data_gen/runs/nerf/run.sh <VIDEO_ID>
2
+ # please place video to data/raw/videos/${VIDEO_ID}.mp4
3
+ VIDEO_ID=$1
4
+ echo Processing $VIDEO_ID
5
+
6
+ echo Resizing the video to 512x512
7
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -y data/raw/videos/${VIDEO_ID}_512.mp4
8
+ mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
9
+ mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
10
+ echo Done
11
+ echo The old video is moved to data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
12
+
13
+ echo mkdir -p data/processed/videos/${VIDEO_ID}
14
+ mkdir -p data/processed/videos/${VIDEO_ID}
15
+ echo Done
16
+
17
+ # extract audio file from the training video
18
+ echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
19
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
20
+ echo Done
21
+
22
+ # extract hubert_mel_f0 from audio
23
+ echo python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
24
+ python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
25
+ echo python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
26
+ python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
27
+ echo Done
28
+
29
+ # extract segment images
30
+ echo mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
31
+ mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
32
+ echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
33
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
34
+ echo Done
35
+
36
+ echo python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
37
+ python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
38
+ echo Done
39
+
40
+ echo python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
41
+ python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
42
+ echo Done
43
+
44
+ pkill -f void*
45
+ echo python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
46
+ python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
47
+ echo Done
48
+
49
+ echo python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
50
+ python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
51
+ echo Done
data_gen/utils/mp_feature_extractors/face_landmarker.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mediapipe as mp
2
+ from mediapipe.tasks import python
3
+ from mediapipe.tasks.python import vision
4
+ import numpy as np
5
+ import cv2
6
+ import os
7
+ import copy
8
+
9
+ # simplified mediapipe ldm at https://github.com/k-m-irfan/simplified_mediapipe_face_landmarks
10
+ index_lm141_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [468,469,470,471,472] + [473,474,475,476,477] + [64,4,294]
11
+ # lm141 without iris
12
+ index_lm131_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [64,4,294]
13
+
14
+ # face alignment lm68
15
+ index_lm68_from_lm478 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
16
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
17
+ # used for weights for key parts
18
+ unmatch_mask_from_lm478 = [ 93, 127, 132, 234, 323, 356, 361, 454]
19
+ index_eye_from_lm478 = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
20
+ index_innerlip_from_lm478 = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
21
+ index_outerlip_from_lm478 = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
22
+ index_withinmouth_from_lm478 = [76, 62] + [184, 183, 74, 72, 73, 41, 72, 38, 11, 12, 302, 268, 303, 271, 304, 272, 408, 407] + [292, 306] + [325, 307, 319, 320, 403, 404, 316, 315, 15, 16, 86, 85, 179, 180, 89, 90, 96, 77]
23
+ index_mouth_from_lm478 = index_innerlip_from_lm478 + index_outerlip_from_lm478 + index_withinmouth_from_lm478
24
+
25
+ index_yaw_from_lm68 = list(range(0, 17))
26
+ index_brow_from_lm68 = list(range(17, 27))
27
+ index_nose_from_lm68 = list(range(27, 36))
28
+ index_eye_from_lm68 = list(range(36, 48))
29
+ index_mouth_from_lm68 = list(range(48, 68))
30
+
31
+
32
+ def read_video_to_frames(video_name):
33
+ frames = []
34
+ cap = cv2.VideoCapture(video_name)
35
+ while cap.isOpened():
36
+ ret, frame_bgr = cap.read()
37
+ if frame_bgr is None:
38
+ break
39
+ frames.append(frame_bgr)
40
+ frames = np.stack(frames)
41
+ frames = np.flip(frames, -1) # BGR ==> RGB
42
+ return frames
43
+
44
+ class MediapipeLandmarker:
45
+ def __init__(self):
46
+ model_path = 'data_gen/utils/mp_feature_extractors/face_landmarker.task'
47
+ if not os.path.exists(model_path):
48
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
49
+ print("downloading face_landmarker model from mediapipe...")
50
+ model_url = 'https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/latest/face_landmarker.task'
51
+ os.system(f"wget {model_url}")
52
+ os.system(f"mv face_landmarker.task {model_path}")
53
+ print("download success")
54
+ base_options = python.BaseOptions(model_asset_path=model_path)
55
+ self.image_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
56
+ running_mode=vision.RunningMode.IMAGE, # IMAGE, VIDEO, LIVE_STREAM
57
+ num_faces=1)
58
+ self.video_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
59
+ running_mode=vision.RunningMode.VIDEO, # IMAGE, VIDEO, LIVE_STREAM
60
+ num_faces=1)
61
+
62
+ def extract_lm478_from_img_name(self, img_name):
63
+ img = cv2.imread(img_name)
64
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
65
+ img_lm478 = self.extract_lm478_from_img(img)
66
+ return img_lm478
67
+
68
+ def extract_lm478_from_img(self, img):
69
+ img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
70
+ frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=img.astype(np.uint8))
71
+ img_face_landmarker_result = img_landmarker.detect(image=frame)
72
+ img_ldm_i = img_face_landmarker_result.face_landmarks[0]
73
+ img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
74
+ H, W, _ = img.shape
75
+ img_lm478 = np.array(img_face_landmarks)[:, :2] * np.array([W, H]).reshape([1,2]) # [478, 2]
76
+ return img_lm478
77
+
78
+ def extract_lm478_from_video_name(self, video_name, fps=25, anti_smooth_factor=2):
79
+ frames = read_video_to_frames(video_name)
80
+ img_lm478, vid_lm478 = self.extract_lm478_from_frames(frames, fps, anti_smooth_factor)
81
+ return img_lm478, vid_lm478
82
+
83
+ def extract_lm478_from_frames(self, frames, fps=25, anti_smooth_factor=20):
84
+ """
85
+ frames: RGB, uint8
86
+ anti_smooth_factor: float, 对video模式的interval进行修改, 1代表无修改, 越大越接近image mode
87
+ """
88
+ img_mpldms = []
89
+ vid_mpldms = []
90
+ img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
91
+ vid_landmarker = vision.FaceLandmarker.create_from_options(self.video_mode_options)
92
+
93
+ for i in range(len(frames)):
94
+ frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=frames[i].astype(np.uint8))
95
+ img_face_landmarker_result = img_landmarker.detect(image=frame)
96
+ vid_face_landmarker_result = vid_landmarker.detect_for_video(image=frame, timestamp_ms=int((1000/fps)*anti_smooth_factor*i))
97
+ try:
98
+ img_ldm_i = img_face_landmarker_result.face_landmarks[0]
99
+ vid_ldm_i = vid_face_landmarker_result.face_landmarks[0]
100
+ except:
101
+ print(f"Warning: failed detect ldm in idx={i}, use previous frame results.")
102
+ img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
103
+ vid_face_landmarks = np.array([[l.x, l.y, l.z] for l in vid_ldm_i])
104
+ img_mpldms.append(img_face_landmarks)
105
+ vid_mpldms.append(vid_face_landmarks)
106
+ img_lm478 = np.stack(img_mpldms)[..., :2]
107
+ vid_lm478 = np.stack(vid_mpldms)[..., :2]
108
+ bs, H, W, _ = frames.shape
109
+ img_lm478 = np.array(img_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
110
+ vid_lm478 = np.array(vid_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
111
+ return img_lm478, vid_lm478
112
+
113
+ def combine_vid_img_lm478_to_lm68(self, img_lm478, vid_lm478):
114
+ img_lm68 = img_lm478[:, index_lm68_from_lm478]
115
+ vid_lm68 = vid_lm478[:, index_lm68_from_lm478]
116
+ combined_lm68 = copy.deepcopy(img_lm68)
117
+ combined_lm68[:, index_yaw_from_lm68] = vid_lm68[:, index_yaw_from_lm68]
118
+ combined_lm68[:, index_brow_from_lm68] = vid_lm68[:, index_brow_from_lm68]
119
+ combined_lm68[:, index_nose_from_lm68] = vid_lm68[:, index_nose_from_lm68]
120
+ return combined_lm68
121
+
122
+ def combine_vid_img_lm478_to_lm478(self, img_lm478, vid_lm478):
123
+ combined_lm478 = copy.deepcopy(vid_lm478)
124
+ combined_lm478[:, index_mouth_from_lm478] = img_lm478[:, index_mouth_from_lm478]
125
+ combined_lm478[:, index_eye_from_lm478] = img_lm478[:, index_eye_from_lm478]
126
+ return combined_lm478
127
+
128
+ if __name__ == '__main__':
129
+ landmarker = MediapipeLandmarker()
130
+ ret = landmarker.extract_lm478_from_video_name("00000.mp4")
data_gen/utils/mp_feature_extractors/face_landmarker.task ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
3
+ size 3758596
data_gen/utils/mp_feature_extractors/mp_segmenter.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import numpy as np
4
+ import tqdm
5
+ import mediapipe as mp
6
+ import torch
7
+ from mediapipe.tasks import python
8
+ from mediapipe.tasks.python import vision
9
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run
10
+ from utils.commons.tensor_utils import convert_to_np
11
+ from sklearn.neighbors import NearestNeighbors
12
+
13
+ def scatter_np(condition_img, classSeg=5):
14
+ # def scatter(condition_img, classSeg=19, label_size=(512, 512)):
15
+ batch, c, height, width = condition_img.shape
16
+ # if height != label_size[0] or width != label_size[1]:
17
+ # condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
18
+ input_label = np.zeros([batch, classSeg, condition_img.shape[2], condition_img.shape[3]]).astype(np.int_)
19
+ # input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
20
+ np.put_along_axis(input_label, condition_img, 1, 1)
21
+ return input_label
22
+
23
+ def scatter(condition_img, classSeg=19):
24
+ # def scatter(condition_img, classSeg=19, label_size=(512, 512)):
25
+ batch, c, height, width = condition_img.size()
26
+ # if height != label_size[0] or width != label_size[1]:
27
+ # condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
28
+ input_label = torch.zeros(batch, classSeg, condition_img.shape[2], condition_img.shape[3], device=condition_img.device)
29
+ # input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
30
+ return input_label.scatter_(1, condition_img.long(), 1)
31
+
32
+ def encode_segmap_mask_to_image(segmap):
33
+ # rgb
34
+ _,h,w = segmap.shape
35
+ encoded_img = np.ones([h,w,3],dtype=np.uint8) * 255
36
+ colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
37
+ for i, color in enumerate(colors):
38
+ mask = segmap[i].astype(int)
39
+ index = np.where(mask != 0)
40
+ encoded_img[index[0], index[1], :] = np.array(color)
41
+ return encoded_img.astype(np.uint8)
42
+
43
+ def decode_segmap_mask_from_image(encoded_img):
44
+ # rgb
45
+ colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
46
+ bg = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
47
+ hair = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
48
+ body_skin = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 255)
49
+ face_skin = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
50
+ clothes = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 0)
51
+ others = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
52
+ segmap = np.stack([bg, hair, body_skin, face_skin, clothes, others], axis=0)
53
+ return segmap.astype(np.uint8)
54
+
55
+ def read_video_frame(video_name, frame_id):
56
+ # https://blog.csdn.net/bby1987/article/details/108923361
57
+ # frame_num = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) # ==> 总帧数
58
+ # fps = video_capture.get(cv2.CAP_PROP_FPS) # ==> 帧率
59
+ # width = video_capture.get(cv2.CAP_PROP_FRAME_WIDTH) # ==> 视频宽度
60
+ # height = video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) # ==> 视频高度
61
+ # pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 句柄位置
62
+ # video_capture.set(cv2.CAP_PROP_POS_FRAMES, 1000) # ==> 设置句柄位置
63
+ # pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 此时 pos = 1000.0
64
+ # video_capture.release()
65
+ vr = cv2.VideoCapture(video_name)
66
+ vr.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
67
+ _, frame = vr.read()
68
+ return frame
69
+
70
+ def decode_segmap_mask_from_segmap_video_frame(video_frame):
71
+ # video_frame: 0~255 BGR, obtained by read_video_frame
72
+ def assign_values(array):
73
+ remainder = array % 40 # 计算数组中每个值与40的余数
74
+ assigned_values = np.where(remainder <= 20, array - remainder, array + (40 - remainder))
75
+ return assigned_values
76
+ segmap = video_frame.mean(-1)
77
+ segmap = assign_values(segmap) // 40 # [H, W] with value 0~5
78
+ segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
79
+ return segmap.astype(np.uint8)
80
+
81
+ def extract_background(img_lst, segmap_lst=None):
82
+ """
83
+ img_lst: list of rgb ndarray
84
+ """
85
+ # only use 1/20 images
86
+ num_frames = len(img_lst)
87
+ img_lst = img_lst[::20] if num_frames > 20 else img_lst[0:1]
88
+
89
+ if segmap_lst is not None:
90
+ segmap_lst = segmap_lst[::20] if num_frames > 20 else segmap_lst[0:1]
91
+ assert len(img_lst) == len(segmap_lst)
92
+ # get H/W
93
+ h, w = img_lst[0].shape[:2]
94
+
95
+ # nearest neighbors
96
+ all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
97
+ distss = []
98
+ for idx, img in enumerate(img_lst):
99
+ if segmap_lst is not None:
100
+ segmap = segmap_lst[idx]
101
+ else:
102
+ segmap = seg_model._cal_seg_map(img)
103
+ bg = (segmap[0]).astype(bool)
104
+ fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
105
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
106
+ dists, _ = nbrs.kneighbors(all_xys)
107
+ distss.append(dists)
108
+
109
+ distss = np.stack(distss)
110
+ max_dist = np.max(distss, 0)
111
+ max_id = np.argmax(distss, 0)
112
+
113
+ bc_pixs = max_dist > 10 # 5
114
+ bc_pixs_id = np.nonzero(bc_pixs)
115
+ bc_ids = max_id[bc_pixs]
116
+
117
+ num_pixs = distss.shape[1]
118
+ imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
119
+
120
+ bg_img = np.zeros((h*w, 3), dtype=np.uint8)
121
+ bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
122
+ bg_img = bg_img.reshape(h, w, 3)
123
+
124
+ max_dist = max_dist.reshape(h, w)
125
+ bc_pixs = max_dist > 10 # 5
126
+ bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
127
+ fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
128
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
129
+ distances, indices = nbrs.kneighbors(bg_xys)
130
+ bg_fg_xys = fg_xys[indices[:, 0]]
131
+ bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
132
+ return bg_img
133
+
134
+
135
+ global_segmenter = None
136
+ def job_cal_seg_map_for_image(img, segmenter_options=None, segmenter=None):
137
+ """
138
+ 被 MediapipeSegmenter.multiprocess_cal_seg_map_for_a_video所使用, 专门用来处理单个长视频.
139
+ """
140
+ global global_segmenter
141
+ if segmenter is not None:
142
+ segmenter_actual = segmenter
143
+ else:
144
+ global_segmenter = vision.ImageSegmenter.create_from_options(segmenter_options) if global_segmenter is None else global_segmenter
145
+ segmenter_actual = global_segmenter
146
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
147
+ out = segmenter_actual.segment(mp_image)
148
+ segmap = out.category_mask.numpy_view().copy() # [H, W]
149
+
150
+ segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
151
+ segmap_image = segmap[:, :, None].repeat(3, 2).astype(float)
152
+ segmap_image = (segmap_image * 40).astype(np.uint8)
153
+
154
+ return segmap_mask, segmap_image
155
+
156
+ class MediapipeSegmenter:
157
+ def __init__(self):
158
+ model_path = 'data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite'
159
+ if not os.path.exists(model_path):
160
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
161
+ print("downloading segmenter model from mediapipe...")
162
+ os.system(f"wget https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite")
163
+ os.system(f"mv selfie_multiclass_256x256.tflite {model_path}")
164
+ print("download success")
165
+ base_options = python.BaseOptions(model_asset_path=model_path)
166
+ self.options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.IMAGE, output_category_mask=True)
167
+ self.video_options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.VIDEO, output_category_mask=True)
168
+
169
+ def multiprocess_cal_seg_map_for_a_video(self, imgs, num_workers=4):
170
+ """
171
+ 并行处理单个长视频
172
+ imgs: list of rgb array in 0~255
173
+ """
174
+ segmap_masks = []
175
+ segmap_images = []
176
+ img_lst = [(self.options, imgs[i]) for i in range(len(imgs))]
177
+ for (i, res) in multiprocess_run_tqdm(job_cal_seg_map_for_image, args=img_lst, num_workers=num_workers, desc='extracting from a video in multi-process'):
178
+ segmap_mask, segmap_image = res
179
+ segmap_masks.append(segmap_mask)
180
+ segmap_images.append(segmap_image)
181
+ return segmap_masks, segmap_images
182
+
183
+ def _cal_seg_map_for_video(self, imgs, segmenter=None, return_onehot_mask=True, return_segmap_image=True):
184
+ segmenter = vision.ImageSegmenter.create_from_options(self.video_options) if segmenter is None else segmenter
185
+ assert return_onehot_mask or return_segmap_image # you should at least return one
186
+ segmap_masks = []
187
+ segmap_images = []
188
+ for i in tqdm.trange(len(imgs), desc="extracting segmaps from a video..."):
189
+ img = imgs[i]
190
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
191
+ out = segmenter.segment_for_video(mp_image, 40 * i)
192
+ segmap = out.category_mask.numpy_view().copy() # [H, W]
193
+
194
+ if return_onehot_mask:
195
+ segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
196
+ segmap_masks.append(segmap_mask)
197
+ if return_segmap_image:
198
+ segmap_image = segmap[:, :, None].repeat(3, 2).astype(float)
199
+ segmap_image = (segmap_image * 40).astype(np.uint8)
200
+ segmap_images.append(segmap_image)
201
+
202
+ if return_onehot_mask and return_segmap_image:
203
+ return segmap_masks, segmap_images
204
+ elif return_onehot_mask:
205
+ return segmap_masks
206
+ elif return_segmap_image:
207
+ return segmap_images
208
+
209
+ def _cal_seg_map(self, img, segmenter=None, return_onehot_mask=True):
210
+ """
211
+ segmenter: vision.ImageSegmenter.create_from_options(options)
212
+ img: numpy, [H, W, 3], 0~255
213
+ segmap: [C, H, W]
214
+ 0 - background
215
+ 1 - hair
216
+ 2 - body-skin
217
+ 3 - face-skin
218
+ 4 - clothes
219
+ 5 - others (accessories)
220
+ """
221
+ assert img.ndim == 3
222
+ segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
223
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
224
+ out = segmenter.segment(image)
225
+ segmap = out.category_mask.numpy_view().copy() # [H, W]
226
+ if return_onehot_mask:
227
+ segmap = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
228
+ return segmap
229
+
230
+ def _seg_out_img_with_segmap(self, img, segmap, mode='head'):
231
+ """
232
+ img: [h,w,c], img is in 0~255, np
233
+ """
234
+ #
235
+ img = copy.deepcopy(img)
236
+ if mode == 'head':
237
+ selected_mask = segmap[[1,3,5] , :, :].sum(axis=0)[None,:] > 0.5 # glasses 也属于others
238
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
239
+ # selected_mask = segmap[[1,3] , :, :].sum(dim=0, keepdim=True) > 0.5
240
+ elif mode == 'person':
241
+ selected_mask = segmap[[1,2,3,4,5], :, :].sum(axis=0)[None,:] > 0.5
242
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
243
+ elif mode == 'torso':
244
+ selected_mask = segmap[[2,4], :, :].sum(axis=0)[None,:] > 0.5
245
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
246
+ elif mode == 'torso_with_bg':
247
+ selected_mask = segmap[[0, 2,4], :, :].sum(axis=0)[None,:] > 0.5
248
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
249
+ elif mode == 'bg':
250
+ selected_mask = segmap[[0], :, :].sum(axis=0)[None,:] > 0.5 # only seg out 0, which means background
251
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
252
+ elif mode == 'full':
253
+ pass
254
+ else:
255
+ raise NotImplementedError()
256
+ return img, selected_mask
257
+
258
+ def _seg_out_img(self, img, segmenter=None, mode='head'):
259
+ """
260
+ imgs [H, W, 3] 0-255
261
+ return : person_img [B, 3, H, W]
262
+ """
263
+ segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
264
+ segmap = self._cal_seg_map(img, segmenter=segmenter, return_onehot_mask=True) # [B, 19, H, W]
265
+ return self._seg_out_img_with_segmap(img, segmap, mode=mode)
266
+
267
+ def seg_out_imgs(self, img, mode='head'):
268
+ """
269
+ api for pytorch img, -1~1
270
+ img: [B, 3, H, W], -1~1
271
+ """
272
+ device = img.device
273
+ img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
274
+ img = ((img + 1) * 127.5).astype(np.uint8)
275
+ img_lst = [copy.deepcopy(img[i]) for i in range(len(img))]
276
+ out_lst = []
277
+ for im in img_lst:
278
+ out = self._seg_out_img(im, mode=mode)
279
+ out_lst.append(out)
280
+ seg_imgs = np.stack(out_lst) # [B, H, W, 3]
281
+ seg_imgs = (seg_imgs - 127.5) / 127.5
282
+ seg_imgs = torch.from_numpy(seg_imgs).permute(0, 3, 1, 2).to(device)
283
+ return seg_imgs
284
+
285
+ if __name__ == '__main__':
286
+ import imageio, cv2, tqdm
287
+ import torchshow as ts
288
+ img = imageio.imread("1.png")
289
+ img = cv2.resize(img, (512,512))
290
+
291
+ seg_model = MediapipeSegmenter()
292
+ img = torch.tensor(img).unsqueeze(0).repeat([1, 1, 1, 1]).permute(0, 3,1,2)
293
+ img = (img-127.5)/127.5
294
+ out = seg_model.seg_out_imgs(img, 'torso')
295
+ ts.save(out,"torso.png")
296
+ out = seg_model.seg_out_imgs(img, 'head')
297
+ ts.save(out,"head.png")
298
+ out = seg_model.seg_out_imgs(img, 'bg')
299
+ ts.save(out,"bg.png")
300
+ img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
301
+ img = ((img + 1) * 127.5).astype(np.uint8)
302
+ bg = extract_background(img)
303
+ ts.save(bg,"bg2.png")
data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
3
+ size 16371837
data_gen/utils/path_converter.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class PathConverter():
5
+ def __init__(self):
6
+ self.prefixs = {
7
+ "vid": "/video/",
8
+ "gt": "/gt_imgs/",
9
+ "head": "/head_imgs/",
10
+ "torso": "/torso_imgs/",
11
+ "person": "/person_imgs/",
12
+ "torso_with_bg": "/torso_with_bg_imgs/",
13
+ "single_bg": "/bg_img/",
14
+ "bg": "/bg_imgs/",
15
+ "segmaps": "/segmaps/",
16
+ "inpaint_torso": "/inpaint_torso_imgs/",
17
+ "com": "/com_imgs/",
18
+ "inpaint_torso_with_com_bg": "/inpaint_torso_with_com_bg_imgs/",
19
+ }
20
+
21
+ def to(self, path: str, old_pattern: str, new_pattern: str):
22
+ return path.replace(self.prefixs[old_pattern], self.prefixs[new_pattern], 1)
23
+
24
+ pc = PathConverter()
data_gen/utils/process_audio/extract_hubert.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Processor, HubertModel
2
+ import soundfile as sf
3
+ import numpy as np
4
+ import torch
5
+ import os
6
+ from utils.commons.hparams import set_hparams, hparams
7
+
8
+
9
+ wav2vec2_processor = None
10
+ hubert_model = None
11
+
12
+
13
+ def get_hubert_from_16k_wav(wav_16k_name):
14
+ speech_16k, _ = sf.read(wav_16k_name)
15
+ hubert = get_hubert_from_16k_speech(speech_16k)
16
+ return hubert
17
+
18
+ @torch.no_grad()
19
+ def get_hubert_from_16k_speech(speech, device="cuda:0"):
20
+ global hubert_model, wav2vec2_processor
21
+ local_path = '/home/tiger/.cache/huggingface/hub/models--facebook--hubert-large-ls960-ft/snapshots/ece5fabbf034c1073acae96d5401b25be96709d8'
22
+ if hubert_model is None:
23
+ print("Loading the HuBERT Model...")
24
+ print("Loading the Wav2Vec2 Processor...")
25
+ if os.path.exists(local_path):
26
+ hubert_model = HubertModel.from_pretrained(local_path)
27
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained(local_path)
28
+ else:
29
+ hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
30
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
31
+ hubert_model = hubert_model.to(device)
32
+
33
+ if speech.ndim ==2:
34
+ speech = speech[:, 0] # [T, 2] ==> [T,]
35
+
36
+ input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
37
+ input_values_all = input_values_all.to(device)
38
+ # For long audio sequence, due to the memory limitation, we cannot process them in one run
39
+ # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
40
+ # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
41
+ # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
42
+ # We have the equation to calculate out time step: T = floor((t-k)/s)
43
+ # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
44
+ # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
45
+ kernel = 400
46
+ stride = 320
47
+ clip_length = stride * 1000
48
+ num_iter = input_values_all.shape[1] // clip_length
49
+ expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
50
+ res_lst = []
51
+ for i in range(num_iter):
52
+ if i == 0:
53
+ start_idx = 0
54
+ end_idx = clip_length - stride + kernel
55
+ else:
56
+ start_idx = clip_length * i
57
+ end_idx = start_idx + (clip_length - stride + kernel)
58
+ input_values = input_values_all[:, start_idx: end_idx]
59
+ hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
60
+ res_lst.append(hidden_states[0])
61
+ if num_iter > 0:
62
+ input_values = input_values_all[:, clip_length * num_iter:]
63
+ else:
64
+ input_values = input_values_all
65
+
66
+ if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
67
+ hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
68
+ res_lst.append(hidden_states[0])
69
+ ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
70
+
71
+ assert abs(ret.shape[0] - expected_T) <= 1
72
+ if ret.shape[0] < expected_T: # if skipping the last short
73
+ ret = torch.cat([ret, ret[:, -1:, :].repeat([1,expected_T-ret.shape[0],1])], dim=1)
74
+ else:
75
+ ret = ret[:expected_T]
76
+
77
+ return ret
78
+
79
+
80
+ if __name__ == '__main__':
81
+ from argparse import ArgumentParser
82
+ parser = ArgumentParser()
83
+ parser.add_argument('--video_id', type=str, default='May', help='')
84
+ args = parser.parse_args()
85
+ ### Process Single Long Audio for NeRF dataset
86
+ person_id = args.video_id
87
+ wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
88
+ hubert_npy_name = f"data/processed/videos/{person_id}/aud_hubert.npy"
89
+ speech_16k, _ = sf.read(wav_16k_name)
90
+ hubert_hidden = get_hubert_from_16k_speech(speech_16k)
91
+ np.save(hubert_npy_name, hubert_hidden.detach().numpy())
92
+ print(f"Saved at {hubert_npy_name}")
data_gen/utils/process_audio/extract_mel_f0.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import glob
4
+ import os
5
+ import tqdm
6
+ import librosa
7
+ import parselmouth
8
+ from utils.commons.pitch_utils import f0_to_coarse
9
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
10
+ from utils.commons.os_utils import multiprocess_glob
11
+ from utils.audio.io import save_wav
12
+
13
+ from moviepy.editor import VideoFileClip
14
+ from utils.commons.hparams import hparams, set_hparams
15
+
16
+ def resample_wav(wav_name, out_name, sr=16000):
17
+ wav_raw, sr = librosa.core.load(wav_name, sr=sr)
18
+ save_wav(wav_raw, out_name, sr)
19
+
20
+ def split_wav(mp4_name, wav_name=None):
21
+ if wav_name is None:
22
+ wav_name = mp4_name.replace(".mp4", ".wav").replace("/video/", "/audio/")
23
+ if os.path.exists(wav_name):
24
+ return wav_name
25
+ os.makedirs(os.path.dirname(wav_name), exist_ok=True)
26
+
27
+ video = VideoFileClip(mp4_name,verbose=False)
28
+ dur = video.duration
29
+ audio = video.audio
30
+ assert audio is not None
31
+ audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None)
32
+ return wav_name
33
+
34
+ def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
35
+ '''compute right padding (final frame) or both sides padding (first and final frames)
36
+ '''
37
+ assert pad_sides in (1, 2)
38
+ # return int(fsize // 2)
39
+ pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
40
+ if pad_sides == 1:
41
+ return 0, pad
42
+ else:
43
+ return pad // 2, pad // 2 + pad % 2
44
+
45
+ def extract_mel_from_fname(wav_path,
46
+ fft_size=512,
47
+ hop_size=320,
48
+ win_length=512,
49
+ window="hann",
50
+ num_mels=80,
51
+ fmin=80,
52
+ fmax=7600,
53
+ eps=1e-6,
54
+ sample_rate=16000,
55
+ min_level_db=-100):
56
+ if isinstance(wav_path, str):
57
+ wav, _ = librosa.core.load(wav_path, sr=sample_rate)
58
+ else:
59
+ wav = wav_path
60
+
61
+ # get amplitude spectrogram
62
+ x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
63
+ win_length=win_length, window=window, center=False)
64
+ spc = np.abs(x_stft) # (n_bins, T)
65
+
66
+ # get mel basis
67
+ fmin = 0 if fmin == -1 else fmin
68
+ fmax = sample_rate / 2 if fmax == -1 else fmax
69
+ mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
70
+ mel = mel_basis @ spc
71
+
72
+ mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
73
+ mel = mel.T
74
+
75
+ l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1)
76
+ wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
77
+
78
+ return wav.T, mel
79
+
80
+ def extract_f0_from_wav_and_mel(wav, mel,
81
+ hop_size=320,
82
+ audio_sample_rate=16000,
83
+ ):
84
+ time_step = hop_size / audio_sample_rate * 1000
85
+ f0_min = 80
86
+ f0_max = 750
87
+ f0 = parselmouth.Sound(wav, audio_sample_rate).to_pitch_ac(
88
+ time_step=time_step / 1000, voicing_threshold=0.6,
89
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
90
+
91
+ delta_l = len(mel) - len(f0)
92
+ assert np.abs(delta_l) <= 8
93
+ if delta_l > 0:
94
+ f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
95
+ f0 = f0[:len(mel)]
96
+ pitch_coarse = f0_to_coarse(f0)
97
+ return f0, pitch_coarse
98
+
99
+
100
+ def extract_mel_f0_from_fname(wav_name=None, out_name=None):
101
+ try:
102
+ out_name = wav_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
103
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
104
+
105
+ wav, mel = extract_mel_from_fname(wav_name)
106
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
107
+ out_dict = {
108
+ "mel": mel, # [T, 80]
109
+ "f0": f0,
110
+ }
111
+ np.save(out_name, out_dict)
112
+ except Exception as e:
113
+ print(e)
114
+
115
+ def extract_mel_f0_from_video_name(mp4_name, wav_name=None, out_name=None):
116
+ if mp4_name.endswith(".mp4"):
117
+ wav_name = split_wav(mp4_name, wav_name)
118
+ if out_name is None:
119
+ out_name = mp4_name.replace(".mp4", "_mel_f0.npy").replace("/video/", "/mel_f0/")
120
+ elif mp4_name.endswith(".wav"):
121
+ wav_name = mp4_name
122
+ if out_name is None:
123
+ out_name = mp4_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
124
+
125
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
126
+
127
+ wav, mel = extract_mel_from_fname(wav_name)
128
+
129
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
130
+ out_dict = {
131
+ "mel": mel, # [T, 80]
132
+ "f0": f0,
133
+ }
134
+ np.save(out_name, out_dict)
135
+
136
+
137
+ if __name__ == '__main__':
138
+ from argparse import ArgumentParser
139
+ parser = ArgumentParser()
140
+ parser.add_argument('--video_id', type=str, default='May', help='')
141
+ args = parser.parse_args()
142
+ ### Process Single Long Audio for NeRF dataset
143
+ person_id = args.video_id
144
+
145
+ wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
146
+ out_name = f"data/processed/videos/{person_id}/aud_mel_f0.npy"
147
+ extract_mel_f0_from_video_name(wav_16k_name, out_name)
148
+ print(f"Saved at {out_name}")
data_gen/utils/process_audio/resample_audio_to_16k.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ from utils.commons.os_utils import multiprocess_glob
3
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
4
+
5
+
6
+ def extract_wav16k_job(audio_name:str):
7
+ out_path = audio_name.replace("/audio_raw/","/audio/",1)
8
+ assert out_path != audio_name # prevent inplace
9
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
10
+ ffmpeg_path = "/usr/bin/ffmpeg"
11
+
12
+ cmd = f'{ffmpeg_path} -i {audio_name} -ar 16000 -v quiet -y {out_path}'
13
+ os.system(cmd)
14
+
15
+ if __name__ == '__main__':
16
+ import argparse, glob, tqdm, random
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--aud_dir", default='/home/tiger/datasets/raw/CMLR/audio_raw/')
19
+ parser.add_argument("--ds_name", default='CMLR')
20
+ parser.add_argument("--num_workers", default=64, type=int)
21
+ parser.add_argument("--process_id", default=0, type=int)
22
+ parser.add_argument("--total_process", default=1, type=int)
23
+ args = parser.parse_args()
24
+ print(f"args {args}")
25
+
26
+ aud_dir = args.aud_dir
27
+ ds_name = args.ds_name
28
+ if ds_name in ['CMLR']:
29
+ aud_name_pattern = os.path.join(aud_dir, "*/*/*.wav")
30
+ aud_names = multiprocess_glob(aud_name_pattern)
31
+ else:
32
+ raise NotImplementedError()
33
+ aud_names = sorted(aud_names)
34
+ print(f"total audio number : {len(aud_names)}")
35
+ print(f"first {aud_names[0]} last {aud_names[-1]}")
36
+ # exit()
37
+ process_id = args.process_id
38
+ total_process = args.total_process
39
+ if total_process > 1:
40
+ assert process_id <= total_process -1
41
+ num_samples_per_process = len(aud_names) // total_process
42
+ if process_id == total_process:
43
+ aud_names = aud_names[process_id * num_samples_per_process : ]
44
+ else:
45
+ aud_names = aud_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
46
+
47
+ for i, res in multiprocess_run_tqdm(extract_wav16k_job, aud_names, num_workers=args.num_workers, desc="resampling videos"):
48
+ pass
49
+
data_gen/utils/process_image/extract_lm2d.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OMP_NUM_THREADS"] = "1"
3
+ import sys
4
+
5
+ import glob
6
+ import cv2
7
+ import tqdm
8
+ import numpy as np
9
+ from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
10
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
+
14
+ import random
15
+ random.seed(42)
16
+
17
+ import pickle
18
+ import json
19
+ import gzip
20
+ from typing import Any
21
+
22
+ def load_file(filename, is_gzip: bool = False, is_json: bool = False) -> Any:
23
+ if is_json:
24
+ if is_gzip:
25
+ with gzip.open(filename, "r", encoding="utf-8") as f:
26
+ loaded_object = json.load(f)
27
+ return loaded_object
28
+ else:
29
+ with open(filename, "r", encoding="utf-8") as f:
30
+ loaded_object = json.load(f)
31
+ return loaded_object
32
+ else:
33
+ if is_gzip:
34
+ with gzip.open(filename, "rb") as f:
35
+ loaded_object = pickle.load(f)
36
+ return loaded_object
37
+ else:
38
+ with open(filename, "rb") as f:
39
+ loaded_object = pickle.load(f)
40
+ return loaded_object
41
+
42
+ def save_file(filename, content, is_gzip: bool = False, is_json: bool = False) -> None:
43
+ if is_json:
44
+ if is_gzip:
45
+ with gzip.open(filename, "w", encoding="utf-8") as f:
46
+ json.dump(content, f)
47
+ else:
48
+ with open(filename, "w", encoding="utf-8") as f:
49
+ json.dump(content, f)
50
+ else:
51
+ if is_gzip:
52
+ with gzip.open(filename, "wb") as f:
53
+ pickle.dump(content, f)
54
+ else:
55
+ with open(filename, "wb") as f:
56
+ pickle.dump(content, f)
57
+
58
+ face_landmarker = None
59
+
60
+ def extract_lms_mediapipe_job(img):
61
+ if img is None:
62
+ return None
63
+ global face_landmarker
64
+ if face_landmarker is None:
65
+ face_landmarker = MediapipeLandmarker()
66
+ lm478 = face_landmarker.extract_lm478_from_img(img)
67
+ return lm478
68
+
69
+ def extract_landmark_job(img_name):
70
+ try:
71
+ # if img_name == 'datasets/PanoHeadGen/raw/images/multi_view/chunk_0/seed0000002.png':
72
+ # print(1)
73
+ # input()
74
+ out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
75
+ if os.path.exists(out_name):
76
+ print("out exists, skip...")
77
+ return
78
+ try:
79
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
80
+ except:
81
+ pass
82
+ img = cv2.imread(img_name)[:,:,::-1]
83
+
84
+ if img is not None:
85
+ lm468 = extract_lms_mediapipe_job(img)
86
+ if lm468 is not None:
87
+ np.save(out_name, lm468)
88
+ # print("Hahaha, solve one item!!!")
89
+ except Exception as e:
90
+ print(e)
91
+ pass
92
+
93
+ def out_exist_job(img_name):
94
+ out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
95
+ if os.path.exists(out_name):
96
+ return None
97
+ else:
98
+ return img_name
99
+
100
+ # def get_todo_img_names(img_names):
101
+ # todo_img_names = []
102
+ # for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
103
+ # if res is not None:
104
+ # todo_img_names.append(res)
105
+ # return todo_img_names
106
+
107
+
108
+ if __name__ == '__main__':
109
+ import argparse, glob, tqdm, random
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512/')
112
+ parser.add_argument("--ds_name", default='FFHQ')
113
+ parser.add_argument("--num_workers", default=64, type=int)
114
+ parser.add_argument("--process_id", default=0, type=int)
115
+ parser.add_argument("--total_process", default=1, type=int)
116
+ parser.add_argument("--reset", action='store_true')
117
+ parser.add_argument("--img_names_file", default="img_names.pkl", type=str)
118
+ parser.add_argument("--load_img_names", action="store_true")
119
+
120
+ args = parser.parse_args()
121
+ print(f"args {args}")
122
+ img_dir = args.img_dir
123
+ img_names_file = os.path.join(img_dir, args.img_names_file)
124
+ if args.load_img_names:
125
+ img_names = load_file(img_names_file)
126
+ print(f"load image names from {img_names_file}")
127
+ else:
128
+ if args.ds_name == 'FFHQ_MV':
129
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
130
+ img_names1 = glob.glob(img_name_pattern1)
131
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
132
+ img_names2 = glob.glob(img_name_pattern2)
133
+ img_names = img_names1 + img_names2
134
+ img_names = sorted(img_names)
135
+ elif args.ds_name == 'FFHQ':
136
+ img_name_pattern = os.path.join(img_dir, "*.png")
137
+ img_names = glob.glob(img_name_pattern)
138
+ img_names = sorted(img_names)
139
+ elif args.ds_name == "PanoHeadGen":
140
+ # img_name_patterns = ["ref/*/*.png", "multi_view/*/*.png", "reverse/*/*.png"]
141
+ img_name_patterns = ["ref/*/*.png"]
142
+ img_names = []
143
+ for img_name_pattern in img_name_patterns:
144
+ img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
145
+ img_names_part = glob.glob(img_name_pattern_full)
146
+ img_names.extend(img_names_part)
147
+ img_names = sorted(img_names)
148
+
149
+ # save image names
150
+ if not args.load_img_names:
151
+ save_file(img_names_file, img_names)
152
+ print(f"save image names in {img_names_file}")
153
+
154
+ print(f"total images number: {len(img_names)}")
155
+
156
+
157
+ process_id = args.process_id
158
+ total_process = args.total_process
159
+ if total_process > 1:
160
+ assert process_id <= total_process -1
161
+ num_samples_per_process = len(img_names) // total_process
162
+ if process_id == total_process:
163
+ img_names = img_names[process_id * num_samples_per_process : ]
164
+ else:
165
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
166
+
167
+ # if not args.reset:
168
+ # img_names = get_todo_img_names(img_names)
169
+
170
+
171
+ print(f"todo_image {img_names[:10]}")
172
+ print(f"processing images number in this process: {len(img_names)}")
173
+ # print(f"todo images number: {len(img_names)}")
174
+ # input()
175
+ # exit()
176
+
177
+ if args.num_workers == 1:
178
+ index = 0
179
+ for img_name in tqdm.tqdm(img_names, desc=f"Root process {args.process_id}: extracting MP-based landmark2d"):
180
+ try:
181
+ extract_landmark_job(img_name)
182
+ except Exception as e:
183
+ print(e)
184
+ pass
185
+ if index % max(1, int(len(img_names) * 0.003)) == 0:
186
+ print(f"processed {index} / {len(img_names)}")
187
+ sys.stdout.flush()
188
+ index += 1
189
+ else:
190
+ for i, res in multiprocess_run_tqdm(
191
+ extract_landmark_job, img_names,
192
+ num_workers=args.num_workers,
193
+ desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
194
+ # if index % max(1, int(len(img_names) * 0.003)) == 0:
195
+ print(f"processed {i+1} / {len(img_names)}")
196
+ sys.stdout.flush()
197
+ print(f"Root {args.process_id}: Finished extracting.")
data_gen/utils/process_image/extract_segment_imgs.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OMP_NUM_THREADS"] = "1"
3
+
4
+ import glob
5
+ import cv2
6
+ import tqdm
7
+ import numpy as np
8
+ import PIL
9
+ from utils.commons.tensor_utils import convert_to_np
10
+ import torch
11
+ import mediapipe as mp
12
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
13
+ from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
14
+ from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background, save_rgb_image_to_path
15
+ seg_model = MediapipeSegmenter()
16
+
17
+
18
+ def extract_segment_job(img_name):
19
+ try:
20
+ img = cv2.imread(img_name)
21
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
22
+
23
+ segmap = seg_model._cal_seg_map(img)
24
+ bg_img = extract_background([img], [segmap])
25
+ out_img_name = img_name.replace("/images_512/",f"/bg_img/").replace(".mp4", ".jpg")
26
+ save_rgb_image_to_path(bg_img, out_img_name)
27
+
28
+ com_img = img.copy()
29
+ bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
30
+ com_img[bg_part] = bg_img[bg_part]
31
+ out_img_name = img_name.replace("/images_512/",f"/com_imgs/")
32
+ save_rgb_image_to_path(com_img, out_img_name)
33
+
34
+ for mode in ['head', 'torso', 'person', 'torso_with_bg', 'bg']:
35
+ out_img, _ = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
36
+ out_img_name = img_name.replace("/images_512/",f"/{mode}_imgs/")
37
+ out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
38
+ try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
39
+ except: pass
40
+ cv2.imwrite(out_img_name, out_img)
41
+
42
+ inpaint_torso_img, inpaint_torso_with_bg_img, _, _ = inpaint_torso_job(img, segmap)
43
+ out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_imgs/")
44
+ save_rgb_image_to_path(inpaint_torso_img, out_img_name)
45
+ inpaint_torso_with_bg_img[bg_part] = bg_img[bg_part]
46
+ out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_with_com_bg_imgs/")
47
+ save_rgb_image_to_path(inpaint_torso_with_bg_img, out_img_name)
48
+ return 0
49
+ except Exception as e:
50
+ print(e)
51
+ return 1
52
+
53
+ def out_exist_job(img_name):
54
+ out_name1 = img_name.replace("/images_512/", "/head_imgs/")
55
+ out_name2 = img_name.replace("/images_512/", "/com_imgs/")
56
+ out_name3 = img_name.replace("/images_512/", "/inpaint_torso_with_com_bg_imgs/")
57
+
58
+ if os.path.exists(out_name1) and os.path.exists(out_name2) and os.path.exists(out_name3):
59
+ return None
60
+ else:
61
+ return img_name
62
+
63
+ def get_todo_img_names(img_names):
64
+ todo_img_names = []
65
+ for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
66
+ if res is not None:
67
+ todo_img_names.append(res)
68
+ return todo_img_names
69
+
70
+
71
+ if __name__ == '__main__':
72
+ import argparse, glob, tqdm, random
73
+ parser = argparse.ArgumentParser()
74
+ parser.add_argument("--img_dir", default='./images_512')
75
+ # parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
76
+ parser.add_argument("--ds_name", default='FFHQ')
77
+ parser.add_argument("--num_workers", default=1, type=int)
78
+ parser.add_argument("--seed", default=0, type=int)
79
+ parser.add_argument("--process_id", default=0, type=int)
80
+ parser.add_argument("--total_process", default=1, type=int)
81
+ parser.add_argument("--reset", action='store_true')
82
+
83
+ args = parser.parse_args()
84
+ img_dir = args.img_dir
85
+ if args.ds_name == 'FFHQ_MV':
86
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
87
+ img_names1 = glob.glob(img_name_pattern1)
88
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
89
+ img_names2 = glob.glob(img_name_pattern2)
90
+ img_names = img_names1 + img_names2
91
+ elif args.ds_name == 'FFHQ':
92
+ img_name_pattern = os.path.join(img_dir, "*.png")
93
+ img_names = glob.glob(img_name_pattern)
94
+
95
+ img_names = sorted(img_names)
96
+ random.seed(args.seed)
97
+ random.shuffle(img_names)
98
+
99
+ process_id = args.process_id
100
+ total_process = args.total_process
101
+ if total_process > 1:
102
+ assert process_id <= total_process -1
103
+ num_samples_per_process = len(img_names) // total_process
104
+ if process_id == total_process:
105
+ img_names = img_names[process_id * num_samples_per_process : ]
106
+ else:
107
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
108
+
109
+ if not args.reset:
110
+ img_names = get_todo_img_names(img_names)
111
+ print(f"todo images number: {len(img_names)}")
112
+
113
+ for vid_name in multiprocess_run_tqdm(extract_segment_job ,img_names, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
114
+ pass
data_gen/utils/process_image/fit_3dmm_landmark.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy.core.numeric import require
2
+ from numpy.lib.function_base import quantile
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import copy
6
+ import numpy as np
7
+
8
+ import os
9
+ import sys
10
+ import cv2
11
+ import argparse
12
+ import tqdm
13
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
14
+ from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
15
+
16
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
17
+ import pickle
18
+
19
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
20
+ camera_distance=10, focal=1015, keypoint_mode='mediapipe')
21
+ face_model.to("cuda")
22
+
23
+
24
+ index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
25
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
26
+
27
+ dir_path = os.path.dirname(os.path.realpath(__file__))
28
+
29
+ LAMBDA_REG_ID = 0.3
30
+ LAMBDA_REG_EXP = 0.05
31
+
32
+ def save_file(name, content):
33
+ with open(name, "wb") as f:
34
+ pickle.dump(content, f)
35
+
36
+ def load_file(name):
37
+ with open(name, "rb") as f:
38
+ content = pickle.load(f)
39
+ return content
40
+
41
+ def cal_lan_loss_mp(proj_lan, gt_lan):
42
+ # [B, 68, 2]
43
+ loss = (proj_lan - gt_lan).pow(2)
44
+ # loss = (proj_lan - gt_lan).abs()
45
+ unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
46
+ eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
47
+ inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
48
+ outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
49
+ weights = torch.ones_like(loss)
50
+ weights[:, eye] = 5
51
+ weights[:, inner_lip] = 2
52
+ weights[:, outer_lip] = 2
53
+ weights[:, unmatch_mask] = 0
54
+ loss = loss * weights
55
+ return torch.mean(loss)
56
+
57
+ def cal_lan_loss(proj_lan, gt_lan):
58
+ # [B, 68, 2]
59
+ loss = (proj_lan - gt_lan)** 2
60
+ # use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
61
+ weights = torch.zeros_like(loss)
62
+ weights = torch.ones_like(loss)
63
+ weights[:, 36:48, :] = 3 # eye 12 points
64
+ weights[:, -8:, :] = 3 # inner lip 8 points
65
+ weights[:, 28:31, :] = 3 # nose 3 points
66
+ loss = loss * weights
67
+ return torch.mean(loss)
68
+
69
+ def set_requires_grad(tensor_list):
70
+ for tensor in tensor_list:
71
+ tensor.requires_grad = True
72
+
73
+ def read_video_to_frames(img_name):
74
+ frames = []
75
+ cap = cv2.VideoCapture(img_name)
76
+ while cap.isOpened():
77
+ ret, frame_bgr = cap.read()
78
+ if frame_bgr is None:
79
+ break
80
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
81
+ frames.append(frame_rgb)
82
+ return np.stack(frames)
83
+
84
+ @torch.enable_grad()
85
+ def fit_3dmm_for_a_image(img_name, debug=False, keypoint_mode='mediapipe', device="cuda:0", save=True):
86
+ img = cv2.imread(img_name)
87
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
88
+ img_h, img_w = img.shape[0], img.shape[0]
89
+ assert img_h == img_w
90
+ num_frames = 1
91
+
92
+ lm_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png", "_lms.npy")
93
+ if lm_name.endswith('_lms.npy') and os.path.exists(lm_name):
94
+ lms = np.load(lm_name)
95
+ else:
96
+ # print("lms_2d file not found, try to extract it from image...")
97
+ try:
98
+ landmarker = MediapipeLandmarker()
99
+ lms = landmarker.extract_lm478_from_img_name(img_name)
100
+ # lms = landmarker.extract_lm478_from_img(img)
101
+ except Exception as e:
102
+ print(e)
103
+ return
104
+ if lms is None:
105
+ print("get None lms_2d, please check whether each frame has one head, exiting...")
106
+ return
107
+ lms = lms[:468].reshape([468,2])
108
+ lms = torch.FloatTensor(lms).to(device=device)
109
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
110
+
111
+ if keypoint_mode == 'mediapipe':
112
+ cal_lan_loss_fn = cal_lan_loss_mp
113
+ out_name = img_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png", "_coeff_fit_mp.npy")
114
+ else:
115
+ cal_lan_loss_fn = cal_lan_loss
116
+ out_name = img_name.replace("/images_512/", "/coeff_fit_lm68/").replace(".png", "_coeff_fit_lm68.npy")
117
+ try:
118
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
119
+ except:
120
+ pass
121
+
122
+ id_dim, exp_dim = 80, 64
123
+ sel_ids = np.arange(0, num_frames, 40)
124
+ sel_num = sel_ids.shape[0]
125
+ arg_focal = face_model.focal
126
+
127
+ h = w = face_model.center * 2
128
+ img_scale_factor = img_h / h
129
+ lms /= img_scale_factor
130
+ cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).to(device=device)
131
+
132
+ id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True) # lms.new_zeros((1, id_dim), requires_grad=True)
133
+ exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
134
+ euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
135
+ trans = lms.new_zeros((num_frames, 3), requires_grad=True)
136
+
137
+ focal_length = lms.new_zeros(1, requires_grad=True)
138
+ focal_length.data += arg_focal
139
+
140
+ set_requires_grad([id_para, exp_para, euler_angle, trans])
141
+
142
+ optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
143
+ optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
144
+
145
+ # 其他参数初始化,先训练euler和trans
146
+ for _ in range(200):
147
+ proj_geo = face_model.compute_for_landmark_fit(
148
+ id_para, exp_para, euler_angle, trans)
149
+ loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
150
+ loss = loss_lan
151
+ optimizer_frame.zero_grad()
152
+ loss.backward()
153
+ optimizer_frame.step()
154
+ # print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
155
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
156
+
157
+ for param_group in optimizer_frame.param_groups:
158
+ param_group['lr'] = 0.1
159
+
160
+ # "jointly roughly training id exp euler trans"
161
+ for _ in range(200):
162
+ proj_geo = face_model.compute_for_landmark_fit(
163
+ id_para, exp_para, euler_angle, trans)
164
+ loss_lan = cal_lan_loss_fn(
165
+ proj_geo[:, :, :2], lms.detach())
166
+ loss_regid = torch.mean(id_para*id_para) # 正则化
167
+ loss_regexp = torch.mean(exp_para * exp_para)
168
+
169
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
170
+ optimizer_idexp.zero_grad()
171
+ optimizer_frame.zero_grad()
172
+ loss.backward()
173
+ optimizer_idexp.step()
174
+ optimizer_frame.step()
175
+ # print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
176
+ # print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
177
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
178
+
179
+ # start fine training, intialize from the roughly trained results
180
+ id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
181
+ id_para_.data = id_para.data.clone()
182
+ id_para = id_para_
183
+ exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
184
+ exp_para_.data = exp_para.data.clone()
185
+ exp_para = exp_para_
186
+ euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
187
+ euler_angle_.data = euler_angle.data.clone()
188
+ euler_angle = euler_angle_
189
+ trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
190
+ trans_.data = trans.data.clone()
191
+ trans = trans_
192
+
193
+ batch_size = 1
194
+
195
+ # "fine fitting the 3DMM in batches"
196
+ for i in range(int((num_frames-1)/batch_size+1)):
197
+ if (i+1)*batch_size > num_frames:
198
+ start_n = num_frames-batch_size
199
+ sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
200
+ else:
201
+ start_n = i*batch_size
202
+ sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
203
+ sel_lms = lms[sel_ids]
204
+
205
+ sel_id_para = id_para.new_zeros(
206
+ (batch_size, id_dim), requires_grad=True)
207
+ sel_id_para.data = id_para[sel_ids].clone()
208
+ sel_exp_para = exp_para.new_zeros(
209
+ (batch_size, exp_dim), requires_grad=True)
210
+ sel_exp_para.data = exp_para[sel_ids].clone()
211
+ sel_euler_angle = euler_angle.new_zeros(
212
+ (batch_size, 3), requires_grad=True)
213
+ sel_euler_angle.data = euler_angle[sel_ids].clone()
214
+ sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
215
+ sel_trans.data = trans[sel_ids].clone()
216
+
217
+ set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
218
+ optimizer_cur_batch = torch.optim.Adam(
219
+ [sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
220
+
221
+ for j in range(50):
222
+ proj_geo = face_model.compute_for_landmark_fit(
223
+ sel_id_para, sel_exp_para, sel_euler_angle, sel_trans)
224
+ loss_lan = cal_lan_loss_fn(
225
+ proj_geo[:, :, :2], lms.unsqueeze(0).detach())
226
+
227
+ loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
228
+ loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
229
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
230
+ optimizer_cur_batch.zero_grad()
231
+ loss.backward()
232
+ optimizer_cur_batch.step()
233
+ print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f}")
234
+ id_para[sel_ids].data = sel_id_para.data.clone()
235
+ exp_para[sel_ids].data = sel_exp_para.data.clone()
236
+ euler_angle[sel_ids].data = sel_euler_angle.data.clone()
237
+ trans[sel_ids].data = sel_trans.data.clone()
238
+
239
+ coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
240
+ 'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
241
+ if save:
242
+ np.save(out_name, coeff_dict, allow_pickle=True)
243
+
244
+ if debug:
245
+ import imageio
246
+ debug_name = img_name.replace("/images_512/", "/coeff_fit_mp_debug/").replace(".png", "_debug.png").replace(".jpg", "_debug.jpg")
247
+ try: os.makedirs(os.path.dirname(debug_name), exist_ok=True)
248
+ except: pass
249
+ proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
250
+ lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
251
+ lm68s = lm68s * img_scale_factor
252
+ lms = lms * img_scale_factor
253
+ lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
254
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
255
+ lm68s = lm68s.astype(int)
256
+ lm68s = lm68s.reshape([-1,2])
257
+ lms = lms.cpu().numpy().astype(int).reshape([-1,2])
258
+ for lm in lm68s:
259
+ img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1)
260
+ for gt_lm in lms:
261
+ img = cv2.circle(img, gt_lm, 2, (255, 0, 0), thickness=1)
262
+ imageio.imwrite(debug_name, img)
263
+ print(f"debug img saved at {debug_name}")
264
+ return coeff_dict
265
+
266
+ def out_exist_job(vid_name):
267
+ out_name = vid_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png","_coeff_fit_mp.npy")
268
+ # if os.path.exists(out_name) or not os.path.exists(lms_name):
269
+ if os.path.exists(out_name):
270
+ return None
271
+ else:
272
+ return vid_name
273
+
274
+ def get_todo_img_names(img_names):
275
+ todo_img_names = []
276
+ for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=16):
277
+ if res is not None:
278
+ todo_img_names.append(res)
279
+ return todo_img_names
280
+
281
+
282
+ if __name__ == '__main__':
283
+ import argparse, glob, tqdm
284
+ parser = argparse.ArgumentParser()
285
+ parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
286
+ parser.add_argument("--ds_name", default='FFHQ')
287
+ parser.add_argument("--seed", default=0, type=int)
288
+ parser.add_argument("--process_id", default=0, type=int)
289
+ parser.add_argument("--total_process", default=1, type=int)
290
+ parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
291
+ parser.add_argument("--debug", action='store_true')
292
+ parser.add_argument("--reset", action='store_true')
293
+ parser.add_argument("--device", default="cuda:0", type=str)
294
+ parser.add_argument("--output_log", action='store_true')
295
+ parser.add_argument("--load_names", action="store_true")
296
+
297
+ args = parser.parse_args()
298
+ img_dir = args.img_dir
299
+ load_names = args.load_names
300
+
301
+ print(f"args {args}")
302
+
303
+ if args.ds_name == 'single_img':
304
+ img_names = [img_dir]
305
+ else:
306
+ img_names_path = os.path.join(img_dir, "img_dir.pkl")
307
+ if os.path.exists(img_names_path) and load_names:
308
+ print(f"loading vid names from {img_names_path}")
309
+ img_names = load_file(img_names_path)
310
+ else:
311
+ if args.ds_name == 'FFHQ_MV':
312
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
313
+ img_names1 = glob.glob(img_name_pattern1)
314
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
315
+ img_names2 = glob.glob(img_name_pattern2)
316
+ img_names = img_names1 + img_names2
317
+ img_names = sorted(img_names)
318
+ elif args.ds_name == 'FFHQ':
319
+ img_name_pattern = os.path.join(img_dir, "*.png")
320
+ img_names = glob.glob(img_name_pattern)
321
+ img_names = sorted(img_names)
322
+ elif args.ds_name == "PanoHeadGen":
323
+ img_name_patterns = ["ref/*/*.png"]
324
+ img_names = []
325
+ for img_name_pattern in img_name_patterns:
326
+ img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
327
+ img_names_part = glob.glob(img_name_pattern_full)
328
+ img_names.extend(img_names_part)
329
+ img_names = sorted(img_names)
330
+ print(f"saving image names to {img_names_path}")
331
+ save_file(img_names_path, img_names)
332
+
333
+ # import random
334
+ # random.seed(args.seed)
335
+ # random.shuffle(img_names)
336
+
337
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
338
+ camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
339
+ face_model.to(torch.device(args.device))
340
+
341
+ process_id = args.process_id
342
+ total_process = args.total_process
343
+ if total_process > 1:
344
+ assert process_id <= total_process -1 and process_id >= 0
345
+ num_samples_per_process = len(img_names) // total_process
346
+ if process_id == total_process:
347
+ img_names = img_names[process_id * num_samples_per_process : ]
348
+ else:
349
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
350
+ print(f"image names number (before fileter): {len(img_names)}")
351
+
352
+
353
+ if not args.reset:
354
+ img_names = get_todo_img_names(img_names)
355
+
356
+ print(f"image names number (after fileter): {len(img_names)}")
357
+ for i in tqdm.trange(len(img_names), desc=f"process {process_id}: fitting 3dmm ..."):
358
+ img_name = img_names[i]
359
+ try:
360
+ fit_3dmm_for_a_image(img_name, args.debug, device=args.device)
361
+ except Exception as e:
362
+ print(img_name, e)
363
+ if args.output_log and i % max(int(len(img_names) * 0.003), 1) == 0:
364
+ print(f"process {process_id}: {i + 1} / {len(img_names)} done")
365
+ sys.stdout.flush()
366
+ sys.stderr.flush()
367
+
368
+ print(f"process {process_id}: fitting 3dmm all done")
369
+
data_gen/utils/process_video/euler2quaterion.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+ import numba
5
+ from scipy.spatial.transform import Rotation as R
6
+
7
+ def euler2quaterion(euler, use_radian=True):
8
+ """
9
+ euler: np.array, [batch, 3]
10
+ return: the quaterion, np.array, [batch, 4]
11
+ """
12
+ r = R.from_euler('xyz',euler, degrees=not use_radian)
13
+ return r.as_quat()
14
+
15
+ def quaterion2euler(quat, use_radian=True):
16
+ """
17
+ quat: np.array, [batch, 4]
18
+ return: the euler, np.array, [batch, 3]
19
+ """
20
+ r = R.from_quat(quat)
21
+ return r.as_euler('xyz', degrees=not use_radian)
22
+
23
+ def rot2quaterion(rot):
24
+ r = R.from_matrix(rot)
25
+ return r.as_quat()
26
+
27
+ def quaterion2rot(quat):
28
+ r = R.from_quat(quat)
29
+ return r.as_matrix()
30
+
31
+ if __name__ == '__main__':
32
+ euler = np.array([89.999,89.999,89.999] * 100).reshape([100,3])
33
+ q = euler2quaterion(euler, use_radian=False)
34
+ e = quaterion2euler(q, use_radian=False)
35
+ print(" ")
data_gen/utils/process_video/extract_blink.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from data_util.face3d_helper import Face3DHelper
3
+ from utils.commons.tensor_utils import convert_to_tensor
4
+
5
+ def polygon_area(x, y):
6
+ """
7
+ x: [T, K=6]
8
+ y: [T, K=6]
9
+ return: [T,]
10
+ """
11
+ x_ = x - x.mean(axis=-1, keepdims=True)
12
+ y_ = y - y.mean(axis=-1, keepdims=True)
13
+ correction = x_[:,-1] * y_[:,0] - y_[:,-1]* x_[:,0]
14
+ main_area = (x_[:,:-1] * y_[:,1:]).sum(axis=-1) - (y_[:,:-1] * x_[:,1:]).sum(axis=-1)
15
+ return 0.5 * np.abs(main_area + correction)
16
+
17
+ def get_eye_area_percent(id, exp, face3d_helper):
18
+ id = convert_to_tensor(id)
19
+ exp = convert_to_tensor(exp)
20
+ cano_lm3d = face3d_helper.reconstruct_cano_lm3d(id, exp)
21
+ cano_lm2d = (cano_lm3d[..., :2] + 1) / 2
22
+ lms = cano_lm2d.cpu().numpy()
23
+ eyes_left = slice(36, 42)
24
+ eyes_right = slice(42, 48)
25
+ area_left = polygon_area(lms[:, eyes_left, 0], lms[:, eyes_left, 1])
26
+ area_right = polygon_area(lms[:, eyes_right, 0], lms[:, eyes_right, 1])
27
+ # area percentage of two eyes of the whole image...
28
+ area_percent = (area_left + area_right) / 1 * 100 # recommend threshold is 0.25%
29
+ return area_percent # [T,]
30
+
31
+
32
+ if __name__ == '__main__':
33
+ import numpy as np
34
+ import imageio
35
+ import cv2
36
+ import torch
37
+ from data_gen.utils.process_video.extract_lm2d import extract_lms_mediapipe_job, read_video_to_frames, index_lm68_from_lm468
38
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
39
+ from data_util.face3d_helper import Face3DHelper
40
+
41
+ face3d_helper = Face3DHelper()
42
+ video_name = 'data/raw/videos/May_10s.mp4'
43
+ frames = read_video_to_frames(video_name)
44
+ coeff = fit_3dmm_for_a_video(video_name, save=False)
45
+ area_percent = get_eye_area_percent(torch.tensor(coeff['id']), torch.tensor(coeff['exp']), face3d_helper)
46
+ writer = imageio.get_writer("1.mp4", fps=25)
47
+ for idx, frame in enumerate(frames):
48
+ frame = cv2.putText(frame, f"{area_percent[idx]:.2f}", org=(128,128), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=1, color=(255,0,0), thickness=1)
49
+ writer.append_data(frame)
50
+ writer.close()
data_gen/utils/process_video/extract_lm2d.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OMP_NUM_THREADS"] = "1"
3
+ import sys
4
+ import glob
5
+ import cv2
6
+ import pickle
7
+ import tqdm
8
+ import numpy as np
9
+ import mediapipe as mp
10
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
11
+ from utils.commons.os_utils import multiprocess_glob
12
+ from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
13
+ import warnings
14
+ import traceback
15
+
16
+ warnings.filterwarnings('ignore')
17
+
18
+ """
19
+ 基于Face_aligment的lm68已被弃用,因为其:
20
+ 1. 对眼睛部位的预测精度极低
21
+ 2. 无法在大偏转角度时准确预测被遮挡的下颚线, 导致大角度时3dmm的GT label就是有问题的, 从而影响性能
22
+ 我们目前转而使用基于mediapipe的lm68
23
+ """
24
+ # def extract_landmarks(ori_imgs_dir):
25
+
26
+ # print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
27
+
28
+ # fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
29
+ # image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.png'))
30
+ # for image_path in tqdm.tqdm(image_paths):
31
+ # out_name = image_path.replace("/images_512/", "/lms_2d/").replace(".png",".lms")
32
+ # if os.path.exists(out_name):
33
+ # continue
34
+ # input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
35
+ # input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
36
+ # preds = fa.get_landmarks(input)
37
+ # if preds is None:
38
+ # print(f"Skip {image_path} for no face detected")
39
+ # continue
40
+ # if len(preds) > 0:
41
+ # lands = preds[0].reshape(-1, 2)[:,:2]
42
+ # os.makedirs(os.path.dirname(out_name), exist_ok=True)
43
+ # np.savetxt(out_name, lands, '%f')
44
+ # del fa
45
+ # print(f'[INFO] ===== extracted face landmarks =====')
46
+
47
+ def save_file(name, content):
48
+ with open(name, "wb") as f:
49
+ pickle.dump(content, f)
50
+
51
+ def load_file(name):
52
+ with open(name, "rb") as f:
53
+ content = pickle.load(f)
54
+ return content
55
+
56
+
57
+ face_landmarker = None
58
+
59
+ def extract_landmark_job(video_name, nerf=False):
60
+ try:
61
+ if nerf:
62
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
63
+ else:
64
+ out_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
65
+ if os.path.exists(out_name):
66
+ # print("out exists, skip...")
67
+ return
68
+ try:
69
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
70
+ except:
71
+ pass
72
+ global face_landmarker
73
+ if face_landmarker is None:
74
+ face_landmarker = MediapipeLandmarker()
75
+ img_lm478, vid_lm478 = face_landmarker.extract_lm478_from_video_name(video_name)
76
+ lm478 = face_landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
77
+ np.save(out_name, lm478)
78
+ return True
79
+ # print("Hahaha, solve one item!!!")
80
+ except Exception as e:
81
+ traceback.print_exc()
82
+ return False
83
+
84
+ def out_exist_job(vid_name):
85
+ out_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
86
+ if os.path.exists(out_name):
87
+ return None
88
+ else:
89
+ return vid_name
90
+
91
+ def get_todo_vid_names(vid_names):
92
+ if len(vid_names) == 1: # nerf
93
+ return vid_names
94
+ todo_vid_names = []
95
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=128):
96
+ if res is not None:
97
+ todo_vid_names.append(res)
98
+ return todo_vid_names
99
+
100
+ if __name__ == '__main__':
101
+ import argparse, glob, tqdm, random
102
+ parser = argparse.ArgumentParser()
103
+ parser.add_argument("--vid_dir", default='nerf')
104
+ parser.add_argument("--ds_name", default='data/raw/videos/May.mp4')
105
+ parser.add_argument("--num_workers", default=2, type=int)
106
+ parser.add_argument("--process_id", default=0, type=int)
107
+ parser.add_argument("--total_process", default=1, type=int)
108
+ parser.add_argument("--reset", action="store_true")
109
+ parser.add_argument("--load_names", action="store_true")
110
+
111
+ args = parser.parse_args()
112
+ vid_dir = args.vid_dir
113
+ ds_name = args.ds_name
114
+ load_names = args.load_names
115
+
116
+ if ds_name.lower() == 'nerf': # 处理单个视频
117
+ vid_names = [vid_dir]
118
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy") for video_name in vid_names]
119
+ else: # 处理整个数据集
120
+ if ds_name in ['lrs3_trainval']:
121
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
122
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
123
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
124
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
125
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
126
+ elif ds_name in ["RAVDESS", 'VFHQ']:
127
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
128
+ else:
129
+ raise NotImplementedError()
130
+
131
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
132
+ if os.path.exists(vid_names_path) and load_names:
133
+ print(f"loading vid names from {vid_names_path}")
134
+ vid_names = load_file(vid_names_path)
135
+ else:
136
+ vid_names = multiprocess_glob(vid_name_pattern)
137
+ vid_names = sorted(vid_names)
138
+ if not load_names:
139
+ print(f"saving vid names to {vid_names_path}")
140
+ save_file(vid_names_path, vid_names)
141
+ out_names = [video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") for video_name in vid_names]
142
+
143
+ process_id = args.process_id
144
+ total_process = args.total_process
145
+ if total_process > 1:
146
+ assert process_id <= total_process -1
147
+ num_samples_per_process = len(vid_names) // total_process
148
+ if process_id == total_process:
149
+ vid_names = vid_names[process_id * num_samples_per_process : ]
150
+ else:
151
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
152
+
153
+ if not args.reset:
154
+ vid_names = get_todo_vid_names(vid_names)
155
+ print(f"todo videos number: {len(vid_names)}")
156
+
157
+ fail_cnt = 0
158
+ job_args = [(vid_name, ds_name=='nerf') for vid_name in vid_names]
159
+ for (i, res) in multiprocess_run_tqdm(extract_landmark_job, job_args, num_workers=args.num_workers, desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
160
+ if res is False:
161
+ fail_cnt += 1
162
+ print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {fail_cnt} / {i + 1} = {fail_cnt / (i + 1):.4f}")
163
+ sys.stdout.flush()
164
+ pass
data_gen/utils/process_video/extract_segment_imgs.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OMP_NUM_THREADS"] = "1"
3
+ import random
4
+ import glob
5
+ import cv2
6
+ import tqdm
7
+ import numpy as np
8
+ from typing import Union
9
+ from utils.commons.tensor_utils import convert_to_np
10
+ from utils.commons.os_utils import multiprocess_glob
11
+ import pickle
12
+ import traceback
13
+ import multiprocessing
14
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
15
+ from scipy.ndimage import binary_erosion, binary_dilation
16
+ from sklearn.neighbors import NearestNeighbors
17
+ from mediapipe.tasks.python import vision
18
+ from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter, encode_segmap_mask_to_image, decode_segmap_mask_from_image, job_cal_seg_map_for_image
19
+
20
+ seg_model = None
21
+ segmenter = None
22
+ mat_model = None
23
+ lama_model = None
24
+ lama_config = None
25
+
26
+ from data_gen.utils.process_video.split_video_to_imgs import extract_img_job
27
+
28
+ BG_NAME_MAP = {
29
+ "knn": "",
30
+ }
31
+ FRAME_SELECT_INTERVAL = 5
32
+ SIM_METHOD = "mse"
33
+ SIM_THRESHOLD = 3
34
+
35
+ def save_file(name, content):
36
+ with open(name, "wb") as f:
37
+ pickle.dump(content, f)
38
+
39
+ def load_file(name):
40
+ with open(name, "rb") as f:
41
+ content = pickle.load(f)
42
+ return content
43
+
44
+ def save_rgb_alpha_image_to_path(img, alpha, img_path):
45
+ try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
46
+ except: pass
47
+ cv2.imwrite(img_path, np.concatenate([cv2.cvtColor(img, cv2.COLOR_RGB2BGR), alpha], axis=-1))
48
+
49
+ def save_rgb_image_to_path(img, img_path):
50
+ try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
51
+ except: pass
52
+ cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
53
+
54
+ def load_rgb_image_to_path(img_path):
55
+ return cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
56
+
57
+ def image_similarity(x: np.ndarray, y: np.ndarray, method="mse"):
58
+ if method == "mse":
59
+ return np.mean((x - y) ** 2)
60
+ else:
61
+ raise NotImplementedError
62
+
63
+ def extract_background(img_lst, segmap_mask_lst=None, method="knn", device='cpu', mix_bg=True):
64
+ """
65
+ img_lst: list of rgb ndarray
66
+ method: "knn"
67
+ """
68
+ global segmenter
69
+ global seg_model
70
+ global mat_model
71
+ global lama_model
72
+ global lama_config
73
+
74
+ assert len(img_lst) > 0
75
+ if segmap_mask_lst is not None:
76
+ assert len(segmap_mask_lst) == len(img_lst)
77
+ else:
78
+ del segmenter
79
+ del seg_model
80
+ seg_model = MediapipeSegmenter()
81
+ segmenter = vision.ImageSegmenter.create_from_options(seg_model.video_options)
82
+
83
+ def get_segmap_mask(img_lst, segmap_mask_lst, index):
84
+ if segmap_mask_lst is not None:
85
+ segmap = refresh_segment_mask(segmap_mask_lst[index])
86
+ else:
87
+ segmap = seg_model._cal_seg_map(refresh_image(img_lst[index]), segmenter=segmenter)
88
+ return segmap
89
+
90
+ if method == "knn":
91
+ num_frames = len(img_lst)
92
+ if num_frames < 100:
93
+ FRAME_SELECT_INTERVAL = 5
94
+ elif num_frames < 10000:
95
+ FRAME_SELECT_INTERVAL = 20
96
+ else:
97
+ FRAME_SELECT_INTERVAL = num_frames // 500
98
+
99
+ img_lst = img_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else img_lst[0:1]
100
+
101
+ if segmap_mask_lst is not None:
102
+ segmap_mask_lst = segmap_mask_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else segmap_mask_lst[0:1]
103
+ assert len(img_lst) == len(segmap_mask_lst)
104
+ # get H/W
105
+ h, w = refresh_image(img_lst[0]).shape[:2]
106
+
107
+ # nearest neighbors
108
+ all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() # [512*512, 2] coordinate grid
109
+ distss = []
110
+ for idx, img in tqdm.tqdm(enumerate(img_lst), desc='combining backgrounds...', total=len(img_lst)):
111
+ segmap = get_segmap_mask(img_lst=img_lst, segmap_mask_lst=segmap_mask_lst, index=idx)
112
+ bg = (segmap[0]).astype(bool) # [h,w] bool mask
113
+ fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) # [N_nonbg,2] coordinate of non-bg pixels
114
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
115
+ dists, _ = nbrs.kneighbors(all_xys) # [512*512, 1] distance to nearest non-bg pixel
116
+ distss.append(dists)
117
+
118
+ distss = np.stack(distss) # [B, 512*512, 1]
119
+ max_dist = np.max(distss, 0) # [512*512, 1]
120
+ max_id = np.argmax(distss, 0) # id of frame
121
+
122
+ bc_pixs = max_dist > 10 # 在各个frame有一个出现过是bg的pixel,bg标准是离最近的non-bg pixel距离大于10
123
+ bc_pixs_id = np.nonzero(bc_pixs)
124
+ bc_ids = max_id[bc_pixs]
125
+
126
+ # TODO: maybe we should reimplement here to avoid memory costs?
127
+ # though there is upper limits of images here
128
+ num_pixs = distss.shape[1]
129
+ bg_img = np.zeros((h*w, 3), dtype=np.uint8)
130
+ img_lst = [refresh_image(img) for img in img_lst]
131
+ imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
132
+ bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] # 对那些铁bg的pixel,直接去对应的image里面采样
133
+ bg_img = bg_img.reshape(h, w, 3)
134
+
135
+ max_dist = max_dist.reshape(h, w)
136
+ bc_pixs = max_dist > 10 # 5
137
+ bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
138
+ fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
139
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
140
+ distances, indices = nbrs.kneighbors(bg_xys) # 对non-bg img,用KNN找最近的bg pixel
141
+ bg_fg_xys = fg_xys[indices[:, 0]]
142
+ bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
143
+ else:
144
+ raise NotImplementedError # deperated
145
+
146
+ return bg_img
147
+
148
+ def inpaint_torso_job(gt_img, segmap):
149
+ bg_part = (segmap[0]).astype(bool)
150
+ head_part = (segmap[1] + segmap[3] + segmap[5]).astype(bool)
151
+ neck_part = (segmap[2]).astype(bool)
152
+ torso_part = (segmap[4]).astype(bool)
153
+ img = gt_img.copy()
154
+ img[head_part] = 0
155
+
156
+ # torso part "vertical" in-painting...
157
+ L = 8 + 1
158
+ torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
159
+ # lexsort: sort 2D coords first by y then by x,
160
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
161
+ inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
162
+ torso_coords = torso_coords[inds]
163
+ # choose the top pixel for each column
164
+ u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
165
+ top_torso_coords = torso_coords[uid] # [m, 2]
166
+ # only keep top-is-head pixels
167
+ top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
168
+ mask = head_part[tuple(top_torso_coords_up.T)]
169
+ if mask.any():
170
+ top_torso_coords = top_torso_coords[mask]
171
+ # get the color
172
+ top_torso_colors = gt_img[tuple(top_torso_coords.T)] # [m, 3]
173
+ # construct inpaint coords (vertically up, or minus in x)
174
+ inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
175
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
176
+ inpaint_torso_coords += inpaint_offsets
177
+ inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
178
+ inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
179
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
180
+ inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
181
+ # set color
182
+ img[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
183
+ inpaint_torso_mask = np.zeros_like(img[..., 0]).astype(bool)
184
+ inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
185
+ else:
186
+ inpaint_torso_mask = None
187
+
188
+ # neck part "vertical" in-painting...
189
+ push_down = 4
190
+ L = 48 + push_down + 1
191
+ neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
192
+ neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
193
+ # lexsort: sort 2D coords first by y then by x,
194
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
195
+ inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
196
+ neck_coords = neck_coords[inds]
197
+ # choose the top pixel for each column
198
+ u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
199
+ top_neck_coords = neck_coords[uid] # [m, 2]
200
+ # only keep top-is-head pixels
201
+ top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
202
+ mask = head_part[tuple(top_neck_coords_up.T)]
203
+ top_neck_coords = top_neck_coords[mask]
204
+ # push these top down for 4 pixels to make the neck inpainting more natural...
205
+ offset_down = np.minimum(ucnt[mask] - 1, push_down)
206
+ top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
207
+ # get the color
208
+ top_neck_colors = gt_img[tuple(top_neck_coords.T)] # [m, 3]
209
+ # construct inpaint coords (vertically up, or minus in x)
210
+ inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
211
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
212
+ inpaint_neck_coords += inpaint_offsets
213
+ inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
214
+ inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
215
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
216
+ inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
217
+ # set color
218
+ img[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
219
+ # apply blurring to the inpaint area to avoid vertical-line artifects...
220
+ inpaint_mask = np.zeros_like(img[..., 0]).astype(bool)
221
+ inpaint_mask[tuple(inpaint_neck_coords.T)] = True
222
+
223
+ blur_img = img.copy()
224
+ blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
225
+ img[inpaint_mask] = blur_img[inpaint_mask]
226
+
227
+ # set mask
228
+ torso_img_mask = (neck_part | torso_part | inpaint_mask)
229
+ torso_with_bg_img_mask = (bg_part | neck_part | torso_part | inpaint_mask)
230
+ if inpaint_torso_mask is not None:
231
+ torso_img_mask = torso_img_mask | inpaint_torso_mask
232
+ torso_with_bg_img_mask = torso_with_bg_img_mask | inpaint_torso_mask
233
+
234
+ torso_img = img.copy()
235
+ torso_img[~torso_img_mask] = 0
236
+ torso_with_bg_img = img.copy()
237
+ torso_img[~torso_with_bg_img_mask] = 0
238
+
239
+ return torso_img, torso_img_mask, torso_with_bg_img, torso_with_bg_img_mask
240
+
241
+ def load_segment_mask_from_file(filename: str):
242
+ encoded_segmap = load_rgb_image_to_path(filename)
243
+ segmap_mask = decode_segmap_mask_from_image(encoded_segmap)
244
+ return segmap_mask
245
+
246
+ # load segment mask to memory if not loaded yet
247
+ def refresh_segment_mask(segmap_mask: Union[str, np.ndarray]):
248
+ if isinstance(segmap_mask, str):
249
+ segmap_mask = load_segment_mask_from_file(segmap_mask)
250
+ return segmap_mask
251
+
252
+ # load segment mask to memory if not loaded yet
253
+ def refresh_image(image: Union[str, np.ndarray]):
254
+ if isinstance(image, str):
255
+ image = load_rgb_image_to_path(image)
256
+ return image
257
+
258
+ def generate_segment_imgs_job(img_name, segmap, img):
259
+ out_img_name = segmap_name = img_name.replace("/gt_imgs/", "/segmaps/").replace(".jpg", ".png") # 存成jpg的话,pixel value会有误差
260
+ try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
261
+ except: pass
262
+ encoded_segmap = encode_segmap_mask_to_image(segmap)
263
+ save_rgb_image_to_path(encoded_segmap, out_img_name)
264
+
265
+ for mode in ['head', 'torso', 'person', 'bg']:
266
+ out_img, mask = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
267
+ img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
268
+ mask = mask[0][..., None]
269
+ img_alpha[~mask] = 0
270
+ out_img_name = img_name.replace("/gt_imgs/", f"/{mode}_imgs/").replace(".jpg", ".png")
271
+ save_rgb_alpha_image_to_path(out_img, img_alpha, out_img_name)
272
+
273
+ inpaint_torso_img, inpaint_torso_img_mask, inpaint_torso_with_bg_img, inpaint_torso_with_bg_img_mask = inpaint_torso_job(img, segmap)
274
+ img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
275
+ img_alpha[~inpaint_torso_img_mask[..., None]] = 0
276
+ out_img_name = img_name.replace("/gt_imgs/", f"/inpaint_torso_imgs/").replace(".jpg", ".png")
277
+ save_rgb_alpha_image_to_path(inpaint_torso_img, img_alpha, out_img_name)
278
+ return segmap_name
279
+
280
+ def segment_and_generate_for_image_job(img_name, img, segmenter_options=None, segmenter=None, store_in_memory=False):
281
+ img = refresh_image(img)
282
+ segmap_mask, segmap_image = job_cal_seg_map_for_image(img, segmenter_options=segmenter_options, segmenter=segmenter)
283
+ segmap_name = generate_segment_imgs_job(img_name=img_name, segmap=segmap_mask, img=img)
284
+ if store_in_memory:
285
+ return segmap_mask
286
+ else:
287
+ return segmap_name
288
+
289
+ def extract_segment_job(
290
+ video_name,
291
+ nerf=False,
292
+ background_method='knn',
293
+ device="cpu",
294
+ total_gpus=0,
295
+ mix_bg=True,
296
+ store_in_memory=False, # set to True to speed up a bit of preprocess, but leads to HUGE memory costs (100GB for 5-min video)
297
+ force_single_process=False, # turn this on if you find multi-process does not work on your environment
298
+ ):
299
+ global segmenter
300
+ global seg_model
301
+ del segmenter
302
+ del seg_model
303
+ seg_model = MediapipeSegmenter()
304
+ segmenter = vision.ImageSegmenter.create_from_options(seg_model.options)
305
+ # nerf means that we extract only one video, so can enable multi-process acceleration
306
+ multiprocess_enable = nerf and not force_single_process
307
+ try:
308
+ if "cuda" in device:
309
+ # determine which cuda index from subprocess id
310
+ pname = multiprocessing.current_process().name
311
+ pid = int(pname.rsplit("-", 1)[-1]) - 1
312
+ cuda_id = pid % total_gpus
313
+ device = f"cuda:{cuda_id}"
314
+
315
+ if nerf: # single video
316
+ raw_img_dir = video_name.replace(".mp4", "/gt_imgs/").replace("/raw/","/processed/")
317
+ else: # whole dataset
318
+ raw_img_dir = video_name.replace(".mp4", "").replace("/video/", "/gt_imgs/")
319
+ if not os.path.exists(raw_img_dir):
320
+ extract_img_job(video_name, raw_img_dir) # use ffmpeg to split video into imgs
321
+
322
+ img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
323
+
324
+ img_lst = []
325
+
326
+ for img_name in img_names:
327
+ if store_in_memory:
328
+ img = load_rgb_image_to_path(img_name)
329
+ else:
330
+ img = img_name
331
+ img_lst.append(img)
332
+
333
+ print("| Extracting Segmaps && Saving...")
334
+ args = []
335
+ segmap_mask_lst = []
336
+ # preparing parameters for segment
337
+ for i in range(len(img_lst)):
338
+ img_name = img_names[i]
339
+ img = img_lst[i]
340
+ if multiprocess_enable: # create seg_model in subprocesses here
341
+ options = seg_model.options
342
+ segmenter_arg = None
343
+ else: # use seg_model of this process
344
+ options = None
345
+ segmenter_arg = segmenter
346
+ arg = (img_name, img, options, segmenter_arg, store_in_memory)
347
+ args.append(arg)
348
+
349
+ if multiprocess_enable:
350
+ for (_, res) in multiprocess_run_tqdm(segment_and_generate_for_image_job, args=args, num_workers=16, desc='generating segment images in multi-processes...'):
351
+ segmap_mask = res
352
+ segmap_mask_lst.append(segmap_mask)
353
+ else:
354
+ for index in tqdm.tqdm(range(len(img_lst)), desc="generating segment images in single-process..."):
355
+ segmap_mask = segment_and_generate_for_image_job(*args[index])
356
+ segmap_mask_lst.append(segmap_mask)
357
+ print("| Extracted Segmaps Done.")
358
+
359
+ print("| Extracting background...")
360
+ bg_prefix_name = f"bg{BG_NAME_MAP[background_method]}"
361
+ bg_img = extract_background(img_lst, segmap_mask_lst, method=background_method, device=device, mix_bg=mix_bg)
362
+ if nerf:
363
+ out_img_name = video_name.replace("/raw/", "/processed/").replace(".mp4", f"/{bg_prefix_name}.jpg")
364
+ else:
365
+ out_img_name = video_name.replace("/video/", f"/{bg_prefix_name}_img/").replace(".mp4", ".jpg")
366
+ save_rgb_image_to_path(bg_img, out_img_name)
367
+ print("| Extracted background done.")
368
+
369
+ print("| Extracting com_imgs...")
370
+ com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
371
+ for i in tqdm.trange(len(img_names), desc='extracting com_imgs'):
372
+ img_name = img_names[i]
373
+ com_img = refresh_image(img_lst[i]).copy()
374
+ segmap = refresh_segment_mask(segmap_mask_lst[i])
375
+ bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
376
+ com_img[bg_part] = bg_img[bg_part]
377
+ out_img_name = img_name.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
378
+ save_rgb_image_to_path(com_img, out_img_name)
379
+ print("| Extracted com_imgs done.")
380
+
381
+ return 0
382
+ except Exception as e:
383
+ print(str(type(e)), e)
384
+ traceback.print_exc(e)
385
+ return 1
386
+
387
+ def out_exist_job(vid_name, background_method='knn'):
388
+ com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
389
+ img_dir = vid_name.replace("/video/", "/gt_imgs/").replace(".mp4", "")
390
+ out_dir1 = img_dir.replace("/gt_imgs/", "/head_imgs/")
391
+ out_dir2 = img_dir.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
392
+
393
+ if os.path.exists(img_dir) and os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) :
394
+ num_frames = len(os.listdir(img_dir))
395
+ if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames:
396
+ return None
397
+ else:
398
+ return vid_name
399
+ else:
400
+ return vid_name
401
+
402
+ def get_todo_vid_names(vid_names, background_method='knn'):
403
+ if len(vid_names) == 1: # nerf
404
+ return vid_names
405
+ todo_vid_names = []
406
+ fn_args = [(vid_name, background_method) for vid_name in vid_names]
407
+ for i, res in multiprocess_run_tqdm(out_exist_job, fn_args, num_workers=16, desc="checking todo videos..."):
408
+ if res is not None:
409
+ todo_vid_names.append(res)
410
+ return todo_vid_names
411
+
412
+ if __name__ == '__main__':
413
+ import argparse, glob, tqdm, random
414
+ parser = argparse.ArgumentParser()
415
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/TH1KH_512/video')
416
+ parser.add_argument("--ds_name", default='TH1KH_512')
417
+ parser.add_argument("--num_workers", default=48, type=int)
418
+ parser.add_argument("--seed", default=0, type=int)
419
+ parser.add_argument("--process_id", default=0, type=int)
420
+ parser.add_argument("--total_process", default=1, type=int)
421
+ parser.add_argument("--reset", action='store_true')
422
+ parser.add_argument("--load_names", action="store_true")
423
+ parser.add_argument("--background_method", choices=['knn', 'mat', 'ddnm', 'lama'], type=str, default='knn')
424
+ parser.add_argument("--total_gpus", default=0, type=int) # zero gpus means utilizing cpu
425
+ parser.add_argument("--no_mix_bg", action="store_true")
426
+ parser.add_argument("--store_in_memory", action="store_true") # set to True to speed up preprocess, but leads to high memory costs
427
+ parser.add_argument("--force_single_process", action="store_true") # turn this on if you find multi-process does not work on your environment
428
+
429
+ args = parser.parse_args()
430
+ vid_dir = args.vid_dir
431
+ ds_name = args.ds_name
432
+ load_names = args.load_names
433
+ background_method = args.background_method
434
+ total_gpus = args.total_gpus
435
+ mix_bg = not args.no_mix_bg
436
+ store_in_memory = args.store_in_memory
437
+ force_single_process = args.force_single_process
438
+
439
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
440
+ for d in devices[:total_gpus]:
441
+ os.system(f'pkill -f "voidgpu{d}"')
442
+
443
+ if ds_name.lower() == 'nerf': # 处理单个视频
444
+ vid_names = [vid_dir]
445
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_lms.npy") for video_name in vid_names]
446
+ else: # 处理整个数据集
447
+ if ds_name in ['lrs3_trainval']:
448
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
449
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
450
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
451
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
452
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
453
+ elif ds_name in ["RAVDESS", 'VFHQ']:
454
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
455
+ else:
456
+ raise NotImplementedError()
457
+
458
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
459
+ if os.path.exists(vid_names_path) and load_names:
460
+ print(f"loading vid names from {vid_names_path}")
461
+ vid_names = load_file(vid_names_path)
462
+ else:
463
+ vid_names = multiprocess_glob(vid_name_pattern)
464
+ vid_names = sorted(vid_names)
465
+ print(f"saving vid names to {vid_names_path}")
466
+ save_file(vid_names_path, vid_names)
467
+
468
+ vid_names = sorted(vid_names)
469
+ random.seed(args.seed)
470
+ random.shuffle(vid_names)
471
+
472
+ process_id = args.process_id
473
+ total_process = args.total_process
474
+ if total_process > 1:
475
+ assert process_id <= total_process -1
476
+ num_samples_per_process = len(vid_names) // total_process
477
+ if process_id == total_process:
478
+ vid_names = vid_names[process_id * num_samples_per_process : ]
479
+ else:
480
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
481
+
482
+ if not args.reset:
483
+ vid_names = get_todo_vid_names(vid_names, background_method)
484
+ print(f"todo videos number: {len(vid_names)}")
485
+
486
+ device = "cuda" if total_gpus > 0 else "cpu"
487
+ extract_job = extract_segment_job
488
+ fn_args = [(vid_name, ds_name=='nerf', background_method, device, total_gpus, mix_bg, store_in_memory, force_single_process) for i, vid_name in enumerate(vid_names)]
489
+
490
+ if ds_name == 'nerf': # 处理单个视频
491
+ extract_job(*fn_args[0])
492
+ else:
493
+ for vid_name in multiprocess_run_tqdm(extract_job, fn_args, desc=f"Root process {args.process_id}: segment images", num_workers=args.num_workers):
494
+ pass
data_gen/utils/process_video/fit_3dmm_landmark.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a script for efficienct 3DMM coefficient extraction.
2
+ # It could reconstruct accurate 3D face in real-time.
3
+ # It is built upon BFM 2009 model and mediapipe landmark extractor.
4
+ # It is authored by ZhenhuiYe ([email protected]), free to contact him for any suggestion on improvement!
5
+
6
+ from numpy.core.numeric import require
7
+ from numpy.lib.function_base import quantile
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import copy
11
+ import numpy as np
12
+
13
+ import random
14
+ import pickle
15
+ import os
16
+ import sys
17
+ import cv2
18
+ import argparse
19
+ import tqdm
20
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
21
+ from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker, read_video_to_frames
22
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
23
+ from deep_3drecon.secc_renderer import SECC_Renderer
24
+ from utils.commons.os_utils import multiprocess_glob
25
+
26
+
27
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
28
+ camera_distance=10, focal=1015, keypoint_mode='mediapipe')
29
+ face_model.to(torch.device("cuda:0"))
30
+
31
+ dir_path = os.path.dirname(os.path.realpath(__file__))
32
+
33
+
34
+ def draw_axes(img, pitch, yaw, roll, tx, ty, size=50):
35
+ # yaw = -yaw
36
+ pitch = - pitch
37
+ roll = - roll
38
+ rotation_matrix = cv2.Rodrigues(np.array([pitch, yaw, roll]))[0].astype(np.float64)
39
+ axes_points = np.array([
40
+ [1, 0, 0, 0],
41
+ [0, 1, 0, 0],
42
+ [0, 0, 1, 0]
43
+ ], dtype=np.float64)
44
+ axes_points = rotation_matrix @ axes_points
45
+ axes_points = (axes_points[:2, :] * size).astype(int)
46
+ axes_points[0, :] = axes_points[0, :] + tx
47
+ axes_points[1, :] = axes_points[1, :] + ty
48
+
49
+ new_img = img.copy()
50
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 0].ravel()), (255, 0, 0), 3)
51
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 1].ravel()), (0, 255, 0), 3)
52
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 2].ravel()), (0, 0, 255), 3)
53
+ return new_img
54
+
55
+ def save_file(name, content):
56
+ with open(name, "wb") as f:
57
+ pickle.dump(content, f)
58
+
59
+ def load_file(name):
60
+ with open(name, "rb") as f:
61
+ content = pickle.load(f)
62
+ return content
63
+
64
+ def cal_lap_loss(in_tensor):
65
+ # [T, 68, 2]
66
+ t = in_tensor.shape[0]
67
+ in_tensor = in_tensor.reshape([t, -1]).permute(1,0).unsqueeze(1) # [c, 1, t]
68
+ in_tensor = torch.cat([in_tensor[:, :, 0:1], in_tensor, in_tensor[:, :, -1:]], dim=-1)
69
+ lap_kernel = torch.Tensor((-0.5, 1.0, -0.5)).reshape([1,1,3]).float().to(in_tensor.device) # [1, 1, kw]
70
+ loss_lap = 0
71
+
72
+ out_tensor = F.conv1d(in_tensor, lap_kernel)
73
+ loss_lap += torch.mean(out_tensor**2)
74
+ return loss_lap
75
+
76
+ def cal_vel_loss(ldm):
77
+ # [B, 68, 2]
78
+ vel = ldm[1:] - ldm[:-1]
79
+ return torch.mean(torch.abs(vel))
80
+
81
+ def cal_lan_loss(proj_lan, gt_lan):
82
+ # [B, 68, 2]
83
+ loss = (proj_lan - gt_lan)** 2
84
+ # use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
85
+ weights = torch.zeros_like(loss)
86
+ weights = torch.ones_like(loss)
87
+ weights[:, 36:48, :] = 3 # eye 12 points
88
+ weights[:, -8:, :] = 3 # inner lip 8 points
89
+ weights[:, 28:31, :] = 3 # nose 3 points
90
+ loss = loss * weights
91
+ return torch.mean(loss)
92
+
93
+ def cal_lan_loss_mp(proj_lan, gt_lan, mean:bool=True):
94
+ # [B, 68, 2]
95
+ loss = (proj_lan - gt_lan).pow(2)
96
+ # loss = (proj_lan - gt_lan).abs()
97
+ unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
98
+ upper_eye = [161,160,159,158,157] + [388,387,386,385,384]
99
+ eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
100
+ inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
101
+ outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
102
+ weights = torch.ones_like(loss)
103
+ weights[:, eye] = 3
104
+ weights[:, upper_eye] = 20
105
+ weights[:, inner_lip] = 5
106
+ weights[:, outer_lip] = 5
107
+ weights[:, unmatch_mask] = 0
108
+ loss = loss * weights
109
+ if mean:
110
+ loss = torch.mean(loss)
111
+ return loss
112
+
113
+ def cal_acceleration_loss(trans):
114
+ vel = trans[1:] - trans[:-1]
115
+ acc = vel[1:] - vel[:-1]
116
+ return torch.mean(torch.abs(acc))
117
+
118
+ def cal_acceleration_ldm_loss(ldm):
119
+ # [B, 68, 2]
120
+ vel = ldm[1:] - ldm[:-1]
121
+ acc = vel[1:] - vel[:-1]
122
+ lip_weight = 0.25 # we dont want smooth the lip too much
123
+ acc[48:68] *= lip_weight
124
+ return torch.mean(torch.abs(acc))
125
+
126
+ def set_requires_grad(tensor_list):
127
+ for tensor in tensor_list:
128
+ tensor.requires_grad = True
129
+
130
+ @torch.enable_grad()
131
+ def fit_3dmm_for_a_video(
132
+ video_name,
133
+ nerf=False, # use the file name convention for GeneFace++
134
+ id_mode='global',
135
+ debug=False,
136
+ keypoint_mode='mediapipe',
137
+ large_yaw_threshold=9999999.9,
138
+ save=True
139
+ ) -> bool: # True: good, False: bad
140
+ assert video_name.endswith(".mp4"), "this function only support video as input"
141
+ if id_mode == 'global':
142
+ LAMBDA_REG_ID = 0.2
143
+ LAMBDA_REG_EXP = 0.6
144
+ LAMBDA_REG_LAP = 1.0
145
+ LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
146
+ LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
147
+ else:
148
+ LAMBDA_REG_ID = 0.3
149
+ LAMBDA_REG_EXP = 0.05
150
+ LAMBDA_REG_LAP = 1.0
151
+ LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
152
+ LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
153
+
154
+ frames = read_video_to_frames(video_name) # [T, H, W, 3]
155
+ img_h, img_w = frames.shape[1], frames.shape[2]
156
+ assert img_h == img_w
157
+ num_frames = len(frames)
158
+
159
+ if nerf: # single video
160
+ lm_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
161
+ else:
162
+ lm_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4", "_lms.npy")
163
+
164
+ if os.path.exists(lm_name):
165
+ lms = np.load(lm_name)
166
+ else:
167
+ print(f"lms_2d file not found, try to extract it from video... {lm_name}")
168
+ try:
169
+ landmarker = MediapipeLandmarker()
170
+ img_lm478, vid_lm478 = landmarker.extract_lm478_from_frames(frames, anti_smooth_factor=20)
171
+ lms = landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
172
+ except Exception as e:
173
+ print(e)
174
+ return False
175
+ if lms is None:
176
+ print(f"get None lms_2d, please check whether each frame has one head, exiting... {lm_name}")
177
+ return False
178
+ lms = lms[:, :468, :]
179
+ lms = torch.FloatTensor(lms).cuda()
180
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
181
+
182
+ if keypoint_mode == 'mediapipe':
183
+ # default
184
+ cal_lan_loss_fn = cal_lan_loss_mp
185
+ if nerf: # single video
186
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/coeff_fit_mp.npy")
187
+ else:
188
+ out_name = video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4", "_coeff_fit_mp.npy")
189
+ else:
190
+ # lm68 is less accurate than mp
191
+ cal_lan_loss_fn = cal_lan_loss
192
+ if nerf: # single video
193
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "_coeff_fit_lm68.npy")
194
+ else:
195
+ out_name = video_name.replace("/video/", "/coeff_fit_lm68/").replace(".mp4", "_coeff_fit_lm68.npy")
196
+ try:
197
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
198
+ except:
199
+ pass
200
+
201
+ id_dim, exp_dim = 80, 64
202
+ sel_ids = np.arange(0, num_frames, 40)
203
+
204
+ h = w = face_model.center * 2
205
+ img_scale_factor = img_h / h
206
+ lms /= img_scale_factor # rescale lms into [0,224]
207
+
208
+ if id_mode == 'global':
209
+ # default choice by GeneFace++ and later works
210
+ id_para = lms.new_zeros((1, id_dim), requires_grad=True)
211
+ elif id_mode == 'finegrained':
212
+ # legacy choice by GeneFace1 (ICLR 2023)
213
+ id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True)
214
+ else: raise NotImplementedError(f"id mode {id_mode} not supported! we only support global or finegrained.")
215
+ exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
216
+ euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
217
+ trans = lms.new_zeros((num_frames, 3), requires_grad=True)
218
+
219
+ set_requires_grad([id_para, exp_para, euler_angle, trans])
220
+
221
+ optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
222
+ optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
223
+
224
+ # 其他参数初始化,先训练euler和trans
225
+ for _ in range(200):
226
+ if id_mode == 'global':
227
+ proj_geo = face_model.compute_for_landmark_fit(
228
+ id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans)
229
+ else:
230
+ proj_geo = face_model.compute_for_landmark_fit(
231
+ id_para, exp_para, euler_angle, trans)
232
+ loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
233
+ loss = loss_lan
234
+ optimizer_frame.zero_grad()
235
+ loss.backward()
236
+ optimizer_frame.step()
237
+
238
+ # print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
239
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
240
+
241
+ for param_group in optimizer_frame.param_groups:
242
+ param_group['lr'] = 0.1
243
+
244
+ # "jointly roughly training id exp euler trans"
245
+ for _ in range(200):
246
+ ret = {}
247
+ if id_mode == 'global':
248
+ proj_geo = face_model.compute_for_landmark_fit(
249
+ id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans, ret)
250
+ else:
251
+ proj_geo = face_model.compute_for_landmark_fit(
252
+ id_para, exp_para, euler_angle, trans, ret)
253
+ loss_lan = cal_lan_loss_fn(
254
+ proj_geo[:, :, :2], lms.detach())
255
+ # loss_lap = cal_lap_loss(proj_geo)
256
+ # laplacian对euler影响不大,但是对trans的提升很大
257
+ loss_lap = cal_lap_loss(id_para) + cal_lap_loss(exp_para) + cal_lap_loss(euler_angle) * 0.3 + cal_lap_loss(trans) * 0.3
258
+
259
+ loss_regid = torch.mean(id_para*id_para) # 正则化
260
+ loss_regexp = torch.mean(exp_para * exp_para)
261
+
262
+ loss_vel_id = cal_vel_loss(id_para)
263
+ loss_vel_exp = cal_vel_loss(exp_para)
264
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP + loss_lap * LAMBDA_REG_LAP
265
+ optimizer_idexp.zero_grad()
266
+ optimizer_frame.zero_grad()
267
+ loss.backward()
268
+ optimizer_idexp.step()
269
+ optimizer_frame.step()
270
+
271
+ # print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
272
+ # print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
273
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
274
+
275
+ # start fine training, intialize from the roughly trained results
276
+ if id_mode == 'global':
277
+ id_para_ = lms.new_zeros((1, id_dim), requires_grad=False)
278
+ else:
279
+ id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
280
+ id_para_.data = id_para.data.clone()
281
+ id_para = id_para_
282
+ exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
283
+ exp_para_.data = exp_para.data.clone()
284
+ exp_para = exp_para_
285
+ euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
286
+ euler_angle_.data = euler_angle.data.clone()
287
+ euler_angle = euler_angle_
288
+ trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
289
+ trans_.data = trans.data.clone()
290
+ trans = trans_
291
+
292
+ batch_size = 50
293
+ # "fine fitting the 3DMM in batches"
294
+ for i in range(int((num_frames-1)/batch_size+1)):
295
+ if (i+1)*batch_size > num_frames:
296
+ start_n = num_frames-batch_size
297
+ sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
298
+ else:
299
+ start_n = i*batch_size
300
+ sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
301
+ sel_lms = lms[sel_ids]
302
+
303
+ if id_mode == 'global':
304
+ sel_id_para = id_para.expand((sel_ids.shape[0], id_dim))
305
+ else:
306
+ sel_id_para = id_para.new_zeros((batch_size, id_dim), requires_grad=True)
307
+ sel_id_para.data = id_para[sel_ids].clone()
308
+ sel_exp_para = exp_para.new_zeros(
309
+ (batch_size, exp_dim), requires_grad=True)
310
+ sel_exp_para.data = exp_para[sel_ids].clone()
311
+ sel_euler_angle = euler_angle.new_zeros(
312
+ (batch_size, 3), requires_grad=True)
313
+ sel_euler_angle.data = euler_angle[sel_ids].clone()
314
+ sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
315
+ sel_trans.data = trans[sel_ids].clone()
316
+
317
+ if id_mode == 'global':
318
+ set_requires_grad([sel_exp_para, sel_euler_angle, sel_trans])
319
+ optimizer_cur_batch = torch.optim.Adam(
320
+ [sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
321
+ else:
322
+ set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
323
+ optimizer_cur_batch = torch.optim.Adam(
324
+ [sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
325
+
326
+ for j in range(50):
327
+ ret = {}
328
+ proj_geo = face_model.compute_for_landmark_fit(
329
+ sel_id_para, sel_exp_para, sel_euler_angle, sel_trans, ret)
330
+ loss_lan = cal_lan_loss_fn(
331
+ proj_geo[:, :, :2], lms[sel_ids].detach())
332
+
333
+ # loss_lap = cal_lap_loss(proj_geo)
334
+ loss_lap = cal_lap_loss(sel_id_para) + cal_lap_loss(sel_exp_para) + cal_lap_loss(sel_euler_angle) * 0.3 + cal_lap_loss(sel_trans) * 0.3
335
+ loss_vel_id = cal_vel_loss(sel_id_para)
336
+ loss_vel_exp = cal_vel_loss(sel_exp_para)
337
+ log_dict = {
338
+ 'loss_vel_id': loss_vel_id,
339
+ 'loss_vel_exp': loss_vel_exp,
340
+ 'loss_vel_euler': cal_vel_loss(sel_euler_angle),
341
+ 'loss_vel_trans': cal_vel_loss(sel_trans),
342
+ }
343
+ loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
344
+ loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
345
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_lap * LAMBDA_REG_LAP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP
346
+
347
+ optimizer_cur_batch.zero_grad()
348
+ loss.backward()
349
+ optimizer_cur_batch.step()
350
+
351
+ if debug:
352
+ print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},loss_lap_ldm:{loss_lap.item():.4f}")
353
+ print("|--------" + ', '.join([f"{k}: {v:.4f}" for k,v in log_dict.items()]))
354
+ if id_mode != 'global':
355
+ id_para[sel_ids].data = sel_id_para.data.clone()
356
+ exp_para[sel_ids].data = sel_exp_para.data.clone()
357
+ euler_angle[sel_ids].data = sel_euler_angle.data.clone()
358
+ trans[sel_ids].data = sel_trans.data.clone()
359
+
360
+ coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
361
+ 'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
362
+
363
+ # filter data by side-view pose
364
+ # bad_yaw = False
365
+ # yaws = [] # not so accurate
366
+ # for index in range(coeff_dict["trans"].shape[0]):
367
+ # yaw = coeff_dict["euler"][index][1]
368
+ # yaw = np.abs(yaw)
369
+ # yaws.append(yaw)
370
+ # if yaw > large_yaw_threshold:
371
+ # bad_yaw = True
372
+
373
+ if debug:
374
+ import imageio
375
+ from utils.visualization.vis_cam3d.camera_pose_visualizer import CameraPoseVisualizer
376
+ from data_util.face3d_helper import Face3DHelper
377
+ from data_gen.utils.process_video.extract_blink import get_eye_area_percent
378
+ face3d_helper = Face3DHelper('deep_3drecon/BFM', keypoint_mode='mediapipe')
379
+
380
+ t = coeff_dict['exp'].shape[0]
381
+ if len(coeff_dict['id']) == 1:
382
+ coeff_dict['id'] = np.repeat(coeff_dict['id'], t, axis=0)
383
+ idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d_np(coeff_dict['id'], coeff_dict['exp']).reshape([t, -1])
384
+ cano_lm3d = idexp_lm3d / 10 + face3d_helper.key_mean_shape.squeeze().reshape([1, -1]).cpu().numpy()
385
+ cano_lm3d = cano_lm3d.reshape([t, -1, 3])
386
+ WH = 512
387
+ cano_lm3d = (cano_lm3d * WH/2 + WH/2).astype(int)
388
+
389
+ with torch.no_grad():
390
+ rot = ParametricFaceModel.compute_rotation(euler_angle)
391
+ extrinsic = torch.zeros([rot.shape[0], 4, 4]).to(rot.device)
392
+ extrinsic[:, :3,:3] = rot
393
+ extrinsic[:, :3, 3] = trans # / 10
394
+ extrinsic[:, 3, 3] = 1
395
+ extrinsic = extrinsic.cpu().numpy()
396
+
397
+ xy_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xy')
398
+ xz_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xz')
399
+
400
+ if nerf:
401
+ debug_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/debug_fit_3dmm.mp4")
402
+ else:
403
+ debug_name = video_name.replace("/video/", "/coeff_fit_debug/").replace(".mp4", "_debug.mp4")
404
+ try:
405
+ os.makedirs(os.path.dirname(debug_name), exist_ok=True)
406
+ except: pass
407
+ writer = imageio.get_writer(debug_name, fps=25)
408
+ if id_mode == 'global':
409
+ id_para = id_para.repeat([exp_para.shape[0], 1])
410
+ proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
411
+ lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
412
+ lm68s = lm68s * img_scale_factor
413
+ lms = lms * img_scale_factor
414
+ lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
415
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
416
+ lm68s = lm68s.astype(int)
417
+ for i in tqdm.trange(min(250, len(frames)), desc=f'rendering debug video to {debug_name}..'):
418
+ xy_cam3d_img = xy_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
419
+ xy_cam3d_img = cv2.resize(xy_cam3d_img, (512,512))
420
+ xz_cam3d_img = xz_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
421
+ xz_cam3d_img = cv2.resize(xz_cam3d_img, (512,512))
422
+
423
+ img = copy.deepcopy(frames[i])
424
+ img2 = copy.deepcopy(frames[i])
425
+
426
+ img = draw_axes(img, euler_angle[i,0].item(), euler_angle[i,1].item(), euler_angle[i,2].item(), lm68s[i][4][0].item(), lm68s[i, 4][1].item(), size=50)
427
+
428
+ gt_lm_color = (255, 0, 0)
429
+
430
+ for lm in lm68s[i]:
431
+ img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1) # blue
432
+ for gt_lm in lms[i]:
433
+ img2 = cv2.circle(img2, gt_lm.cpu().numpy().astype(int), 2, gt_lm_color, thickness=1)
434
+
435
+ cano_lm3d_img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
436
+ for j in range(len(cano_lm3d[i])):
437
+ x, y, _ = cano_lm3d[i, j]
438
+ color = (255,0,0)
439
+ cano_lm3d_img = cv2.circle(cano_lm3d_img, center=(x,y), radius=3, color=color, thickness=-1)
440
+ cano_lm3d_img = cv2.flip(cano_lm3d_img, 0)
441
+
442
+ _, secc_img = secc_renderer(id_para[0:1], exp_para[i:i+1], euler_angle[i:i+1]*0, trans[i:i+1]*0)
443
+ secc_img = (secc_img +1)*127.5
444
+ secc_img = F.interpolate(secc_img, size=(img_h, img_w))
445
+ secc_img = secc_img.permute(0, 2,3,1).int().cpu().numpy()[0]
446
+ out_img1 = np.concatenate([img, img2, secc_img], axis=1).astype(np.uint8)
447
+ font = cv2.FONT_HERSHEY_SIMPLEX
448
+ out_img2 = np.concatenate([xy_cam3d_img, xz_cam3d_img, cano_lm3d_img], axis=1).astype(np.uint8)
449
+ out_img = np.concatenate([out_img1, out_img2], axis=0)
450
+ writer.append_data(out_img)
451
+ writer.close()
452
+
453
+ # if bad_yaw:
454
+ # print(f"Skip {video_name} due to TOO LARGE YAW")
455
+ # return False
456
+
457
+ if save:
458
+ np.save(out_name, coeff_dict, allow_pickle=True)
459
+ return coeff_dict
460
+
461
+ def out_exist_job(vid_name):
462
+ out_name = vid_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy")
463
+ lms_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
464
+ if os.path.exists(out_name) or not os.path.exists(lms_name):
465
+ return None
466
+ else:
467
+ return vid_name
468
+
469
+ def get_todo_vid_names(vid_names):
470
+ if len(vid_names) == 1: # single video, nerf
471
+ return vid_names
472
+ todo_vid_names = []
473
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
474
+ if res is not None:
475
+ todo_vid_names.append(res)
476
+ return todo_vid_names
477
+
478
+
479
+ if __name__ == '__main__':
480
+ import argparse, glob, tqdm
481
+ parser = argparse.ArgumentParser()
482
+ # parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
483
+ parser.add_argument("--vid_dir", default='data/raw/videos/May_10s.mp4')
484
+ parser.add_argument("--ds_name", default='nerf') # 'nerf' | 'CelebV-HQ' | 'TH1KH_512' | etc
485
+ parser.add_argument("--seed", default=0, type=int)
486
+ parser.add_argument("--process_id", default=0, type=int)
487
+ parser.add_argument("--total_process", default=1, type=int)
488
+ parser.add_argument("--id_mode", default='global', type=str) # global | finegrained
489
+ parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
490
+ parser.add_argument("--large_yaw_threshold", default=9999999.9, type=float) # could be 0.7
491
+ parser.add_argument("--debug", action='store_true')
492
+ parser.add_argument("--reset", action='store_true')
493
+ parser.add_argument("--load_names", action="store_true")
494
+
495
+ args = parser.parse_args()
496
+ vid_dir = args.vid_dir
497
+ ds_name = args.ds_name
498
+ load_names = args.load_names
499
+
500
+ print(f"args {args}")
501
+
502
+ if ds_name.lower() == 'nerf': # 处理单个视频
503
+ vid_names = [vid_dir]
504
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
505
+ else: # 处理整个数据集
506
+ if ds_name in ['lrs3_trainval']:
507
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
508
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
509
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
510
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
511
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
512
+ elif ds_name in ["RAVDESS", 'VFHQ']:
513
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
514
+ else:
515
+ raise NotImplementedError()
516
+
517
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
518
+ if os.path.exists(vid_names_path) and load_names:
519
+ print(f"loading vid names from {vid_names_path}")
520
+ vid_names = load_file(vid_names_path)
521
+ else:
522
+ vid_names = multiprocess_glob(vid_name_pattern)
523
+ vid_names = sorted(vid_names)
524
+ print(f"saving vid names to {vid_names_path}")
525
+ save_file(vid_names_path, vid_names)
526
+ out_names = [video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
527
+
528
+ print(vid_names[:10])
529
+ random.seed(args.seed)
530
+ random.shuffle(vid_names)
531
+
532
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
533
+ camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
534
+ face_model.to(torch.device("cuda:0"))
535
+ secc_renderer = SECC_Renderer(512)
536
+ secc_renderer.to("cuda:0")
537
+
538
+ process_id = args.process_id
539
+ total_process = args.total_process
540
+ if total_process > 1:
541
+ assert process_id <= total_process -1
542
+ num_samples_per_process = len(vid_names) // total_process
543
+ if process_id == total_process:
544
+ vid_names = vid_names[process_id * num_samples_per_process : ]
545
+ else:
546
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
547
+
548
+ if not args.reset:
549
+ vid_names = get_todo_vid_names(vid_names)
550
+
551
+ failed_img_names = []
552
+ for i in tqdm.trange(len(vid_names), desc=f"process {process_id}: fitting 3dmm ..."):
553
+ img_name = vid_names[i]
554
+ try:
555
+ is_person_specific_data = ds_name=='nerf'
556
+ success = fit_3dmm_for_a_video(img_name, is_person_specific_data, args.id_mode, args.debug, large_yaw_threshold=args.large_yaw_threshold)
557
+ if not success:
558
+ failed_img_names.append(img_name)
559
+ except Exception as e:
560
+ print(img_name, e)
561
+ failed_img_names.append(img_name)
562
+ print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {len(failed_img_names)} / {i + 1} = {len(failed_img_names) / (i + 1):.4f}")
563
+ sys.stdout.flush()
564
+ print(f"all failed image names: {failed_img_names}")
565
+ print(f"All finished!")
data_gen/utils/process_video/inpaint_torso_imgs.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import numpy as np
4
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
5
+ from scipy.ndimage import binary_erosion, binary_dilation
6
+
7
+ from tasks.eg3ds.loss_utils.segment_loss.mp_segmenter import MediapipeSegmenter
8
+ seg_model = MediapipeSegmenter()
9
+
10
+ def inpaint_torso_job(video_name, idx=None, total=None):
11
+ raw_img_dir = video_name.replace(".mp4", "").replace("/video/","/gt_imgs/")
12
+ img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
13
+
14
+ for image_path in tqdm.tqdm(img_names):
15
+ # read ori image
16
+ ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
17
+ segmap = seg_model._cal_seg_map(cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB))
18
+ head_part = (segmap[1] + segmap[3] + segmap[5]).astype(np.bool)
19
+ torso_part = (segmap[4]).astype(np.bool)
20
+ neck_part = (segmap[2]).astype(np.bool)
21
+ bg_part = segmap[0].astype(np.bool)
22
+ head_image = cv2.imread(image_path.replace("/gt_imgs/", "/head_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
23
+ torso_image = cv2.imread(image_path.replace("/gt_imgs/", "/torso_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
24
+ bg_image = cv2.imread(image_path.replace("/gt_imgs/", "/bg_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
25
+
26
+ # head_part = (head_image[...,0] != 0) & (head_image[...,1] != 0) & (head_image[...,2] != 0)
27
+ # torso_part = (torso_image[...,0] != 0) & (torso_image[...,1] != 0) & (torso_image[...,2] != 0)
28
+ # bg_part = (bg_image[...,0] != 0) & (bg_image[...,1] != 0) & (bg_image[...,2] != 0)
29
+
30
+ # get gt image
31
+ gt_image = ori_image.copy()
32
+ gt_image[bg_part] = bg_image[bg_part]
33
+ cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
34
+
35
+ # get torso image
36
+ torso_image = gt_image.copy() # rgb
37
+ torso_image[head_part] = 0
38
+ torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
39
+
40
+ # torso part "vertical" in-painting...
41
+ L = 8 + 1
42
+ torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
43
+ # lexsort: sort 2D coords first by y then by x,
44
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
45
+ inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
46
+ torso_coords = torso_coords[inds]
47
+ # choose the top pixel for each column
48
+ u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
49
+ top_torso_coords = torso_coords[uid] # [m, 2]
50
+ # only keep top-is-head pixels
51
+ top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
52
+ mask = head_part[tuple(top_torso_coords_up.T)]
53
+ if mask.any():
54
+ top_torso_coords = top_torso_coords[mask]
55
+ # get the color
56
+ top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
57
+ # construct inpaint coords (vertically up, or minus in x)
58
+ inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
59
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
60
+ inpaint_torso_coords += inpaint_offsets
61
+ inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
62
+ inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
63
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
64
+ inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
65
+ # set color
66
+ torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
67
+
68
+ inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
69
+ inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
70
+ else:
71
+ inpaint_torso_mask = None
72
+
73
+ # neck part "vertical" in-painting...
74
+ push_down = 4
75
+ L = 48 + push_down + 1
76
+
77
+ neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
78
+
79
+ neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
80
+ # lexsort: sort 2D coords first by y then by x,
81
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
82
+ inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
83
+ neck_coords = neck_coords[inds]
84
+ # choose the top pixel for each column
85
+ u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
86
+ top_neck_coords = neck_coords[uid] # [m, 2]
87
+ # only keep top-is-head pixels
88
+ top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
89
+ mask = head_part[tuple(top_neck_coords_up.T)]
90
+
91
+ top_neck_coords = top_neck_coords[mask]
92
+ # push these top down for 4 pixels to make the neck inpainting more natural...
93
+ offset_down = np.minimum(ucnt[mask] - 1, push_down)
94
+ top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
95
+ # get the color
96
+ top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
97
+ # construct inpaint coords (vertically up, or minus in x)
98
+ inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
99
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
100
+ inpaint_neck_coords += inpaint_offsets
101
+ inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
102
+ inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
103
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
104
+ inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
105
+ # set color
106
+ torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
107
+
108
+ # apply blurring to the inpaint area to avoid vertical-line artifects...
109
+ inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
110
+ inpaint_mask[tuple(inpaint_neck_coords.T)] = True
111
+
112
+ blur_img = torso_image.copy()
113
+ blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
114
+
115
+ torso_image[inpaint_mask] = blur_img[inpaint_mask]
116
+
117
+ # set mask
118
+ mask = (neck_part | torso_part | inpaint_mask)
119
+ if inpaint_torso_mask is not None:
120
+ mask = mask | inpaint_torso_mask
121
+ torso_image[~mask] = 0
122
+ torso_alpha[~mask] = 0
123
+
124
+ cv2.imwrite("0.png", np.concatenate([torso_image, torso_alpha], axis=-1))
125
+
126
+ print(f'[INFO] ===== extracted torso and gt images =====')
127
+
128
+
129
+ def out_exist_job(vid_name):
130
+ out_dir1 = vid_name.replace("/video/", "/inpaint_torso_imgs/").replace(".mp4","")
131
+ out_dir2 = vid_name.replace("/video/", "/inpaint_torso_with_bg_imgs/").replace(".mp4","")
132
+ out_dir3 = vid_name.replace("/video/", "/torso_imgs/").replace(".mp4","")
133
+ out_dir4 = vid_name.replace("/video/", "/torso_with_bg_imgs/").replace(".mp4","")
134
+
135
+ if os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) and os.path.exists(out_dir3) and os.path.exists(out_dir4):
136
+ num_frames = len(os.listdir(out_dir1))
137
+ if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames and len(os.listdir(out_dir3)) == num_frames and len(os.listdir(out_dir4)) == num_frames:
138
+ return None
139
+ else:
140
+ return vid_name
141
+ else:
142
+ return vid_name
143
+
144
+ def get_todo_vid_names(vid_names):
145
+ todo_vid_names = []
146
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
147
+ if res is not None:
148
+ todo_vid_names.append(res)
149
+ return todo_vid_names
150
+
151
+ if __name__ == '__main__':
152
+ import argparse, glob, tqdm, random
153
+ parser = argparse.ArgumentParser()
154
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
155
+ parser.add_argument("--ds_name", default='CelebV-HQ')
156
+ parser.add_argument("--num_workers", default=48, type=int)
157
+ parser.add_argument("--seed", default=0, type=int)
158
+ parser.add_argument("--process_id", default=0, type=int)
159
+ parser.add_argument("--total_process", default=1, type=int)
160
+ parser.add_argument("--reset", action='store_true')
161
+
162
+ inpaint_torso_job('/home/tiger/datasets/raw/CelebV-HQ/video/dgdEr-mXQT4_8.mp4')
163
+ # args = parser.parse_args()
164
+ # vid_dir = args.vid_dir
165
+ # ds_name = args.ds_name
166
+ # if ds_name in ['lrs3_trainval']:
167
+ # mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
168
+ # if ds_name in ['TH1KH_512', 'CelebV-HQ']:
169
+ # vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
170
+ # elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
171
+ # vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
172
+ # vid_names = glob.glob(vid_name_pattern)
173
+ # vid_names = sorted(vid_names)
174
+ # random.seed(args.seed)
175
+ # random.shuffle(vid_names)
176
+
177
+ # process_id = args.process_id
178
+ # total_process = args.total_process
179
+ # if total_process > 1:
180
+ # assert process_id <= total_process -1
181
+ # num_samples_per_process = len(vid_names) // total_process
182
+ # if process_id == total_process:
183
+ # vid_names = vid_names[process_id * num_samples_per_process : ]
184
+ # else:
185
+ # vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
186
+
187
+ # if not args.reset:
188
+ # vid_names = get_todo_vid_names(vid_names)
189
+ # print(f"todo videos number: {len(vid_names)}")
190
+
191
+ # fn_args = [(vid_name,i,len(vid_names)) for i, vid_name in enumerate(vid_names)]
192
+ # for vid_name in multiprocess_run_tqdm(inpaint_torso_job ,fn_args, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
193
+ # pass
data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ import cv2
3
+ from utils.commons.os_utils import multiprocess_glob
4
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
5
+
6
+ def get_video_infos(video_path):
7
+ vid_cap = cv2.VideoCapture(video_path)
8
+ height = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
9
+ width = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
10
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
11
+ total_frames = int(vid_cap.get(cv2.CAP_PROP_FRAME_COUNT))
12
+ return {'height': height, 'width': width, 'fps': fps, 'total_frames':total_frames}
13
+
14
+ def extract_img_job(video_name:str):
15
+ out_path = video_name.replace("/video_raw/","/video/",1)
16
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
17
+ ffmpeg_path = "/usr/bin/ffmpeg"
18
+ vid_info = get_video_infos(video_name)
19
+ assert vid_info['width'] == vid_info['height']
20
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
21
+ os.system(cmd)
22
+
23
+ def extract_img_job_crop(video_name:str):
24
+ out_path = video_name.replace("/video_raw/","/video/",1)
25
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
26
+ ffmpeg_path = "/usr/bin/ffmpeg"
27
+ vid_info = get_video_infos(video_name)
28
+ wh = min(vid_info['width'], vid_info['height'])
29
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop={wh}:{wh},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
30
+ os.system(cmd)
31
+
32
+ def extract_img_job_crop_ravdess(video_name:str):
33
+ out_path = video_name.replace("/video_raw/","/video/",1)
34
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
35
+ ffmpeg_path = "/usr/bin/ffmpeg"
36
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop=720:720,scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
37
+ os.system(cmd)
38
+
39
+ if __name__ == '__main__':
40
+ import argparse, glob, tqdm, random
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video_raw/')
43
+ parser.add_argument("--ds_name", default='CelebV-HQ')
44
+ parser.add_argument("--num_workers", default=32, type=int)
45
+ parser.add_argument("--process_id", default=0, type=int)
46
+ parser.add_argument("--total_process", default=1, type=int)
47
+ args = parser.parse_args()
48
+ print(f"args {args}")
49
+
50
+ vid_dir = args.vid_dir
51
+ ds_name = args.ds_name
52
+ if ds_name in ['lrs3_trainval']:
53
+ mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
54
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
55
+ vid_names = multiprocess_glob(os.path.join(vid_dir, "*.mp4"))
56
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
57
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
58
+ vid_names = multiprocess_glob(vid_name_pattern)
59
+ elif ds_name in ["RAVDESS", 'VFHQ']:
60
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
61
+ vid_names = multiprocess_glob(vid_name_pattern)
62
+ else:
63
+ raise NotImplementedError()
64
+ vid_names = sorted(vid_names)
65
+ print(f"total video number : {len(vid_names)}")
66
+ print(f"first {vid_names[0]} last {vid_names[-1]}")
67
+ # exit()
68
+ process_id = args.process_id
69
+ total_process = args.total_process
70
+ if total_process > 1:
71
+ assert process_id <= total_process -1
72
+ num_samples_per_process = len(vid_names) // total_process
73
+ if process_id == total_process:
74
+ vid_names = vid_names[process_id * num_samples_per_process : ]
75
+ else:
76
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
77
+
78
+ if ds_name == "RAVDESS":
79
+ for i, res in multiprocess_run_tqdm(extract_img_job_crop_ravdess, vid_names, num_workers=args.num_workers, desc="resampling videos"):
80
+ pass
81
+ elif ds_name == "CMLR":
82
+ for i, res in multiprocess_run_tqdm(extract_img_job_crop, vid_names, num_workers=args.num_workers, desc="resampling videos"):
83
+ pass
84
+ else:
85
+ for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="resampling videos"):
86
+ pass
87
+
data_gen/utils/process_video/split_video_to_imgs.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
3
+
4
+ from data_gen.utils.path_converter import PathConverter, pc
5
+
6
+ # mp4_names = glob.glob("/home/tiger/datasets/raw/CelebV-HQ/video/*.mp4")
7
+
8
+ def extract_img_job(video_name, raw_img_dir=None):
9
+ if raw_img_dir is not None:
10
+ out_path = raw_img_dir
11
+ else:
12
+ out_path = pc.to(video_name.replace(".mp4", ""), "vid", "gt")
13
+ os.makedirs(out_path, exist_ok=True)
14
+ ffmpeg_path = "/usr/bin/ffmpeg"
15
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet {os.path.join(out_path, "%8d.jpg")}'
16
+ os.system(cmd)
17
+
18
+ if __name__ == '__main__':
19
+ import argparse, glob, tqdm, random
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
22
+ parser.add_argument("--ds_name", default='CelebV-HQ')
23
+ parser.add_argument("--num_workers", default=64, type=int)
24
+ parser.add_argument("--process_id", default=0, type=int)
25
+ parser.add_argument("--total_process", default=1, type=int)
26
+ args = parser.parse_args()
27
+ vid_dir = args.vid_dir
28
+ ds_name = args.ds_name
29
+ if ds_name in ['lrs3_trainval']:
30
+ mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
31
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
32
+ vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
33
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
34
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
35
+ vid_names = glob.glob(vid_name_pattern)
36
+ elif ds_name in ["RAVDESS", 'VFHQ']:
37
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
38
+ vid_names = glob.glob(vid_name_pattern)
39
+ vid_names = sorted(vid_names)
40
+
41
+ process_id = args.process_id
42
+ total_process = args.total_process
43
+ if total_process > 1:
44
+ assert process_id <= total_process -1
45
+ num_samples_per_process = len(vid_names) // total_process
46
+ if process_id == total_process:
47
+ vid_names = vid_names[process_id * num_samples_per_process : ]
48
+ else:
49
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
50
+
51
+ for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="extracting images"):
52
+ pass
53
+
data_util/face3d_helper.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from scipy.io import loadmat
6
+
7
+ from deep_3drecon.deep_3drecon_models.bfm import perspective_projection
8
+
9
+
10
+ class Face3DHelper(nn.Module):
11
+ def __init__(self, bfm_dir='deep_3drecon/BFM', keypoint_mode='lm68', use_gpu=True):
12
+ super().__init__()
13
+ self.keypoint_mode = keypoint_mode # lm68 | mediapipe
14
+ self.bfm_dir = bfm_dir
15
+ self.load_3dmm()
16
+ if use_gpu: self.to("cuda")
17
+
18
+ def load_3dmm(self):
19
+ model = loadmat(os.path.join(self.bfm_dir, "BFM_model_front.mat"))
20
+ self.register_buffer('mean_shape',torch.from_numpy(model['meanshape'].transpose()).float()) # mean face shape. [3*N, 1], N=35709, xyz=3, ==> 3*N=107127
21
+ mean_shape = self.mean_shape.reshape([-1, 3])
22
+ # re-center
23
+ mean_shape = mean_shape - torch.mean(mean_shape, dim=0, keepdims=True)
24
+ self.mean_shape = mean_shape.reshape([-1, 1])
25
+ self.register_buffer('id_base',torch.from_numpy(model['idBase']).float()) # identity basis. [3*N,80], we have 80 eigen faces for identity
26
+ self.register_buffer('exp_base',torch.from_numpy(model['exBase']).float()) # expression basis. [3*N,64], we have 64 eigen faces for expression
27
+
28
+ self.register_buffer('mean_texure',torch.from_numpy(model['meantex'].transpose()).float()) # mean face texture. [3*N,1] (0-255)
29
+ self.register_buffer('tex_base',torch.from_numpy(model['texBase']).float()) # texture basis. [3*N,80], rgb=3
30
+
31
+ self.register_buffer('point_buf',torch.from_numpy(model['point_buf']).float()) # triangle indices for each vertex that lies in. starts from 1. [N,8] (1-F)
32
+ self.register_buffer('face_buf',torch.from_numpy(model['tri']).float()) # vertex indices in each triangle. starts from 1. [F,3] (1-N)
33
+ if self.keypoint_mode == 'mediapipe':
34
+ self.register_buffer('key_points', torch.from_numpy(np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)))
35
+ unmatch_mask = self.key_points < 0
36
+ self.key_points[unmatch_mask] = 0
37
+ else:
38
+ self.register_buffer('key_points',torch.from_numpy(model['keypoints'].squeeze().astype(np.int_)).long()) # vertex indices of 68 facial landmarks. starts from 1. [68,1]
39
+
40
+
41
+ self.register_buffer('key_mean_shape',self.mean_shape.reshape([-1,3])[self.key_points,:])
42
+ self.register_buffer('key_id_base', self.id_base.reshape([-1,3,80])[self.key_points, :, :].reshape([-1,80]))
43
+ self.register_buffer('key_exp_base', self.exp_base.reshape([-1,3,64])[self.key_points, :, :].reshape([-1,64]))
44
+ self.key_id_base_np = self.key_id_base.cpu().numpy()
45
+ self.key_exp_base_np = self.key_exp_base.cpu().numpy()
46
+
47
+ self.register_buffer('persc_proj', torch.tensor(perspective_projection(focal=1015, center=112)))
48
+ def split_coeff(self, coeff):
49
+ """
50
+ coeff: Tensor[B, T, c=257] or [T, c=257]
51
+ """
52
+ ret_dict = {
53
+ 'identity': coeff[..., :80], # identity, [b, t, c=80]
54
+ 'expression': coeff[..., 80:144], # expression, [b, t, c=80]
55
+ 'texture': coeff[..., 144:224], # texture, [b, t, c=80]
56
+ 'euler': coeff[..., 224:227], # euler euler for pose, [b, t, c=3]
57
+ 'translation': coeff[..., 254:257], # translation, [b, t, c=3]
58
+ 'gamma': coeff[..., 227:254] # lighting, [b, t, c=27]
59
+ }
60
+ return ret_dict
61
+
62
+ def reconstruct_face_mesh(self, id_coeff, exp_coeff):
63
+ """
64
+ Generate a pose-independent 3D face mesh!
65
+ id_coeff: Tensor[T, c=80]
66
+ exp_coeff: Tensor[T, c=64]
67
+ """
68
+ id_coeff = id_coeff.to(self.key_id_base.device)
69
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
70
+ mean_face = self.mean_shape.squeeze().reshape([1, -1]) # [3N, 1] ==> [1, 3N]
71
+ id_base, exp_base = self.id_base, self.exp_base # [3*N, C]
72
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
73
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
74
+
75
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
76
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
77
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
78
+ # mean_xyz = self.mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
79
+ # face_mesh = face - mean_xyz.unsqueeze(0) # [t,N,3]
80
+ return face
81
+
82
+ def reconstruct_cano_lm3d(self, id_coeff, exp_coeff):
83
+ """
84
+ Generate 3D landmark with keypoint base!
85
+ id_coeff: Tensor[T, c=80]
86
+ exp_coeff: Tensor[T, c=64]
87
+ """
88
+ id_coeff = id_coeff.to(self.key_id_base.device)
89
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
90
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
91
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
92
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
93
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
94
+
95
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
96
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
97
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
98
+ # mean_xyz = self.key_mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
99
+ # lm3d = face - mean_xyz.unsqueeze(0) # [t,N,3]
100
+ return face
101
+
102
+ def reconstruct_lm3d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
103
+ """
104
+ Generate 3D landmark with keypoint base!
105
+ id_coeff: Tensor[T, c=80]
106
+ exp_coeff: Tensor[T, c=64]
107
+ """
108
+ id_coeff = id_coeff.to(self.key_id_base.device)
109
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
110
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
111
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
112
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
113
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
114
+
115
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
116
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
117
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
118
+ rot = self.compute_rotation(euler)
119
+ # transform
120
+ lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
121
+ # to camera
122
+ if to_camera:
123
+ lm3d[...,-1] = 10 - lm3d[...,-1]
124
+ return lm3d
125
+
126
+ def reconstruct_lm2d_nerf(self, id_coeff, exp_coeff, euler, trans):
127
+ lm2d = self.reconstruct_lm2d(id_coeff, exp_coeff, euler, trans, to_camera=False)
128
+ lm2d[..., 0] = 1 - lm2d[..., 0]
129
+ lm2d[..., 1] = 1 - lm2d[..., 1]
130
+ return lm2d
131
+
132
+ def reconstruct_lm2d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
133
+ """
134
+ Generate 3D landmark with keypoint base!
135
+ id_coeff: Tensor[T, c=80]
136
+ exp_coeff: Tensor[T, c=64]
137
+ """
138
+ is_btc_flag = True if id_coeff.ndim == 3 else False
139
+ if is_btc_flag:
140
+ b,t,_ = id_coeff.shape
141
+ id_coeff = id_coeff.reshape([b*t,-1])
142
+ exp_coeff = exp_coeff.reshape([b*t,-1])
143
+ euler = euler.reshape([b*t,-1])
144
+ trans = trans.reshape([b*t,-1])
145
+ id_coeff = id_coeff.to(self.key_id_base.device)
146
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
147
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
148
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
149
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
150
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
151
+
152
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
153
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
154
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
155
+ rot = self.compute_rotation(euler)
156
+ # transform
157
+ lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
158
+ # to camera
159
+ if to_camera:
160
+ lm3d[...,-1] = 10 - lm3d[...,-1]
161
+ # to image_plane
162
+ lm3d = lm3d @ self.persc_proj
163
+ lm2d = lm3d[..., :2] / lm3d[..., 2:]
164
+ # flip
165
+ lm2d[..., 1] = 224 - lm2d[..., 1]
166
+ lm2d /= 224
167
+ if is_btc_flag:
168
+ return lm2d.reshape([b,t,-1,2])
169
+ return lm2d
170
+
171
+ def compute_rotation(self, euler):
172
+ """
173
+ Return:
174
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
175
+
176
+ Parameters:
177
+ euler -- torch.tensor, size (B, 3), radian
178
+ """
179
+
180
+ batch_size = euler.shape[0]
181
+ euler = euler.to(self.key_id_base.device)
182
+ ones = torch.ones([batch_size, 1]).to(self.key_id_base.device)
183
+ zeros = torch.zeros([batch_size, 1]).to(self.key_id_base.device)
184
+ x, y, z = euler[:, :1], euler[:, 1:2], euler[:, 2:],
185
+
186
+ rot_x = torch.cat([
187
+ ones, zeros, zeros,
188
+ zeros, torch.cos(x), -torch.sin(x),
189
+ zeros, torch.sin(x), torch.cos(x)
190
+ ], dim=1).reshape([batch_size, 3, 3])
191
+
192
+ rot_y = torch.cat([
193
+ torch.cos(y), zeros, torch.sin(y),
194
+ zeros, ones, zeros,
195
+ -torch.sin(y), zeros, torch.cos(y)
196
+ ], dim=1).reshape([batch_size, 3, 3])
197
+
198
+ rot_z = torch.cat([
199
+ torch.cos(z), -torch.sin(z), zeros,
200
+ torch.sin(z), torch.cos(z), zeros,
201
+ zeros, zeros, ones
202
+ ], dim=1).reshape([batch_size, 3, 3])
203
+
204
+ rot = rot_z @ rot_y @ rot_x
205
+ return rot.permute(0, 2, 1)
206
+
207
+ def reconstruct_idexp_lm3d(self, id_coeff, exp_coeff):
208
+ """
209
+ Generate 3D landmark with keypoint base!
210
+ id_coeff: Tensor[T, c=80]
211
+ exp_coeff: Tensor[T, c=64]
212
+ """
213
+ id_coeff = id_coeff.to(self.key_id_base.device)
214
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
215
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
216
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
217
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
218
+
219
+ face = identity_diff_face + expression_diff_face # [t,3N]
220
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
221
+ lm3d = face * 10
222
+ return lm3d
223
+
224
+ def reconstruct_idexp_lm3d_np(self, id_coeff, exp_coeff):
225
+ """
226
+ Generate 3D landmark with keypoint base!
227
+ id_coeff: Tensor[T, c=80]
228
+ exp_coeff: Tensor[T, c=64]
229
+ """
230
+ id_base, exp_base = self.key_id_base_np, self.key_exp_base_np # [3*68, C]
231
+ identity_diff_face = np.dot(id_coeff, id_base.T) # [t,c],[c,3*68] ==> [t,3*68]
232
+ expression_diff_face = np.dot(exp_coeff, exp_base.T) # [t,c],[c,3*68] ==> [t,3*68]
233
+
234
+ face = identity_diff_face + expression_diff_face # [t,3N]
235
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
236
+ lm3d = face * 10
237
+ return lm3d
238
+
239
+ def get_eye_mouth_lm_from_lm3d(self, lm3d):
240
+ eye_lm = lm3d[:, 17:48] # [T, 31, 3]
241
+ mouth_lm = lm3d[:, 48:68] # [T, 20, 3]
242
+ return eye_lm, mouth_lm
243
+
244
+ def get_eye_mouth_lm_from_lm3d_batch(self, lm3d):
245
+ eye_lm = lm3d[:, :, 17:48] # [T, 31, 3]
246
+ mouth_lm = lm3d[:, :, 48:68] # [T, 20, 3]
247
+ return eye_lm, mouth_lm
248
+
249
+ def close_mouth_for_idexp_lm3d(self, idexp_lm3d, freeze_as_first_frame=True):
250
+ idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
251
+ num_frames = idexp_lm3d.shape[0]
252
+ eps = 0.0
253
+ # [n_landmarks=68,xyz=3], x 代表左右,y代表上下,z代表深度
254
+ idexp_lm3d[:,49:54, 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 + eps * 2
255
+ idexp_lm3d[:,range(59,54,-1), 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 - eps * 2
256
+
257
+ idexp_lm3d[:,61:64, 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 + eps
258
+ idexp_lm3d[:,range(67,64,-1), 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 - eps
259
+
260
+ idexp_lm3d[:,49:54, 1] += (0.03 - idexp_lm3d[:,49:54, 1].mean(dim=1) + idexp_lm3d[:,61:64, 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
261
+ idexp_lm3d[:,range(59,54,-1), 1] += (-0.03 - idexp_lm3d[:,range(59,54,-1), 1].mean(dim=1) + idexp_lm3d[:,range(67,64,-1), 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
262
+
263
+ if freeze_as_first_frame:
264
+ idexp_lm3d[:, 48:68,] = idexp_lm3d[0, 48:68].unsqueeze(0).clone().repeat([num_frames, 1,1])*0
265
+ return idexp_lm3d.cpu()
266
+
267
+ def close_eyes_for_idexp_lm3d(self, idexp_lm3d):
268
+ idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
269
+ eps = 0.003
270
+ idexp_lm3d[:,37:39, 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 + eps
271
+ idexp_lm3d[:,range(41,39,-1), 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 - eps
272
+
273
+ idexp_lm3d[:,43:45, 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 + eps
274
+ idexp_lm3d[:,range(47,45,-1), 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 - eps
275
+
276
+ return idexp_lm3d
277
+
278
+ if __name__ == '__main__':
279
+ import cv2
280
+
281
+ font = cv2.FONT_HERSHEY_SIMPLEX
282
+
283
+ face_mesh_helper = Face3DHelper('deep_3drecon/BFM')
284
+ coeff_npy = 'data/coeff_fit_mp/crop_nana_003_coeff_fit_mp.npy'
285
+ coeff_dict = np.load(coeff_npy, allow_pickle=True).tolist()
286
+ lm3d = face_mesh_helper.reconstruct_lm2d(torch.tensor(coeff_dict['id']).cuda(), torch.tensor(coeff_dict['exp']).cuda(), torch.tensor(coeff_dict['euler']).cuda(), torch.tensor(coeff_dict['trans']).cuda() )
287
+
288
+ WH = 512
289
+ lm3d = (lm3d * WH).cpu().int().numpy()
290
+ eye_idx = list(range(36,48))
291
+ mouth_idx = list(range(48,68))
292
+ import imageio
293
+ debug_name = 'debug_lm3d.mp4'
294
+ writer = imageio.get_writer(debug_name, fps=25)
295
+ for i_img in range(len(lm3d)):
296
+ lm2d = lm3d[i_img ,:, :2] # [68, 2]
297
+ img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
298
+ for i in range(len(lm2d)):
299
+ x, y = lm2d[i]
300
+ if i in eye_idx:
301
+ color = (0,0,255)
302
+ elif i in mouth_idx:
303
+ color = (0,255,0)
304
+ else:
305
+ color = (255,0,0)
306
+ img = cv2.circle(img, center=(x,y), radius=3, color=color, thickness=-1)
307
+ img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0))
308
+ writer.append_data(img)
309
+ writer.close()
deep_3drecon/BFM/.gitkeep ADDED
File without changes
deep_3drecon/BFM/basel_53201.txt ADDED
The diff for this file is too large to render. See raw diff
 
deep_3drecon/BFM/index_mp468_from_mesh35709_v1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d238a90df0c55075c9cea43dab76348421379a75c204931e34dbd2c11fb4b65
3
+ size 3872
deep_3drecon/BFM/index_mp468_from_mesh35709_v2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe95e2bb10ac1e54804006184d7de3c5ccd0eb98a5f1bd28e00b9f3569f6ce5a
3
+ size 3872
deep_3drecon/BFM/index_mp468_from_mesh35709_v3.1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:053b8cce8424b722db6ec5b068514eb007a23b4c5afd629449eb08746e643211
3
+ size 3872
deep_3drecon/BFM/index_mp468_from_mesh35709_v3.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b007b3619dd02892b38349ba3d4b10e32bc2eff201c265f25d6ed62f67dbd51
3
+ size 3872
deep_3drecon/BFM/select_vertex_id.mat ADDED
Binary file (62.3 kB). View file
 
deep_3drecon/BFM/similarity_Lm3D_all.mat ADDED
Binary file (994 Bytes). View file
 
deep_3drecon/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .reconstructor import *
deep_3drecon/bfm_left_eye_faces.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9651756ea2c0fac069a1edf858ed1f125eddc358fa74c529a370c1e7b5730d28
3
+ size 4680
deep_3drecon/bfm_right_eye_faces.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28cb5bbacf578d30a3d5006ec28c617fe5a3ecaeeeb87d9433a884e0f0301a2e
3
+ size 4648
deep_3drecon/data_preparation.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script is the data preparation script for Deep3DFaceRecon_pytorch
2
+ """
3
+
4
+ import os
5
+ import numpy as np
6
+ import argparse
7
+ from util.detect_lm68 import detect_68p,load_lm_graph
8
+ from util.skin_mask import get_skin_mask
9
+ from util.generate_list import check_list, write_list
10
+ import warnings
11
+ warnings.filterwarnings("ignore")
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data')
15
+ parser.add_argument('--img_folder', nargs="+", required=True, help='folders of training images')
16
+ parser.add_argument('--mode', type=str, default='train', help='train or val')
17
+ opt = parser.parse_args()
18
+
19
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
20
+
21
+ def data_prepare(folder_list,mode):
22
+
23
+ lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector
24
+
25
+ for img_folder in folder_list:
26
+ detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images
27
+ get_skin_mask(img_folder) # generate skin attention mask for images
28
+
29
+ # create files that record path to all training data
30
+ msks_list = []
31
+ for img_folder in folder_list:
32
+ path = os.path.join(img_folder, 'mask')
33
+ msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or
34
+ 'png' in i or 'jpeg' in i or 'PNG' in i]
35
+
36
+ imgs_list = [i.replace('mask/', '') for i in msks_list]
37
+ lms_list = [i.replace('mask', 'landmarks') for i in msks_list]
38
+ lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list]
39
+
40
+ lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid
41
+ write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files
42
+
43
+ if __name__ == '__main__':
44
+ print('Datasets:',opt.img_folder)
45
+ data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode)
deep_3drecon/deep_3drecon_models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from .base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "deep_3drecon_models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
deep_3drecon/deep_3drecon_models/arcface_torch/README.md ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distributed Arcface Training in Pytorch
2
+
3
+ The "arcface_torch" repository is the official implementation of the ArcFace algorithm. It supports distributed and sparse training with multiple distributed training examples, including several memory-saving techniques such as mixed precision training and gradient checkpointing. It also supports training for ViT models and datasets including WebFace42M and Glint360K, two of the largest open-source datasets. Additionally, the repository comes with a built-in tool for converting to ONNX format, making it easy to submit to MFR evaluation systems.
4
+
5
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-ijb-c)](https://paperswithcode.com/sota/face-verification-on-ijb-c?p=killing-two-birds-with-one-stone-efficient)
6
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-ijb-b)](https://paperswithcode.com/sota/face-verification-on-ijb-b?p=killing-two-birds-with-one-stone-efficient)
7
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-agedb-30)](https://paperswithcode.com/sota/face-verification-on-agedb-30?p=killing-two-birds-with-one-stone-efficient)
8
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-cfp-fp)](https://paperswithcode.com/sota/face-verification-on-cfp-fp?p=killing-two-birds-with-one-stone-efficient)
9
+
10
+ ## Requirements
11
+
12
+ To avail the latest features of PyTorch, we have upgraded to version 1.12.0.
13
+
14
+ - Install [PyTorch](https://pytorch.org/get-started/previous-versions/) (torch>=1.12.0).
15
+ - (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md).
16
+ - `pip install -r requirement.txt`.
17
+
18
+ ## How to Training
19
+
20
+ To train a model, execute the `train.py` script with the path to the configuration files. The sample commands provided below demonstrate the process of conducting distributed training.
21
+
22
+ ### 1. To run on one GPU:
23
+
24
+ ```shell
25
+ python train_v2.py configs/ms1mv3_r50_onegpu
26
+ ```
27
+
28
+ Note:
29
+ It is not recommended to use a single GPU for training, as this may result in longer training times and suboptimal performance. For best results, we suggest using multiple GPUs or a GPU cluster.
30
+
31
+
32
+ ### 2. To run on a machine with 8 GPUs:
33
+
34
+ ```shell
35
+ torchrun --nproc_per_node=8 train.py configs/ms1mv3_r50
36
+ ```
37
+
38
+ ### 3. To run on 2 machines with 8 GPUs each:
39
+
40
+ Node 0:
41
+
42
+ ```shell
43
+ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
44
+ ```
45
+
46
+ Node 1:
47
+
48
+ ```shell
49
+ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
50
+ ```
51
+
52
+ ### 4. Run ViT-B on a machine with 24k batchsize:
53
+
54
+ ```shell
55
+ torchrun --nproc_per_node=8 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b
56
+ ```
57
+
58
+
59
+ ## Download Datasets or Prepare Datasets
60
+ - [MS1MV2](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-arcface-85k-ids58m-images-57) (87k IDs, 5.8M images)
61
+ - [MS1MV3](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) (93k IDs, 5.2M images)
62
+ - [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images)
63
+ - [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images)
64
+ - [Your Dataset, Click Here!](docs/prepare_custom_dataset.md)
65
+
66
+ Note:
67
+ If you want to use DALI for data reading, please use the script 'scripts/shuffle_rec.py' to shuffle the InsightFace style rec before using it.
68
+ Example:
69
+
70
+ `python scripts/shuffle_rec.py ms1m-retinaface-t1`
71
+
72
+ You will get the "shuffled_ms1m-retinaface-t1" folder, where the samples in the "train.rec" file are shuffled.
73
+
74
+
75
+ ## Model Zoo
76
+
77
+ - The models are available for non-commercial research purposes only.
78
+ - All models can be found in here.
79
+ - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
80
+ - [OneDrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
81
+
82
+ ### Performance on IJB-C and [**ICCV2021-MFR**](https://github.com/deepinsight/insightface/blob/master/challenges/mfr/README.md)
83
+
84
+ ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
85
+ recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
86
+ As the result, we can evaluate the FAIR performance for different algorithms.
87
+
88
+ For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
89
+ globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
90
+
91
+
92
+ #### 1. Training on Single-Host GPU
93
+
94
+ | Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
95
+ |:---------------|:--------------------|:------------|:------------|:------------|:------------------------------------------------------------------------------------------------------------------------------------|
96
+ | MS1MV2 | mobilefacenet-0.45G | 62.07 | 93.61 | 90.28 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_mbf/training.log) |
97
+ | MS1MV2 | r50 | 75.13 | 95.97 | 94.07 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r50/training.log) |
98
+ | MS1MV2 | r100 | 78.12 | 96.37 | 94.27 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r100/training.log) |
99
+ | MS1MV3 | mobilefacenet-0.45G | 63.78 | 94.23 | 91.33 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mbf/training.log) |
100
+ | MS1MV3 | r50 | 79.14 | 96.37 | 94.47 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r50/training.log) |
101
+ | MS1MV3 | r100 | 81.97 | 96.85 | 95.02 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100/training.log) |
102
+ | Glint360K | mobilefacenet-0.45G | 70.18 | 95.04 | 92.62 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mbf/training.log) |
103
+ | Glint360K | r50 | 86.34 | 97.16 | 95.81 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r50/training.log) |
104
+ | Glint360k | r100 | 89.52 | 97.55 | 96.38 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100/training.log) |
105
+ | WF4M | r100 | 89.87 | 97.19 | 95.48 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf4m_r100/training.log) |
106
+ | WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc02_r100/training.log) |
107
+ | WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc03_r100/training.log) |
108
+ | WF12M | r100 | 94.69 | 97.59 | 95.97 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_r100/training.log) |
109
+ | WF42M-PFC-0.2 | r100 | 96.27 | 97.70 | 96.31 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_r100/training.log) |
110
+ | WF42M-PFC-0.2 | ViT-T-1.5G | 92.04 | 97.27 | 95.68 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_40epoch_8gpu_vit_t/training.log) |
111
+ | WF42M-PFC-0.3 | ViT-B-11G | 97.16 | 97.91 | 97.05 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_8gpu/training.log) |
112
+
113
+ #### 2. Training on Multi-Host GPU
114
+
115
+ | Datasets | Backbone(bs*gpus) | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
116
+ |:-----------------|:------------------|:------------|:------------|:------------|:-----------|:-------------------------------------------------------------------------------------------------------------------------------------------|
117
+ | WF42M-PFC-0.2 | r50(512*8) | 93.83 | 97.53 | 96.16 | ~5900 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log) |
118
+ | WF42M-PFC-0.2 | r50(512*16) | 93.96 | 97.46 | 96.12 | ~11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log) |
119
+ | WF42M-PFC-0.2 | r50(128*32) | 94.04 | 97.48 | 95.94 | ~17000 | click me |
120
+ | WF42M-PFC-0.2 | r100(128*16) | 96.28 | 97.80 | 96.57 | ~5200 | click me |
121
+ | WF42M-PFC-0.2 | r100(256*16) | 96.69 | 97.85 | 96.63 | ~5200 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log) |
122
+ | WF42M-PFC-0.0018 | r100(512*32) | 93.08 | 97.51 | 95.88 | ~10000 | click me |
123
+ | WF42M-PFC-0.2 | r100(128*32) | 96.57 | 97.83 | 96.50 | ~9800 | click me |
124
+
125
+ `r100(128*32)` means backbone is r100, batchsize per gpu is 128, the number of gpus is 32.
126
+
127
+
128
+
129
+ #### 3. ViT For Face Recognition
130
+
131
+ | Datasets | Backbone(bs) | FLOPs | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
132
+ |:--------------|:--------------|:------|:------------|:------------|:------------|:-----------|:-----------------------------------------------------------------------------------------------------------------------------|
133
+ | WF42M-PFC-0.3 | r18(128*32) | 2.6 | 79.13 | 95.77 | 93.36 | - | click me |
134
+ | WF42M-PFC-0.3 | r50(128*32) | 6.3 | 94.03 | 97.48 | 95.94 | - | click me |
135
+ | WF42M-PFC-0.3 | r100(128*32) | 12.1 | 96.69 | 97.82 | 96.45 | - | click me |
136
+ | WF42M-PFC-0.3 | r200(128*32) | 23.5 | 97.70 | 97.97 | 96.93 | - | click me |
137
+ | WF42M-PFC-0.3 | VIT-T(384*64) | 1.5 | 92.24 | 97.31 | 95.97 | ~35000 | click me |
138
+ | WF42M-PFC-0.3 | VIT-S(384*64) | 5.7 | 95.87 | 97.73 | 96.57 | ~25000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_s_64gpu/training.log) |
139
+ | WF42M-PFC-0.3 | VIT-B(384*64) | 11.4 | 97.42 | 97.90 | 97.04 | ~13800 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_64gpu/training.log) |
140
+ | WF42M-PFC-0.3 | VIT-L(384*64) | 25.3 | 97.85 | 98.00 | 97.23 | ~9406 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_l_64gpu/training.log) |
141
+
142
+ `WF42M` means WebFace42M, `PFC-0.3` means negivate class centers sample rate is 0.3.
143
+
144
+ #### 4. Noisy Datasets
145
+
146
+ | Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
147
+ |:-------------------------|:---------|:------------|:------------|:------------|:---------|
148
+ | WF12M-Flip(40%) | r50 | 43.87 | 88.35 | 80.78 | click me |
149
+ | WF12M-Flip(40%)-PFC-0.1* | r50 | 80.20 | 96.11 | 93.79 | click me |
150
+ | WF12M-Conflict | r50 | 79.93 | 95.30 | 91.56 | click me |
151
+ | WF12M-Conflict-PFC-0.3* | r50 | 91.68 | 97.28 | 95.75 | click me |
152
+
153
+ `WF12M` means WebFace12M, `+PFC-0.1*` denotes additional abnormal inter-class filtering.
154
+
155
+
156
+
157
+ ## Speed Benchmark
158
+ <div><img src="https://github.com/anxiangsir/insightface_arcface_log/blob/master/pfc_exp.png" width = "90%" /></div>
159
+
160
+
161
+ **Arcface-Torch** is an efficient tool for training large-scale face recognition training sets. When the number of classes in the training sets exceeds one million, the partial FC sampling strategy maintains the same accuracy while providing several times faster training performance and lower GPU memory utilization. The partial FC is a sparse variant of the model parallel architecture for large-scale face recognition, utilizing a sparse softmax that dynamically samples a subset of class centers for each training batch. During each iteration, only a sparse portion of the parameters are updated, leading to a significant reduction in GPU memory requirements and computational demands. With the partial FC approach, it is possible to train sets with up to 29 million identities, the largest to date. Furthermore, the partial FC method supports multi-machine distributed training and mixed precision training.
162
+
163
+
164
+
165
+ More details see
166
+ [speed_benchmark.md](docs/speed_benchmark.md) in docs.
167
+
168
+ > 1. Training Speed of Various Parallel Techniques (Samples per Second) on a Tesla V100 32GB x 8 System (Higher is Optimal)
169
+
170
+ `-` means training failed because of gpu memory limitations.
171
+
172
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
173
+ |:--------------------------------|:--------------|:---------------|:---------------|
174
+ | 125000 | 4681 | 4824 | 5004 |
175
+ | 1400000 | **1672** | 3043 | 4738 |
176
+ | 5500000 | **-** | **1389** | 3975 |
177
+ | 8000000 | **-** | **-** | 3565 |
178
+ | 16000000 | **-** | **-** | 2679 |
179
+ | 29000000 | **-** | **-** | **1855** |
180
+
181
+ > 2. GPU Memory Utilization of Various Parallel Techniques (MB per GPU) on a Tesla V100 32GB x 8 System (Lower is Optimal)
182
+
183
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
184
+ |:--------------------------------|:--------------|:---------------|:---------------|
185
+ | 125000 | 7358 | 5306 | 4868 |
186
+ | 1400000 | 32252 | 11178 | 6056 |
187
+ | 5500000 | **-** | 32188 | 9854 |
188
+ | 8000000 | **-** | **-** | 12310 |
189
+ | 16000000 | **-** | **-** | 19950 |
190
+ | 29000000 | **-** | **-** | 32324 |
191
+
192
+
193
+ ## Citations
194
+
195
+ ```
196
+ @inproceedings{deng2019arcface,
197
+ title={Arcface: Additive angular margin loss for deep face recognition},
198
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
199
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
200
+ pages={4690--4699},
201
+ year={2019}
202
+ }
203
+ @inproceedings{An_2022_CVPR,
204
+ author={An, Xiang and Deng, Jiankang and Guo, Jia and Feng, Ziyong and Zhu, XuHan and Yang, Jing and Liu, Tongliang},
205
+ title={Killing Two Birds With One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC},
206
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
207
+ month={June},
208
+ year={2022},
209
+ pages={4042-4051}
210
+ }
211
+ @inproceedings{zhu2021webface260m,
212
+ title={Webface260m: A benchmark unveiling the power of million-scale deep face recognition},
213
+ author={Zhu, Zheng and Huang, Guan and Deng, Jiankang and Ye, Yun and Huang, Junjie and Chen, Xinze and Zhu, Jiagang and Yang, Tian and Lu, Jiwen and Du, Dalong and Zhou, Jie},
214
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
215
+ pages={10492--10502},
216
+ year={2021}
217
+ }
218
+ ```
deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__init__.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2
+ from .mobilefacenet import get_mbf
3
+
4
+
5
+ def get_model(name, **kwargs):
6
+ # resnet
7
+ if name == "r18":
8
+ return iresnet18(False, **kwargs)
9
+ elif name == "r34":
10
+ return iresnet34(False, **kwargs)
11
+ elif name == "r50":
12
+ return iresnet50(False, **kwargs)
13
+ elif name == "r100":
14
+ return iresnet100(False, **kwargs)
15
+ elif name == "r200":
16
+ return iresnet200(False, **kwargs)
17
+ elif name == "r2060":
18
+ from .iresnet2060 import iresnet2060
19
+ return iresnet2060(False, **kwargs)
20
+
21
+ elif name == "mbf":
22
+ fp16 = kwargs.get("fp16", False)
23
+ num_features = kwargs.get("num_features", 512)
24
+ return get_mbf(fp16=fp16, num_features=num_features)
25
+
26
+ elif name == "mbf_large":
27
+ from .mobilefacenet import get_mbf_large
28
+ fp16 = kwargs.get("fp16", False)
29
+ num_features = kwargs.get("num_features", 512)
30
+ return get_mbf_large(fp16=fp16, num_features=num_features)
31
+
32
+ elif name == "vit_t":
33
+ num_features = kwargs.get("num_features", 512)
34
+ from .vit import VisionTransformer
35
+ return VisionTransformer(
36
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
37
+ num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
38
+
39
+ elif name == "vit_t_dp005_mask0": # For WebFace42M
40
+ num_features = kwargs.get("num_features", 512)
41
+ from .vit import VisionTransformer
42
+ return VisionTransformer(
43
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
44
+ num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
45
+
46
+ elif name == "vit_s":
47
+ num_features = kwargs.get("num_features", 512)
48
+ from .vit import VisionTransformer
49
+ return VisionTransformer(
50
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
51
+ num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
52
+
53
+ elif name == "vit_s_dp005_mask_0": # For WebFace42M
54
+ num_features = kwargs.get("num_features", 512)
55
+ from .vit import VisionTransformer
56
+ return VisionTransformer(
57
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
58
+ num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
59
+
60
+ elif name == "vit_b":
61
+ # this is a feature
62
+ num_features = kwargs.get("num_features", 512)
63
+ from .vit import VisionTransformer
64
+ return VisionTransformer(
65
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
66
+ num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True)
67
+
68
+ elif name == "vit_b_dp005_mask_005": # For WebFace42M
69
+ # this is a feature
70
+ num_features = kwargs.get("num_features", 512)
71
+ from .vit import VisionTransformer
72
+ return VisionTransformer(
73
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
74
+ num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
75
+
76
+ elif name == "vit_l_dp005_mask_005": # For WebFace42M
77
+ # this is a feature
78
+ num_features = kwargs.get("num_features", 512)
79
+ from .vit import VisionTransformer
80
+ return VisionTransformer(
81
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
82
+ num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
83
+
84
+ else:
85
+ raise ValueError()
deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
6
+ using_ckpt = False
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
9
+ """3x3 convolution with padding"""
10
+ return nn.Conv2d(in_planes,
11
+ out_planes,
12
+ kernel_size=3,
13
+ stride=stride,
14
+ padding=dilation,
15
+ groups=groups,
16
+ bias=False,
17
+ dilation=dilation)
18
+
19
+
20
+ def conv1x1(in_planes, out_planes, stride=1):
21
+ """1x1 convolution"""
22
+ return nn.Conv2d(in_planes,
23
+ out_planes,
24
+ kernel_size=1,
25
+ stride=stride,
26
+ bias=False)
27
+
28
+
29
+ class IBasicBlock(nn.Module):
30
+ expansion = 1
31
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
32
+ groups=1, base_width=64, dilation=1):
33
+ super(IBasicBlock, self).__init__()
34
+ if groups != 1 or base_width != 64:
35
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
36
+ if dilation > 1:
37
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
38
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
39
+ self.conv1 = conv3x3(inplanes, planes)
40
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
41
+ self.prelu = nn.PReLU(planes)
42
+ self.conv2 = conv3x3(planes, planes, stride)
43
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
44
+ self.downsample = downsample
45
+ self.stride = stride
46
+
47
+ def forward_impl(self, x):
48
+ identity = x
49
+ out = self.bn1(x)
50
+ out = self.conv1(out)
51
+ out = self.bn2(out)
52
+ out = self.prelu(out)
53
+ out = self.conv2(out)
54
+ out = self.bn3(out)
55
+ if self.downsample is not None:
56
+ identity = self.downsample(x)
57
+ out += identity
58
+ return out
59
+
60
+ def forward(self, x):
61
+ if self.training and using_ckpt:
62
+ return checkpoint(self.forward_impl, x)
63
+ else:
64
+ return self.forward_impl(x)
65
+
66
+
67
+ class IResNet(nn.Module):
68
+ fc_scale = 7 * 7
69
+ def __init__(self,
70
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
71
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
72
+ super(IResNet, self).__init__()
73
+ self.extra_gflops = 0.0
74
+ self.fp16 = fp16
75
+ self.inplanes = 64
76
+ self.dilation = 1
77
+ if replace_stride_with_dilation is None:
78
+ replace_stride_with_dilation = [False, False, False]
79
+ if len(replace_stride_with_dilation) != 3:
80
+ raise ValueError("replace_stride_with_dilation should be None "
81
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
82
+ self.groups = groups
83
+ self.base_width = width_per_group
84
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
85
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
86
+ self.prelu = nn.PReLU(self.inplanes)
87
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
88
+ self.layer2 = self._make_layer(block,
89
+ 128,
90
+ layers[1],
91
+ stride=2,
92
+ dilate=replace_stride_with_dilation[0])
93
+ self.layer3 = self._make_layer(block,
94
+ 256,
95
+ layers[2],
96
+ stride=2,
97
+ dilate=replace_stride_with_dilation[1])
98
+ self.layer4 = self._make_layer(block,
99
+ 512,
100
+ layers[3],
101
+ stride=2,
102
+ dilate=replace_stride_with_dilation[2])
103
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
104
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
105
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
106
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
107
+ nn.init.constant_(self.features.weight, 1.0)
108
+ self.features.weight.requires_grad = False
109
+
110
+ for m in self.modules():
111
+ if isinstance(m, nn.Conv2d):
112
+ nn.init.normal_(m.weight, 0, 0.1)
113
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
114
+ nn.init.constant_(m.weight, 1)
115
+ nn.init.constant_(m.bias, 0)
116
+
117
+ if zero_init_residual:
118
+ for m in self.modules():
119
+ if isinstance(m, IBasicBlock):
120
+ nn.init.constant_(m.bn2.weight, 0)
121
+
122
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
123
+ downsample = None
124
+ previous_dilation = self.dilation
125
+ if dilate:
126
+ self.dilation *= stride
127
+ stride = 1
128
+ if stride != 1 or self.inplanes != planes * block.expansion:
129
+ downsample = nn.Sequential(
130
+ conv1x1(self.inplanes, planes * block.expansion, stride),
131
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
132
+ )
133
+ layers = []
134
+ layers.append(
135
+ block(self.inplanes, planes, stride, downsample, self.groups,
136
+ self.base_width, previous_dilation))
137
+ self.inplanes = planes * block.expansion
138
+ for _ in range(1, blocks):
139
+ layers.append(
140
+ block(self.inplanes,
141
+ planes,
142
+ groups=self.groups,
143
+ base_width=self.base_width,
144
+ dilation=self.dilation))
145
+
146
+ return nn.Sequential(*layers)
147
+
148
+ def forward(self, x):
149
+ with torch.cuda.amp.autocast(self.fp16):
150
+ x = self.conv1(x)
151
+ x = self.bn1(x)
152
+ x = self.prelu(x)
153
+ x = self.layer1(x)
154
+ x = self.layer2(x)
155
+ x = self.layer3(x)
156
+ x = self.layer4(x)
157
+ x = self.bn2(x)
158
+ x = torch.flatten(x, 1)
159
+ x = self.dropout(x)
160
+ x = self.fc(x.float() if self.fp16 else x)
161
+ x = self.features(x)
162
+ return x
163
+
164
+
165
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
166
+ model = IResNet(block, layers, **kwargs)
167
+ if pretrained:
168
+ raise ValueError()
169
+ return model
170
+
171
+
172
+ def iresnet18(pretrained=False, progress=True, **kwargs):
173
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
174
+ progress, **kwargs)
175
+
176
+
177
+ def iresnet34(pretrained=False, progress=True, **kwargs):
178
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
179
+ progress, **kwargs)
180
+
181
+
182
+ def iresnet50(pretrained=False, progress=True, **kwargs):
183
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
184
+ progress, **kwargs)
185
+
186
+
187
+ def iresnet100(pretrained=False, progress=True, **kwargs):
188
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
189
+ progress, **kwargs)
190
+
191
+
192
+ def iresnet200(pretrained=False, progress=True, **kwargs):
193
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
194
+ progress, **kwargs)
deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet2060.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ assert torch.__version__ >= "1.8.1"
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ __all__ = ['iresnet2060']
8
+
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
+ """3x3 convolution with padding"""
12
+ return nn.Conv2d(in_planes,
13
+ out_planes,
14
+ kernel_size=3,
15
+ stride=stride,
16
+ padding=dilation,
17
+ groups=groups,
18
+ bias=False,
19
+ dilation=dilation)
20
+
21
+
22
+ def conv1x1(in_planes, out_planes, stride=1):
23
+ """1x1 convolution"""
24
+ return nn.Conv2d(in_planes,
25
+ out_planes,
26
+ kernel_size=1,
27
+ stride=stride,
28
+ bias=False)
29
+
30
+
31
+ class IBasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
35
+ groups=1, base_width=64, dilation=1):
36
+ super(IBasicBlock, self).__init__()
37
+ if groups != 1 or base_width != 64:
38
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39
+ if dilation > 1:
40
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42
+ self.conv1 = conv3x3(inplanes, planes)
43
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44
+ self.prelu = nn.PReLU(planes)
45
+ self.conv2 = conv3x3(planes, planes, stride)
46
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47
+ self.downsample = downsample
48
+ self.stride = stride
49
+
50
+ def forward(self, x):
51
+ identity = x
52
+ out = self.bn1(x)
53
+ out = self.conv1(out)
54
+ out = self.bn2(out)
55
+ out = self.prelu(out)
56
+ out = self.conv2(out)
57
+ out = self.bn3(out)
58
+ if self.downsample is not None:
59
+ identity = self.downsample(x)
60
+ out += identity
61
+ return out
62
+
63
+
64
+ class IResNet(nn.Module):
65
+ fc_scale = 7 * 7
66
+
67
+ def __init__(self,
68
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
69
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70
+ super(IResNet, self).__init__()
71
+ self.fp16 = fp16
72
+ self.inplanes = 64
73
+ self.dilation = 1
74
+ if replace_stride_with_dilation is None:
75
+ replace_stride_with_dilation = [False, False, False]
76
+ if len(replace_stride_with_dilation) != 3:
77
+ raise ValueError("replace_stride_with_dilation should be None "
78
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79
+ self.groups = groups
80
+ self.base_width = width_per_group
81
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83
+ self.prelu = nn.PReLU(self.inplanes)
84
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85
+ self.layer2 = self._make_layer(block,
86
+ 128,
87
+ layers[1],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[0])
90
+ self.layer3 = self._make_layer(block,
91
+ 256,
92
+ layers[2],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[1])
95
+ self.layer4 = self._make_layer(block,
96
+ 512,
97
+ layers[3],
98
+ stride=2,
99
+ dilate=replace_stride_with_dilation[2])
100
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
102
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104
+ nn.init.constant_(self.features.weight, 1.0)
105
+ self.features.weight.requires_grad = False
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.normal_(m.weight, 0, 0.1)
110
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111
+ nn.init.constant_(m.weight, 1)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ if zero_init_residual:
115
+ for m in self.modules():
116
+ if isinstance(m, IBasicBlock):
117
+ nn.init.constant_(m.bn2.weight, 0)
118
+
119
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120
+ downsample = None
121
+ previous_dilation = self.dilation
122
+ if dilate:
123
+ self.dilation *= stride
124
+ stride = 1
125
+ if stride != 1 or self.inplanes != planes * block.expansion:
126
+ downsample = nn.Sequential(
127
+ conv1x1(self.inplanes, planes * block.expansion, stride),
128
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129
+ )
130
+ layers = []
131
+ layers.append(
132
+ block(self.inplanes, planes, stride, downsample, self.groups,
133
+ self.base_width, previous_dilation))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(
137
+ block(self.inplanes,
138
+ planes,
139
+ groups=self.groups,
140
+ base_width=self.base_width,
141
+ dilation=self.dilation))
142
+
143
+ return nn.Sequential(*layers)
144
+
145
+ def checkpoint(self, func, num_seg, x):
146
+ if self.training:
147
+ return checkpoint_sequential(func, num_seg, x)
148
+ else:
149
+ return func(x)
150
+
151
+ def forward(self, x):
152
+ with torch.cuda.amp.autocast(self.fp16):
153
+ x = self.conv1(x)
154
+ x = self.bn1(x)
155
+ x = self.prelu(x)
156
+ x = self.layer1(x)
157
+ x = self.checkpoint(self.layer2, 20, x)
158
+ x = self.checkpoint(self.layer3, 100, x)
159
+ x = self.layer4(x)
160
+ x = self.bn2(x)
161
+ x = torch.flatten(x, 1)
162
+ x = self.dropout(x)
163
+ x = self.fc(x.float() if self.fp16 else x)
164
+ x = self.features(x)
165
+ return x
166
+
167
+
168
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169
+ model = IResNet(block, layers, **kwargs)
170
+ if pretrained:
171
+ raise ValueError()
172
+ return model
173
+
174
+
175
+ def iresnet2060(pretrained=False, progress=True, **kwargs):
176
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
deep_3drecon/deep_3drecon_models/arcface_torch/backbones/mobilefacenet.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3
+ Original author cavalleria
4
+ '''
5
+
6
+ import torch.nn as nn
7
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8
+ import torch
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, x):
13
+ return x.view(x.size(0), -1)
14
+
15
+
16
+ class ConvBlock(Module):
17
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18
+ super(ConvBlock, self).__init__()
19
+ self.layers = nn.Sequential(
20
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21
+ BatchNorm2d(num_features=out_c),
22
+ PReLU(num_parameters=out_c)
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.layers(x)
27
+
28
+
29
+ class LinearBlock(Module):
30
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31
+ super(LinearBlock, self).__init__()
32
+ self.layers = nn.Sequential(
33
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34
+ BatchNorm2d(num_features=out_c)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layers(x)
39
+
40
+
41
+ class DepthWise(Module):
42
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43
+ super(DepthWise, self).__init__()
44
+ self.residual = residual
45
+ self.layers = nn.Sequential(
46
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49
+ )
50
+
51
+ def forward(self, x):
52
+ short_cut = None
53
+ if self.residual:
54
+ short_cut = x
55
+ x = self.layers(x)
56
+ if self.residual:
57
+ output = short_cut + x
58
+ else:
59
+ output = x
60
+ return output
61
+
62
+
63
+ class Residual(Module):
64
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65
+ super(Residual, self).__init__()
66
+ modules = []
67
+ for _ in range(num_block):
68
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69
+ self.layers = Sequential(*modules)
70
+
71
+ def forward(self, x):
72
+ return self.layers(x)
73
+
74
+
75
+ class GDC(Module):
76
+ def __init__(self, embedding_size):
77
+ super(GDC, self).__init__()
78
+ self.layers = nn.Sequential(
79
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80
+ Flatten(),
81
+ Linear(512, embedding_size, bias=False),
82
+ BatchNorm1d(embedding_size))
83
+
84
+ def forward(self, x):
85
+ return self.layers(x)
86
+
87
+
88
+ class MobileFaceNet(Module):
89
+ def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2):
90
+ super(MobileFaceNet, self).__init__()
91
+ self.scale = scale
92
+ self.fp16 = fp16
93
+ self.layers = nn.ModuleList()
94
+ self.layers.append(
95
+ ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
96
+ )
97
+ if blocks[0] == 1:
98
+ self.layers.append(
99
+ ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
100
+ )
101
+ else:
102
+ self.layers.append(
103
+ Residual(64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
104
+ )
105
+
106
+ self.layers.extend(
107
+ [
108
+ DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
109
+ Residual(64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
110
+ DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
111
+ Residual(128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
112
+ DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
113
+ Residual(128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
114
+ ])
115
+
116
+ self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
117
+ self.features = GDC(num_features)
118
+ self._initialize_weights()
119
+
120
+ def _initialize_weights(self):
121
+ for m in self.modules():
122
+ if isinstance(m, nn.Conv2d):
123
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
124
+ if m.bias is not None:
125
+ m.bias.data.zero_()
126
+ elif isinstance(m, nn.BatchNorm2d):
127
+ m.weight.data.fill_(1)
128
+ m.bias.data.zero_()
129
+ elif isinstance(m, nn.Linear):
130
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
131
+ if m.bias is not None:
132
+ m.bias.data.zero_()
133
+
134
+ def forward(self, x):
135
+ with torch.cuda.amp.autocast(self.fp16):
136
+ for func in self.layers:
137
+ x = func(x)
138
+ x = self.conv_sep(x.float() if self.fp16 else x)
139
+ x = self.features(x)
140
+ return x
141
+
142
+
143
+ def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2):
144
+ return MobileFaceNet(fp16, num_features, blocks, scale=scale)
145
+
146
+ def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4):
147
+ return MobileFaceNet(fp16, num_features, blocks, scale=scale)
deep_3drecon/deep_3drecon_models/arcface_torch/backbones/vit.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
4
+ from typing import Optional, Callable
5
+
6
+ class Mlp(nn.Module):
7
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):
8
+ super().__init__()
9
+ out_features = out_features or in_features
10
+ hidden_features = hidden_features or in_features
11
+ self.fc1 = nn.Linear(in_features, hidden_features)
12
+ self.act = act_layer()
13
+ self.fc2 = nn.Linear(hidden_features, out_features)
14
+ self.drop = nn.Dropout(drop)
15
+
16
+ def forward(self, x):
17
+ x = self.fc1(x)
18
+ x = self.act(x)
19
+ x = self.drop(x)
20
+ x = self.fc2(x)
21
+ x = self.drop(x)
22
+ return x
23
+
24
+
25
+ class VITBatchNorm(nn.Module):
26
+ def __init__(self, num_features):
27
+ super().__init__()
28
+ self.num_features = num_features
29
+ self.bn = nn.BatchNorm1d(num_features=num_features)
30
+
31
+ def forward(self, x):
32
+ return self.bn(x)
33
+
34
+
35
+ class Attention(nn.Module):
36
+ def __init__(self,
37
+ dim: int,
38
+ num_heads: int = 8,
39
+ qkv_bias: bool = False,
40
+ qk_scale: Optional[None] = None,
41
+ attn_drop: float = 0.,
42
+ proj_drop: float = 0.):
43
+ super().__init__()
44
+ self.num_heads = num_heads
45
+ head_dim = dim // num_heads
46
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
47
+ self.scale = qk_scale or head_dim ** -0.5
48
+
49
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
50
+ self.attn_drop = nn.Dropout(attn_drop)
51
+ self.proj = nn.Linear(dim, dim)
52
+ self.proj_drop = nn.Dropout(proj_drop)
53
+
54
+ def forward(self, x):
55
+
56
+ with torch.cuda.amp.autocast(True):
57
+ batch_size, num_token, embed_dim = x.shape
58
+ #qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads]
59
+ qkv = self.qkv(x).reshape(
60
+ batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads).permute(2, 0, 3, 1, 4)
61
+ with torch.cuda.amp.autocast(False):
62
+ q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float()
63
+ attn = (q @ k.transpose(-2, -1)) * self.scale
64
+ attn = attn.softmax(dim=-1)
65
+ attn = self.attn_drop(attn)
66
+ x = (attn @ v).transpose(1, 2).reshape(batch_size, num_token, embed_dim)
67
+ with torch.cuda.amp.autocast(True):
68
+ x = self.proj(x)
69
+ x = self.proj_drop(x)
70
+ return x
71
+
72
+
73
+ class Block(nn.Module):
74
+
75
+ def __init__(self,
76
+ dim: int,
77
+ num_heads: int,
78
+ num_patches: int,
79
+ mlp_ratio: float = 4.,
80
+ qkv_bias: bool = False,
81
+ qk_scale: Optional[None] = None,
82
+ drop: float = 0.,
83
+ attn_drop: float = 0.,
84
+ drop_path: float = 0.,
85
+ act_layer: Callable = nn.ReLU6,
86
+ norm_layer: str = "ln",
87
+ patch_n: int = 144):
88
+ super().__init__()
89
+
90
+ if norm_layer == "bn":
91
+ self.norm1 = VITBatchNorm(num_features=num_patches)
92
+ self.norm2 = VITBatchNorm(num_features=num_patches)
93
+ elif norm_layer == "ln":
94
+ self.norm1 = nn.LayerNorm(dim)
95
+ self.norm2 = nn.LayerNorm(dim)
96
+
97
+ self.attn = Attention(
98
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
99
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
100
+ self.drop_path = DropPath(
101
+ drop_path) if drop_path > 0. else nn.Identity()
102
+ mlp_hidden_dim = int(dim * mlp_ratio)
103
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
104
+ act_layer=act_layer, drop=drop)
105
+ self.extra_gflops = (num_heads * patch_n * (dim//num_heads)*patch_n * 2) / (1000**3)
106
+
107
+ def forward(self, x):
108
+ x = x + self.drop_path(self.attn(self.norm1(x)))
109
+ with torch.cuda.amp.autocast(True):
110
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
111
+ return x
112
+
113
+
114
+ class PatchEmbed(nn.Module):
115
+ def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768):
116
+ super().__init__()
117
+ img_size = to_2tuple(img_size)
118
+ patch_size = to_2tuple(patch_size)
119
+ num_patches = (img_size[1] // patch_size[1]) * \
120
+ (img_size[0] // patch_size[0])
121
+ self.img_size = img_size
122
+ self.patch_size = patch_size
123
+ self.num_patches = num_patches
124
+ self.proj = nn.Conv2d(in_channels, embed_dim,
125
+ kernel_size=patch_size, stride=patch_size)
126
+
127
+ def forward(self, x):
128
+ batch_size, channels, height, width = x.shape
129
+ assert height == self.img_size[0] and width == self.img_size[1], \
130
+ f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
131
+ x = self.proj(x).flatten(2).transpose(1, 2)
132
+ return x
133
+
134
+
135
+ class VisionTransformer(nn.Module):
136
+ """ Vision Transformer with support for patch or hybrid CNN input stage
137
+ """
138
+
139
+ def __init__(self,
140
+ img_size: int = 112,
141
+ patch_size: int = 16,
142
+ in_channels: int = 3,
143
+ num_classes: int = 1000,
144
+ embed_dim: int = 768,
145
+ depth: int = 12,
146
+ num_heads: int = 12,
147
+ mlp_ratio: float = 4.,
148
+ qkv_bias: bool = False,
149
+ qk_scale: Optional[None] = None,
150
+ drop_rate: float = 0.,
151
+ attn_drop_rate: float = 0.,
152
+ drop_path_rate: float = 0.,
153
+ hybrid_backbone: Optional[None] = None,
154
+ norm_layer: str = "ln",
155
+ mask_ratio = 0.1,
156
+ using_checkpoint = False,
157
+ ):
158
+ super().__init__()
159
+ self.num_classes = num_classes
160
+ # num_features for consistency with other models
161
+ self.num_features = self.embed_dim = embed_dim
162
+
163
+ if hybrid_backbone is not None:
164
+ raise ValueError
165
+ else:
166
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
167
+ self.mask_ratio = mask_ratio
168
+ self.using_checkpoint = using_checkpoint
169
+ num_patches = self.patch_embed.num_patches
170
+ self.num_patches = num_patches
171
+
172
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
173
+ self.pos_drop = nn.Dropout(p=drop_rate)
174
+
175
+ # stochastic depth decay rule
176
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
177
+ patch_n = (img_size//patch_size)**2
178
+ self.blocks = nn.ModuleList(
179
+ [
180
+ Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
181
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
182
+ num_patches=num_patches, patch_n=patch_n)
183
+ for i in range(depth)]
184
+ )
185
+ self.extra_gflops = 0.0
186
+ for _block in self.blocks:
187
+ self.extra_gflops += _block.extra_gflops
188
+
189
+ if norm_layer == "ln":
190
+ self.norm = nn.LayerNorm(embed_dim)
191
+ elif norm_layer == "bn":
192
+ self.norm = VITBatchNorm(self.num_patches)
193
+
194
+ # features head
195
+ self.feature = nn.Sequential(
196
+ nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False),
197
+ nn.BatchNorm1d(num_features=embed_dim, eps=2e-5),
198
+ nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False),
199
+ nn.BatchNorm1d(num_features=num_classes, eps=2e-5)
200
+ )
201
+
202
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
203
+ torch.nn.init.normal_(self.mask_token, std=.02)
204
+ trunc_normal_(self.pos_embed, std=.02)
205
+ # trunc_normal_(self.cls_token, std=.02)
206
+ self.apply(self._init_weights)
207
+
208
+ def _init_weights(self, m):
209
+ if isinstance(m, nn.Linear):
210
+ trunc_normal_(m.weight, std=.02)
211
+ if isinstance(m, nn.Linear) and m.bias is not None:
212
+ nn.init.constant_(m.bias, 0)
213
+ elif isinstance(m, nn.LayerNorm):
214
+ nn.init.constant_(m.bias, 0)
215
+ nn.init.constant_(m.weight, 1.0)
216
+
217
+ @torch.jit.ignore
218
+ def no_weight_decay(self):
219
+ return {'pos_embed', 'cls_token'}
220
+
221
+ def get_classifier(self):
222
+ return self.head
223
+
224
+ def random_masking(self, x, mask_ratio=0.1):
225
+ """
226
+ Perform per-sample random masking by per-sample shuffling.
227
+ Per-sample shuffling is done by argsort random noise.
228
+ x: [N, L, D], sequence
229
+ """
230
+ N, L, D = x.size() # batch, length, dim
231
+ len_keep = int(L * (1 - mask_ratio))
232
+
233
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
234
+
235
+ # sort noise for each sample
236
+ # ascend: small is keep, large is remove
237
+ ids_shuffle = torch.argsort(noise, dim=1)
238
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
239
+
240
+ # keep the first subset
241
+ ids_keep = ids_shuffle[:, :len_keep]
242
+ x_masked = torch.gather(
243
+ x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
244
+
245
+ # generate the binary mask: 0 is keep, 1 is remove
246
+ mask = torch.ones([N, L], device=x.device)
247
+ mask[:, :len_keep] = 0
248
+ # unshuffle to get the binary mask
249
+ mask = torch.gather(mask, dim=1, index=ids_restore)
250
+
251
+ return x_masked, mask, ids_restore
252
+
253
+ def forward_features(self, x):
254
+ B = x.shape[0]
255
+ x = self.patch_embed(x)
256
+ x = x + self.pos_embed
257
+ x = self.pos_drop(x)
258
+
259
+ if self.training and self.mask_ratio > 0:
260
+ x, _, ids_restore = self.random_masking(x)
261
+
262
+ for func in self.blocks:
263
+ if self.using_checkpoint and self.training:
264
+ from torch.utils.checkpoint import checkpoint
265
+ x = checkpoint(func, x)
266
+ else:
267
+ x = func(x)
268
+ x = self.norm(x.float())
269
+
270
+ if self.training and self.mask_ratio > 0:
271
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
272
+ x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
273
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
274
+ x = x_
275
+ return torch.reshape(x, (B, self.num_patches * self.embed_dim))
276
+
277
+ def forward(self, x):
278
+ x = self.forward_features(x)
279
+ x = self.feature(x)
280
+ return x