Spaces:
Build error
Build error
Commit
·
8eb4303
0
Parent(s):
init project
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +12 -0
- .gitignore +201 -0
- LICENSE +21 -0
- README-zh.md +157 -0
- README.md +155 -0
- checkpoints/.gitkeep +0 -0
- data_gen/eg3d/convert_to_eg3d_convention.py +146 -0
- data_gen/runs/binarizer_nerf.py +335 -0
- data_gen/runs/binarizer_th1kh.py +100 -0
- data_gen/runs/nerf/process_guide.md +49 -0
- data_gen/runs/nerf/run.sh +51 -0
- data_gen/utils/mp_feature_extractors/face_landmarker.py +130 -0
- data_gen/utils/mp_feature_extractors/face_landmarker.task +3 -0
- data_gen/utils/mp_feature_extractors/mp_segmenter.py +303 -0
- data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite +3 -0
- data_gen/utils/path_converter.py +24 -0
- data_gen/utils/process_audio/extract_hubert.py +92 -0
- data_gen/utils/process_audio/extract_mel_f0.py +148 -0
- data_gen/utils/process_audio/resample_audio_to_16k.py +49 -0
- data_gen/utils/process_image/extract_lm2d.py +197 -0
- data_gen/utils/process_image/extract_segment_imgs.py +114 -0
- data_gen/utils/process_image/fit_3dmm_landmark.py +369 -0
- data_gen/utils/process_video/euler2quaterion.py +35 -0
- data_gen/utils/process_video/extract_blink.py +50 -0
- data_gen/utils/process_video/extract_lm2d.py +164 -0
- data_gen/utils/process_video/extract_segment_imgs.py +494 -0
- data_gen/utils/process_video/fit_3dmm_landmark.py +565 -0
- data_gen/utils/process_video/inpaint_torso_imgs.py +193 -0
- data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py +87 -0
- data_gen/utils/process_video/split_video_to_imgs.py +53 -0
- data_util/face3d_helper.py +309 -0
- deep_3drecon/BFM/.gitkeep +0 -0
- deep_3drecon/BFM/basel_53201.txt +0 -0
- deep_3drecon/BFM/index_mp468_from_mesh35709_v1.npy +3 -0
- deep_3drecon/BFM/index_mp468_from_mesh35709_v2.npy +3 -0
- deep_3drecon/BFM/index_mp468_from_mesh35709_v3.1.npy +3 -0
- deep_3drecon/BFM/index_mp468_from_mesh35709_v3.npy +3 -0
- deep_3drecon/BFM/select_vertex_id.mat +0 -0
- deep_3drecon/BFM/similarity_Lm3D_all.mat +0 -0
- deep_3drecon/__init__.py +1 -0
- deep_3drecon/bfm_left_eye_faces.npy +3 -0
- deep_3drecon/bfm_right_eye_faces.npy +3 -0
- deep_3drecon/data_preparation.py +45 -0
- deep_3drecon/deep_3drecon_models/__init__.py +67 -0
- deep_3drecon/deep_3drecon_models/arcface_torch/README.md +218 -0
- deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__init__.py +85 -0
- deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet.py +194 -0
- deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet2060.py +176 -0
- deep_3drecon/deep_3drecon_models/arcface_torch/backbones/mobilefacenet.py +147 -0
- deep_3drecon/deep_3drecon_models/arcface_torch/backbones/vit.py +280 -0
.gitattributes
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite filter=lfs diff=lfs merge=lfs -text
|
2 |
+
data_gen/utils/mp_feature_extractors/face_landmarker.task filter=lfs diff=lfs merge=lfs -text
|
3 |
+
utils/audio/pitch/bin/ReaperF0 filter=lfs diff=lfs merge=lfs -text
|
4 |
+
utils/audio/pitch/bin/ExtractF0ByStraight filter=lfs diff=lfs merge=lfs -text
|
5 |
+
utils/audio/pitch/bin/InterpF0 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
deep_3drecon/bfm_left_eye_faces.npy filter=lfs diff=lfs merge=lfs -text
|
7 |
+
deep_3drecon/bfm_right_eye_faces.npy filter=lfs diff=lfs merge=lfs -text
|
8 |
+
deep_3drecon/ncc_code.npy filter=lfs diff=lfs merge=lfs -text
|
9 |
+
deep_3drecon/BFM/index_mp468_from_mesh35709_v1.npy filter=lfs diff=lfs merge=lfs -text
|
10 |
+
deep_3drecon/BFM/index_mp468_from_mesh35709_v2.npy filter=lfs diff=lfs merge=lfs -text
|
11 |
+
deep_3drecon/BFM/index_mp468_from_mesh35709_v3.1.npy filter=lfs diff=lfs merge=lfs -text
|
12 |
+
deep_3drecon/BFM/index_mp468_from_mesh35709_v3.npy filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# big files
|
2 |
+
data_util/face_tracking/3DMM/01_MorphableModel.mat
|
3 |
+
data_util/face_tracking/3DMM/3DMM_info.npy
|
4 |
+
|
5 |
+
!/deep_3drecon/BFM/.gitkeep
|
6 |
+
deep_3drecon/BFM/Exp_Pca.bin
|
7 |
+
deep_3drecon/BFM/01_MorphableModel.mat
|
8 |
+
deep_3drecon/BFM/BFM_model_front.mat
|
9 |
+
deep_3drecon/network/FaceReconModel.pb
|
10 |
+
deep_3drecon/checkpoints/*
|
11 |
+
|
12 |
+
.vscode
|
13 |
+
### Project ignore
|
14 |
+
./checkpoints/*
|
15 |
+
/checkpoints/*
|
16 |
+
!/checkpoints/.gitkeep
|
17 |
+
/data/*
|
18 |
+
!/data/.gitkeep
|
19 |
+
infer_out
|
20 |
+
rsync
|
21 |
+
.idea
|
22 |
+
.DS_Store
|
23 |
+
bak
|
24 |
+
tmp
|
25 |
+
*.tar.gz
|
26 |
+
mos
|
27 |
+
nbs
|
28 |
+
/configs_usr/*
|
29 |
+
!/configs_usr/.gitkeep
|
30 |
+
/egs_usr/*
|
31 |
+
!/egs_usr/.gitkeep
|
32 |
+
/rnnoise
|
33 |
+
#/usr/*
|
34 |
+
#!/usr/.gitkeep
|
35 |
+
scripts_usr
|
36 |
+
|
37 |
+
# Created by .ignore support plugin (hsz.mobi)
|
38 |
+
### Python template
|
39 |
+
# Byte-compiled / optimized / DLL files
|
40 |
+
__pycache__/
|
41 |
+
*.py[cod]
|
42 |
+
*$py.class
|
43 |
+
|
44 |
+
# C extensions
|
45 |
+
*.so
|
46 |
+
|
47 |
+
# Distribution / packaging
|
48 |
+
.Python
|
49 |
+
build/
|
50 |
+
develop-eggs/
|
51 |
+
dist/
|
52 |
+
downloads/
|
53 |
+
eggs/
|
54 |
+
.eggs/
|
55 |
+
lib/
|
56 |
+
lib64/
|
57 |
+
parts/
|
58 |
+
sdist/
|
59 |
+
var/
|
60 |
+
wheels/
|
61 |
+
pip-wheel-metadata/
|
62 |
+
share/python-wheels/
|
63 |
+
*.egg-info/
|
64 |
+
.installed.cfg
|
65 |
+
*.egg
|
66 |
+
MANIFEST
|
67 |
+
|
68 |
+
# PyInstaller
|
69 |
+
# Usually these files are written by a python script from a template
|
70 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
71 |
+
*.manifest
|
72 |
+
*.spec
|
73 |
+
|
74 |
+
# Installer logs
|
75 |
+
pip-log.txt
|
76 |
+
pip-delete-this-directory.txt
|
77 |
+
|
78 |
+
# Unit test / coverage reports
|
79 |
+
htmlcov/
|
80 |
+
.tox/
|
81 |
+
.nox/
|
82 |
+
.coverage
|
83 |
+
.coverage.*
|
84 |
+
.cache
|
85 |
+
nosetests.xml
|
86 |
+
coverage.xml
|
87 |
+
*.cover
|
88 |
+
.hypothesis/
|
89 |
+
.pytest_cache/
|
90 |
+
|
91 |
+
# Translations
|
92 |
+
*.mo
|
93 |
+
*.pot
|
94 |
+
|
95 |
+
# Django stuff:
|
96 |
+
*.log
|
97 |
+
local_settings.py
|
98 |
+
db.sqlite3
|
99 |
+
db.sqlite3-journal
|
100 |
+
|
101 |
+
# Flask stuff:
|
102 |
+
instance/
|
103 |
+
.webassets-cache
|
104 |
+
|
105 |
+
# Scrapy stuff:
|
106 |
+
.scrapy
|
107 |
+
|
108 |
+
# Sphinx documentation
|
109 |
+
docs/_build/
|
110 |
+
|
111 |
+
# PyBuilder
|
112 |
+
target/
|
113 |
+
|
114 |
+
# Jupyter Notebook
|
115 |
+
.ipynb_checkpoints
|
116 |
+
|
117 |
+
# IPython
|
118 |
+
profile_default/
|
119 |
+
ipython_config.py
|
120 |
+
|
121 |
+
# pyenv
|
122 |
+
.python-version
|
123 |
+
|
124 |
+
# pipenv
|
125 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
126 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
127 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
128 |
+
# install all needed dependencies.
|
129 |
+
#Pipfile.lock
|
130 |
+
|
131 |
+
# celery beat schedule file
|
132 |
+
celerybeat-schedule
|
133 |
+
|
134 |
+
# SageMath parsed files
|
135 |
+
*.sage.py
|
136 |
+
|
137 |
+
# Environments
|
138 |
+
.env
|
139 |
+
.venv
|
140 |
+
env/
|
141 |
+
venv/
|
142 |
+
ENV/
|
143 |
+
env.bak/
|
144 |
+
venv.bak/
|
145 |
+
|
146 |
+
# Spyder project settings
|
147 |
+
.spyderproject
|
148 |
+
.spyproject
|
149 |
+
|
150 |
+
# Rope project settings
|
151 |
+
.ropeproject
|
152 |
+
|
153 |
+
# mkdocs documentation
|
154 |
+
/site
|
155 |
+
|
156 |
+
# mypy
|
157 |
+
.mypy_cache/
|
158 |
+
.dmypy.json
|
159 |
+
dmypy.json
|
160 |
+
|
161 |
+
# Pyre type checker
|
162 |
+
.pyre/
|
163 |
+
data_util/deepspeech_features/deepspeech-0.9.2-models.pbmm
|
164 |
+
deep_3drecon/mesh_renderer/bazel-bin
|
165 |
+
deep_3drecon/mesh_renderer/bazel-mesh_renderer
|
166 |
+
deep_3drecon/mesh_renderer/bazel-out
|
167 |
+
deep_3drecon/mesh_renderer/bazel-testlogs
|
168 |
+
|
169 |
+
.nfs*
|
170 |
+
infer_outs/*
|
171 |
+
|
172 |
+
*.pth
|
173 |
+
venv_113/*
|
174 |
+
*.pt
|
175 |
+
experiments/trials
|
176 |
+
flame_3drecon/*
|
177 |
+
|
178 |
+
temp/
|
179 |
+
/kill.sh
|
180 |
+
/datasets
|
181 |
+
data_util/imagenet_classes.txt
|
182 |
+
process_data_May.sh
|
183 |
+
/env_prepare_reproduce.md
|
184 |
+
/my_debug.py
|
185 |
+
|
186 |
+
utils/metrics/shape_predictor_68_face_landmarks.dat
|
187 |
+
*.mp4
|
188 |
+
_torchshow/
|
189 |
+
*.png
|
190 |
+
*.jpg
|
191 |
+
|
192 |
+
*.mrc
|
193 |
+
|
194 |
+
deep_3drecon/BFM/BFM_exp_idx.mat
|
195 |
+
deep_3drecon/BFM/BFM_front_idx.mat
|
196 |
+
deep_3drecon/BFM/facemodel_info.mat
|
197 |
+
deep_3drecon/BFM/index_mp468_from_mesh35709.npy
|
198 |
+
deep_3drecon/BFM/mediapipe_in_bfm53201.npy
|
199 |
+
deep_3drecon/BFM/std_exp.txt
|
200 |
+
!data/raw/examples/*
|
201 |
+
/heckpoints_mimictalk
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 ZhenhuiYe
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README-zh.md
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes | NeurIPS 2024
|
2 |
+
[](https://arxiv.org/abs/2401.08503)| [](https://github.com/yerfor/MimicTalk) | [English Readme](./README.md)
|
4 |
+
|
5 |
+
这个仓库是MimicTalk的官方PyTorch实现, 用于实现特定说话人的高表现力的虚拟人视频合成。该仓库代码基于我们先前的工作[Real3D-Portrait](https://github.com/yerfor/Real3DPortrait) (ICLR 2024),即基于NeRF的one-shot说话人合成,这让Mimictalk的训练加速且效果增强。您可以访问我们的[项目页面](https://mimictalk.github.io/)以观看Demo视频, 阅读我们的[论文](https://arxiv.org/abs/2410.06734)以了解技术细节。
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<br>
|
9 |
+
<img src="assets/mimictalk.png" width="100%"/>
|
10 |
+
<br>
|
11 |
+
</p>
|
12 |
+
|
13 |
+
# 快速上手!
|
14 |
+
## 安装环境
|
15 |
+
请参照[环境配置文档](docs/prepare_env/install_guide-zh.md),配置Conda环境`mimictalk`
|
16 |
+
## 下载预训练与第三方模型
|
17 |
+
### 3DMM BFM模型
|
18 |
+
下载3DMM BFM模型:[Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) 提取码: m9q5
|
19 |
+
|
20 |
+
|
21 |
+
下载完成后,放置全部的文件到`deep_3drecon/BFM`里,文件结构如下:
|
22 |
+
```
|
23 |
+
deep_3drecon/BFM/
|
24 |
+
├── 01_MorphableModel.mat
|
25 |
+
├── BFM_exp_idx.mat
|
26 |
+
├── BFM_front_idx.mat
|
27 |
+
├── BFM_model_front.mat
|
28 |
+
├── Exp_Pca.bin
|
29 |
+
├── facemodel_info.mat
|
30 |
+
├── index_mp468_from_mesh35709.npy
|
31 |
+
├── mediapipe_in_bfm53201.npy
|
32 |
+
└── std_exp.txt
|
33 |
+
```
|
34 |
+
|
35 |
+
### 预训练模型
|
36 |
+
下载预训练的MimicTalk相关Checkpoints:[Google Drive](https://drive.google.com/drive/folders/1Kc6ueDO9HFDN3BhtJCEKNCZtyKHSktaA?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1nQKyGV5JB6rJtda7qsThUg?pwd=mimi) 提取码: mimi
|
37 |
+
|
38 |
+
下载完成后,放置全部的文件到`checkpoints`与`checkpoints_mimictalk`里并解压,文件结构如下:
|
39 |
+
```
|
40 |
+
checkpoints/
|
41 |
+
├── mimictalk_orig
|
42 |
+
│ └── os_secc2plane_torso
|
43 |
+
│ ├── config.yaml
|
44 |
+
│ └── model_ckpt_steps_100000.ckpt
|
45 |
+
|-- 240112_icl_audio2secc_vox2_cmlr
|
46 |
+
│ ├── config.yaml
|
47 |
+
│ └── model_ckpt_steps_1856000.ckpt
|
48 |
+
└── pretrained_ckpts
|
49 |
+
└── mit_b0.pth
|
50 |
+
|
51 |
+
checkpoints_mimictalk/
|
52 |
+
└── German_20s
|
53 |
+
├── config.yaml
|
54 |
+
└── model_ckpt_steps_10000.ckpt
|
55 |
+
```
|
56 |
+
|
57 |
+
## MimicTalk训练与推理的最简命令
|
58 |
+
```
|
59 |
+
python inference/train_mimictalk_on_a_video.py # train the model, this may take 10 minutes for 2,000 steps
|
60 |
+
python inference/mimictalk_infer.py # infer the model
|
61 |
+
```
|
62 |
+
|
63 |
+
|
64 |
+
# 训练与推理细节
|
65 |
+
我们目前提供了**命令行(CLI)**与**Gradio WebUI**推理方式。音频驱动推理的人像信息来自于`torso_ckpt`,因此需要至少再提供`driving audio`用于推理。另外,可以提供`style video`让模型能够预测与该视频风格一致的说话人动作。
|
66 |
+
|
67 |
+
首先,切换至项目根目录并启用Conda环境:
|
68 |
+
```bash
|
69 |
+
cd <Real3DPortraitRoot>
|
70 |
+
conda activate mimictalk
|
71 |
+
export PYTHONPATH=./
|
72 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
73 |
+
```
|
74 |
+
|
75 |
+
## Gradio WebUI推理
|
76 |
+
启动Gradio WebUI,按照提示上传素材,点击`Training`按钮进行训练;训练完成后点击`Generate`按钮即可推理:
|
77 |
+
```bash
|
78 |
+
python inference/app_mimictalk.py
|
79 |
+
```
|
80 |
+
|
81 |
+
## 命令行特定说话人训练
|
82 |
+
|
83 |
+
需要至少提供`source video`,训练指令:
|
84 |
+
```bash
|
85 |
+
python inference/train_mimictalk_on_a_video.py \
|
86 |
+
--video_id <PATH_TO_SOURCE_VIDEO> \
|
87 |
+
--max_updates <UPDATES_NUMBER> \
|
88 |
+
--work_dir <PATH_TO_SAVING_CKPT>
|
89 |
+
```
|
90 |
+
|
91 |
+
一些可选参数注释:
|
92 |
+
|
93 |
+
- `--torso_ckpt` 预训练的Real3D-Portrait模型
|
94 |
+
- `--max_updates` 训练更新次数
|
95 |
+
- `--batch_size` 训练的batch size: `1` 需要约8GB显存; `2`需要约15GB显存
|
96 |
+
- `--lr_triplane` triplane的学习率:对于视频输入, 应为0.1; 对于图片输入,应为0.001
|
97 |
+
- `--work_dir` 未指定时,将默认存储在`checkpoints_mimictalk/`中
|
98 |
+
|
99 |
+
指令示例:
|
100 |
+
```bash
|
101 |
+
python inference/train_mimictalk_on_a_video.py \
|
102 |
+
--video_id data/raw/videos/German_20s.mp4 \
|
103 |
+
--max_updates 2000 \
|
104 |
+
--work_dir checkpoints_mimictalk/German_20s
|
105 |
+
```
|
106 |
+
|
107 |
+
## 命令行推理
|
108 |
+
|
109 |
+
需要至少提供`driving audio`,可选提供`driving style`,推理指令:
|
110 |
+
```bash
|
111 |
+
python inference/mimictalk_infer.py \
|
112 |
+
--drv_aud <PATH_TO_AUDIO> \
|
113 |
+
--drv_style <PATH_TO_STYLE_VIDEO, OPTIONAL> \
|
114 |
+
--drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
|
115 |
+
--bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
|
116 |
+
--out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
|
117 |
+
```
|
118 |
+
|
119 |
+
一些可选参数注释:
|
120 |
+
- `--drv_pose` 指定时提供了运动pose信息,不指定则为静态运动
|
121 |
+
- `--bg_img` 指定时提供了背景信息,不指定则为source image提取的背景
|
122 |
+
- `--mouth_amp` 嘴部张幅参数,值越大张幅越大
|
123 |
+
- `--map_to_init_pose` 值为`True`时,首帧的pose将被映��到source pose,后续帧也作相同变换
|
124 |
+
- `--temperature` 代表audio2motion的采样温度,值越大结果越多样,但同时精确度越低
|
125 |
+
- `--out_name` 不指定时,结果将保存在`infer_out/tmp/`中
|
126 |
+
- `--out_mode` 值为`final`时,只输出说话人视频;值为`concat_debug`时,同时输出一些可视化的中间结果
|
127 |
+
|
128 |
+
推理命令例子:
|
129 |
+
```bash
|
130 |
+
python inference/mimictalk_infer.py \
|
131 |
+
--drv_aud data/raw/examples/Obama_5s.wav \
|
132 |
+
--drv_pose data/raw/examples/German_20s.mp4 \
|
133 |
+
--drv_style data/raw/examples/German_20s.mp4 \
|
134 |
+
--bg_img data/raw/examples/bg.png \
|
135 |
+
--out_name output.mp4 \
|
136 |
+
--out_mode final
|
137 |
+
```
|
138 |
+
|
139 |
+
# 声明
|
140 |
+
任何组织或个人未经本人同意,不得使用本文提及的任何技术生成他人说话的视频,包括但不限于政府领导人、政界人士、社会名流等。如不遵守本条款,则可能违反版权法。
|
141 |
+
|
142 |
+
# 引用我们
|
143 |
+
如果这个仓库对你有帮助,请考虑引用我们的工作:
|
144 |
+
```
|
145 |
+
@inproceedings{ye2024mimicktalk,
|
146 |
+
author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiangwei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and Zhang, Chen and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
|
147 |
+
title = {MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes},
|
148 |
+
journal = {NeurIPS},
|
149 |
+
year = {2024},
|
150 |
+
}
|
151 |
+
@inproceedings{ye2024real3d,
|
152 |
+
title = {Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
|
153 |
+
author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
|
154 |
+
journal = {ICLR},
|
155 |
+
year={2024}
|
156 |
+
}
|
157 |
+
```
|
README.md
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes | NeurIPS 2024
|
2 |
+
[](https://arxiv.org/abs/2401.08503)| [](https://github.com/yerfor/MimicTalk) | [中文文档](./README-zh.md)
|
4 |
+
|
5 |
+
This is the official repo of MimicTalk with Pytorch implementation, for training a personalized and expressive talking avatar in minutes. The code is built upon our previous work, [Real3D-Portrait](https://github.com/yerfor/Real3DPortrait) (ICLR 2024), which is a one-shot NeRF-based talking avatar system and enables the fast training and good quality of our MimicTalk. You can visit our [Demo Page](https://mimictalk.github.io/) for watching demo videos, and read our [Paper](https://arxiv.org/abs/2410.06734) for technical details.
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<br>
|
9 |
+
<img src="assets/mimictalk.png" width="100%"/>
|
10 |
+
<br>
|
11 |
+
</p>
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Quick Start!
|
16 |
+
## Environment Installation
|
17 |
+
Please refer to [Installation Guide](docs/prepare_env/install_guide.md), prepare a Conda environment `mimictalk`.
|
18 |
+
## Download Pre-trained & Third-Party Models
|
19 |
+
### 3DMM BFM Model
|
20 |
+
Download 3DMM BFM Model from [Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) with Password m9q5.
|
21 |
+
|
22 |
+
|
23 |
+
Put all the files in `deep_3drecon/BFM`, the file structure will be like this:
|
24 |
+
```
|
25 |
+
deep_3drecon/BFM/
|
26 |
+
├── 01_MorphableModel.mat
|
27 |
+
├── BFM_exp_idx.mat
|
28 |
+
├── BFM_front_idx.mat
|
29 |
+
├── BFM_model_front.mat
|
30 |
+
├── Exp_Pca.bin
|
31 |
+
├── facemodel_info.mat
|
32 |
+
├── index_mp468_from_mesh35709.npy
|
33 |
+
└── std_exp.txt
|
34 |
+
```
|
35 |
+
|
36 |
+
### Pre-trained Real3D-Portrait & MimicTalk
|
37 |
+
Download Pre-trained MimicTalk Checkpoints:[Google Drive](https://drive.google.com/drive/folders/1Kc6ueDO9HFDN3BhtJCEKNCZtyKHSktaA?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1nQKyGV5JB6rJtda7qsThUg?pwd=mimi) with Password `mimi`
|
38 |
+
|
39 |
+
|
40 |
+
Put the zip files in `checkpoints` & `checkpoints_mimictalk` and unzip them, the file structure will be like this:
|
41 |
+
```
|
42 |
+
checkpoints/
|
43 |
+
├── mimictalk_orig
|
44 |
+
│ └── os_secc2plane_torso
|
45 |
+
│ ├── config.yaml
|
46 |
+
│ └── model_ckpt_steps_100000.ckpt
|
47 |
+
|-- 240112_icl_audio2secc_vox2_cmlr
|
48 |
+
│ ├── config.yaml
|
49 |
+
│ └── model_ckpt_steps_1856000.ckpt
|
50 |
+
└── pretrained_ckpts
|
51 |
+
└── mit_b0.pth
|
52 |
+
|
53 |
+
checkpoints_mimictalk/
|
54 |
+
└── German_20s
|
55 |
+
├── config.yaml
|
56 |
+
└── model_ckpt_steps_10000.ckpt
|
57 |
+
```
|
58 |
+
|
59 |
+
## Train & Infer MimicTalk in two lines
|
60 |
+
```
|
61 |
+
python inference/train_mimictalk_on_a_video.py # train the model, this may take 10 minutes for 2,000 steps
|
62 |
+
python inference/mimictalk_infer.py # infer the model
|
63 |
+
```
|
64 |
+
|
65 |
+
# Detailed options for train & infer
|
66 |
+
Currently, we provide **CLI**, **Gradio WebUI** for inference. We support Audio-Driven talking head generation for specific-person (which is from `torso_ckpt`), and at least prepare `driving audio` for inference. Optionly, providing `style video` enables model to predict corressponding talking style with it.
|
67 |
+
|
68 |
+
Firstly, switch to project folder and activate conda environment:
|
69 |
+
```bash
|
70 |
+
cd <mimictalkRoot>
|
71 |
+
conda activate mimictalk
|
72 |
+
export PYTHONPATH=./
|
73 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
74 |
+
```
|
75 |
+
|
76 |
+
## Gradio WebUI
|
77 |
+
Run Gradio WebUI demo, upload resouces in webpage,click `Training` button to train a person-specific MimicTalk model, and then click `Generate` button to inference with arbitary audio and style:
|
78 |
+
```bash
|
79 |
+
python inference/app_mimictalk.py
|
80 |
+
```
|
81 |
+
|
82 |
+
## CLI Training for specific-person video
|
83 |
+
Provide `source video` for specific-person:
|
84 |
+
```bash
|
85 |
+
python inference/train_mimictalk_on_a_video.py \
|
86 |
+
--video_id <PATH_TO_SOURCE_VIDEO> \
|
87 |
+
--max_updates <UPDATES_NUMBER> \
|
88 |
+
--work_dir <PATH_TO_SAVING_CKPT>
|
89 |
+
```
|
90 |
+
|
91 |
+
Some training optional parameters:
|
92 |
+
- `--torso_ckpt` Pre-trained Real3d-Portrait checkpoints path
|
93 |
+
- `--max_updates` The number of training updates.
|
94 |
+
- `--batch_size` Batch size during training: `1` needs about 8GB VRAM; `2` needs about 15GB
|
95 |
+
- `--lr_triplane` Learning rate of triplane: for video, 0.1; for an image, 0.001
|
96 |
+
- `--work_dir` When not assigned, the results will be stored at `checkpoints_mimictalk/`.
|
97 |
+
|
98 |
+
Commandline example:
|
99 |
+
```bash
|
100 |
+
python inference/train_mimictalk_on_a_video.py \
|
101 |
+
--video_id data/raw/videos/German_20s.mp4 \
|
102 |
+
--max_updates 2000 \
|
103 |
+
--work_dir checkpoints_mimictalk/German_20s
|
104 |
+
```
|
105 |
+
|
106 |
+
## CLI Inference
|
107 |
+
|
108 |
+
Provide `driving audio` and `driving style` (Optionly):
|
109 |
+
```bash
|
110 |
+
python inference/mimictalk_infer.py \
|
111 |
+
--drv_aud <PATH_TO_AUDIO> \
|
112 |
+
--drv_style <PATH_TO_STYLE_VIDEO, OPTIONAL> \
|
113 |
+
--drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
|
114 |
+
--bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
|
115 |
+
--out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
|
116 |
+
```
|
117 |
+
|
118 |
+
Some inference optional parameters:
|
119 |
+
- `--drv_pose` provide motion pose information, default to be static poses
|
120 |
+
- `--bg_img` provide background information, default to be image extracted from source
|
121 |
+
- `--map_to_init_pose` when set to `True`, the initial pose will be mapped to source pose, and other poses will be equally transformed
|
122 |
+
- `--temperature` stands for the sampling temperature of audio2motion, higher for more diverse results at the expense of lower accuracy
|
123 |
+
- `--out_name` When not assigned, the results will be stored at `infer_out/tmp/`.
|
124 |
+
- `--out_mode` When `final`, only outputs the final result; when `concat_debug`, also outputs visualization of several intermediate process.
|
125 |
+
|
126 |
+
Commandline example:
|
127 |
+
```bash
|
128 |
+
python inference/mimictalk_infer.py \
|
129 |
+
--drv_aud data/raw/examples/Obama_5s.wav \
|
130 |
+
--drv_pose data/raw/examples/German_20s.mp4 \
|
131 |
+
--drv_style data/raw/examples/German_20s.mp4 \
|
132 |
+
--bg_img data/raw/examples/bg.png \
|
133 |
+
--out_name output.mp4 \
|
134 |
+
--out_mode final
|
135 |
+
```
|
136 |
+
|
137 |
+
# Disclaimer
|
138 |
+
Any organization or individual is prohibited from using any technology mentioned in this paper to generate someone's talking video without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws.
|
139 |
+
|
140 |
+
# Citation
|
141 |
+
If you found this repo helpful to your work, please consider cite us:
|
142 |
+
```
|
143 |
+
@inproceedings{ye2024mimicktalk,
|
144 |
+
author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiangwei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and Zhang, Chen and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
|
145 |
+
title = {MimicTalk: Mimicking a personalized and expressive 3D talking face in few minutes},
|
146 |
+
journal = {NeurIPS},
|
147 |
+
year = {2024},
|
148 |
+
}
|
149 |
+
@inproceedings{ye2024real3d,
|
150 |
+
title = {Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
|
151 |
+
author = {Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
|
152 |
+
journal = {ICLR},
|
153 |
+
year={2024}
|
154 |
+
}
|
155 |
+
```
|
checkpoints/.gitkeep
ADDED
File without changes
|
data_gen/eg3d/convert_to_eg3d_convention.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import copy
|
4 |
+
from utils.commons.tensor_utils import convert_to_tensor, convert_to_np
|
5 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
6 |
+
|
7 |
+
|
8 |
+
def _fix_intrinsics(intrinsics):
|
9 |
+
"""
|
10 |
+
intrinsics: [3,3], not batch-wise
|
11 |
+
"""
|
12 |
+
# unnormalized normalized
|
13 |
+
|
14 |
+
# [[ f_x, s=0, x_0] [[ f_x/size_x, s=0, x_0/size_x=0.5]
|
15 |
+
# [ 0, f_y, y_0] -> [ 0, f_y/size_y, y_0/size_y=0.5]
|
16 |
+
# [ 0, 0, 1 ]] [ 0, 0, 1 ]]
|
17 |
+
intrinsics = np.array(intrinsics).copy()
|
18 |
+
assert intrinsics.shape == (3, 3), intrinsics
|
19 |
+
intrinsics[0,0] = 2985.29/700
|
20 |
+
intrinsics[1,1] = 2985.29/700
|
21 |
+
intrinsics[0,2] = 1/2
|
22 |
+
intrinsics[1,2] = 1/2
|
23 |
+
assert intrinsics[0,1] == 0
|
24 |
+
assert intrinsics[2,2] == 1
|
25 |
+
assert intrinsics[1,0] == 0
|
26 |
+
assert intrinsics[2,0] == 0
|
27 |
+
assert intrinsics[2,1] == 0
|
28 |
+
return intrinsics
|
29 |
+
|
30 |
+
# Used in original submission
|
31 |
+
def _fix_pose_orig(pose):
|
32 |
+
"""
|
33 |
+
pose: [4,4], not batch-wise
|
34 |
+
"""
|
35 |
+
pose = np.array(pose).copy()
|
36 |
+
location = pose[:3, 3]
|
37 |
+
radius = np.linalg.norm(location)
|
38 |
+
pose[:3, 3] = pose[:3, 3]/radius * 2.7
|
39 |
+
return pose
|
40 |
+
|
41 |
+
|
42 |
+
def get_eg3d_convention_camera_pose_intrinsic(item):
|
43 |
+
"""
|
44 |
+
item: a dict during binarize
|
45 |
+
|
46 |
+
"""
|
47 |
+
if item['euler'].ndim == 1:
|
48 |
+
angle = convert_to_tensor(copy.copy(item['euler']))
|
49 |
+
trans = copy.deepcopy(item['trans'])
|
50 |
+
|
51 |
+
# handle the difference of euler axis between eg3d and ours
|
52 |
+
# see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
|
53 |
+
# angle += torch.tensor([0, 3.1415926535, 3.1415926535], device=angle.device)
|
54 |
+
R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
|
55 |
+
trans[2] += -10
|
56 |
+
c = -np.dot(R, trans)
|
57 |
+
pose = np.eye(4)
|
58 |
+
pose[:3,:3] = R
|
59 |
+
c *= 0.27 # normalize camera radius
|
60 |
+
c[1] += 0.006 # additional offset used in submission
|
61 |
+
c[2] += 0.161 # additional offset used in submission
|
62 |
+
pose[0,3] = c[0]
|
63 |
+
pose[1,3] = c[1]
|
64 |
+
pose[2,3] = c[2]
|
65 |
+
|
66 |
+
focal = 2985.29 # = 1015*1024/224*(300/466.285),
|
67 |
+
# todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
|
68 |
+
pp = 512#112
|
69 |
+
w = 1024#224
|
70 |
+
h = 1024#224
|
71 |
+
|
72 |
+
K = np.eye(3)
|
73 |
+
K[0][0] = focal
|
74 |
+
K[1][1] = focal
|
75 |
+
K[0][2] = w/2.0
|
76 |
+
K[1][2] = h/2.0
|
77 |
+
convention_K = _fix_intrinsics(K)
|
78 |
+
|
79 |
+
Rot = np.eye(3)
|
80 |
+
Rot[0, 0] = 1
|
81 |
+
Rot[1, 1] = -1
|
82 |
+
Rot[2, 2] = -1
|
83 |
+
pose[:3, :3] = np.dot(pose[:3, :3], Rot) # permute axes
|
84 |
+
convention_pose = _fix_pose_orig(pose)
|
85 |
+
|
86 |
+
item['c2w'] = pose
|
87 |
+
item['convention_c2w'] = convention_pose
|
88 |
+
item['intrinsics'] = convention_K
|
89 |
+
return item
|
90 |
+
else:
|
91 |
+
num_samples = len(item['euler'])
|
92 |
+
eulers_all = convert_to_tensor(copy.deepcopy(item['euler'])) # [B, 3]
|
93 |
+
trans_all = copy.deepcopy(item['trans']) # [B, 3]
|
94 |
+
|
95 |
+
# handle the difference of euler axis between eg3d and ours
|
96 |
+
# see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
|
97 |
+
# eulers_all += torch.tensor([0, 3.1415926535, 3.1415926535], device=eulers_all.device).unsqueeze(0).repeat([eulers_all.shape[0],1])
|
98 |
+
|
99 |
+
intrinsics = []
|
100 |
+
poses = []
|
101 |
+
convention_poses = []
|
102 |
+
for i in range(num_samples):
|
103 |
+
angle = eulers_all[i]
|
104 |
+
trans = trans_all[i]
|
105 |
+
R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
|
106 |
+
trans[2] += -10
|
107 |
+
c = -np.dot(R, trans)
|
108 |
+
pose = np.eye(4)
|
109 |
+
pose[:3,:3] = R
|
110 |
+
c *= 0.27 # normalize camera radius
|
111 |
+
c[1] += 0.006 # additional offset used in submission
|
112 |
+
c[2] += 0.161 # additional offset used in submission
|
113 |
+
pose[0,3] = c[0]
|
114 |
+
pose[1,3] = c[1]
|
115 |
+
pose[2,3] = c[2]
|
116 |
+
|
117 |
+
focal = 2985.29 # = 1015*1024/224*(300/466.285),
|
118 |
+
# todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
|
119 |
+
pp = 512#112
|
120 |
+
w = 1024#224
|
121 |
+
h = 1024#224
|
122 |
+
|
123 |
+
K = np.eye(3)
|
124 |
+
K[0][0] = focal
|
125 |
+
K[1][1] = focal
|
126 |
+
K[0][2] = w/2.0
|
127 |
+
K[1][2] = h/2.0
|
128 |
+
convention_K = _fix_intrinsics(K)
|
129 |
+
intrinsics.append(convention_K)
|
130 |
+
|
131 |
+
Rot = np.eye(3)
|
132 |
+
Rot[0, 0] = 1
|
133 |
+
Rot[1, 1] = -1
|
134 |
+
Rot[2, 2] = -1
|
135 |
+
pose[:3, :3] = np.dot(pose[:3, :3], Rot)
|
136 |
+
convention_pose = _fix_pose_orig(pose)
|
137 |
+
convention_poses.append(convention_pose)
|
138 |
+
poses.append(pose)
|
139 |
+
|
140 |
+
intrinsics = np.stack(intrinsics) # [B, 3, 3]
|
141 |
+
poses = np.stack(poses) # [B, 4, 4]
|
142 |
+
convention_poses = np.stack(convention_poses) # [B, 4, 4]
|
143 |
+
item['intrinsics'] = intrinsics
|
144 |
+
item['c2w'] = poses
|
145 |
+
item['convention_c2w'] = convention_poses
|
146 |
+
return item
|
data_gen/runs/binarizer_nerf.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
import json
|
5 |
+
import imageio
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
from data_util.face3d_helper import Face3DHelper
|
11 |
+
from utils.commons.euler2rot import euler_trans_2_c2w, c2w_to_euler_trans
|
12 |
+
from data_gen.utils.process_video.euler2quaterion import euler2quaterion, quaterion2euler
|
13 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
14 |
+
|
15 |
+
|
16 |
+
def euler2rot(euler_angle):
|
17 |
+
batch_size = euler_angle.shape[0]
|
18 |
+
theta = euler_angle[:, 0].reshape(-1, 1, 1)
|
19 |
+
phi = euler_angle[:, 1].reshape(-1, 1, 1)
|
20 |
+
psi = euler_angle[:, 2].reshape(-1, 1, 1)
|
21 |
+
one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
|
22 |
+
zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
|
23 |
+
rot_x = torch.cat((
|
24 |
+
torch.cat((one, zero, zero), 1),
|
25 |
+
torch.cat((zero, theta.cos(), theta.sin()), 1),
|
26 |
+
torch.cat((zero, -theta.sin(), theta.cos()), 1),
|
27 |
+
), 2)
|
28 |
+
rot_y = torch.cat((
|
29 |
+
torch.cat((phi.cos(), zero, -phi.sin()), 1),
|
30 |
+
torch.cat((zero, one, zero), 1),
|
31 |
+
torch.cat((phi.sin(), zero, phi.cos()), 1),
|
32 |
+
), 2)
|
33 |
+
rot_z = torch.cat((
|
34 |
+
torch.cat((psi.cos(), -psi.sin(), zero), 1),
|
35 |
+
torch.cat((psi.sin(), psi.cos(), zero), 1),
|
36 |
+
torch.cat((zero, zero, one), 1)
|
37 |
+
), 2)
|
38 |
+
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
|
39 |
+
|
40 |
+
|
41 |
+
def rot2euler(rot_mat):
|
42 |
+
batch_size = len(rot_mat)
|
43 |
+
# we assert that y in in [-0.5pi, 0.5pi]
|
44 |
+
cos_y = torch.sqrt(rot_mat[:, 1, 2] * rot_mat[:, 1, 2] + rot_mat[:, 2, 2] * rot_mat[:, 2, 2])
|
45 |
+
theta_x = torch.atan2(-rot_mat[:, 1, 2], rot_mat[:, 2, 2])
|
46 |
+
theta_y = torch.atan2(rot_mat[:, 2, 0], cos_y)
|
47 |
+
theta_z = torch.atan2(rot_mat[:, 0, 1], rot_mat[:, 0, 0])
|
48 |
+
euler_angles = torch.zeros([batch_size, 3])
|
49 |
+
euler_angles[:, 0] = theta_x
|
50 |
+
euler_angles[:, 1] = theta_y
|
51 |
+
euler_angles[:, 2] = theta_z
|
52 |
+
return euler_angles
|
53 |
+
|
54 |
+
index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
|
55 |
+
33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
|
56 |
+
|
57 |
+
def plot_lm2d(lm2d):
|
58 |
+
WH = 512
|
59 |
+
img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
|
60 |
+
|
61 |
+
for i in range(len(lm2d)):
|
62 |
+
x, y = lm2d[i]
|
63 |
+
color = (255,0,0)
|
64 |
+
img = cv2.circle(img, center=(int(x),int(y)), radius=3, color=color, thickness=-1)
|
65 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
66 |
+
for i in range(len(lm2d)):
|
67 |
+
x, y = lm2d[i]
|
68 |
+
img = cv2.putText(img, f"{i}", org=(int(x),int(y)), fontFace=font, fontScale=0.3, color=(255,0,0))
|
69 |
+
return img
|
70 |
+
|
71 |
+
def get_face_rect(lms, h, w):
|
72 |
+
"""
|
73 |
+
lms: [68, 2]
|
74 |
+
h, w: int
|
75 |
+
return: [4,]
|
76 |
+
"""
|
77 |
+
assert len(lms) == 68
|
78 |
+
# min_x, max_x = np.min(lms, 0)[0], np.max(lms, 0)[0]
|
79 |
+
min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
|
80 |
+
cx = int((min_x+max_x)/2.0)
|
81 |
+
cy = int(lms[27, 1])
|
82 |
+
h_w = int((max_x-cx)*1.5)
|
83 |
+
h_h = int((lms[8, 1]-cy)*1.15)
|
84 |
+
rect_x = cx - h_w
|
85 |
+
rect_y = cy - h_h
|
86 |
+
if rect_x < 0:
|
87 |
+
rect_x = 0
|
88 |
+
if rect_y < 0:
|
89 |
+
rect_y = 0
|
90 |
+
rect_w = min(w-1-rect_x, 2*h_w)
|
91 |
+
rect_h = min(h-1-rect_y, 2*h_h)
|
92 |
+
# rect = np.array((rect_x, rect_y, rect_w, rect_h), dtype=np.int32)
|
93 |
+
# rect = [rect_x, rect_y, rect_w, rect_h]
|
94 |
+
rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
|
95 |
+
return rect # this x is width, y is height
|
96 |
+
|
97 |
+
def get_lip_rect(lms, h, w):
|
98 |
+
"""
|
99 |
+
lms: [68, 2]
|
100 |
+
h, w: int
|
101 |
+
return: [4,]
|
102 |
+
"""
|
103 |
+
# this x is width, y is height
|
104 |
+
# for lms, lms[:, 0] is width, lms[:, 1] is height
|
105 |
+
assert len(lms) == 68
|
106 |
+
lips = slice(48, 60)
|
107 |
+
lms = lms[lips]
|
108 |
+
min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
|
109 |
+
min_y, max_y = np.min(lms[:, 1]), np.max(lms[:, 1])
|
110 |
+
cx = int((min_x+max_x)/2.0)
|
111 |
+
cy = int((min_y+max_y)/2.0)
|
112 |
+
h_w = int((max_x-cx)*1.2)
|
113 |
+
h_h = int((max_y-cy)*1.2)
|
114 |
+
|
115 |
+
h_w = max(h_w, h_h)
|
116 |
+
h_h = h_w
|
117 |
+
|
118 |
+
rect_x = cx - h_w
|
119 |
+
rect_y = cy - h_h
|
120 |
+
rect_w = 2*h_w
|
121 |
+
rect_h = 2*h_h
|
122 |
+
if rect_x < 0:
|
123 |
+
rect_x = 0
|
124 |
+
if rect_y < 0:
|
125 |
+
rect_y = 0
|
126 |
+
|
127 |
+
if rect_x + rect_w > w:
|
128 |
+
rect_x = w - rect_w
|
129 |
+
if rect_y + rect_h > h:
|
130 |
+
rect_y = h - rect_h
|
131 |
+
|
132 |
+
rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
|
133 |
+
return rect # this x is width, y is height
|
134 |
+
|
135 |
+
|
136 |
+
# def get_lip_rect(lms, h, w):
|
137 |
+
# """
|
138 |
+
# lms: [68, 2]
|
139 |
+
# h, w: int
|
140 |
+
# return: [4,]
|
141 |
+
# """
|
142 |
+
# assert len(lms) == 68
|
143 |
+
# lips = slice(48, 60)
|
144 |
+
# # this x is width, y is height
|
145 |
+
# xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max())
|
146 |
+
# ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max())
|
147 |
+
# # padding to H == W
|
148 |
+
# cx = (xmin + xmax) // 2
|
149 |
+
# cy = (ymin + ymax) // 2
|
150 |
+
# l = max(xmax - xmin, ymax - ymin) // 2
|
151 |
+
# xmin = max(0, cx - l)
|
152 |
+
# xmax = min(h, cx + l)
|
153 |
+
# ymin = max(0, cy - l)
|
154 |
+
# ymax = min(w, cy + l)
|
155 |
+
# lip_rect = [xmin, xmax, ymin, ymax]
|
156 |
+
# return lip_rect
|
157 |
+
|
158 |
+
def get_win_conds(conds, idx, smo_win_size=8, pad_option='zero'):
|
159 |
+
"""
|
160 |
+
conds: [b, t=16, h=29]
|
161 |
+
idx: long, time index of the selected frame
|
162 |
+
"""
|
163 |
+
idx = max(0, idx)
|
164 |
+
idx = min(idx, conds.shape[0]-1)
|
165 |
+
smo_half_win_size = smo_win_size//2
|
166 |
+
left_i = idx - smo_half_win_size
|
167 |
+
right_i = idx + (smo_win_size - smo_half_win_size)
|
168 |
+
pad_left, pad_right = 0, 0
|
169 |
+
if left_i < 0:
|
170 |
+
pad_left = -left_i
|
171 |
+
left_i = 0
|
172 |
+
if right_i > conds.shape[0]:
|
173 |
+
pad_right = right_i - conds.shape[0]
|
174 |
+
right_i = conds.shape[0]
|
175 |
+
conds_win = conds[left_i:right_i]
|
176 |
+
if pad_left > 0:
|
177 |
+
if pad_option == 'zero':
|
178 |
+
conds_win = np.concatenate([np.zeros_like(conds_win)[:pad_left], conds_win], axis=0)
|
179 |
+
elif pad_option == 'edge':
|
180 |
+
edge_value = conds[0][np.newaxis, ...]
|
181 |
+
conds_win = np.concatenate([edge_value] * pad_left + [conds_win], axis=0)
|
182 |
+
else:
|
183 |
+
raise NotImplementedError
|
184 |
+
if pad_right > 0:
|
185 |
+
if pad_option == 'zero':
|
186 |
+
conds_win = np.concatenate([conds_win, np.zeros_like(conds_win)[:pad_right]], axis=0)
|
187 |
+
elif pad_option == 'edge':
|
188 |
+
edge_value = conds[-1][np.newaxis, ...]
|
189 |
+
conds_win = np.concatenate([conds_win] + [edge_value] * pad_right , axis=0)
|
190 |
+
else:
|
191 |
+
raise NotImplementedError
|
192 |
+
assert conds_win.shape[0] == smo_win_size
|
193 |
+
return conds_win
|
194 |
+
|
195 |
+
|
196 |
+
def load_processed_data(processed_dir):
|
197 |
+
# load necessary files
|
198 |
+
background_img_name = os.path.join(processed_dir, "bg.jpg")
|
199 |
+
assert os.path.exists(background_img_name)
|
200 |
+
head_img_dir = os.path.join(processed_dir, "head_imgs")
|
201 |
+
torso_img_dir = os.path.join(processed_dir, "inpaint_torso_imgs")
|
202 |
+
gt_img_dir = os.path.join(processed_dir, "gt_imgs")
|
203 |
+
|
204 |
+
hubert_npy_name = os.path.join(processed_dir, "aud_hubert.npy")
|
205 |
+
mel_f0_npy_name = os.path.join(processed_dir, "aud_mel_f0.npy")
|
206 |
+
coeff_npy_name = os.path.join(processed_dir, "coeff_fit_mp.npy")
|
207 |
+
lm2d_npy_name = os.path.join(processed_dir, "lms_2d.npy")
|
208 |
+
|
209 |
+
ret_dict = {}
|
210 |
+
|
211 |
+
ret_dict['bg_img'] = imageio.imread(background_img_name)
|
212 |
+
ret_dict['H'], ret_dict['W'] = ret_dict['bg_img'].shape[:2]
|
213 |
+
ret_dict['focal'], ret_dict['cx'], ret_dict['cy'] = face_model.focal, face_model.center, face_model.center
|
214 |
+
|
215 |
+
print("loading lm2d coeff ...")
|
216 |
+
lm2d_arr = np.load(lm2d_npy_name)
|
217 |
+
face_rect_lst = []
|
218 |
+
lip_rect_lst = []
|
219 |
+
for lm2d in lm2d_arr:
|
220 |
+
if len(lm2d) in [468, 478]:
|
221 |
+
lm2d = lm2d[index_lm68_from_lm468]
|
222 |
+
face_rect = get_face_rect(lm2d, ret_dict['H'], ret_dict['W'])
|
223 |
+
lip_rect = get_lip_rect(lm2d, ret_dict['H'], ret_dict['W'])
|
224 |
+
face_rect_lst.append(face_rect)
|
225 |
+
lip_rect_lst.append(lip_rect)
|
226 |
+
face_rects = np.stack(face_rect_lst, axis=0) # [T, 4]
|
227 |
+
|
228 |
+
print("loading fitted 3dmm coeff ...")
|
229 |
+
coeff_dict = np.load(coeff_npy_name, allow_pickle=True).tolist()
|
230 |
+
identity_arr = coeff_dict['id']
|
231 |
+
exp_arr = coeff_dict['exp']
|
232 |
+
ret_dict['id'] = identity_arr
|
233 |
+
ret_dict['exp'] = exp_arr
|
234 |
+
euler_arr = ret_dict['euler'] = coeff_dict['euler']
|
235 |
+
trans_arr = ret_dict['trans'] = coeff_dict['trans']
|
236 |
+
print("calculating lm3d ...")
|
237 |
+
idexp_lm3d_arr = face3d_helper.reconstruct_idexp_lm3d(torch.from_numpy(identity_arr), torch.from_numpy(exp_arr)).cpu().numpy().reshape([-1, 68*3])
|
238 |
+
len_motion = len(idexp_lm3d_arr)
|
239 |
+
video_idexp_lm3d_mean = idexp_lm3d_arr.mean(axis=0)
|
240 |
+
video_idexp_lm3d_std = idexp_lm3d_arr.std(axis=0)
|
241 |
+
ret_dict['idexp_lm3d'] = idexp_lm3d_arr
|
242 |
+
ret_dict['idexp_lm3d_mean'] = video_idexp_lm3d_mean
|
243 |
+
ret_dict['idexp_lm3d_std'] = video_idexp_lm3d_std
|
244 |
+
|
245 |
+
# now we convert the euler_trans from deep3d convention to adnerf convention
|
246 |
+
eulers = torch.FloatTensor(euler_arr)
|
247 |
+
trans = torch.FloatTensor(trans_arr)
|
248 |
+
rots = face_model.compute_rotation(eulers) # rotation matrix is a better intermediate for convention-transplan than euler
|
249 |
+
|
250 |
+
# handle the camera pose to geneface's convention
|
251 |
+
trans[:, 2] = 10 - trans[:, 2] # 抵消fit阶段的to_camera操作,即trans[...,2] = 10 - trans[...,2]
|
252 |
+
rots = rots.permute(0, 2, 1)
|
253 |
+
trans[:, 2] = - trans[:,2] # 因为intrinsic proj不同
|
254 |
+
# below is the NeRF camera preprocessing strategy, see `save_transforms` in data_util/process.py
|
255 |
+
trans = trans / 10.0
|
256 |
+
rots_inv = rots.permute(0, 2, 1)
|
257 |
+
trans_inv = - torch.bmm(rots_inv, trans.unsqueeze(2))
|
258 |
+
|
259 |
+
pose = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat([len_motion, 1, 1]) # [T, 4, 4]
|
260 |
+
pose[:, :3, :3] = rots_inv
|
261 |
+
pose[:, :3, 3] = trans_inv[:, :, 0]
|
262 |
+
c2w_transform_matrices = pose.numpy()
|
263 |
+
|
264 |
+
# process the audio features used for postnet training
|
265 |
+
print("loading hubert ...")
|
266 |
+
hubert_features = np.load(hubert_npy_name)
|
267 |
+
print("loading Mel and F0 ...")
|
268 |
+
mel_f0_features = np.load(mel_f0_npy_name, allow_pickle=True).tolist()
|
269 |
+
|
270 |
+
ret_dict['hubert'] = hubert_features
|
271 |
+
ret_dict['mel'] = mel_f0_features['mel']
|
272 |
+
ret_dict['f0'] = mel_f0_features['f0']
|
273 |
+
|
274 |
+
# obtaining train samples
|
275 |
+
frame_indices = list(range(len_motion))
|
276 |
+
num_train = len_motion // 11 * 10
|
277 |
+
train_indices = frame_indices[:num_train]
|
278 |
+
val_indices = frame_indices[num_train:]
|
279 |
+
|
280 |
+
for split in ['train', 'val']:
|
281 |
+
if split == 'train':
|
282 |
+
indices = train_indices
|
283 |
+
samples = []
|
284 |
+
ret_dict['train_samples'] = samples
|
285 |
+
elif split == 'val':
|
286 |
+
indices = val_indices
|
287 |
+
samples = []
|
288 |
+
ret_dict['val_samples'] = samples
|
289 |
+
|
290 |
+
for idx in indices:
|
291 |
+
sample = {}
|
292 |
+
sample['idx'] = idx
|
293 |
+
sample['head_img_fname'] = os.path.join(head_img_dir,f"{idx:08d}.png")
|
294 |
+
sample['torso_img_fname'] = os.path.join(torso_img_dir,f"{idx:08d}.png")
|
295 |
+
sample['gt_img_fname'] = os.path.join(gt_img_dir,f"{idx:08d}.jpg")
|
296 |
+
# assert os.path.exists(sample['head_img_fname']) and os.path.exists(sample['torso_img_fname']) and os.path.exists(sample['gt_img_fname'])
|
297 |
+
sample['face_rect'] = face_rects[idx]
|
298 |
+
sample['lip_rect'] = lip_rect_lst[idx]
|
299 |
+
sample['c2w'] = c2w_transform_matrices[idx]
|
300 |
+
samples.append(sample)
|
301 |
+
return ret_dict
|
302 |
+
|
303 |
+
|
304 |
+
class Binarizer:
|
305 |
+
def __init__(self):
|
306 |
+
self.data_dir = 'data/'
|
307 |
+
|
308 |
+
def parse(self, video_id):
|
309 |
+
processed_dir = os.path.join(self.data_dir, 'processed/videos', video_id)
|
310 |
+
binary_dir = os.path.join(self.data_dir, 'binary/videos', video_id)
|
311 |
+
out_fname = os.path.join(binary_dir, "trainval_dataset.npy")
|
312 |
+
os.makedirs(binary_dir, exist_ok=True)
|
313 |
+
ret = load_processed_data(processed_dir)
|
314 |
+
mel_name = os.path.join(processed_dir, 'aud_mel_f0.npy')
|
315 |
+
mel_f0_dict = np.load(mel_name, allow_pickle=True).tolist()
|
316 |
+
ret.update(mel_f0_dict)
|
317 |
+
np.save(out_fname, ret, allow_pickle=True)
|
318 |
+
|
319 |
+
|
320 |
+
|
321 |
+
if __name__ == '__main__':
|
322 |
+
from argparse import ArgumentParser
|
323 |
+
parser = ArgumentParser()
|
324 |
+
parser.add_argument('--video_id', type=str, default='May', help='')
|
325 |
+
args = parser.parse_args()
|
326 |
+
### Process Single Long Audio for NeRF dataset
|
327 |
+
video_id = args.video_id
|
328 |
+
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
|
329 |
+
camera_distance=10, focal=1015)
|
330 |
+
face_model.to("cpu")
|
331 |
+
face3d_helper = Face3DHelper()
|
332 |
+
|
333 |
+
binarizer = Binarizer()
|
334 |
+
binarizer.parse(video_id)
|
335 |
+
print(f"Binarization for {video_id} Done!")
|
data_gen/runs/binarizer_th1kh.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from scipy.misc import face
|
4 |
+
import torch
|
5 |
+
from tqdm import trange
|
6 |
+
import pickle
|
7 |
+
from copy import deepcopy
|
8 |
+
|
9 |
+
from data_util.face3d_helper import Face3DHelper
|
10 |
+
from utils.commons.indexed_datasets import IndexedDataset, IndexedDatasetBuilder
|
11 |
+
|
12 |
+
|
13 |
+
def load_video_npy(fn):
|
14 |
+
assert fn.endswith("_coeff_fit_mp.npy")
|
15 |
+
ret_dict = np.load(fn,allow_pickle=True).item()
|
16 |
+
video_dict = {
|
17 |
+
'euler': ret_dict['euler'], # [T, 3]
|
18 |
+
'trans': ret_dict['trans'], # [T, 3]
|
19 |
+
'id': ret_dict['id'], # [T, 80]
|
20 |
+
'exp': ret_dict['exp'], # [T, 64]
|
21 |
+
}
|
22 |
+
return video_dict
|
23 |
+
|
24 |
+
def cal_lm3d_in_video_dict(video_dict, face3d_helper):
|
25 |
+
identity = video_dict['id']
|
26 |
+
exp = video_dict['exp']
|
27 |
+
idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d(identity, exp).cpu().numpy()
|
28 |
+
video_dict['idexp_lm3d'] = idexp_lm3d
|
29 |
+
|
30 |
+
|
31 |
+
def load_audio_npy(fn):
|
32 |
+
assert fn.endswith(".npy")
|
33 |
+
ret_dict = np.load(fn,allow_pickle=True).item()
|
34 |
+
audio_dict = {
|
35 |
+
"mel": ret_dict['mel'], # [T, 80]
|
36 |
+
"f0": ret_dict['f0'], # [T,1]
|
37 |
+
}
|
38 |
+
return audio_dict
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
face3d_helper = Face3DHelper(use_gpu=False)
|
43 |
+
|
44 |
+
import glob,tqdm
|
45 |
+
prefixs = ['val', 'train']
|
46 |
+
binarized_ds_path = "data/binary/th1kh"
|
47 |
+
os.makedirs(binarized_ds_path, exist_ok=True)
|
48 |
+
for prefix in prefixs:
|
49 |
+
databuilder = IndexedDatasetBuilder(os.path.join(binarized_ds_path, prefix), gzip=False, default_idx_size=1024*1024*1024*2)
|
50 |
+
raw_base_dir = '/mnt/bn/ailabrenyi/entries/yezhenhui/datasets/raw/TH1KH_512/video'
|
51 |
+
mp4_names = glob.glob(os.path.join(raw_base_dir, '*.mp4'))
|
52 |
+
mp4_names = mp4_names[:1000]
|
53 |
+
cnt = 0
|
54 |
+
scnt = 0
|
55 |
+
pbar = tqdm.tqdm(enumerate(mp4_names), total=len(mp4_names))
|
56 |
+
for i, mp4_name in pbar:
|
57 |
+
cnt += 1
|
58 |
+
if prefix == 'train':
|
59 |
+
if i % 100 == 0:
|
60 |
+
continue
|
61 |
+
else:
|
62 |
+
if i % 100 != 0:
|
63 |
+
continue
|
64 |
+
hubert_npy_name = mp4_name.replace("/video/", "/hubert/").replace(".mp4", "_hubert.npy")
|
65 |
+
audio_npy_name = mp4_name.replace("/video/", "/mel_f0/").replace(".mp4", "_mel_f0.npy")
|
66 |
+
video_npy_name = mp4_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4", "_coeff_fit_mp.npy")
|
67 |
+
if not os.path.exists(audio_npy_name):
|
68 |
+
print(f"Skip item for audio npy not found.")
|
69 |
+
continue
|
70 |
+
if not os.path.exists(video_npy_name):
|
71 |
+
print(f"Skip item for video npy not found.")
|
72 |
+
continue
|
73 |
+
if (not os.path.exists(hubert_npy_name)):
|
74 |
+
print(f"Skip item for hubert_npy not found.")
|
75 |
+
continue
|
76 |
+
audio_dict = load_audio_npy(audio_npy_name)
|
77 |
+
hubert = np.load(hubert_npy_name)
|
78 |
+
video_dict = load_video_npy(video_npy_name)
|
79 |
+
com_img_dir = mp4_name.replace("/video/", "/com_imgs/").replace(".mp4", "")
|
80 |
+
num_com_imgs = len(glob.glob(os.path.join(com_img_dir, '*')))
|
81 |
+
num_frames = len(video_dict['exp'])
|
82 |
+
if num_com_imgs != num_frames:
|
83 |
+
print(f"Skip item for length mismatch.")
|
84 |
+
continue
|
85 |
+
mel = audio_dict['mel']
|
86 |
+
if mel.shape[0] < 32: # the video is shorter than 0.6s
|
87 |
+
print(f"Skip item for too short.")
|
88 |
+
continue
|
89 |
+
|
90 |
+
audio_dict.update(video_dict)
|
91 |
+
audio_dict['item_id'] = os.path.basename(mp4_name)[:-4]
|
92 |
+
audio_dict['hubert'] = hubert # [T_x, hid=1024]
|
93 |
+
audio_dict['img_dir'] = com_img_dir
|
94 |
+
|
95 |
+
|
96 |
+
databuilder.add_item(audio_dict)
|
97 |
+
scnt += 1
|
98 |
+
pbar.set_postfix({'success': scnt, 'success rate': scnt / cnt})
|
99 |
+
databuilder.finalize()
|
100 |
+
print(f"{prefix} set has {cnt} samples!")
|
data_gen/runs/nerf/process_guide.md
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 温馨提示:第一次执行可以先一步步跑完下面的命令行,把环境跑通后,之后可以直接运行同目录的run.sh,一键完成下面的所有步骤。
|
2 |
+
|
3 |
+
# Step0. 将视频Crop到512x512分辨率,25FPS,确保每一帧都有目标人脸
|
4 |
+
```
|
5 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 data/raw/videos/${VIDEO_ID}_512.mp4
|
6 |
+
mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
|
7 |
+
mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
|
8 |
+
```
|
9 |
+
# step1: 提取音频特征, 如mel, f0, hubuert, esperanto
|
10 |
+
```
|
11 |
+
export CUDA_VISIBLE_DEVICES=0
|
12 |
+
export VIDEO_ID=May
|
13 |
+
mkdir -p data/processed/videos/${VIDEO_ID}
|
14 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 data/processed/videos/${VIDEO_ID}/aud.wav
|
15 |
+
python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
|
16 |
+
python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
|
17 |
+
```
|
18 |
+
|
19 |
+
# Step2. 提取图片
|
20 |
+
```
|
21 |
+
export VIDEO_ID=May
|
22 |
+
export CUDA_VISIBLE_DEVICES=0
|
23 |
+
mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
|
24 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
|
25 |
+
python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
|
26 |
+
```
|
27 |
+
|
28 |
+
# Step3. 提取lm2d_mediapipe
|
29 |
+
### 提取2D landmark用于之后Fit 3DMM
|
30 |
+
### num_workers是本机上的CPU worker数量;total_process是使用的机器数;process_id是本机的编号
|
31 |
+
|
32 |
+
```
|
33 |
+
export VIDEO_ID=May
|
34 |
+
python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
|
35 |
+
```
|
36 |
+
|
37 |
+
# Step3. fit 3dmm
|
38 |
+
```
|
39 |
+
export VIDEO_ID=May
|
40 |
+
export CUDA_VISIBLE_DEVICES=0
|
41 |
+
python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
|
42 |
+
```
|
43 |
+
|
44 |
+
# Step4. Binarize
|
45 |
+
```
|
46 |
+
export VIDEO_ID=May
|
47 |
+
python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
|
48 |
+
```
|
49 |
+
可以看到在`data/binary/videos/Mayssss`目录下得到了数据集。
|
data_gen/runs/nerf/run.sh
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# usage: CUDA_VISIBLE_DEVICES=0 bash data_gen/runs/nerf/run.sh <VIDEO_ID>
|
2 |
+
# please place video to data/raw/videos/${VIDEO_ID}.mp4
|
3 |
+
VIDEO_ID=$1
|
4 |
+
echo Processing $VIDEO_ID
|
5 |
+
|
6 |
+
echo Resizing the video to 512x512
|
7 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -y data/raw/videos/${VIDEO_ID}_512.mp4
|
8 |
+
mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
|
9 |
+
mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
|
10 |
+
echo Done
|
11 |
+
echo The old video is moved to data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
|
12 |
+
|
13 |
+
echo mkdir -p data/processed/videos/${VIDEO_ID}
|
14 |
+
mkdir -p data/processed/videos/${VIDEO_ID}
|
15 |
+
echo Done
|
16 |
+
|
17 |
+
# extract audio file from the training video
|
18 |
+
echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
|
19 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
|
20 |
+
echo Done
|
21 |
+
|
22 |
+
# extract hubert_mel_f0 from audio
|
23 |
+
echo python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
|
24 |
+
python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
|
25 |
+
echo python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
|
26 |
+
python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
|
27 |
+
echo Done
|
28 |
+
|
29 |
+
# extract segment images
|
30 |
+
echo mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
|
31 |
+
mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
|
32 |
+
echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
|
33 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
|
34 |
+
echo Done
|
35 |
+
|
36 |
+
echo python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
|
37 |
+
python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
|
38 |
+
echo Done
|
39 |
+
|
40 |
+
echo python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
|
41 |
+
python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
|
42 |
+
echo Done
|
43 |
+
|
44 |
+
pkill -f void*
|
45 |
+
echo python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
|
46 |
+
python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
|
47 |
+
echo Done
|
48 |
+
|
49 |
+
echo python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
|
50 |
+
python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
|
51 |
+
echo Done
|
data_gen/utils/mp_feature_extractors/face_landmarker.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mediapipe as mp
|
2 |
+
from mediapipe.tasks import python
|
3 |
+
from mediapipe.tasks.python import vision
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
|
9 |
+
# simplified mediapipe ldm at https://github.com/k-m-irfan/simplified_mediapipe_face_landmarks
|
10 |
+
index_lm141_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [468,469,470,471,472] + [473,474,475,476,477] + [64,4,294]
|
11 |
+
# lm141 without iris
|
12 |
+
index_lm131_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [64,4,294]
|
13 |
+
|
14 |
+
# face alignment lm68
|
15 |
+
index_lm68_from_lm478 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
|
16 |
+
33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
|
17 |
+
# used for weights for key parts
|
18 |
+
unmatch_mask_from_lm478 = [ 93, 127, 132, 234, 323, 356, 361, 454]
|
19 |
+
index_eye_from_lm478 = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
|
20 |
+
index_innerlip_from_lm478 = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
|
21 |
+
index_outerlip_from_lm478 = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
|
22 |
+
index_withinmouth_from_lm478 = [76, 62] + [184, 183, 74, 72, 73, 41, 72, 38, 11, 12, 302, 268, 303, 271, 304, 272, 408, 407] + [292, 306] + [325, 307, 319, 320, 403, 404, 316, 315, 15, 16, 86, 85, 179, 180, 89, 90, 96, 77]
|
23 |
+
index_mouth_from_lm478 = index_innerlip_from_lm478 + index_outerlip_from_lm478 + index_withinmouth_from_lm478
|
24 |
+
|
25 |
+
index_yaw_from_lm68 = list(range(0, 17))
|
26 |
+
index_brow_from_lm68 = list(range(17, 27))
|
27 |
+
index_nose_from_lm68 = list(range(27, 36))
|
28 |
+
index_eye_from_lm68 = list(range(36, 48))
|
29 |
+
index_mouth_from_lm68 = list(range(48, 68))
|
30 |
+
|
31 |
+
|
32 |
+
def read_video_to_frames(video_name):
|
33 |
+
frames = []
|
34 |
+
cap = cv2.VideoCapture(video_name)
|
35 |
+
while cap.isOpened():
|
36 |
+
ret, frame_bgr = cap.read()
|
37 |
+
if frame_bgr is None:
|
38 |
+
break
|
39 |
+
frames.append(frame_bgr)
|
40 |
+
frames = np.stack(frames)
|
41 |
+
frames = np.flip(frames, -1) # BGR ==> RGB
|
42 |
+
return frames
|
43 |
+
|
44 |
+
class MediapipeLandmarker:
|
45 |
+
def __init__(self):
|
46 |
+
model_path = 'data_gen/utils/mp_feature_extractors/face_landmarker.task'
|
47 |
+
if not os.path.exists(model_path):
|
48 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
49 |
+
print("downloading face_landmarker model from mediapipe...")
|
50 |
+
model_url = 'https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/latest/face_landmarker.task'
|
51 |
+
os.system(f"wget {model_url}")
|
52 |
+
os.system(f"mv face_landmarker.task {model_path}")
|
53 |
+
print("download success")
|
54 |
+
base_options = python.BaseOptions(model_asset_path=model_path)
|
55 |
+
self.image_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
|
56 |
+
running_mode=vision.RunningMode.IMAGE, # IMAGE, VIDEO, LIVE_STREAM
|
57 |
+
num_faces=1)
|
58 |
+
self.video_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
|
59 |
+
running_mode=vision.RunningMode.VIDEO, # IMAGE, VIDEO, LIVE_STREAM
|
60 |
+
num_faces=1)
|
61 |
+
|
62 |
+
def extract_lm478_from_img_name(self, img_name):
|
63 |
+
img = cv2.imread(img_name)
|
64 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
65 |
+
img_lm478 = self.extract_lm478_from_img(img)
|
66 |
+
return img_lm478
|
67 |
+
|
68 |
+
def extract_lm478_from_img(self, img):
|
69 |
+
img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
|
70 |
+
frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=img.astype(np.uint8))
|
71 |
+
img_face_landmarker_result = img_landmarker.detect(image=frame)
|
72 |
+
img_ldm_i = img_face_landmarker_result.face_landmarks[0]
|
73 |
+
img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
|
74 |
+
H, W, _ = img.shape
|
75 |
+
img_lm478 = np.array(img_face_landmarks)[:, :2] * np.array([W, H]).reshape([1,2]) # [478, 2]
|
76 |
+
return img_lm478
|
77 |
+
|
78 |
+
def extract_lm478_from_video_name(self, video_name, fps=25, anti_smooth_factor=2):
|
79 |
+
frames = read_video_to_frames(video_name)
|
80 |
+
img_lm478, vid_lm478 = self.extract_lm478_from_frames(frames, fps, anti_smooth_factor)
|
81 |
+
return img_lm478, vid_lm478
|
82 |
+
|
83 |
+
def extract_lm478_from_frames(self, frames, fps=25, anti_smooth_factor=20):
|
84 |
+
"""
|
85 |
+
frames: RGB, uint8
|
86 |
+
anti_smooth_factor: float, 对video模式的interval进行修改, 1代表无修改, 越大越接近image mode
|
87 |
+
"""
|
88 |
+
img_mpldms = []
|
89 |
+
vid_mpldms = []
|
90 |
+
img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
|
91 |
+
vid_landmarker = vision.FaceLandmarker.create_from_options(self.video_mode_options)
|
92 |
+
|
93 |
+
for i in range(len(frames)):
|
94 |
+
frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=frames[i].astype(np.uint8))
|
95 |
+
img_face_landmarker_result = img_landmarker.detect(image=frame)
|
96 |
+
vid_face_landmarker_result = vid_landmarker.detect_for_video(image=frame, timestamp_ms=int((1000/fps)*anti_smooth_factor*i))
|
97 |
+
try:
|
98 |
+
img_ldm_i = img_face_landmarker_result.face_landmarks[0]
|
99 |
+
vid_ldm_i = vid_face_landmarker_result.face_landmarks[0]
|
100 |
+
except:
|
101 |
+
print(f"Warning: failed detect ldm in idx={i}, use previous frame results.")
|
102 |
+
img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
|
103 |
+
vid_face_landmarks = np.array([[l.x, l.y, l.z] for l in vid_ldm_i])
|
104 |
+
img_mpldms.append(img_face_landmarks)
|
105 |
+
vid_mpldms.append(vid_face_landmarks)
|
106 |
+
img_lm478 = np.stack(img_mpldms)[..., :2]
|
107 |
+
vid_lm478 = np.stack(vid_mpldms)[..., :2]
|
108 |
+
bs, H, W, _ = frames.shape
|
109 |
+
img_lm478 = np.array(img_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
|
110 |
+
vid_lm478 = np.array(vid_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
|
111 |
+
return img_lm478, vid_lm478
|
112 |
+
|
113 |
+
def combine_vid_img_lm478_to_lm68(self, img_lm478, vid_lm478):
|
114 |
+
img_lm68 = img_lm478[:, index_lm68_from_lm478]
|
115 |
+
vid_lm68 = vid_lm478[:, index_lm68_from_lm478]
|
116 |
+
combined_lm68 = copy.deepcopy(img_lm68)
|
117 |
+
combined_lm68[:, index_yaw_from_lm68] = vid_lm68[:, index_yaw_from_lm68]
|
118 |
+
combined_lm68[:, index_brow_from_lm68] = vid_lm68[:, index_brow_from_lm68]
|
119 |
+
combined_lm68[:, index_nose_from_lm68] = vid_lm68[:, index_nose_from_lm68]
|
120 |
+
return combined_lm68
|
121 |
+
|
122 |
+
def combine_vid_img_lm478_to_lm478(self, img_lm478, vid_lm478):
|
123 |
+
combined_lm478 = copy.deepcopy(vid_lm478)
|
124 |
+
combined_lm478[:, index_mouth_from_lm478] = img_lm478[:, index_mouth_from_lm478]
|
125 |
+
combined_lm478[:, index_eye_from_lm478] = img_lm478[:, index_eye_from_lm478]
|
126 |
+
return combined_lm478
|
127 |
+
|
128 |
+
if __name__ == '__main__':
|
129 |
+
landmarker = MediapipeLandmarker()
|
130 |
+
ret = landmarker.extract_lm478_from_video_name("00000.mp4")
|
data_gen/utils/mp_feature_extractors/face_landmarker.task
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
|
3 |
+
size 3758596
|
data_gen/utils/mp_feature_extractors/mp_segmenter.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import numpy as np
|
4 |
+
import tqdm
|
5 |
+
import mediapipe as mp
|
6 |
+
import torch
|
7 |
+
from mediapipe.tasks import python
|
8 |
+
from mediapipe.tasks.python import vision
|
9 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run
|
10 |
+
from utils.commons.tensor_utils import convert_to_np
|
11 |
+
from sklearn.neighbors import NearestNeighbors
|
12 |
+
|
13 |
+
def scatter_np(condition_img, classSeg=5):
|
14 |
+
# def scatter(condition_img, classSeg=19, label_size=(512, 512)):
|
15 |
+
batch, c, height, width = condition_img.shape
|
16 |
+
# if height != label_size[0] or width != label_size[1]:
|
17 |
+
# condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
|
18 |
+
input_label = np.zeros([batch, classSeg, condition_img.shape[2], condition_img.shape[3]]).astype(np.int_)
|
19 |
+
# input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
|
20 |
+
np.put_along_axis(input_label, condition_img, 1, 1)
|
21 |
+
return input_label
|
22 |
+
|
23 |
+
def scatter(condition_img, classSeg=19):
|
24 |
+
# def scatter(condition_img, classSeg=19, label_size=(512, 512)):
|
25 |
+
batch, c, height, width = condition_img.size()
|
26 |
+
# if height != label_size[0] or width != label_size[1]:
|
27 |
+
# condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
|
28 |
+
input_label = torch.zeros(batch, classSeg, condition_img.shape[2], condition_img.shape[3], device=condition_img.device)
|
29 |
+
# input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
|
30 |
+
return input_label.scatter_(1, condition_img.long(), 1)
|
31 |
+
|
32 |
+
def encode_segmap_mask_to_image(segmap):
|
33 |
+
# rgb
|
34 |
+
_,h,w = segmap.shape
|
35 |
+
encoded_img = np.ones([h,w,3],dtype=np.uint8) * 255
|
36 |
+
colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
|
37 |
+
for i, color in enumerate(colors):
|
38 |
+
mask = segmap[i].astype(int)
|
39 |
+
index = np.where(mask != 0)
|
40 |
+
encoded_img[index[0], index[1], :] = np.array(color)
|
41 |
+
return encoded_img.astype(np.uint8)
|
42 |
+
|
43 |
+
def decode_segmap_mask_from_image(encoded_img):
|
44 |
+
# rgb
|
45 |
+
colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
|
46 |
+
bg = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
|
47 |
+
hair = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
|
48 |
+
body_skin = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 255)
|
49 |
+
face_skin = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
|
50 |
+
clothes = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 0)
|
51 |
+
others = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
|
52 |
+
segmap = np.stack([bg, hair, body_skin, face_skin, clothes, others], axis=0)
|
53 |
+
return segmap.astype(np.uint8)
|
54 |
+
|
55 |
+
def read_video_frame(video_name, frame_id):
|
56 |
+
# https://blog.csdn.net/bby1987/article/details/108923361
|
57 |
+
# frame_num = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) # ==> 总帧数
|
58 |
+
# fps = video_capture.get(cv2.CAP_PROP_FPS) # ==> 帧率
|
59 |
+
# width = video_capture.get(cv2.CAP_PROP_FRAME_WIDTH) # ==> 视频宽度
|
60 |
+
# height = video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) # ==> 视频高度
|
61 |
+
# pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 句柄位置
|
62 |
+
# video_capture.set(cv2.CAP_PROP_POS_FRAMES, 1000) # ==> 设置句柄位置
|
63 |
+
# pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 此时 pos = 1000.0
|
64 |
+
# video_capture.release()
|
65 |
+
vr = cv2.VideoCapture(video_name)
|
66 |
+
vr.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
|
67 |
+
_, frame = vr.read()
|
68 |
+
return frame
|
69 |
+
|
70 |
+
def decode_segmap_mask_from_segmap_video_frame(video_frame):
|
71 |
+
# video_frame: 0~255 BGR, obtained by read_video_frame
|
72 |
+
def assign_values(array):
|
73 |
+
remainder = array % 40 # 计算数组中每个值与40的余数
|
74 |
+
assigned_values = np.where(remainder <= 20, array - remainder, array + (40 - remainder))
|
75 |
+
return assigned_values
|
76 |
+
segmap = video_frame.mean(-1)
|
77 |
+
segmap = assign_values(segmap) // 40 # [H, W] with value 0~5
|
78 |
+
segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
|
79 |
+
return segmap.astype(np.uint8)
|
80 |
+
|
81 |
+
def extract_background(img_lst, segmap_lst=None):
|
82 |
+
"""
|
83 |
+
img_lst: list of rgb ndarray
|
84 |
+
"""
|
85 |
+
# only use 1/20 images
|
86 |
+
num_frames = len(img_lst)
|
87 |
+
img_lst = img_lst[::20] if num_frames > 20 else img_lst[0:1]
|
88 |
+
|
89 |
+
if segmap_lst is not None:
|
90 |
+
segmap_lst = segmap_lst[::20] if num_frames > 20 else segmap_lst[0:1]
|
91 |
+
assert len(img_lst) == len(segmap_lst)
|
92 |
+
# get H/W
|
93 |
+
h, w = img_lst[0].shape[:2]
|
94 |
+
|
95 |
+
# nearest neighbors
|
96 |
+
all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
|
97 |
+
distss = []
|
98 |
+
for idx, img in enumerate(img_lst):
|
99 |
+
if segmap_lst is not None:
|
100 |
+
segmap = segmap_lst[idx]
|
101 |
+
else:
|
102 |
+
segmap = seg_model._cal_seg_map(img)
|
103 |
+
bg = (segmap[0]).astype(bool)
|
104 |
+
fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
|
105 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
106 |
+
dists, _ = nbrs.kneighbors(all_xys)
|
107 |
+
distss.append(dists)
|
108 |
+
|
109 |
+
distss = np.stack(distss)
|
110 |
+
max_dist = np.max(distss, 0)
|
111 |
+
max_id = np.argmax(distss, 0)
|
112 |
+
|
113 |
+
bc_pixs = max_dist > 10 # 5
|
114 |
+
bc_pixs_id = np.nonzero(bc_pixs)
|
115 |
+
bc_ids = max_id[bc_pixs]
|
116 |
+
|
117 |
+
num_pixs = distss.shape[1]
|
118 |
+
imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
|
119 |
+
|
120 |
+
bg_img = np.zeros((h*w, 3), dtype=np.uint8)
|
121 |
+
bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
|
122 |
+
bg_img = bg_img.reshape(h, w, 3)
|
123 |
+
|
124 |
+
max_dist = max_dist.reshape(h, w)
|
125 |
+
bc_pixs = max_dist > 10 # 5
|
126 |
+
bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
|
127 |
+
fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
|
128 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
129 |
+
distances, indices = nbrs.kneighbors(bg_xys)
|
130 |
+
bg_fg_xys = fg_xys[indices[:, 0]]
|
131 |
+
bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
|
132 |
+
return bg_img
|
133 |
+
|
134 |
+
|
135 |
+
global_segmenter = None
|
136 |
+
def job_cal_seg_map_for_image(img, segmenter_options=None, segmenter=None):
|
137 |
+
"""
|
138 |
+
被 MediapipeSegmenter.multiprocess_cal_seg_map_for_a_video所使用, 专门用来处理单个长视频.
|
139 |
+
"""
|
140 |
+
global global_segmenter
|
141 |
+
if segmenter is not None:
|
142 |
+
segmenter_actual = segmenter
|
143 |
+
else:
|
144 |
+
global_segmenter = vision.ImageSegmenter.create_from_options(segmenter_options) if global_segmenter is None else global_segmenter
|
145 |
+
segmenter_actual = global_segmenter
|
146 |
+
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
|
147 |
+
out = segmenter_actual.segment(mp_image)
|
148 |
+
segmap = out.category_mask.numpy_view().copy() # [H, W]
|
149 |
+
|
150 |
+
segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
|
151 |
+
segmap_image = segmap[:, :, None].repeat(3, 2).astype(float)
|
152 |
+
segmap_image = (segmap_image * 40).astype(np.uint8)
|
153 |
+
|
154 |
+
return segmap_mask, segmap_image
|
155 |
+
|
156 |
+
class MediapipeSegmenter:
|
157 |
+
def __init__(self):
|
158 |
+
model_path = 'data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite'
|
159 |
+
if not os.path.exists(model_path):
|
160 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
161 |
+
print("downloading segmenter model from mediapipe...")
|
162 |
+
os.system(f"wget https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite")
|
163 |
+
os.system(f"mv selfie_multiclass_256x256.tflite {model_path}")
|
164 |
+
print("download success")
|
165 |
+
base_options = python.BaseOptions(model_asset_path=model_path)
|
166 |
+
self.options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.IMAGE, output_category_mask=True)
|
167 |
+
self.video_options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.VIDEO, output_category_mask=True)
|
168 |
+
|
169 |
+
def multiprocess_cal_seg_map_for_a_video(self, imgs, num_workers=4):
|
170 |
+
"""
|
171 |
+
并行处理单个长视频
|
172 |
+
imgs: list of rgb array in 0~255
|
173 |
+
"""
|
174 |
+
segmap_masks = []
|
175 |
+
segmap_images = []
|
176 |
+
img_lst = [(self.options, imgs[i]) for i in range(len(imgs))]
|
177 |
+
for (i, res) in multiprocess_run_tqdm(job_cal_seg_map_for_image, args=img_lst, num_workers=num_workers, desc='extracting from a video in multi-process'):
|
178 |
+
segmap_mask, segmap_image = res
|
179 |
+
segmap_masks.append(segmap_mask)
|
180 |
+
segmap_images.append(segmap_image)
|
181 |
+
return segmap_masks, segmap_images
|
182 |
+
|
183 |
+
def _cal_seg_map_for_video(self, imgs, segmenter=None, return_onehot_mask=True, return_segmap_image=True):
|
184 |
+
segmenter = vision.ImageSegmenter.create_from_options(self.video_options) if segmenter is None else segmenter
|
185 |
+
assert return_onehot_mask or return_segmap_image # you should at least return one
|
186 |
+
segmap_masks = []
|
187 |
+
segmap_images = []
|
188 |
+
for i in tqdm.trange(len(imgs), desc="extracting segmaps from a video..."):
|
189 |
+
img = imgs[i]
|
190 |
+
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
|
191 |
+
out = segmenter.segment_for_video(mp_image, 40 * i)
|
192 |
+
segmap = out.category_mask.numpy_view().copy() # [H, W]
|
193 |
+
|
194 |
+
if return_onehot_mask:
|
195 |
+
segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
|
196 |
+
segmap_masks.append(segmap_mask)
|
197 |
+
if return_segmap_image:
|
198 |
+
segmap_image = segmap[:, :, None].repeat(3, 2).astype(float)
|
199 |
+
segmap_image = (segmap_image * 40).astype(np.uint8)
|
200 |
+
segmap_images.append(segmap_image)
|
201 |
+
|
202 |
+
if return_onehot_mask and return_segmap_image:
|
203 |
+
return segmap_masks, segmap_images
|
204 |
+
elif return_onehot_mask:
|
205 |
+
return segmap_masks
|
206 |
+
elif return_segmap_image:
|
207 |
+
return segmap_images
|
208 |
+
|
209 |
+
def _cal_seg_map(self, img, segmenter=None, return_onehot_mask=True):
|
210 |
+
"""
|
211 |
+
segmenter: vision.ImageSegmenter.create_from_options(options)
|
212 |
+
img: numpy, [H, W, 3], 0~255
|
213 |
+
segmap: [C, H, W]
|
214 |
+
0 - background
|
215 |
+
1 - hair
|
216 |
+
2 - body-skin
|
217 |
+
3 - face-skin
|
218 |
+
4 - clothes
|
219 |
+
5 - others (accessories)
|
220 |
+
"""
|
221 |
+
assert img.ndim == 3
|
222 |
+
segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
|
223 |
+
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
|
224 |
+
out = segmenter.segment(image)
|
225 |
+
segmap = out.category_mask.numpy_view().copy() # [H, W]
|
226 |
+
if return_onehot_mask:
|
227 |
+
segmap = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
|
228 |
+
return segmap
|
229 |
+
|
230 |
+
def _seg_out_img_with_segmap(self, img, segmap, mode='head'):
|
231 |
+
"""
|
232 |
+
img: [h,w,c], img is in 0~255, np
|
233 |
+
"""
|
234 |
+
#
|
235 |
+
img = copy.deepcopy(img)
|
236 |
+
if mode == 'head':
|
237 |
+
selected_mask = segmap[[1,3,5] , :, :].sum(axis=0)[None,:] > 0.5 # glasses 也属于others
|
238 |
+
img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
|
239 |
+
# selected_mask = segmap[[1,3] , :, :].sum(dim=0, keepdim=True) > 0.5
|
240 |
+
elif mode == 'person':
|
241 |
+
selected_mask = segmap[[1,2,3,4,5], :, :].sum(axis=0)[None,:] > 0.5
|
242 |
+
img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
|
243 |
+
elif mode == 'torso':
|
244 |
+
selected_mask = segmap[[2,4], :, :].sum(axis=0)[None,:] > 0.5
|
245 |
+
img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
|
246 |
+
elif mode == 'torso_with_bg':
|
247 |
+
selected_mask = segmap[[0, 2,4], :, :].sum(axis=0)[None,:] > 0.5
|
248 |
+
img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
|
249 |
+
elif mode == 'bg':
|
250 |
+
selected_mask = segmap[[0], :, :].sum(axis=0)[None,:] > 0.5 # only seg out 0, which means background
|
251 |
+
img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
|
252 |
+
elif mode == 'full':
|
253 |
+
pass
|
254 |
+
else:
|
255 |
+
raise NotImplementedError()
|
256 |
+
return img, selected_mask
|
257 |
+
|
258 |
+
def _seg_out_img(self, img, segmenter=None, mode='head'):
|
259 |
+
"""
|
260 |
+
imgs [H, W, 3] 0-255
|
261 |
+
return : person_img [B, 3, H, W]
|
262 |
+
"""
|
263 |
+
segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
|
264 |
+
segmap = self._cal_seg_map(img, segmenter=segmenter, return_onehot_mask=True) # [B, 19, H, W]
|
265 |
+
return self._seg_out_img_with_segmap(img, segmap, mode=mode)
|
266 |
+
|
267 |
+
def seg_out_imgs(self, img, mode='head'):
|
268 |
+
"""
|
269 |
+
api for pytorch img, -1~1
|
270 |
+
img: [B, 3, H, W], -1~1
|
271 |
+
"""
|
272 |
+
device = img.device
|
273 |
+
img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
|
274 |
+
img = ((img + 1) * 127.5).astype(np.uint8)
|
275 |
+
img_lst = [copy.deepcopy(img[i]) for i in range(len(img))]
|
276 |
+
out_lst = []
|
277 |
+
for im in img_lst:
|
278 |
+
out = self._seg_out_img(im, mode=mode)
|
279 |
+
out_lst.append(out)
|
280 |
+
seg_imgs = np.stack(out_lst) # [B, H, W, 3]
|
281 |
+
seg_imgs = (seg_imgs - 127.5) / 127.5
|
282 |
+
seg_imgs = torch.from_numpy(seg_imgs).permute(0, 3, 1, 2).to(device)
|
283 |
+
return seg_imgs
|
284 |
+
|
285 |
+
if __name__ == '__main__':
|
286 |
+
import imageio, cv2, tqdm
|
287 |
+
import torchshow as ts
|
288 |
+
img = imageio.imread("1.png")
|
289 |
+
img = cv2.resize(img, (512,512))
|
290 |
+
|
291 |
+
seg_model = MediapipeSegmenter()
|
292 |
+
img = torch.tensor(img).unsqueeze(0).repeat([1, 1, 1, 1]).permute(0, 3,1,2)
|
293 |
+
img = (img-127.5)/127.5
|
294 |
+
out = seg_model.seg_out_imgs(img, 'torso')
|
295 |
+
ts.save(out,"torso.png")
|
296 |
+
out = seg_model.seg_out_imgs(img, 'head')
|
297 |
+
ts.save(out,"head.png")
|
298 |
+
out = seg_model.seg_out_imgs(img, 'bg')
|
299 |
+
ts.save(out,"bg.png")
|
300 |
+
img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
|
301 |
+
img = ((img + 1) * 127.5).astype(np.uint8)
|
302 |
+
bg = extract_background(img)
|
303 |
+
ts.save(bg,"bg2.png")
|
data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
|
3 |
+
size 16371837
|
data_gen/utils/path_converter.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
class PathConverter():
|
5 |
+
def __init__(self):
|
6 |
+
self.prefixs = {
|
7 |
+
"vid": "/video/",
|
8 |
+
"gt": "/gt_imgs/",
|
9 |
+
"head": "/head_imgs/",
|
10 |
+
"torso": "/torso_imgs/",
|
11 |
+
"person": "/person_imgs/",
|
12 |
+
"torso_with_bg": "/torso_with_bg_imgs/",
|
13 |
+
"single_bg": "/bg_img/",
|
14 |
+
"bg": "/bg_imgs/",
|
15 |
+
"segmaps": "/segmaps/",
|
16 |
+
"inpaint_torso": "/inpaint_torso_imgs/",
|
17 |
+
"com": "/com_imgs/",
|
18 |
+
"inpaint_torso_with_com_bg": "/inpaint_torso_with_com_bg_imgs/",
|
19 |
+
}
|
20 |
+
|
21 |
+
def to(self, path: str, old_pattern: str, new_pattern: str):
|
22 |
+
return path.replace(self.prefixs[old_pattern], self.prefixs[new_pattern], 1)
|
23 |
+
|
24 |
+
pc = PathConverter()
|
data_gen/utils/process_audio/extract_hubert.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Wav2Vec2Processor, HubertModel
|
2 |
+
import soundfile as sf
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from utils.commons.hparams import set_hparams, hparams
|
7 |
+
|
8 |
+
|
9 |
+
wav2vec2_processor = None
|
10 |
+
hubert_model = None
|
11 |
+
|
12 |
+
|
13 |
+
def get_hubert_from_16k_wav(wav_16k_name):
|
14 |
+
speech_16k, _ = sf.read(wav_16k_name)
|
15 |
+
hubert = get_hubert_from_16k_speech(speech_16k)
|
16 |
+
return hubert
|
17 |
+
|
18 |
+
@torch.no_grad()
|
19 |
+
def get_hubert_from_16k_speech(speech, device="cuda:0"):
|
20 |
+
global hubert_model, wav2vec2_processor
|
21 |
+
local_path = '/home/tiger/.cache/huggingface/hub/models--facebook--hubert-large-ls960-ft/snapshots/ece5fabbf034c1073acae96d5401b25be96709d8'
|
22 |
+
if hubert_model is None:
|
23 |
+
print("Loading the HuBERT Model...")
|
24 |
+
print("Loading the Wav2Vec2 Processor...")
|
25 |
+
if os.path.exists(local_path):
|
26 |
+
hubert_model = HubertModel.from_pretrained(local_path)
|
27 |
+
wav2vec2_processor = Wav2Vec2Processor.from_pretrained(local_path)
|
28 |
+
else:
|
29 |
+
hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
|
30 |
+
wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
|
31 |
+
hubert_model = hubert_model.to(device)
|
32 |
+
|
33 |
+
if speech.ndim ==2:
|
34 |
+
speech = speech[:, 0] # [T, 2] ==> [T,]
|
35 |
+
|
36 |
+
input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
|
37 |
+
input_values_all = input_values_all.to(device)
|
38 |
+
# For long audio sequence, due to the memory limitation, we cannot process them in one run
|
39 |
+
# HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
|
40 |
+
# Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
|
41 |
+
# So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
|
42 |
+
# We have the equation to calculate out time step: T = floor((t-k)/s)
|
43 |
+
# To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
|
44 |
+
# The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
|
45 |
+
kernel = 400
|
46 |
+
stride = 320
|
47 |
+
clip_length = stride * 1000
|
48 |
+
num_iter = input_values_all.shape[1] // clip_length
|
49 |
+
expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
|
50 |
+
res_lst = []
|
51 |
+
for i in range(num_iter):
|
52 |
+
if i == 0:
|
53 |
+
start_idx = 0
|
54 |
+
end_idx = clip_length - stride + kernel
|
55 |
+
else:
|
56 |
+
start_idx = clip_length * i
|
57 |
+
end_idx = start_idx + (clip_length - stride + kernel)
|
58 |
+
input_values = input_values_all[:, start_idx: end_idx]
|
59 |
+
hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
|
60 |
+
res_lst.append(hidden_states[0])
|
61 |
+
if num_iter > 0:
|
62 |
+
input_values = input_values_all[:, clip_length * num_iter:]
|
63 |
+
else:
|
64 |
+
input_values = input_values_all
|
65 |
+
|
66 |
+
if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
|
67 |
+
hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
|
68 |
+
res_lst.append(hidden_states[0])
|
69 |
+
ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
|
70 |
+
|
71 |
+
assert abs(ret.shape[0] - expected_T) <= 1
|
72 |
+
if ret.shape[0] < expected_T: # if skipping the last short
|
73 |
+
ret = torch.cat([ret, ret[:, -1:, :].repeat([1,expected_T-ret.shape[0],1])], dim=1)
|
74 |
+
else:
|
75 |
+
ret = ret[:expected_T]
|
76 |
+
|
77 |
+
return ret
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == '__main__':
|
81 |
+
from argparse import ArgumentParser
|
82 |
+
parser = ArgumentParser()
|
83 |
+
parser.add_argument('--video_id', type=str, default='May', help='')
|
84 |
+
args = parser.parse_args()
|
85 |
+
### Process Single Long Audio for NeRF dataset
|
86 |
+
person_id = args.video_id
|
87 |
+
wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
|
88 |
+
hubert_npy_name = f"data/processed/videos/{person_id}/aud_hubert.npy"
|
89 |
+
speech_16k, _ = sf.read(wav_16k_name)
|
90 |
+
hubert_hidden = get_hubert_from_16k_speech(speech_16k)
|
91 |
+
np.save(hubert_npy_name, hubert_hidden.detach().numpy())
|
92 |
+
print(f"Saved at {hubert_npy_name}")
|
data_gen/utils/process_audio/extract_mel_f0.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import tqdm
|
6 |
+
import librosa
|
7 |
+
import parselmouth
|
8 |
+
from utils.commons.pitch_utils import f0_to_coarse
|
9 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
10 |
+
from utils.commons.os_utils import multiprocess_glob
|
11 |
+
from utils.audio.io import save_wav
|
12 |
+
|
13 |
+
from moviepy.editor import VideoFileClip
|
14 |
+
from utils.commons.hparams import hparams, set_hparams
|
15 |
+
|
16 |
+
def resample_wav(wav_name, out_name, sr=16000):
|
17 |
+
wav_raw, sr = librosa.core.load(wav_name, sr=sr)
|
18 |
+
save_wav(wav_raw, out_name, sr)
|
19 |
+
|
20 |
+
def split_wav(mp4_name, wav_name=None):
|
21 |
+
if wav_name is None:
|
22 |
+
wav_name = mp4_name.replace(".mp4", ".wav").replace("/video/", "/audio/")
|
23 |
+
if os.path.exists(wav_name):
|
24 |
+
return wav_name
|
25 |
+
os.makedirs(os.path.dirname(wav_name), exist_ok=True)
|
26 |
+
|
27 |
+
video = VideoFileClip(mp4_name,verbose=False)
|
28 |
+
dur = video.duration
|
29 |
+
audio = video.audio
|
30 |
+
assert audio is not None
|
31 |
+
audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None)
|
32 |
+
return wav_name
|
33 |
+
|
34 |
+
def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
|
35 |
+
'''compute right padding (final frame) or both sides padding (first and final frames)
|
36 |
+
'''
|
37 |
+
assert pad_sides in (1, 2)
|
38 |
+
# return int(fsize // 2)
|
39 |
+
pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
40 |
+
if pad_sides == 1:
|
41 |
+
return 0, pad
|
42 |
+
else:
|
43 |
+
return pad // 2, pad // 2 + pad % 2
|
44 |
+
|
45 |
+
def extract_mel_from_fname(wav_path,
|
46 |
+
fft_size=512,
|
47 |
+
hop_size=320,
|
48 |
+
win_length=512,
|
49 |
+
window="hann",
|
50 |
+
num_mels=80,
|
51 |
+
fmin=80,
|
52 |
+
fmax=7600,
|
53 |
+
eps=1e-6,
|
54 |
+
sample_rate=16000,
|
55 |
+
min_level_db=-100):
|
56 |
+
if isinstance(wav_path, str):
|
57 |
+
wav, _ = librosa.core.load(wav_path, sr=sample_rate)
|
58 |
+
else:
|
59 |
+
wav = wav_path
|
60 |
+
|
61 |
+
# get amplitude spectrogram
|
62 |
+
x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
|
63 |
+
win_length=win_length, window=window, center=False)
|
64 |
+
spc = np.abs(x_stft) # (n_bins, T)
|
65 |
+
|
66 |
+
# get mel basis
|
67 |
+
fmin = 0 if fmin == -1 else fmin
|
68 |
+
fmax = sample_rate / 2 if fmax == -1 else fmax
|
69 |
+
mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
70 |
+
mel = mel_basis @ spc
|
71 |
+
|
72 |
+
mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
|
73 |
+
mel = mel.T
|
74 |
+
|
75 |
+
l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1)
|
76 |
+
wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
|
77 |
+
|
78 |
+
return wav.T, mel
|
79 |
+
|
80 |
+
def extract_f0_from_wav_and_mel(wav, mel,
|
81 |
+
hop_size=320,
|
82 |
+
audio_sample_rate=16000,
|
83 |
+
):
|
84 |
+
time_step = hop_size / audio_sample_rate * 1000
|
85 |
+
f0_min = 80
|
86 |
+
f0_max = 750
|
87 |
+
f0 = parselmouth.Sound(wav, audio_sample_rate).to_pitch_ac(
|
88 |
+
time_step=time_step / 1000, voicing_threshold=0.6,
|
89 |
+
pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
|
90 |
+
|
91 |
+
delta_l = len(mel) - len(f0)
|
92 |
+
assert np.abs(delta_l) <= 8
|
93 |
+
if delta_l > 0:
|
94 |
+
f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
|
95 |
+
f0 = f0[:len(mel)]
|
96 |
+
pitch_coarse = f0_to_coarse(f0)
|
97 |
+
return f0, pitch_coarse
|
98 |
+
|
99 |
+
|
100 |
+
def extract_mel_f0_from_fname(wav_name=None, out_name=None):
|
101 |
+
try:
|
102 |
+
out_name = wav_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
|
103 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
104 |
+
|
105 |
+
wav, mel = extract_mel_from_fname(wav_name)
|
106 |
+
f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
|
107 |
+
out_dict = {
|
108 |
+
"mel": mel, # [T, 80]
|
109 |
+
"f0": f0,
|
110 |
+
}
|
111 |
+
np.save(out_name, out_dict)
|
112 |
+
except Exception as e:
|
113 |
+
print(e)
|
114 |
+
|
115 |
+
def extract_mel_f0_from_video_name(mp4_name, wav_name=None, out_name=None):
|
116 |
+
if mp4_name.endswith(".mp4"):
|
117 |
+
wav_name = split_wav(mp4_name, wav_name)
|
118 |
+
if out_name is None:
|
119 |
+
out_name = mp4_name.replace(".mp4", "_mel_f0.npy").replace("/video/", "/mel_f0/")
|
120 |
+
elif mp4_name.endswith(".wav"):
|
121 |
+
wav_name = mp4_name
|
122 |
+
if out_name is None:
|
123 |
+
out_name = mp4_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
|
124 |
+
|
125 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
126 |
+
|
127 |
+
wav, mel = extract_mel_from_fname(wav_name)
|
128 |
+
|
129 |
+
f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
|
130 |
+
out_dict = {
|
131 |
+
"mel": mel, # [T, 80]
|
132 |
+
"f0": f0,
|
133 |
+
}
|
134 |
+
np.save(out_name, out_dict)
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
from argparse import ArgumentParser
|
139 |
+
parser = ArgumentParser()
|
140 |
+
parser.add_argument('--video_id', type=str, default='May', help='')
|
141 |
+
args = parser.parse_args()
|
142 |
+
### Process Single Long Audio for NeRF dataset
|
143 |
+
person_id = args.video_id
|
144 |
+
|
145 |
+
wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
|
146 |
+
out_name = f"data/processed/videos/{person_id}/aud_mel_f0.npy"
|
147 |
+
extract_mel_f0_from_video_name(wav_16k_name, out_name)
|
148 |
+
print(f"Saved at {out_name}")
|
data_gen/utils/process_audio/resample_audio_to_16k.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob
|
2 |
+
from utils.commons.os_utils import multiprocess_glob
|
3 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
4 |
+
|
5 |
+
|
6 |
+
def extract_wav16k_job(audio_name:str):
|
7 |
+
out_path = audio_name.replace("/audio_raw/","/audio/",1)
|
8 |
+
assert out_path != audio_name # prevent inplace
|
9 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
10 |
+
ffmpeg_path = "/usr/bin/ffmpeg"
|
11 |
+
|
12 |
+
cmd = f'{ffmpeg_path} -i {audio_name} -ar 16000 -v quiet -y {out_path}'
|
13 |
+
os.system(cmd)
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
import argparse, glob, tqdm, random
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--aud_dir", default='/home/tiger/datasets/raw/CMLR/audio_raw/')
|
19 |
+
parser.add_argument("--ds_name", default='CMLR')
|
20 |
+
parser.add_argument("--num_workers", default=64, type=int)
|
21 |
+
parser.add_argument("--process_id", default=0, type=int)
|
22 |
+
parser.add_argument("--total_process", default=1, type=int)
|
23 |
+
args = parser.parse_args()
|
24 |
+
print(f"args {args}")
|
25 |
+
|
26 |
+
aud_dir = args.aud_dir
|
27 |
+
ds_name = args.ds_name
|
28 |
+
if ds_name in ['CMLR']:
|
29 |
+
aud_name_pattern = os.path.join(aud_dir, "*/*/*.wav")
|
30 |
+
aud_names = multiprocess_glob(aud_name_pattern)
|
31 |
+
else:
|
32 |
+
raise NotImplementedError()
|
33 |
+
aud_names = sorted(aud_names)
|
34 |
+
print(f"total audio number : {len(aud_names)}")
|
35 |
+
print(f"first {aud_names[0]} last {aud_names[-1]}")
|
36 |
+
# exit()
|
37 |
+
process_id = args.process_id
|
38 |
+
total_process = args.total_process
|
39 |
+
if total_process > 1:
|
40 |
+
assert process_id <= total_process -1
|
41 |
+
num_samples_per_process = len(aud_names) // total_process
|
42 |
+
if process_id == total_process:
|
43 |
+
aud_names = aud_names[process_id * num_samples_per_process : ]
|
44 |
+
else:
|
45 |
+
aud_names = aud_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
46 |
+
|
47 |
+
for i, res in multiprocess_run_tqdm(extract_wav16k_job, aud_names, num_workers=args.num_workers, desc="resampling videos"):
|
48 |
+
pass
|
49 |
+
|
data_gen/utils/process_image/extract_lm2d.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import glob
|
6 |
+
import cv2
|
7 |
+
import tqdm
|
8 |
+
import numpy as np
|
9 |
+
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
|
10 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings('ignore')
|
13 |
+
|
14 |
+
import random
|
15 |
+
random.seed(42)
|
16 |
+
|
17 |
+
import pickle
|
18 |
+
import json
|
19 |
+
import gzip
|
20 |
+
from typing import Any
|
21 |
+
|
22 |
+
def load_file(filename, is_gzip: bool = False, is_json: bool = False) -> Any:
|
23 |
+
if is_json:
|
24 |
+
if is_gzip:
|
25 |
+
with gzip.open(filename, "r", encoding="utf-8") as f:
|
26 |
+
loaded_object = json.load(f)
|
27 |
+
return loaded_object
|
28 |
+
else:
|
29 |
+
with open(filename, "r", encoding="utf-8") as f:
|
30 |
+
loaded_object = json.load(f)
|
31 |
+
return loaded_object
|
32 |
+
else:
|
33 |
+
if is_gzip:
|
34 |
+
with gzip.open(filename, "rb") as f:
|
35 |
+
loaded_object = pickle.load(f)
|
36 |
+
return loaded_object
|
37 |
+
else:
|
38 |
+
with open(filename, "rb") as f:
|
39 |
+
loaded_object = pickle.load(f)
|
40 |
+
return loaded_object
|
41 |
+
|
42 |
+
def save_file(filename, content, is_gzip: bool = False, is_json: bool = False) -> None:
|
43 |
+
if is_json:
|
44 |
+
if is_gzip:
|
45 |
+
with gzip.open(filename, "w", encoding="utf-8") as f:
|
46 |
+
json.dump(content, f)
|
47 |
+
else:
|
48 |
+
with open(filename, "w", encoding="utf-8") as f:
|
49 |
+
json.dump(content, f)
|
50 |
+
else:
|
51 |
+
if is_gzip:
|
52 |
+
with gzip.open(filename, "wb") as f:
|
53 |
+
pickle.dump(content, f)
|
54 |
+
else:
|
55 |
+
with open(filename, "wb") as f:
|
56 |
+
pickle.dump(content, f)
|
57 |
+
|
58 |
+
face_landmarker = None
|
59 |
+
|
60 |
+
def extract_lms_mediapipe_job(img):
|
61 |
+
if img is None:
|
62 |
+
return None
|
63 |
+
global face_landmarker
|
64 |
+
if face_landmarker is None:
|
65 |
+
face_landmarker = MediapipeLandmarker()
|
66 |
+
lm478 = face_landmarker.extract_lm478_from_img(img)
|
67 |
+
return lm478
|
68 |
+
|
69 |
+
def extract_landmark_job(img_name):
|
70 |
+
try:
|
71 |
+
# if img_name == 'datasets/PanoHeadGen/raw/images/multi_view/chunk_0/seed0000002.png':
|
72 |
+
# print(1)
|
73 |
+
# input()
|
74 |
+
out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
|
75 |
+
if os.path.exists(out_name):
|
76 |
+
print("out exists, skip...")
|
77 |
+
return
|
78 |
+
try:
|
79 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
80 |
+
except:
|
81 |
+
pass
|
82 |
+
img = cv2.imread(img_name)[:,:,::-1]
|
83 |
+
|
84 |
+
if img is not None:
|
85 |
+
lm468 = extract_lms_mediapipe_job(img)
|
86 |
+
if lm468 is not None:
|
87 |
+
np.save(out_name, lm468)
|
88 |
+
# print("Hahaha, solve one item!!!")
|
89 |
+
except Exception as e:
|
90 |
+
print(e)
|
91 |
+
pass
|
92 |
+
|
93 |
+
def out_exist_job(img_name):
|
94 |
+
out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
|
95 |
+
if os.path.exists(out_name):
|
96 |
+
return None
|
97 |
+
else:
|
98 |
+
return img_name
|
99 |
+
|
100 |
+
# def get_todo_img_names(img_names):
|
101 |
+
# todo_img_names = []
|
102 |
+
# for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
|
103 |
+
# if res is not None:
|
104 |
+
# todo_img_names.append(res)
|
105 |
+
# return todo_img_names
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == '__main__':
|
109 |
+
import argparse, glob, tqdm, random
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512/')
|
112 |
+
parser.add_argument("--ds_name", default='FFHQ')
|
113 |
+
parser.add_argument("--num_workers", default=64, type=int)
|
114 |
+
parser.add_argument("--process_id", default=0, type=int)
|
115 |
+
parser.add_argument("--total_process", default=1, type=int)
|
116 |
+
parser.add_argument("--reset", action='store_true')
|
117 |
+
parser.add_argument("--img_names_file", default="img_names.pkl", type=str)
|
118 |
+
parser.add_argument("--load_img_names", action="store_true")
|
119 |
+
|
120 |
+
args = parser.parse_args()
|
121 |
+
print(f"args {args}")
|
122 |
+
img_dir = args.img_dir
|
123 |
+
img_names_file = os.path.join(img_dir, args.img_names_file)
|
124 |
+
if args.load_img_names:
|
125 |
+
img_names = load_file(img_names_file)
|
126 |
+
print(f"load image names from {img_names_file}")
|
127 |
+
else:
|
128 |
+
if args.ds_name == 'FFHQ_MV':
|
129 |
+
img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
|
130 |
+
img_names1 = glob.glob(img_name_pattern1)
|
131 |
+
img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
|
132 |
+
img_names2 = glob.glob(img_name_pattern2)
|
133 |
+
img_names = img_names1 + img_names2
|
134 |
+
img_names = sorted(img_names)
|
135 |
+
elif args.ds_name == 'FFHQ':
|
136 |
+
img_name_pattern = os.path.join(img_dir, "*.png")
|
137 |
+
img_names = glob.glob(img_name_pattern)
|
138 |
+
img_names = sorted(img_names)
|
139 |
+
elif args.ds_name == "PanoHeadGen":
|
140 |
+
# img_name_patterns = ["ref/*/*.png", "multi_view/*/*.png", "reverse/*/*.png"]
|
141 |
+
img_name_patterns = ["ref/*/*.png"]
|
142 |
+
img_names = []
|
143 |
+
for img_name_pattern in img_name_patterns:
|
144 |
+
img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
|
145 |
+
img_names_part = glob.glob(img_name_pattern_full)
|
146 |
+
img_names.extend(img_names_part)
|
147 |
+
img_names = sorted(img_names)
|
148 |
+
|
149 |
+
# save image names
|
150 |
+
if not args.load_img_names:
|
151 |
+
save_file(img_names_file, img_names)
|
152 |
+
print(f"save image names in {img_names_file}")
|
153 |
+
|
154 |
+
print(f"total images number: {len(img_names)}")
|
155 |
+
|
156 |
+
|
157 |
+
process_id = args.process_id
|
158 |
+
total_process = args.total_process
|
159 |
+
if total_process > 1:
|
160 |
+
assert process_id <= total_process -1
|
161 |
+
num_samples_per_process = len(img_names) // total_process
|
162 |
+
if process_id == total_process:
|
163 |
+
img_names = img_names[process_id * num_samples_per_process : ]
|
164 |
+
else:
|
165 |
+
img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
166 |
+
|
167 |
+
# if not args.reset:
|
168 |
+
# img_names = get_todo_img_names(img_names)
|
169 |
+
|
170 |
+
|
171 |
+
print(f"todo_image {img_names[:10]}")
|
172 |
+
print(f"processing images number in this process: {len(img_names)}")
|
173 |
+
# print(f"todo images number: {len(img_names)}")
|
174 |
+
# input()
|
175 |
+
# exit()
|
176 |
+
|
177 |
+
if args.num_workers == 1:
|
178 |
+
index = 0
|
179 |
+
for img_name in tqdm.tqdm(img_names, desc=f"Root process {args.process_id}: extracting MP-based landmark2d"):
|
180 |
+
try:
|
181 |
+
extract_landmark_job(img_name)
|
182 |
+
except Exception as e:
|
183 |
+
print(e)
|
184 |
+
pass
|
185 |
+
if index % max(1, int(len(img_names) * 0.003)) == 0:
|
186 |
+
print(f"processed {index} / {len(img_names)}")
|
187 |
+
sys.stdout.flush()
|
188 |
+
index += 1
|
189 |
+
else:
|
190 |
+
for i, res in multiprocess_run_tqdm(
|
191 |
+
extract_landmark_job, img_names,
|
192 |
+
num_workers=args.num_workers,
|
193 |
+
desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
|
194 |
+
# if index % max(1, int(len(img_names) * 0.003)) == 0:
|
195 |
+
print(f"processed {i+1} / {len(img_names)}")
|
196 |
+
sys.stdout.flush()
|
197 |
+
print(f"Root {args.process_id}: Finished extracting.")
|
data_gen/utils/process_image/extract_segment_imgs.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
3 |
+
|
4 |
+
import glob
|
5 |
+
import cv2
|
6 |
+
import tqdm
|
7 |
+
import numpy as np
|
8 |
+
import PIL
|
9 |
+
from utils.commons.tensor_utils import convert_to_np
|
10 |
+
import torch
|
11 |
+
import mediapipe as mp
|
12 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
13 |
+
from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
|
14 |
+
from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background, save_rgb_image_to_path
|
15 |
+
seg_model = MediapipeSegmenter()
|
16 |
+
|
17 |
+
|
18 |
+
def extract_segment_job(img_name):
|
19 |
+
try:
|
20 |
+
img = cv2.imread(img_name)
|
21 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
22 |
+
|
23 |
+
segmap = seg_model._cal_seg_map(img)
|
24 |
+
bg_img = extract_background([img], [segmap])
|
25 |
+
out_img_name = img_name.replace("/images_512/",f"/bg_img/").replace(".mp4", ".jpg")
|
26 |
+
save_rgb_image_to_path(bg_img, out_img_name)
|
27 |
+
|
28 |
+
com_img = img.copy()
|
29 |
+
bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
|
30 |
+
com_img[bg_part] = bg_img[bg_part]
|
31 |
+
out_img_name = img_name.replace("/images_512/",f"/com_imgs/")
|
32 |
+
save_rgb_image_to_path(com_img, out_img_name)
|
33 |
+
|
34 |
+
for mode in ['head', 'torso', 'person', 'torso_with_bg', 'bg']:
|
35 |
+
out_img, _ = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
|
36 |
+
out_img_name = img_name.replace("/images_512/",f"/{mode}_imgs/")
|
37 |
+
out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
|
38 |
+
try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
|
39 |
+
except: pass
|
40 |
+
cv2.imwrite(out_img_name, out_img)
|
41 |
+
|
42 |
+
inpaint_torso_img, inpaint_torso_with_bg_img, _, _ = inpaint_torso_job(img, segmap)
|
43 |
+
out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_imgs/")
|
44 |
+
save_rgb_image_to_path(inpaint_torso_img, out_img_name)
|
45 |
+
inpaint_torso_with_bg_img[bg_part] = bg_img[bg_part]
|
46 |
+
out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_with_com_bg_imgs/")
|
47 |
+
save_rgb_image_to_path(inpaint_torso_with_bg_img, out_img_name)
|
48 |
+
return 0
|
49 |
+
except Exception as e:
|
50 |
+
print(e)
|
51 |
+
return 1
|
52 |
+
|
53 |
+
def out_exist_job(img_name):
|
54 |
+
out_name1 = img_name.replace("/images_512/", "/head_imgs/")
|
55 |
+
out_name2 = img_name.replace("/images_512/", "/com_imgs/")
|
56 |
+
out_name3 = img_name.replace("/images_512/", "/inpaint_torso_with_com_bg_imgs/")
|
57 |
+
|
58 |
+
if os.path.exists(out_name1) and os.path.exists(out_name2) and os.path.exists(out_name3):
|
59 |
+
return None
|
60 |
+
else:
|
61 |
+
return img_name
|
62 |
+
|
63 |
+
def get_todo_img_names(img_names):
|
64 |
+
todo_img_names = []
|
65 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
|
66 |
+
if res is not None:
|
67 |
+
todo_img_names.append(res)
|
68 |
+
return todo_img_names
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == '__main__':
|
72 |
+
import argparse, glob, tqdm, random
|
73 |
+
parser = argparse.ArgumentParser()
|
74 |
+
parser.add_argument("--img_dir", default='./images_512')
|
75 |
+
# parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
|
76 |
+
parser.add_argument("--ds_name", default='FFHQ')
|
77 |
+
parser.add_argument("--num_workers", default=1, type=int)
|
78 |
+
parser.add_argument("--seed", default=0, type=int)
|
79 |
+
parser.add_argument("--process_id", default=0, type=int)
|
80 |
+
parser.add_argument("--total_process", default=1, type=int)
|
81 |
+
parser.add_argument("--reset", action='store_true')
|
82 |
+
|
83 |
+
args = parser.parse_args()
|
84 |
+
img_dir = args.img_dir
|
85 |
+
if args.ds_name == 'FFHQ_MV':
|
86 |
+
img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
|
87 |
+
img_names1 = glob.glob(img_name_pattern1)
|
88 |
+
img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
|
89 |
+
img_names2 = glob.glob(img_name_pattern2)
|
90 |
+
img_names = img_names1 + img_names2
|
91 |
+
elif args.ds_name == 'FFHQ':
|
92 |
+
img_name_pattern = os.path.join(img_dir, "*.png")
|
93 |
+
img_names = glob.glob(img_name_pattern)
|
94 |
+
|
95 |
+
img_names = sorted(img_names)
|
96 |
+
random.seed(args.seed)
|
97 |
+
random.shuffle(img_names)
|
98 |
+
|
99 |
+
process_id = args.process_id
|
100 |
+
total_process = args.total_process
|
101 |
+
if total_process > 1:
|
102 |
+
assert process_id <= total_process -1
|
103 |
+
num_samples_per_process = len(img_names) // total_process
|
104 |
+
if process_id == total_process:
|
105 |
+
img_names = img_names[process_id * num_samples_per_process : ]
|
106 |
+
else:
|
107 |
+
img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
108 |
+
|
109 |
+
if not args.reset:
|
110 |
+
img_names = get_todo_img_names(img_names)
|
111 |
+
print(f"todo images number: {len(img_names)}")
|
112 |
+
|
113 |
+
for vid_name in multiprocess_run_tqdm(extract_segment_job ,img_names, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
|
114 |
+
pass
|
data_gen/utils/process_image/fit_3dmm_landmark.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy.core.numeric import require
|
2 |
+
from numpy.lib.function_base import quantile
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import copy
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import cv2
|
11 |
+
import argparse
|
12 |
+
import tqdm
|
13 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
14 |
+
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
|
15 |
+
|
16 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
17 |
+
import pickle
|
18 |
+
|
19 |
+
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
|
20 |
+
camera_distance=10, focal=1015, keypoint_mode='mediapipe')
|
21 |
+
face_model.to("cuda")
|
22 |
+
|
23 |
+
|
24 |
+
index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
|
25 |
+
33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
|
26 |
+
|
27 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
28 |
+
|
29 |
+
LAMBDA_REG_ID = 0.3
|
30 |
+
LAMBDA_REG_EXP = 0.05
|
31 |
+
|
32 |
+
def save_file(name, content):
|
33 |
+
with open(name, "wb") as f:
|
34 |
+
pickle.dump(content, f)
|
35 |
+
|
36 |
+
def load_file(name):
|
37 |
+
with open(name, "rb") as f:
|
38 |
+
content = pickle.load(f)
|
39 |
+
return content
|
40 |
+
|
41 |
+
def cal_lan_loss_mp(proj_lan, gt_lan):
|
42 |
+
# [B, 68, 2]
|
43 |
+
loss = (proj_lan - gt_lan).pow(2)
|
44 |
+
# loss = (proj_lan - gt_lan).abs()
|
45 |
+
unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
|
46 |
+
eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
|
47 |
+
inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
|
48 |
+
outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
|
49 |
+
weights = torch.ones_like(loss)
|
50 |
+
weights[:, eye] = 5
|
51 |
+
weights[:, inner_lip] = 2
|
52 |
+
weights[:, outer_lip] = 2
|
53 |
+
weights[:, unmatch_mask] = 0
|
54 |
+
loss = loss * weights
|
55 |
+
return torch.mean(loss)
|
56 |
+
|
57 |
+
def cal_lan_loss(proj_lan, gt_lan):
|
58 |
+
# [B, 68, 2]
|
59 |
+
loss = (proj_lan - gt_lan)** 2
|
60 |
+
# use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
|
61 |
+
weights = torch.zeros_like(loss)
|
62 |
+
weights = torch.ones_like(loss)
|
63 |
+
weights[:, 36:48, :] = 3 # eye 12 points
|
64 |
+
weights[:, -8:, :] = 3 # inner lip 8 points
|
65 |
+
weights[:, 28:31, :] = 3 # nose 3 points
|
66 |
+
loss = loss * weights
|
67 |
+
return torch.mean(loss)
|
68 |
+
|
69 |
+
def set_requires_grad(tensor_list):
|
70 |
+
for tensor in tensor_list:
|
71 |
+
tensor.requires_grad = True
|
72 |
+
|
73 |
+
def read_video_to_frames(img_name):
|
74 |
+
frames = []
|
75 |
+
cap = cv2.VideoCapture(img_name)
|
76 |
+
while cap.isOpened():
|
77 |
+
ret, frame_bgr = cap.read()
|
78 |
+
if frame_bgr is None:
|
79 |
+
break
|
80 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
81 |
+
frames.append(frame_rgb)
|
82 |
+
return np.stack(frames)
|
83 |
+
|
84 |
+
@torch.enable_grad()
|
85 |
+
def fit_3dmm_for_a_image(img_name, debug=False, keypoint_mode='mediapipe', device="cuda:0", save=True):
|
86 |
+
img = cv2.imread(img_name)
|
87 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
88 |
+
img_h, img_w = img.shape[0], img.shape[0]
|
89 |
+
assert img_h == img_w
|
90 |
+
num_frames = 1
|
91 |
+
|
92 |
+
lm_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png", "_lms.npy")
|
93 |
+
if lm_name.endswith('_lms.npy') and os.path.exists(lm_name):
|
94 |
+
lms = np.load(lm_name)
|
95 |
+
else:
|
96 |
+
# print("lms_2d file not found, try to extract it from image...")
|
97 |
+
try:
|
98 |
+
landmarker = MediapipeLandmarker()
|
99 |
+
lms = landmarker.extract_lm478_from_img_name(img_name)
|
100 |
+
# lms = landmarker.extract_lm478_from_img(img)
|
101 |
+
except Exception as e:
|
102 |
+
print(e)
|
103 |
+
return
|
104 |
+
if lms is None:
|
105 |
+
print("get None lms_2d, please check whether each frame has one head, exiting...")
|
106 |
+
return
|
107 |
+
lms = lms[:468].reshape([468,2])
|
108 |
+
lms = torch.FloatTensor(lms).to(device=device)
|
109 |
+
lms[..., 1] = img_h - lms[..., 1] # flip the height axis
|
110 |
+
|
111 |
+
if keypoint_mode == 'mediapipe':
|
112 |
+
cal_lan_loss_fn = cal_lan_loss_mp
|
113 |
+
out_name = img_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png", "_coeff_fit_mp.npy")
|
114 |
+
else:
|
115 |
+
cal_lan_loss_fn = cal_lan_loss
|
116 |
+
out_name = img_name.replace("/images_512/", "/coeff_fit_lm68/").replace(".png", "_coeff_fit_lm68.npy")
|
117 |
+
try:
|
118 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
119 |
+
except:
|
120 |
+
pass
|
121 |
+
|
122 |
+
id_dim, exp_dim = 80, 64
|
123 |
+
sel_ids = np.arange(0, num_frames, 40)
|
124 |
+
sel_num = sel_ids.shape[0]
|
125 |
+
arg_focal = face_model.focal
|
126 |
+
|
127 |
+
h = w = face_model.center * 2
|
128 |
+
img_scale_factor = img_h / h
|
129 |
+
lms /= img_scale_factor
|
130 |
+
cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).to(device=device)
|
131 |
+
|
132 |
+
id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True) # lms.new_zeros((1, id_dim), requires_grad=True)
|
133 |
+
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
134 |
+
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
|
135 |
+
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
|
136 |
+
|
137 |
+
focal_length = lms.new_zeros(1, requires_grad=True)
|
138 |
+
focal_length.data += arg_focal
|
139 |
+
|
140 |
+
set_requires_grad([id_para, exp_para, euler_angle, trans])
|
141 |
+
|
142 |
+
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
|
143 |
+
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
|
144 |
+
|
145 |
+
# 其他参数初始化,先训练euler和trans
|
146 |
+
for _ in range(200):
|
147 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
148 |
+
id_para, exp_para, euler_angle, trans)
|
149 |
+
loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
|
150 |
+
loss = loss_lan
|
151 |
+
optimizer_frame.zero_grad()
|
152 |
+
loss.backward()
|
153 |
+
optimizer_frame.step()
|
154 |
+
# print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
|
155 |
+
# print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
|
156 |
+
|
157 |
+
for param_group in optimizer_frame.param_groups:
|
158 |
+
param_group['lr'] = 0.1
|
159 |
+
|
160 |
+
# "jointly roughly training id exp euler trans"
|
161 |
+
for _ in range(200):
|
162 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
163 |
+
id_para, exp_para, euler_angle, trans)
|
164 |
+
loss_lan = cal_lan_loss_fn(
|
165 |
+
proj_geo[:, :, :2], lms.detach())
|
166 |
+
loss_regid = torch.mean(id_para*id_para) # 正则化
|
167 |
+
loss_regexp = torch.mean(exp_para * exp_para)
|
168 |
+
|
169 |
+
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
|
170 |
+
optimizer_idexp.zero_grad()
|
171 |
+
optimizer_frame.zero_grad()
|
172 |
+
loss.backward()
|
173 |
+
optimizer_idexp.step()
|
174 |
+
optimizer_frame.step()
|
175 |
+
# print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
|
176 |
+
# print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
|
177 |
+
# print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
|
178 |
+
|
179 |
+
# start fine training, intialize from the roughly trained results
|
180 |
+
id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
|
181 |
+
id_para_.data = id_para.data.clone()
|
182 |
+
id_para = id_para_
|
183 |
+
exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
184 |
+
exp_para_.data = exp_para.data.clone()
|
185 |
+
exp_para = exp_para_
|
186 |
+
euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
|
187 |
+
euler_angle_.data = euler_angle.data.clone()
|
188 |
+
euler_angle = euler_angle_
|
189 |
+
trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
|
190 |
+
trans_.data = trans.data.clone()
|
191 |
+
trans = trans_
|
192 |
+
|
193 |
+
batch_size = 1
|
194 |
+
|
195 |
+
# "fine fitting the 3DMM in batches"
|
196 |
+
for i in range(int((num_frames-1)/batch_size+1)):
|
197 |
+
if (i+1)*batch_size > num_frames:
|
198 |
+
start_n = num_frames-batch_size
|
199 |
+
sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
|
200 |
+
else:
|
201 |
+
start_n = i*batch_size
|
202 |
+
sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
|
203 |
+
sel_lms = lms[sel_ids]
|
204 |
+
|
205 |
+
sel_id_para = id_para.new_zeros(
|
206 |
+
(batch_size, id_dim), requires_grad=True)
|
207 |
+
sel_id_para.data = id_para[sel_ids].clone()
|
208 |
+
sel_exp_para = exp_para.new_zeros(
|
209 |
+
(batch_size, exp_dim), requires_grad=True)
|
210 |
+
sel_exp_para.data = exp_para[sel_ids].clone()
|
211 |
+
sel_euler_angle = euler_angle.new_zeros(
|
212 |
+
(batch_size, 3), requires_grad=True)
|
213 |
+
sel_euler_angle.data = euler_angle[sel_ids].clone()
|
214 |
+
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
|
215 |
+
sel_trans.data = trans[sel_ids].clone()
|
216 |
+
|
217 |
+
set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
|
218 |
+
optimizer_cur_batch = torch.optim.Adam(
|
219 |
+
[sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
|
220 |
+
|
221 |
+
for j in range(50):
|
222 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
223 |
+
sel_id_para, sel_exp_para, sel_euler_angle, sel_trans)
|
224 |
+
loss_lan = cal_lan_loss_fn(
|
225 |
+
proj_geo[:, :, :2], lms.unsqueeze(0).detach())
|
226 |
+
|
227 |
+
loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
|
228 |
+
loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
|
229 |
+
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
|
230 |
+
optimizer_cur_batch.zero_grad()
|
231 |
+
loss.backward()
|
232 |
+
optimizer_cur_batch.step()
|
233 |
+
print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f}")
|
234 |
+
id_para[sel_ids].data = sel_id_para.data.clone()
|
235 |
+
exp_para[sel_ids].data = sel_exp_para.data.clone()
|
236 |
+
euler_angle[sel_ids].data = sel_euler_angle.data.clone()
|
237 |
+
trans[sel_ids].data = sel_trans.data.clone()
|
238 |
+
|
239 |
+
coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
|
240 |
+
'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
|
241 |
+
if save:
|
242 |
+
np.save(out_name, coeff_dict, allow_pickle=True)
|
243 |
+
|
244 |
+
if debug:
|
245 |
+
import imageio
|
246 |
+
debug_name = img_name.replace("/images_512/", "/coeff_fit_mp_debug/").replace(".png", "_debug.png").replace(".jpg", "_debug.jpg")
|
247 |
+
try: os.makedirs(os.path.dirname(debug_name), exist_ok=True)
|
248 |
+
except: pass
|
249 |
+
proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
|
250 |
+
lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
|
251 |
+
lm68s = lm68s * img_scale_factor
|
252 |
+
lms = lms * img_scale_factor
|
253 |
+
lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
|
254 |
+
lms[..., 1] = img_h - lms[..., 1] # flip the height axis
|
255 |
+
lm68s = lm68s.astype(int)
|
256 |
+
lm68s = lm68s.reshape([-1,2])
|
257 |
+
lms = lms.cpu().numpy().astype(int).reshape([-1,2])
|
258 |
+
for lm in lm68s:
|
259 |
+
img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1)
|
260 |
+
for gt_lm in lms:
|
261 |
+
img = cv2.circle(img, gt_lm, 2, (255, 0, 0), thickness=1)
|
262 |
+
imageio.imwrite(debug_name, img)
|
263 |
+
print(f"debug img saved at {debug_name}")
|
264 |
+
return coeff_dict
|
265 |
+
|
266 |
+
def out_exist_job(vid_name):
|
267 |
+
out_name = vid_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png","_coeff_fit_mp.npy")
|
268 |
+
# if os.path.exists(out_name) or not os.path.exists(lms_name):
|
269 |
+
if os.path.exists(out_name):
|
270 |
+
return None
|
271 |
+
else:
|
272 |
+
return vid_name
|
273 |
+
|
274 |
+
def get_todo_img_names(img_names):
|
275 |
+
todo_img_names = []
|
276 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=16):
|
277 |
+
if res is not None:
|
278 |
+
todo_img_names.append(res)
|
279 |
+
return todo_img_names
|
280 |
+
|
281 |
+
|
282 |
+
if __name__ == '__main__':
|
283 |
+
import argparse, glob, tqdm
|
284 |
+
parser = argparse.ArgumentParser()
|
285 |
+
parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
|
286 |
+
parser.add_argument("--ds_name", default='FFHQ')
|
287 |
+
parser.add_argument("--seed", default=0, type=int)
|
288 |
+
parser.add_argument("--process_id", default=0, type=int)
|
289 |
+
parser.add_argument("--total_process", default=1, type=int)
|
290 |
+
parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
|
291 |
+
parser.add_argument("--debug", action='store_true')
|
292 |
+
parser.add_argument("--reset", action='store_true')
|
293 |
+
parser.add_argument("--device", default="cuda:0", type=str)
|
294 |
+
parser.add_argument("--output_log", action='store_true')
|
295 |
+
parser.add_argument("--load_names", action="store_true")
|
296 |
+
|
297 |
+
args = parser.parse_args()
|
298 |
+
img_dir = args.img_dir
|
299 |
+
load_names = args.load_names
|
300 |
+
|
301 |
+
print(f"args {args}")
|
302 |
+
|
303 |
+
if args.ds_name == 'single_img':
|
304 |
+
img_names = [img_dir]
|
305 |
+
else:
|
306 |
+
img_names_path = os.path.join(img_dir, "img_dir.pkl")
|
307 |
+
if os.path.exists(img_names_path) and load_names:
|
308 |
+
print(f"loading vid names from {img_names_path}")
|
309 |
+
img_names = load_file(img_names_path)
|
310 |
+
else:
|
311 |
+
if args.ds_name == 'FFHQ_MV':
|
312 |
+
img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
|
313 |
+
img_names1 = glob.glob(img_name_pattern1)
|
314 |
+
img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
|
315 |
+
img_names2 = glob.glob(img_name_pattern2)
|
316 |
+
img_names = img_names1 + img_names2
|
317 |
+
img_names = sorted(img_names)
|
318 |
+
elif args.ds_name == 'FFHQ':
|
319 |
+
img_name_pattern = os.path.join(img_dir, "*.png")
|
320 |
+
img_names = glob.glob(img_name_pattern)
|
321 |
+
img_names = sorted(img_names)
|
322 |
+
elif args.ds_name == "PanoHeadGen":
|
323 |
+
img_name_patterns = ["ref/*/*.png"]
|
324 |
+
img_names = []
|
325 |
+
for img_name_pattern in img_name_patterns:
|
326 |
+
img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
|
327 |
+
img_names_part = glob.glob(img_name_pattern_full)
|
328 |
+
img_names.extend(img_names_part)
|
329 |
+
img_names = sorted(img_names)
|
330 |
+
print(f"saving image names to {img_names_path}")
|
331 |
+
save_file(img_names_path, img_names)
|
332 |
+
|
333 |
+
# import random
|
334 |
+
# random.seed(args.seed)
|
335 |
+
# random.shuffle(img_names)
|
336 |
+
|
337 |
+
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
|
338 |
+
camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
|
339 |
+
face_model.to(torch.device(args.device))
|
340 |
+
|
341 |
+
process_id = args.process_id
|
342 |
+
total_process = args.total_process
|
343 |
+
if total_process > 1:
|
344 |
+
assert process_id <= total_process -1 and process_id >= 0
|
345 |
+
num_samples_per_process = len(img_names) // total_process
|
346 |
+
if process_id == total_process:
|
347 |
+
img_names = img_names[process_id * num_samples_per_process : ]
|
348 |
+
else:
|
349 |
+
img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
350 |
+
print(f"image names number (before fileter): {len(img_names)}")
|
351 |
+
|
352 |
+
|
353 |
+
if not args.reset:
|
354 |
+
img_names = get_todo_img_names(img_names)
|
355 |
+
|
356 |
+
print(f"image names number (after fileter): {len(img_names)}")
|
357 |
+
for i in tqdm.trange(len(img_names), desc=f"process {process_id}: fitting 3dmm ..."):
|
358 |
+
img_name = img_names[i]
|
359 |
+
try:
|
360 |
+
fit_3dmm_for_a_image(img_name, args.debug, device=args.device)
|
361 |
+
except Exception as e:
|
362 |
+
print(img_name, e)
|
363 |
+
if args.output_log and i % max(int(len(img_names) * 0.003), 1) == 0:
|
364 |
+
print(f"process {process_id}: {i + 1} / {len(img_names)} done")
|
365 |
+
sys.stdout.flush()
|
366 |
+
sys.stderr.flush()
|
367 |
+
|
368 |
+
print(f"process {process_id}: fitting 3dmm all done")
|
369 |
+
|
data_gen/utils/process_video/euler2quaterion.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
import numba
|
5 |
+
from scipy.spatial.transform import Rotation as R
|
6 |
+
|
7 |
+
def euler2quaterion(euler, use_radian=True):
|
8 |
+
"""
|
9 |
+
euler: np.array, [batch, 3]
|
10 |
+
return: the quaterion, np.array, [batch, 4]
|
11 |
+
"""
|
12 |
+
r = R.from_euler('xyz',euler, degrees=not use_radian)
|
13 |
+
return r.as_quat()
|
14 |
+
|
15 |
+
def quaterion2euler(quat, use_radian=True):
|
16 |
+
"""
|
17 |
+
quat: np.array, [batch, 4]
|
18 |
+
return: the euler, np.array, [batch, 3]
|
19 |
+
"""
|
20 |
+
r = R.from_quat(quat)
|
21 |
+
return r.as_euler('xyz', degrees=not use_radian)
|
22 |
+
|
23 |
+
def rot2quaterion(rot):
|
24 |
+
r = R.from_matrix(rot)
|
25 |
+
return r.as_quat()
|
26 |
+
|
27 |
+
def quaterion2rot(quat):
|
28 |
+
r = R.from_quat(quat)
|
29 |
+
return r.as_matrix()
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
euler = np.array([89.999,89.999,89.999] * 100).reshape([100,3])
|
33 |
+
q = euler2quaterion(euler, use_radian=False)
|
34 |
+
e = quaterion2euler(q, use_radian=False)
|
35 |
+
print(" ")
|
data_gen/utils/process_video/extract_blink.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from data_util.face3d_helper import Face3DHelper
|
3 |
+
from utils.commons.tensor_utils import convert_to_tensor
|
4 |
+
|
5 |
+
def polygon_area(x, y):
|
6 |
+
"""
|
7 |
+
x: [T, K=6]
|
8 |
+
y: [T, K=6]
|
9 |
+
return: [T,]
|
10 |
+
"""
|
11 |
+
x_ = x - x.mean(axis=-1, keepdims=True)
|
12 |
+
y_ = y - y.mean(axis=-1, keepdims=True)
|
13 |
+
correction = x_[:,-1] * y_[:,0] - y_[:,-1]* x_[:,0]
|
14 |
+
main_area = (x_[:,:-1] * y_[:,1:]).sum(axis=-1) - (y_[:,:-1] * x_[:,1:]).sum(axis=-1)
|
15 |
+
return 0.5 * np.abs(main_area + correction)
|
16 |
+
|
17 |
+
def get_eye_area_percent(id, exp, face3d_helper):
|
18 |
+
id = convert_to_tensor(id)
|
19 |
+
exp = convert_to_tensor(exp)
|
20 |
+
cano_lm3d = face3d_helper.reconstruct_cano_lm3d(id, exp)
|
21 |
+
cano_lm2d = (cano_lm3d[..., :2] + 1) / 2
|
22 |
+
lms = cano_lm2d.cpu().numpy()
|
23 |
+
eyes_left = slice(36, 42)
|
24 |
+
eyes_right = slice(42, 48)
|
25 |
+
area_left = polygon_area(lms[:, eyes_left, 0], lms[:, eyes_left, 1])
|
26 |
+
area_right = polygon_area(lms[:, eyes_right, 0], lms[:, eyes_right, 1])
|
27 |
+
# area percentage of two eyes of the whole image...
|
28 |
+
area_percent = (area_left + area_right) / 1 * 100 # recommend threshold is 0.25%
|
29 |
+
return area_percent # [T,]
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
import numpy as np
|
34 |
+
import imageio
|
35 |
+
import cv2
|
36 |
+
import torch
|
37 |
+
from data_gen.utils.process_video.extract_lm2d import extract_lms_mediapipe_job, read_video_to_frames, index_lm68_from_lm468
|
38 |
+
from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
|
39 |
+
from data_util.face3d_helper import Face3DHelper
|
40 |
+
|
41 |
+
face3d_helper = Face3DHelper()
|
42 |
+
video_name = 'data/raw/videos/May_10s.mp4'
|
43 |
+
frames = read_video_to_frames(video_name)
|
44 |
+
coeff = fit_3dmm_for_a_video(video_name, save=False)
|
45 |
+
area_percent = get_eye_area_percent(torch.tensor(coeff['id']), torch.tensor(coeff['exp']), face3d_helper)
|
46 |
+
writer = imageio.get_writer("1.mp4", fps=25)
|
47 |
+
for idx, frame in enumerate(frames):
|
48 |
+
frame = cv2.putText(frame, f"{area_percent[idx]:.2f}", org=(128,128), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=1, color=(255,0,0), thickness=1)
|
49 |
+
writer.append_data(frame)
|
50 |
+
writer.close()
|
data_gen/utils/process_video/extract_lm2d.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
3 |
+
import sys
|
4 |
+
import glob
|
5 |
+
import cv2
|
6 |
+
import pickle
|
7 |
+
import tqdm
|
8 |
+
import numpy as np
|
9 |
+
import mediapipe as mp
|
10 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
11 |
+
from utils.commons.os_utils import multiprocess_glob
|
12 |
+
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
|
13 |
+
import warnings
|
14 |
+
import traceback
|
15 |
+
|
16 |
+
warnings.filterwarnings('ignore')
|
17 |
+
|
18 |
+
"""
|
19 |
+
基于Face_aligment的lm68已被弃用,因为其:
|
20 |
+
1. 对眼睛部位的预测精度极低
|
21 |
+
2. 无法在大偏转角度时准确预测被遮挡的下颚线, 导致大角度时3dmm的GT label就是有问题的, 从而影响性能
|
22 |
+
我们目前转而使用基于mediapipe的lm68
|
23 |
+
"""
|
24 |
+
# def extract_landmarks(ori_imgs_dir):
|
25 |
+
|
26 |
+
# print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
|
27 |
+
|
28 |
+
# fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
|
29 |
+
# image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.png'))
|
30 |
+
# for image_path in tqdm.tqdm(image_paths):
|
31 |
+
# out_name = image_path.replace("/images_512/", "/lms_2d/").replace(".png",".lms")
|
32 |
+
# if os.path.exists(out_name):
|
33 |
+
# continue
|
34 |
+
# input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
35 |
+
# input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
|
36 |
+
# preds = fa.get_landmarks(input)
|
37 |
+
# if preds is None:
|
38 |
+
# print(f"Skip {image_path} for no face detected")
|
39 |
+
# continue
|
40 |
+
# if len(preds) > 0:
|
41 |
+
# lands = preds[0].reshape(-1, 2)[:,:2]
|
42 |
+
# os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
43 |
+
# np.savetxt(out_name, lands, '%f')
|
44 |
+
# del fa
|
45 |
+
# print(f'[INFO] ===== extracted face landmarks =====')
|
46 |
+
|
47 |
+
def save_file(name, content):
|
48 |
+
with open(name, "wb") as f:
|
49 |
+
pickle.dump(content, f)
|
50 |
+
|
51 |
+
def load_file(name):
|
52 |
+
with open(name, "rb") as f:
|
53 |
+
content = pickle.load(f)
|
54 |
+
return content
|
55 |
+
|
56 |
+
|
57 |
+
face_landmarker = None
|
58 |
+
|
59 |
+
def extract_landmark_job(video_name, nerf=False):
|
60 |
+
try:
|
61 |
+
if nerf:
|
62 |
+
out_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
|
63 |
+
else:
|
64 |
+
out_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
|
65 |
+
if os.path.exists(out_name):
|
66 |
+
# print("out exists, skip...")
|
67 |
+
return
|
68 |
+
try:
|
69 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
70 |
+
except:
|
71 |
+
pass
|
72 |
+
global face_landmarker
|
73 |
+
if face_landmarker is None:
|
74 |
+
face_landmarker = MediapipeLandmarker()
|
75 |
+
img_lm478, vid_lm478 = face_landmarker.extract_lm478_from_video_name(video_name)
|
76 |
+
lm478 = face_landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
|
77 |
+
np.save(out_name, lm478)
|
78 |
+
return True
|
79 |
+
# print("Hahaha, solve one item!!!")
|
80 |
+
except Exception as e:
|
81 |
+
traceback.print_exc()
|
82 |
+
return False
|
83 |
+
|
84 |
+
def out_exist_job(vid_name):
|
85 |
+
out_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
|
86 |
+
if os.path.exists(out_name):
|
87 |
+
return None
|
88 |
+
else:
|
89 |
+
return vid_name
|
90 |
+
|
91 |
+
def get_todo_vid_names(vid_names):
|
92 |
+
if len(vid_names) == 1: # nerf
|
93 |
+
return vid_names
|
94 |
+
todo_vid_names = []
|
95 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=128):
|
96 |
+
if res is not None:
|
97 |
+
todo_vid_names.append(res)
|
98 |
+
return todo_vid_names
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
import argparse, glob, tqdm, random
|
102 |
+
parser = argparse.ArgumentParser()
|
103 |
+
parser.add_argument("--vid_dir", default='nerf')
|
104 |
+
parser.add_argument("--ds_name", default='data/raw/videos/May.mp4')
|
105 |
+
parser.add_argument("--num_workers", default=2, type=int)
|
106 |
+
parser.add_argument("--process_id", default=0, type=int)
|
107 |
+
parser.add_argument("--total_process", default=1, type=int)
|
108 |
+
parser.add_argument("--reset", action="store_true")
|
109 |
+
parser.add_argument("--load_names", action="store_true")
|
110 |
+
|
111 |
+
args = parser.parse_args()
|
112 |
+
vid_dir = args.vid_dir
|
113 |
+
ds_name = args.ds_name
|
114 |
+
load_names = args.load_names
|
115 |
+
|
116 |
+
if ds_name.lower() == 'nerf': # 处理单个视频
|
117 |
+
vid_names = [vid_dir]
|
118 |
+
out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy") for video_name in vid_names]
|
119 |
+
else: # 处理整个数据集
|
120 |
+
if ds_name in ['lrs3_trainval']:
|
121 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
122 |
+
elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
123 |
+
vid_name_pattern = os.path.join(vid_dir, "*.mp4")
|
124 |
+
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
|
125 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
126 |
+
elif ds_name in ["RAVDESS", 'VFHQ']:
|
127 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
|
128 |
+
else:
|
129 |
+
raise NotImplementedError()
|
130 |
+
|
131 |
+
vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
|
132 |
+
if os.path.exists(vid_names_path) and load_names:
|
133 |
+
print(f"loading vid names from {vid_names_path}")
|
134 |
+
vid_names = load_file(vid_names_path)
|
135 |
+
else:
|
136 |
+
vid_names = multiprocess_glob(vid_name_pattern)
|
137 |
+
vid_names = sorted(vid_names)
|
138 |
+
if not load_names:
|
139 |
+
print(f"saving vid names to {vid_names_path}")
|
140 |
+
save_file(vid_names_path, vid_names)
|
141 |
+
out_names = [video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") for video_name in vid_names]
|
142 |
+
|
143 |
+
process_id = args.process_id
|
144 |
+
total_process = args.total_process
|
145 |
+
if total_process > 1:
|
146 |
+
assert process_id <= total_process -1
|
147 |
+
num_samples_per_process = len(vid_names) // total_process
|
148 |
+
if process_id == total_process:
|
149 |
+
vid_names = vid_names[process_id * num_samples_per_process : ]
|
150 |
+
else:
|
151 |
+
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
152 |
+
|
153 |
+
if not args.reset:
|
154 |
+
vid_names = get_todo_vid_names(vid_names)
|
155 |
+
print(f"todo videos number: {len(vid_names)}")
|
156 |
+
|
157 |
+
fail_cnt = 0
|
158 |
+
job_args = [(vid_name, ds_name=='nerf') for vid_name in vid_names]
|
159 |
+
for (i, res) in multiprocess_run_tqdm(extract_landmark_job, job_args, num_workers=args.num_workers, desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
|
160 |
+
if res is False:
|
161 |
+
fail_cnt += 1
|
162 |
+
print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {fail_cnt} / {i + 1} = {fail_cnt / (i + 1):.4f}")
|
163 |
+
sys.stdout.flush()
|
164 |
+
pass
|
data_gen/utils/process_video/extract_segment_imgs.py
ADDED
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
3 |
+
import random
|
4 |
+
import glob
|
5 |
+
import cv2
|
6 |
+
import tqdm
|
7 |
+
import numpy as np
|
8 |
+
from typing import Union
|
9 |
+
from utils.commons.tensor_utils import convert_to_np
|
10 |
+
from utils.commons.os_utils import multiprocess_glob
|
11 |
+
import pickle
|
12 |
+
import traceback
|
13 |
+
import multiprocessing
|
14 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
15 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
16 |
+
from sklearn.neighbors import NearestNeighbors
|
17 |
+
from mediapipe.tasks.python import vision
|
18 |
+
from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter, encode_segmap_mask_to_image, decode_segmap_mask_from_image, job_cal_seg_map_for_image
|
19 |
+
|
20 |
+
seg_model = None
|
21 |
+
segmenter = None
|
22 |
+
mat_model = None
|
23 |
+
lama_model = None
|
24 |
+
lama_config = None
|
25 |
+
|
26 |
+
from data_gen.utils.process_video.split_video_to_imgs import extract_img_job
|
27 |
+
|
28 |
+
BG_NAME_MAP = {
|
29 |
+
"knn": "",
|
30 |
+
}
|
31 |
+
FRAME_SELECT_INTERVAL = 5
|
32 |
+
SIM_METHOD = "mse"
|
33 |
+
SIM_THRESHOLD = 3
|
34 |
+
|
35 |
+
def save_file(name, content):
|
36 |
+
with open(name, "wb") as f:
|
37 |
+
pickle.dump(content, f)
|
38 |
+
|
39 |
+
def load_file(name):
|
40 |
+
with open(name, "rb") as f:
|
41 |
+
content = pickle.load(f)
|
42 |
+
return content
|
43 |
+
|
44 |
+
def save_rgb_alpha_image_to_path(img, alpha, img_path):
|
45 |
+
try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
46 |
+
except: pass
|
47 |
+
cv2.imwrite(img_path, np.concatenate([cv2.cvtColor(img, cv2.COLOR_RGB2BGR), alpha], axis=-1))
|
48 |
+
|
49 |
+
def save_rgb_image_to_path(img, img_path):
|
50 |
+
try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
51 |
+
except: pass
|
52 |
+
cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
53 |
+
|
54 |
+
def load_rgb_image_to_path(img_path):
|
55 |
+
return cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
|
56 |
+
|
57 |
+
def image_similarity(x: np.ndarray, y: np.ndarray, method="mse"):
|
58 |
+
if method == "mse":
|
59 |
+
return np.mean((x - y) ** 2)
|
60 |
+
else:
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
def extract_background(img_lst, segmap_mask_lst=None, method="knn", device='cpu', mix_bg=True):
|
64 |
+
"""
|
65 |
+
img_lst: list of rgb ndarray
|
66 |
+
method: "knn"
|
67 |
+
"""
|
68 |
+
global segmenter
|
69 |
+
global seg_model
|
70 |
+
global mat_model
|
71 |
+
global lama_model
|
72 |
+
global lama_config
|
73 |
+
|
74 |
+
assert len(img_lst) > 0
|
75 |
+
if segmap_mask_lst is not None:
|
76 |
+
assert len(segmap_mask_lst) == len(img_lst)
|
77 |
+
else:
|
78 |
+
del segmenter
|
79 |
+
del seg_model
|
80 |
+
seg_model = MediapipeSegmenter()
|
81 |
+
segmenter = vision.ImageSegmenter.create_from_options(seg_model.video_options)
|
82 |
+
|
83 |
+
def get_segmap_mask(img_lst, segmap_mask_lst, index):
|
84 |
+
if segmap_mask_lst is not None:
|
85 |
+
segmap = refresh_segment_mask(segmap_mask_lst[index])
|
86 |
+
else:
|
87 |
+
segmap = seg_model._cal_seg_map(refresh_image(img_lst[index]), segmenter=segmenter)
|
88 |
+
return segmap
|
89 |
+
|
90 |
+
if method == "knn":
|
91 |
+
num_frames = len(img_lst)
|
92 |
+
if num_frames < 100:
|
93 |
+
FRAME_SELECT_INTERVAL = 5
|
94 |
+
elif num_frames < 10000:
|
95 |
+
FRAME_SELECT_INTERVAL = 20
|
96 |
+
else:
|
97 |
+
FRAME_SELECT_INTERVAL = num_frames // 500
|
98 |
+
|
99 |
+
img_lst = img_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else img_lst[0:1]
|
100 |
+
|
101 |
+
if segmap_mask_lst is not None:
|
102 |
+
segmap_mask_lst = segmap_mask_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else segmap_mask_lst[0:1]
|
103 |
+
assert len(img_lst) == len(segmap_mask_lst)
|
104 |
+
# get H/W
|
105 |
+
h, w = refresh_image(img_lst[0]).shape[:2]
|
106 |
+
|
107 |
+
# nearest neighbors
|
108 |
+
all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() # [512*512, 2] coordinate grid
|
109 |
+
distss = []
|
110 |
+
for idx, img in tqdm.tqdm(enumerate(img_lst), desc='combining backgrounds...', total=len(img_lst)):
|
111 |
+
segmap = get_segmap_mask(img_lst=img_lst, segmap_mask_lst=segmap_mask_lst, index=idx)
|
112 |
+
bg = (segmap[0]).astype(bool) # [h,w] bool mask
|
113 |
+
fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) # [N_nonbg,2] coordinate of non-bg pixels
|
114 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
115 |
+
dists, _ = nbrs.kneighbors(all_xys) # [512*512, 1] distance to nearest non-bg pixel
|
116 |
+
distss.append(dists)
|
117 |
+
|
118 |
+
distss = np.stack(distss) # [B, 512*512, 1]
|
119 |
+
max_dist = np.max(distss, 0) # [512*512, 1]
|
120 |
+
max_id = np.argmax(distss, 0) # id of frame
|
121 |
+
|
122 |
+
bc_pixs = max_dist > 10 # 在各个frame有一个出现过是bg的pixel,bg标准是离最近的non-bg pixel距离大于10
|
123 |
+
bc_pixs_id = np.nonzero(bc_pixs)
|
124 |
+
bc_ids = max_id[bc_pixs]
|
125 |
+
|
126 |
+
# TODO: maybe we should reimplement here to avoid memory costs?
|
127 |
+
# though there is upper limits of images here
|
128 |
+
num_pixs = distss.shape[1]
|
129 |
+
bg_img = np.zeros((h*w, 3), dtype=np.uint8)
|
130 |
+
img_lst = [refresh_image(img) for img in img_lst]
|
131 |
+
imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
|
132 |
+
bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] # 对那些铁bg的pixel,直接去对应的image里面采样
|
133 |
+
bg_img = bg_img.reshape(h, w, 3)
|
134 |
+
|
135 |
+
max_dist = max_dist.reshape(h, w)
|
136 |
+
bc_pixs = max_dist > 10 # 5
|
137 |
+
bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
|
138 |
+
fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
|
139 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
140 |
+
distances, indices = nbrs.kneighbors(bg_xys) # 对non-bg img,用KNN找最近的bg pixel
|
141 |
+
bg_fg_xys = fg_xys[indices[:, 0]]
|
142 |
+
bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
|
143 |
+
else:
|
144 |
+
raise NotImplementedError # deperated
|
145 |
+
|
146 |
+
return bg_img
|
147 |
+
|
148 |
+
def inpaint_torso_job(gt_img, segmap):
|
149 |
+
bg_part = (segmap[0]).astype(bool)
|
150 |
+
head_part = (segmap[1] + segmap[3] + segmap[5]).astype(bool)
|
151 |
+
neck_part = (segmap[2]).astype(bool)
|
152 |
+
torso_part = (segmap[4]).astype(bool)
|
153 |
+
img = gt_img.copy()
|
154 |
+
img[head_part] = 0
|
155 |
+
|
156 |
+
# torso part "vertical" in-painting...
|
157 |
+
L = 8 + 1
|
158 |
+
torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
|
159 |
+
# lexsort: sort 2D coords first by y then by x,
|
160 |
+
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
161 |
+
inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
|
162 |
+
torso_coords = torso_coords[inds]
|
163 |
+
# choose the top pixel for each column
|
164 |
+
u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
|
165 |
+
top_torso_coords = torso_coords[uid] # [m, 2]
|
166 |
+
# only keep top-is-head pixels
|
167 |
+
top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
|
168 |
+
mask = head_part[tuple(top_torso_coords_up.T)]
|
169 |
+
if mask.any():
|
170 |
+
top_torso_coords = top_torso_coords[mask]
|
171 |
+
# get the color
|
172 |
+
top_torso_colors = gt_img[tuple(top_torso_coords.T)] # [m, 3]
|
173 |
+
# construct inpaint coords (vertically up, or minus in x)
|
174 |
+
inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
|
175 |
+
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
176 |
+
inpaint_torso_coords += inpaint_offsets
|
177 |
+
inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
|
178 |
+
inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
|
179 |
+
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
180 |
+
inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
181 |
+
# set color
|
182 |
+
img[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
|
183 |
+
inpaint_torso_mask = np.zeros_like(img[..., 0]).astype(bool)
|
184 |
+
inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
|
185 |
+
else:
|
186 |
+
inpaint_torso_mask = None
|
187 |
+
|
188 |
+
# neck part "vertical" in-painting...
|
189 |
+
push_down = 4
|
190 |
+
L = 48 + push_down + 1
|
191 |
+
neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
|
192 |
+
neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
|
193 |
+
# lexsort: sort 2D coords first by y then by x,
|
194 |
+
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
195 |
+
inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
|
196 |
+
neck_coords = neck_coords[inds]
|
197 |
+
# choose the top pixel for each column
|
198 |
+
u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
|
199 |
+
top_neck_coords = neck_coords[uid] # [m, 2]
|
200 |
+
# only keep top-is-head pixels
|
201 |
+
top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
|
202 |
+
mask = head_part[tuple(top_neck_coords_up.T)]
|
203 |
+
top_neck_coords = top_neck_coords[mask]
|
204 |
+
# push these top down for 4 pixels to make the neck inpainting more natural...
|
205 |
+
offset_down = np.minimum(ucnt[mask] - 1, push_down)
|
206 |
+
top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
|
207 |
+
# get the color
|
208 |
+
top_neck_colors = gt_img[tuple(top_neck_coords.T)] # [m, 3]
|
209 |
+
# construct inpaint coords (vertically up, or minus in x)
|
210 |
+
inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
|
211 |
+
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
212 |
+
inpaint_neck_coords += inpaint_offsets
|
213 |
+
inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
|
214 |
+
inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
|
215 |
+
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
216 |
+
inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
217 |
+
# set color
|
218 |
+
img[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
|
219 |
+
# apply blurring to the inpaint area to avoid vertical-line artifects...
|
220 |
+
inpaint_mask = np.zeros_like(img[..., 0]).astype(bool)
|
221 |
+
inpaint_mask[tuple(inpaint_neck_coords.T)] = True
|
222 |
+
|
223 |
+
blur_img = img.copy()
|
224 |
+
blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
|
225 |
+
img[inpaint_mask] = blur_img[inpaint_mask]
|
226 |
+
|
227 |
+
# set mask
|
228 |
+
torso_img_mask = (neck_part | torso_part | inpaint_mask)
|
229 |
+
torso_with_bg_img_mask = (bg_part | neck_part | torso_part | inpaint_mask)
|
230 |
+
if inpaint_torso_mask is not None:
|
231 |
+
torso_img_mask = torso_img_mask | inpaint_torso_mask
|
232 |
+
torso_with_bg_img_mask = torso_with_bg_img_mask | inpaint_torso_mask
|
233 |
+
|
234 |
+
torso_img = img.copy()
|
235 |
+
torso_img[~torso_img_mask] = 0
|
236 |
+
torso_with_bg_img = img.copy()
|
237 |
+
torso_img[~torso_with_bg_img_mask] = 0
|
238 |
+
|
239 |
+
return torso_img, torso_img_mask, torso_with_bg_img, torso_with_bg_img_mask
|
240 |
+
|
241 |
+
def load_segment_mask_from_file(filename: str):
|
242 |
+
encoded_segmap = load_rgb_image_to_path(filename)
|
243 |
+
segmap_mask = decode_segmap_mask_from_image(encoded_segmap)
|
244 |
+
return segmap_mask
|
245 |
+
|
246 |
+
# load segment mask to memory if not loaded yet
|
247 |
+
def refresh_segment_mask(segmap_mask: Union[str, np.ndarray]):
|
248 |
+
if isinstance(segmap_mask, str):
|
249 |
+
segmap_mask = load_segment_mask_from_file(segmap_mask)
|
250 |
+
return segmap_mask
|
251 |
+
|
252 |
+
# load segment mask to memory if not loaded yet
|
253 |
+
def refresh_image(image: Union[str, np.ndarray]):
|
254 |
+
if isinstance(image, str):
|
255 |
+
image = load_rgb_image_to_path(image)
|
256 |
+
return image
|
257 |
+
|
258 |
+
def generate_segment_imgs_job(img_name, segmap, img):
|
259 |
+
out_img_name = segmap_name = img_name.replace("/gt_imgs/", "/segmaps/").replace(".jpg", ".png") # 存成jpg的话,pixel value会有误差
|
260 |
+
try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
|
261 |
+
except: pass
|
262 |
+
encoded_segmap = encode_segmap_mask_to_image(segmap)
|
263 |
+
save_rgb_image_to_path(encoded_segmap, out_img_name)
|
264 |
+
|
265 |
+
for mode in ['head', 'torso', 'person', 'bg']:
|
266 |
+
out_img, mask = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
|
267 |
+
img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
|
268 |
+
mask = mask[0][..., None]
|
269 |
+
img_alpha[~mask] = 0
|
270 |
+
out_img_name = img_name.replace("/gt_imgs/", f"/{mode}_imgs/").replace(".jpg", ".png")
|
271 |
+
save_rgb_alpha_image_to_path(out_img, img_alpha, out_img_name)
|
272 |
+
|
273 |
+
inpaint_torso_img, inpaint_torso_img_mask, inpaint_torso_with_bg_img, inpaint_torso_with_bg_img_mask = inpaint_torso_job(img, segmap)
|
274 |
+
img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
|
275 |
+
img_alpha[~inpaint_torso_img_mask[..., None]] = 0
|
276 |
+
out_img_name = img_name.replace("/gt_imgs/", f"/inpaint_torso_imgs/").replace(".jpg", ".png")
|
277 |
+
save_rgb_alpha_image_to_path(inpaint_torso_img, img_alpha, out_img_name)
|
278 |
+
return segmap_name
|
279 |
+
|
280 |
+
def segment_and_generate_for_image_job(img_name, img, segmenter_options=None, segmenter=None, store_in_memory=False):
|
281 |
+
img = refresh_image(img)
|
282 |
+
segmap_mask, segmap_image = job_cal_seg_map_for_image(img, segmenter_options=segmenter_options, segmenter=segmenter)
|
283 |
+
segmap_name = generate_segment_imgs_job(img_name=img_name, segmap=segmap_mask, img=img)
|
284 |
+
if store_in_memory:
|
285 |
+
return segmap_mask
|
286 |
+
else:
|
287 |
+
return segmap_name
|
288 |
+
|
289 |
+
def extract_segment_job(
|
290 |
+
video_name,
|
291 |
+
nerf=False,
|
292 |
+
background_method='knn',
|
293 |
+
device="cpu",
|
294 |
+
total_gpus=0,
|
295 |
+
mix_bg=True,
|
296 |
+
store_in_memory=False, # set to True to speed up a bit of preprocess, but leads to HUGE memory costs (100GB for 5-min video)
|
297 |
+
force_single_process=False, # turn this on if you find multi-process does not work on your environment
|
298 |
+
):
|
299 |
+
global segmenter
|
300 |
+
global seg_model
|
301 |
+
del segmenter
|
302 |
+
del seg_model
|
303 |
+
seg_model = MediapipeSegmenter()
|
304 |
+
segmenter = vision.ImageSegmenter.create_from_options(seg_model.options)
|
305 |
+
# nerf means that we extract only one video, so can enable multi-process acceleration
|
306 |
+
multiprocess_enable = nerf and not force_single_process
|
307 |
+
try:
|
308 |
+
if "cuda" in device:
|
309 |
+
# determine which cuda index from subprocess id
|
310 |
+
pname = multiprocessing.current_process().name
|
311 |
+
pid = int(pname.rsplit("-", 1)[-1]) - 1
|
312 |
+
cuda_id = pid % total_gpus
|
313 |
+
device = f"cuda:{cuda_id}"
|
314 |
+
|
315 |
+
if nerf: # single video
|
316 |
+
raw_img_dir = video_name.replace(".mp4", "/gt_imgs/").replace("/raw/","/processed/")
|
317 |
+
else: # whole dataset
|
318 |
+
raw_img_dir = video_name.replace(".mp4", "").replace("/video/", "/gt_imgs/")
|
319 |
+
if not os.path.exists(raw_img_dir):
|
320 |
+
extract_img_job(video_name, raw_img_dir) # use ffmpeg to split video into imgs
|
321 |
+
|
322 |
+
img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
|
323 |
+
|
324 |
+
img_lst = []
|
325 |
+
|
326 |
+
for img_name in img_names:
|
327 |
+
if store_in_memory:
|
328 |
+
img = load_rgb_image_to_path(img_name)
|
329 |
+
else:
|
330 |
+
img = img_name
|
331 |
+
img_lst.append(img)
|
332 |
+
|
333 |
+
print("| Extracting Segmaps && Saving...")
|
334 |
+
args = []
|
335 |
+
segmap_mask_lst = []
|
336 |
+
# preparing parameters for segment
|
337 |
+
for i in range(len(img_lst)):
|
338 |
+
img_name = img_names[i]
|
339 |
+
img = img_lst[i]
|
340 |
+
if multiprocess_enable: # create seg_model in subprocesses here
|
341 |
+
options = seg_model.options
|
342 |
+
segmenter_arg = None
|
343 |
+
else: # use seg_model of this process
|
344 |
+
options = None
|
345 |
+
segmenter_arg = segmenter
|
346 |
+
arg = (img_name, img, options, segmenter_arg, store_in_memory)
|
347 |
+
args.append(arg)
|
348 |
+
|
349 |
+
if multiprocess_enable:
|
350 |
+
for (_, res) in multiprocess_run_tqdm(segment_and_generate_for_image_job, args=args, num_workers=16, desc='generating segment images in multi-processes...'):
|
351 |
+
segmap_mask = res
|
352 |
+
segmap_mask_lst.append(segmap_mask)
|
353 |
+
else:
|
354 |
+
for index in tqdm.tqdm(range(len(img_lst)), desc="generating segment images in single-process..."):
|
355 |
+
segmap_mask = segment_and_generate_for_image_job(*args[index])
|
356 |
+
segmap_mask_lst.append(segmap_mask)
|
357 |
+
print("| Extracted Segmaps Done.")
|
358 |
+
|
359 |
+
print("| Extracting background...")
|
360 |
+
bg_prefix_name = f"bg{BG_NAME_MAP[background_method]}"
|
361 |
+
bg_img = extract_background(img_lst, segmap_mask_lst, method=background_method, device=device, mix_bg=mix_bg)
|
362 |
+
if nerf:
|
363 |
+
out_img_name = video_name.replace("/raw/", "/processed/").replace(".mp4", f"/{bg_prefix_name}.jpg")
|
364 |
+
else:
|
365 |
+
out_img_name = video_name.replace("/video/", f"/{bg_prefix_name}_img/").replace(".mp4", ".jpg")
|
366 |
+
save_rgb_image_to_path(bg_img, out_img_name)
|
367 |
+
print("| Extracted background done.")
|
368 |
+
|
369 |
+
print("| Extracting com_imgs...")
|
370 |
+
com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
|
371 |
+
for i in tqdm.trange(len(img_names), desc='extracting com_imgs'):
|
372 |
+
img_name = img_names[i]
|
373 |
+
com_img = refresh_image(img_lst[i]).copy()
|
374 |
+
segmap = refresh_segment_mask(segmap_mask_lst[i])
|
375 |
+
bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
|
376 |
+
com_img[bg_part] = bg_img[bg_part]
|
377 |
+
out_img_name = img_name.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
|
378 |
+
save_rgb_image_to_path(com_img, out_img_name)
|
379 |
+
print("| Extracted com_imgs done.")
|
380 |
+
|
381 |
+
return 0
|
382 |
+
except Exception as e:
|
383 |
+
print(str(type(e)), e)
|
384 |
+
traceback.print_exc(e)
|
385 |
+
return 1
|
386 |
+
|
387 |
+
def out_exist_job(vid_name, background_method='knn'):
|
388 |
+
com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
|
389 |
+
img_dir = vid_name.replace("/video/", "/gt_imgs/").replace(".mp4", "")
|
390 |
+
out_dir1 = img_dir.replace("/gt_imgs/", "/head_imgs/")
|
391 |
+
out_dir2 = img_dir.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
|
392 |
+
|
393 |
+
if os.path.exists(img_dir) and os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) :
|
394 |
+
num_frames = len(os.listdir(img_dir))
|
395 |
+
if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames:
|
396 |
+
return None
|
397 |
+
else:
|
398 |
+
return vid_name
|
399 |
+
else:
|
400 |
+
return vid_name
|
401 |
+
|
402 |
+
def get_todo_vid_names(vid_names, background_method='knn'):
|
403 |
+
if len(vid_names) == 1: # nerf
|
404 |
+
return vid_names
|
405 |
+
todo_vid_names = []
|
406 |
+
fn_args = [(vid_name, background_method) for vid_name in vid_names]
|
407 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, fn_args, num_workers=16, desc="checking todo videos..."):
|
408 |
+
if res is not None:
|
409 |
+
todo_vid_names.append(res)
|
410 |
+
return todo_vid_names
|
411 |
+
|
412 |
+
if __name__ == '__main__':
|
413 |
+
import argparse, glob, tqdm, random
|
414 |
+
parser = argparse.ArgumentParser()
|
415 |
+
parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/TH1KH_512/video')
|
416 |
+
parser.add_argument("--ds_name", default='TH1KH_512')
|
417 |
+
parser.add_argument("--num_workers", default=48, type=int)
|
418 |
+
parser.add_argument("--seed", default=0, type=int)
|
419 |
+
parser.add_argument("--process_id", default=0, type=int)
|
420 |
+
parser.add_argument("--total_process", default=1, type=int)
|
421 |
+
parser.add_argument("--reset", action='store_true')
|
422 |
+
parser.add_argument("--load_names", action="store_true")
|
423 |
+
parser.add_argument("--background_method", choices=['knn', 'mat', 'ddnm', 'lama'], type=str, default='knn')
|
424 |
+
parser.add_argument("--total_gpus", default=0, type=int) # zero gpus means utilizing cpu
|
425 |
+
parser.add_argument("--no_mix_bg", action="store_true")
|
426 |
+
parser.add_argument("--store_in_memory", action="store_true") # set to True to speed up preprocess, but leads to high memory costs
|
427 |
+
parser.add_argument("--force_single_process", action="store_true") # turn this on if you find multi-process does not work on your environment
|
428 |
+
|
429 |
+
args = parser.parse_args()
|
430 |
+
vid_dir = args.vid_dir
|
431 |
+
ds_name = args.ds_name
|
432 |
+
load_names = args.load_names
|
433 |
+
background_method = args.background_method
|
434 |
+
total_gpus = args.total_gpus
|
435 |
+
mix_bg = not args.no_mix_bg
|
436 |
+
store_in_memory = args.store_in_memory
|
437 |
+
force_single_process = args.force_single_process
|
438 |
+
|
439 |
+
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
|
440 |
+
for d in devices[:total_gpus]:
|
441 |
+
os.system(f'pkill -f "voidgpu{d}"')
|
442 |
+
|
443 |
+
if ds_name.lower() == 'nerf': # 处理单个视频
|
444 |
+
vid_names = [vid_dir]
|
445 |
+
out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_lms.npy") for video_name in vid_names]
|
446 |
+
else: # 处理整个数据集
|
447 |
+
if ds_name in ['lrs3_trainval']:
|
448 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
449 |
+
elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
450 |
+
vid_name_pattern = os.path.join(vid_dir, "*.mp4")
|
451 |
+
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
|
452 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
453 |
+
elif ds_name in ["RAVDESS", 'VFHQ']:
|
454 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
|
455 |
+
else:
|
456 |
+
raise NotImplementedError()
|
457 |
+
|
458 |
+
vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
|
459 |
+
if os.path.exists(vid_names_path) and load_names:
|
460 |
+
print(f"loading vid names from {vid_names_path}")
|
461 |
+
vid_names = load_file(vid_names_path)
|
462 |
+
else:
|
463 |
+
vid_names = multiprocess_glob(vid_name_pattern)
|
464 |
+
vid_names = sorted(vid_names)
|
465 |
+
print(f"saving vid names to {vid_names_path}")
|
466 |
+
save_file(vid_names_path, vid_names)
|
467 |
+
|
468 |
+
vid_names = sorted(vid_names)
|
469 |
+
random.seed(args.seed)
|
470 |
+
random.shuffle(vid_names)
|
471 |
+
|
472 |
+
process_id = args.process_id
|
473 |
+
total_process = args.total_process
|
474 |
+
if total_process > 1:
|
475 |
+
assert process_id <= total_process -1
|
476 |
+
num_samples_per_process = len(vid_names) // total_process
|
477 |
+
if process_id == total_process:
|
478 |
+
vid_names = vid_names[process_id * num_samples_per_process : ]
|
479 |
+
else:
|
480 |
+
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
481 |
+
|
482 |
+
if not args.reset:
|
483 |
+
vid_names = get_todo_vid_names(vid_names, background_method)
|
484 |
+
print(f"todo videos number: {len(vid_names)}")
|
485 |
+
|
486 |
+
device = "cuda" if total_gpus > 0 else "cpu"
|
487 |
+
extract_job = extract_segment_job
|
488 |
+
fn_args = [(vid_name, ds_name=='nerf', background_method, device, total_gpus, mix_bg, store_in_memory, force_single_process) for i, vid_name in enumerate(vid_names)]
|
489 |
+
|
490 |
+
if ds_name == 'nerf': # 处理单个视频
|
491 |
+
extract_job(*fn_args[0])
|
492 |
+
else:
|
493 |
+
for vid_name in multiprocess_run_tqdm(extract_job, fn_args, desc=f"Root process {args.process_id}: segment images", num_workers=args.num_workers):
|
494 |
+
pass
|
data_gen/utils/process_video/fit_3dmm_landmark.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is a script for efficienct 3DMM coefficient extraction.
|
2 |
+
# It could reconstruct accurate 3D face in real-time.
|
3 |
+
# It is built upon BFM 2009 model and mediapipe landmark extractor.
|
4 |
+
# It is authored by ZhenhuiYe ([email protected]), free to contact him for any suggestion on improvement!
|
5 |
+
|
6 |
+
from numpy.core.numeric import require
|
7 |
+
from numpy.lib.function_base import quantile
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import copy
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import random
|
14 |
+
import pickle
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
import cv2
|
18 |
+
import argparse
|
19 |
+
import tqdm
|
20 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
21 |
+
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker, read_video_to_frames
|
22 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
23 |
+
from deep_3drecon.secc_renderer import SECC_Renderer
|
24 |
+
from utils.commons.os_utils import multiprocess_glob
|
25 |
+
|
26 |
+
|
27 |
+
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
|
28 |
+
camera_distance=10, focal=1015, keypoint_mode='mediapipe')
|
29 |
+
face_model.to(torch.device("cuda:0"))
|
30 |
+
|
31 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
32 |
+
|
33 |
+
|
34 |
+
def draw_axes(img, pitch, yaw, roll, tx, ty, size=50):
|
35 |
+
# yaw = -yaw
|
36 |
+
pitch = - pitch
|
37 |
+
roll = - roll
|
38 |
+
rotation_matrix = cv2.Rodrigues(np.array([pitch, yaw, roll]))[0].astype(np.float64)
|
39 |
+
axes_points = np.array([
|
40 |
+
[1, 0, 0, 0],
|
41 |
+
[0, 1, 0, 0],
|
42 |
+
[0, 0, 1, 0]
|
43 |
+
], dtype=np.float64)
|
44 |
+
axes_points = rotation_matrix @ axes_points
|
45 |
+
axes_points = (axes_points[:2, :] * size).astype(int)
|
46 |
+
axes_points[0, :] = axes_points[0, :] + tx
|
47 |
+
axes_points[1, :] = axes_points[1, :] + ty
|
48 |
+
|
49 |
+
new_img = img.copy()
|
50 |
+
cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 0].ravel()), (255, 0, 0), 3)
|
51 |
+
cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 1].ravel()), (0, 255, 0), 3)
|
52 |
+
cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 2].ravel()), (0, 0, 255), 3)
|
53 |
+
return new_img
|
54 |
+
|
55 |
+
def save_file(name, content):
|
56 |
+
with open(name, "wb") as f:
|
57 |
+
pickle.dump(content, f)
|
58 |
+
|
59 |
+
def load_file(name):
|
60 |
+
with open(name, "rb") as f:
|
61 |
+
content = pickle.load(f)
|
62 |
+
return content
|
63 |
+
|
64 |
+
def cal_lap_loss(in_tensor):
|
65 |
+
# [T, 68, 2]
|
66 |
+
t = in_tensor.shape[0]
|
67 |
+
in_tensor = in_tensor.reshape([t, -1]).permute(1,0).unsqueeze(1) # [c, 1, t]
|
68 |
+
in_tensor = torch.cat([in_tensor[:, :, 0:1], in_tensor, in_tensor[:, :, -1:]], dim=-1)
|
69 |
+
lap_kernel = torch.Tensor((-0.5, 1.0, -0.5)).reshape([1,1,3]).float().to(in_tensor.device) # [1, 1, kw]
|
70 |
+
loss_lap = 0
|
71 |
+
|
72 |
+
out_tensor = F.conv1d(in_tensor, lap_kernel)
|
73 |
+
loss_lap += torch.mean(out_tensor**2)
|
74 |
+
return loss_lap
|
75 |
+
|
76 |
+
def cal_vel_loss(ldm):
|
77 |
+
# [B, 68, 2]
|
78 |
+
vel = ldm[1:] - ldm[:-1]
|
79 |
+
return torch.mean(torch.abs(vel))
|
80 |
+
|
81 |
+
def cal_lan_loss(proj_lan, gt_lan):
|
82 |
+
# [B, 68, 2]
|
83 |
+
loss = (proj_lan - gt_lan)** 2
|
84 |
+
# use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
|
85 |
+
weights = torch.zeros_like(loss)
|
86 |
+
weights = torch.ones_like(loss)
|
87 |
+
weights[:, 36:48, :] = 3 # eye 12 points
|
88 |
+
weights[:, -8:, :] = 3 # inner lip 8 points
|
89 |
+
weights[:, 28:31, :] = 3 # nose 3 points
|
90 |
+
loss = loss * weights
|
91 |
+
return torch.mean(loss)
|
92 |
+
|
93 |
+
def cal_lan_loss_mp(proj_lan, gt_lan, mean:bool=True):
|
94 |
+
# [B, 68, 2]
|
95 |
+
loss = (proj_lan - gt_lan).pow(2)
|
96 |
+
# loss = (proj_lan - gt_lan).abs()
|
97 |
+
unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
|
98 |
+
upper_eye = [161,160,159,158,157] + [388,387,386,385,384]
|
99 |
+
eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
|
100 |
+
inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
|
101 |
+
outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
|
102 |
+
weights = torch.ones_like(loss)
|
103 |
+
weights[:, eye] = 3
|
104 |
+
weights[:, upper_eye] = 20
|
105 |
+
weights[:, inner_lip] = 5
|
106 |
+
weights[:, outer_lip] = 5
|
107 |
+
weights[:, unmatch_mask] = 0
|
108 |
+
loss = loss * weights
|
109 |
+
if mean:
|
110 |
+
loss = torch.mean(loss)
|
111 |
+
return loss
|
112 |
+
|
113 |
+
def cal_acceleration_loss(trans):
|
114 |
+
vel = trans[1:] - trans[:-1]
|
115 |
+
acc = vel[1:] - vel[:-1]
|
116 |
+
return torch.mean(torch.abs(acc))
|
117 |
+
|
118 |
+
def cal_acceleration_ldm_loss(ldm):
|
119 |
+
# [B, 68, 2]
|
120 |
+
vel = ldm[1:] - ldm[:-1]
|
121 |
+
acc = vel[1:] - vel[:-1]
|
122 |
+
lip_weight = 0.25 # we dont want smooth the lip too much
|
123 |
+
acc[48:68] *= lip_weight
|
124 |
+
return torch.mean(torch.abs(acc))
|
125 |
+
|
126 |
+
def set_requires_grad(tensor_list):
|
127 |
+
for tensor in tensor_list:
|
128 |
+
tensor.requires_grad = True
|
129 |
+
|
130 |
+
@torch.enable_grad()
|
131 |
+
def fit_3dmm_for_a_video(
|
132 |
+
video_name,
|
133 |
+
nerf=False, # use the file name convention for GeneFace++
|
134 |
+
id_mode='global',
|
135 |
+
debug=False,
|
136 |
+
keypoint_mode='mediapipe',
|
137 |
+
large_yaw_threshold=9999999.9,
|
138 |
+
save=True
|
139 |
+
) -> bool: # True: good, False: bad
|
140 |
+
assert video_name.endswith(".mp4"), "this function only support video as input"
|
141 |
+
if id_mode == 'global':
|
142 |
+
LAMBDA_REG_ID = 0.2
|
143 |
+
LAMBDA_REG_EXP = 0.6
|
144 |
+
LAMBDA_REG_LAP = 1.0
|
145 |
+
LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
|
146 |
+
LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
|
147 |
+
else:
|
148 |
+
LAMBDA_REG_ID = 0.3
|
149 |
+
LAMBDA_REG_EXP = 0.05
|
150 |
+
LAMBDA_REG_LAP = 1.0
|
151 |
+
LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
|
152 |
+
LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
|
153 |
+
|
154 |
+
frames = read_video_to_frames(video_name) # [T, H, W, 3]
|
155 |
+
img_h, img_w = frames.shape[1], frames.shape[2]
|
156 |
+
assert img_h == img_w
|
157 |
+
num_frames = len(frames)
|
158 |
+
|
159 |
+
if nerf: # single video
|
160 |
+
lm_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
|
161 |
+
else:
|
162 |
+
lm_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4", "_lms.npy")
|
163 |
+
|
164 |
+
if os.path.exists(lm_name):
|
165 |
+
lms = np.load(lm_name)
|
166 |
+
else:
|
167 |
+
print(f"lms_2d file not found, try to extract it from video... {lm_name}")
|
168 |
+
try:
|
169 |
+
landmarker = MediapipeLandmarker()
|
170 |
+
img_lm478, vid_lm478 = landmarker.extract_lm478_from_frames(frames, anti_smooth_factor=20)
|
171 |
+
lms = landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
|
172 |
+
except Exception as e:
|
173 |
+
print(e)
|
174 |
+
return False
|
175 |
+
if lms is None:
|
176 |
+
print(f"get None lms_2d, please check whether each frame has one head, exiting... {lm_name}")
|
177 |
+
return False
|
178 |
+
lms = lms[:, :468, :]
|
179 |
+
lms = torch.FloatTensor(lms).cuda()
|
180 |
+
lms[..., 1] = img_h - lms[..., 1] # flip the height axis
|
181 |
+
|
182 |
+
if keypoint_mode == 'mediapipe':
|
183 |
+
# default
|
184 |
+
cal_lan_loss_fn = cal_lan_loss_mp
|
185 |
+
if nerf: # single video
|
186 |
+
out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/coeff_fit_mp.npy")
|
187 |
+
else:
|
188 |
+
out_name = video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4", "_coeff_fit_mp.npy")
|
189 |
+
else:
|
190 |
+
# lm68 is less accurate than mp
|
191 |
+
cal_lan_loss_fn = cal_lan_loss
|
192 |
+
if nerf: # single video
|
193 |
+
out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "_coeff_fit_lm68.npy")
|
194 |
+
else:
|
195 |
+
out_name = video_name.replace("/video/", "/coeff_fit_lm68/").replace(".mp4", "_coeff_fit_lm68.npy")
|
196 |
+
try:
|
197 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
198 |
+
except:
|
199 |
+
pass
|
200 |
+
|
201 |
+
id_dim, exp_dim = 80, 64
|
202 |
+
sel_ids = np.arange(0, num_frames, 40)
|
203 |
+
|
204 |
+
h = w = face_model.center * 2
|
205 |
+
img_scale_factor = img_h / h
|
206 |
+
lms /= img_scale_factor # rescale lms into [0,224]
|
207 |
+
|
208 |
+
if id_mode == 'global':
|
209 |
+
# default choice by GeneFace++ and later works
|
210 |
+
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
|
211 |
+
elif id_mode == 'finegrained':
|
212 |
+
# legacy choice by GeneFace1 (ICLR 2023)
|
213 |
+
id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True)
|
214 |
+
else: raise NotImplementedError(f"id mode {id_mode} not supported! we only support global or finegrained.")
|
215 |
+
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
216 |
+
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
|
217 |
+
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
|
218 |
+
|
219 |
+
set_requires_grad([id_para, exp_para, euler_angle, trans])
|
220 |
+
|
221 |
+
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
|
222 |
+
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
|
223 |
+
|
224 |
+
# 其他参数初始化,先训练euler和trans
|
225 |
+
for _ in range(200):
|
226 |
+
if id_mode == 'global':
|
227 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
228 |
+
id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans)
|
229 |
+
else:
|
230 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
231 |
+
id_para, exp_para, euler_angle, trans)
|
232 |
+
loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
|
233 |
+
loss = loss_lan
|
234 |
+
optimizer_frame.zero_grad()
|
235 |
+
loss.backward()
|
236 |
+
optimizer_frame.step()
|
237 |
+
|
238 |
+
# print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
|
239 |
+
# print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
|
240 |
+
|
241 |
+
for param_group in optimizer_frame.param_groups:
|
242 |
+
param_group['lr'] = 0.1
|
243 |
+
|
244 |
+
# "jointly roughly training id exp euler trans"
|
245 |
+
for _ in range(200):
|
246 |
+
ret = {}
|
247 |
+
if id_mode == 'global':
|
248 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
249 |
+
id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans, ret)
|
250 |
+
else:
|
251 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
252 |
+
id_para, exp_para, euler_angle, trans, ret)
|
253 |
+
loss_lan = cal_lan_loss_fn(
|
254 |
+
proj_geo[:, :, :2], lms.detach())
|
255 |
+
# loss_lap = cal_lap_loss(proj_geo)
|
256 |
+
# laplacian对euler影响不大,但是对trans的提升很大
|
257 |
+
loss_lap = cal_lap_loss(id_para) + cal_lap_loss(exp_para) + cal_lap_loss(euler_angle) * 0.3 + cal_lap_loss(trans) * 0.3
|
258 |
+
|
259 |
+
loss_regid = torch.mean(id_para*id_para) # 正则化
|
260 |
+
loss_regexp = torch.mean(exp_para * exp_para)
|
261 |
+
|
262 |
+
loss_vel_id = cal_vel_loss(id_para)
|
263 |
+
loss_vel_exp = cal_vel_loss(exp_para)
|
264 |
+
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP + loss_lap * LAMBDA_REG_LAP
|
265 |
+
optimizer_idexp.zero_grad()
|
266 |
+
optimizer_frame.zero_grad()
|
267 |
+
loss.backward()
|
268 |
+
optimizer_idexp.step()
|
269 |
+
optimizer_frame.step()
|
270 |
+
|
271 |
+
# print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
|
272 |
+
# print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
|
273 |
+
# print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
|
274 |
+
|
275 |
+
# start fine training, intialize from the roughly trained results
|
276 |
+
if id_mode == 'global':
|
277 |
+
id_para_ = lms.new_zeros((1, id_dim), requires_grad=False)
|
278 |
+
else:
|
279 |
+
id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
|
280 |
+
id_para_.data = id_para.data.clone()
|
281 |
+
id_para = id_para_
|
282 |
+
exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
283 |
+
exp_para_.data = exp_para.data.clone()
|
284 |
+
exp_para = exp_para_
|
285 |
+
euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
|
286 |
+
euler_angle_.data = euler_angle.data.clone()
|
287 |
+
euler_angle = euler_angle_
|
288 |
+
trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
|
289 |
+
trans_.data = trans.data.clone()
|
290 |
+
trans = trans_
|
291 |
+
|
292 |
+
batch_size = 50
|
293 |
+
# "fine fitting the 3DMM in batches"
|
294 |
+
for i in range(int((num_frames-1)/batch_size+1)):
|
295 |
+
if (i+1)*batch_size > num_frames:
|
296 |
+
start_n = num_frames-batch_size
|
297 |
+
sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
|
298 |
+
else:
|
299 |
+
start_n = i*batch_size
|
300 |
+
sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
|
301 |
+
sel_lms = lms[sel_ids]
|
302 |
+
|
303 |
+
if id_mode == 'global':
|
304 |
+
sel_id_para = id_para.expand((sel_ids.shape[0], id_dim))
|
305 |
+
else:
|
306 |
+
sel_id_para = id_para.new_zeros((batch_size, id_dim), requires_grad=True)
|
307 |
+
sel_id_para.data = id_para[sel_ids].clone()
|
308 |
+
sel_exp_para = exp_para.new_zeros(
|
309 |
+
(batch_size, exp_dim), requires_grad=True)
|
310 |
+
sel_exp_para.data = exp_para[sel_ids].clone()
|
311 |
+
sel_euler_angle = euler_angle.new_zeros(
|
312 |
+
(batch_size, 3), requires_grad=True)
|
313 |
+
sel_euler_angle.data = euler_angle[sel_ids].clone()
|
314 |
+
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
|
315 |
+
sel_trans.data = trans[sel_ids].clone()
|
316 |
+
|
317 |
+
if id_mode == 'global':
|
318 |
+
set_requires_grad([sel_exp_para, sel_euler_angle, sel_trans])
|
319 |
+
optimizer_cur_batch = torch.optim.Adam(
|
320 |
+
[sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
|
321 |
+
else:
|
322 |
+
set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
|
323 |
+
optimizer_cur_batch = torch.optim.Adam(
|
324 |
+
[sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
|
325 |
+
|
326 |
+
for j in range(50):
|
327 |
+
ret = {}
|
328 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
329 |
+
sel_id_para, sel_exp_para, sel_euler_angle, sel_trans, ret)
|
330 |
+
loss_lan = cal_lan_loss_fn(
|
331 |
+
proj_geo[:, :, :2], lms[sel_ids].detach())
|
332 |
+
|
333 |
+
# loss_lap = cal_lap_loss(proj_geo)
|
334 |
+
loss_lap = cal_lap_loss(sel_id_para) + cal_lap_loss(sel_exp_para) + cal_lap_loss(sel_euler_angle) * 0.3 + cal_lap_loss(sel_trans) * 0.3
|
335 |
+
loss_vel_id = cal_vel_loss(sel_id_para)
|
336 |
+
loss_vel_exp = cal_vel_loss(sel_exp_para)
|
337 |
+
log_dict = {
|
338 |
+
'loss_vel_id': loss_vel_id,
|
339 |
+
'loss_vel_exp': loss_vel_exp,
|
340 |
+
'loss_vel_euler': cal_vel_loss(sel_euler_angle),
|
341 |
+
'loss_vel_trans': cal_vel_loss(sel_trans),
|
342 |
+
}
|
343 |
+
loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
|
344 |
+
loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
|
345 |
+
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_lap * LAMBDA_REG_LAP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP
|
346 |
+
|
347 |
+
optimizer_cur_batch.zero_grad()
|
348 |
+
loss.backward()
|
349 |
+
optimizer_cur_batch.step()
|
350 |
+
|
351 |
+
if debug:
|
352 |
+
print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},loss_lap_ldm:{loss_lap.item():.4f}")
|
353 |
+
print("|--------" + ', '.join([f"{k}: {v:.4f}" for k,v in log_dict.items()]))
|
354 |
+
if id_mode != 'global':
|
355 |
+
id_para[sel_ids].data = sel_id_para.data.clone()
|
356 |
+
exp_para[sel_ids].data = sel_exp_para.data.clone()
|
357 |
+
euler_angle[sel_ids].data = sel_euler_angle.data.clone()
|
358 |
+
trans[sel_ids].data = sel_trans.data.clone()
|
359 |
+
|
360 |
+
coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
|
361 |
+
'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
|
362 |
+
|
363 |
+
# filter data by side-view pose
|
364 |
+
# bad_yaw = False
|
365 |
+
# yaws = [] # not so accurate
|
366 |
+
# for index in range(coeff_dict["trans"].shape[0]):
|
367 |
+
# yaw = coeff_dict["euler"][index][1]
|
368 |
+
# yaw = np.abs(yaw)
|
369 |
+
# yaws.append(yaw)
|
370 |
+
# if yaw > large_yaw_threshold:
|
371 |
+
# bad_yaw = True
|
372 |
+
|
373 |
+
if debug:
|
374 |
+
import imageio
|
375 |
+
from utils.visualization.vis_cam3d.camera_pose_visualizer import CameraPoseVisualizer
|
376 |
+
from data_util.face3d_helper import Face3DHelper
|
377 |
+
from data_gen.utils.process_video.extract_blink import get_eye_area_percent
|
378 |
+
face3d_helper = Face3DHelper('deep_3drecon/BFM', keypoint_mode='mediapipe')
|
379 |
+
|
380 |
+
t = coeff_dict['exp'].shape[0]
|
381 |
+
if len(coeff_dict['id']) == 1:
|
382 |
+
coeff_dict['id'] = np.repeat(coeff_dict['id'], t, axis=0)
|
383 |
+
idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d_np(coeff_dict['id'], coeff_dict['exp']).reshape([t, -1])
|
384 |
+
cano_lm3d = idexp_lm3d / 10 + face3d_helper.key_mean_shape.squeeze().reshape([1, -1]).cpu().numpy()
|
385 |
+
cano_lm3d = cano_lm3d.reshape([t, -1, 3])
|
386 |
+
WH = 512
|
387 |
+
cano_lm3d = (cano_lm3d * WH/2 + WH/2).astype(int)
|
388 |
+
|
389 |
+
with torch.no_grad():
|
390 |
+
rot = ParametricFaceModel.compute_rotation(euler_angle)
|
391 |
+
extrinsic = torch.zeros([rot.shape[0], 4, 4]).to(rot.device)
|
392 |
+
extrinsic[:, :3,:3] = rot
|
393 |
+
extrinsic[:, :3, 3] = trans # / 10
|
394 |
+
extrinsic[:, 3, 3] = 1
|
395 |
+
extrinsic = extrinsic.cpu().numpy()
|
396 |
+
|
397 |
+
xy_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xy')
|
398 |
+
xz_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xz')
|
399 |
+
|
400 |
+
if nerf:
|
401 |
+
debug_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/debug_fit_3dmm.mp4")
|
402 |
+
else:
|
403 |
+
debug_name = video_name.replace("/video/", "/coeff_fit_debug/").replace(".mp4", "_debug.mp4")
|
404 |
+
try:
|
405 |
+
os.makedirs(os.path.dirname(debug_name), exist_ok=True)
|
406 |
+
except: pass
|
407 |
+
writer = imageio.get_writer(debug_name, fps=25)
|
408 |
+
if id_mode == 'global':
|
409 |
+
id_para = id_para.repeat([exp_para.shape[0], 1])
|
410 |
+
proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
|
411 |
+
lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
|
412 |
+
lm68s = lm68s * img_scale_factor
|
413 |
+
lms = lms * img_scale_factor
|
414 |
+
lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
|
415 |
+
lms[..., 1] = img_h - lms[..., 1] # flip the height axis
|
416 |
+
lm68s = lm68s.astype(int)
|
417 |
+
for i in tqdm.trange(min(250, len(frames)), desc=f'rendering debug video to {debug_name}..'):
|
418 |
+
xy_cam3d_img = xy_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
|
419 |
+
xy_cam3d_img = cv2.resize(xy_cam3d_img, (512,512))
|
420 |
+
xz_cam3d_img = xz_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
|
421 |
+
xz_cam3d_img = cv2.resize(xz_cam3d_img, (512,512))
|
422 |
+
|
423 |
+
img = copy.deepcopy(frames[i])
|
424 |
+
img2 = copy.deepcopy(frames[i])
|
425 |
+
|
426 |
+
img = draw_axes(img, euler_angle[i,0].item(), euler_angle[i,1].item(), euler_angle[i,2].item(), lm68s[i][4][0].item(), lm68s[i, 4][1].item(), size=50)
|
427 |
+
|
428 |
+
gt_lm_color = (255, 0, 0)
|
429 |
+
|
430 |
+
for lm in lm68s[i]:
|
431 |
+
img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1) # blue
|
432 |
+
for gt_lm in lms[i]:
|
433 |
+
img2 = cv2.circle(img2, gt_lm.cpu().numpy().astype(int), 2, gt_lm_color, thickness=1)
|
434 |
+
|
435 |
+
cano_lm3d_img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
|
436 |
+
for j in range(len(cano_lm3d[i])):
|
437 |
+
x, y, _ = cano_lm3d[i, j]
|
438 |
+
color = (255,0,0)
|
439 |
+
cano_lm3d_img = cv2.circle(cano_lm3d_img, center=(x,y), radius=3, color=color, thickness=-1)
|
440 |
+
cano_lm3d_img = cv2.flip(cano_lm3d_img, 0)
|
441 |
+
|
442 |
+
_, secc_img = secc_renderer(id_para[0:1], exp_para[i:i+1], euler_angle[i:i+1]*0, trans[i:i+1]*0)
|
443 |
+
secc_img = (secc_img +1)*127.5
|
444 |
+
secc_img = F.interpolate(secc_img, size=(img_h, img_w))
|
445 |
+
secc_img = secc_img.permute(0, 2,3,1).int().cpu().numpy()[0]
|
446 |
+
out_img1 = np.concatenate([img, img2, secc_img], axis=1).astype(np.uint8)
|
447 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
448 |
+
out_img2 = np.concatenate([xy_cam3d_img, xz_cam3d_img, cano_lm3d_img], axis=1).astype(np.uint8)
|
449 |
+
out_img = np.concatenate([out_img1, out_img2], axis=0)
|
450 |
+
writer.append_data(out_img)
|
451 |
+
writer.close()
|
452 |
+
|
453 |
+
# if bad_yaw:
|
454 |
+
# print(f"Skip {video_name} due to TOO LARGE YAW")
|
455 |
+
# return False
|
456 |
+
|
457 |
+
if save:
|
458 |
+
np.save(out_name, coeff_dict, allow_pickle=True)
|
459 |
+
return coeff_dict
|
460 |
+
|
461 |
+
def out_exist_job(vid_name):
|
462 |
+
out_name = vid_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy")
|
463 |
+
lms_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
|
464 |
+
if os.path.exists(out_name) or not os.path.exists(lms_name):
|
465 |
+
return None
|
466 |
+
else:
|
467 |
+
return vid_name
|
468 |
+
|
469 |
+
def get_todo_vid_names(vid_names):
|
470 |
+
if len(vid_names) == 1: # single video, nerf
|
471 |
+
return vid_names
|
472 |
+
todo_vid_names = []
|
473 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
|
474 |
+
if res is not None:
|
475 |
+
todo_vid_names.append(res)
|
476 |
+
return todo_vid_names
|
477 |
+
|
478 |
+
|
479 |
+
if __name__ == '__main__':
|
480 |
+
import argparse, glob, tqdm
|
481 |
+
parser = argparse.ArgumentParser()
|
482 |
+
# parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
|
483 |
+
parser.add_argument("--vid_dir", default='data/raw/videos/May_10s.mp4')
|
484 |
+
parser.add_argument("--ds_name", default='nerf') # 'nerf' | 'CelebV-HQ' | 'TH1KH_512' | etc
|
485 |
+
parser.add_argument("--seed", default=0, type=int)
|
486 |
+
parser.add_argument("--process_id", default=0, type=int)
|
487 |
+
parser.add_argument("--total_process", default=1, type=int)
|
488 |
+
parser.add_argument("--id_mode", default='global', type=str) # global | finegrained
|
489 |
+
parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
|
490 |
+
parser.add_argument("--large_yaw_threshold", default=9999999.9, type=float) # could be 0.7
|
491 |
+
parser.add_argument("--debug", action='store_true')
|
492 |
+
parser.add_argument("--reset", action='store_true')
|
493 |
+
parser.add_argument("--load_names", action="store_true")
|
494 |
+
|
495 |
+
args = parser.parse_args()
|
496 |
+
vid_dir = args.vid_dir
|
497 |
+
ds_name = args.ds_name
|
498 |
+
load_names = args.load_names
|
499 |
+
|
500 |
+
print(f"args {args}")
|
501 |
+
|
502 |
+
if ds_name.lower() == 'nerf': # 处理单个视频
|
503 |
+
vid_names = [vid_dir]
|
504 |
+
out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
|
505 |
+
else: # 处理整个数据集
|
506 |
+
if ds_name in ['lrs3_trainval']:
|
507 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
508 |
+
elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
509 |
+
vid_name_pattern = os.path.join(vid_dir, "*.mp4")
|
510 |
+
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
|
511 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
512 |
+
elif ds_name in ["RAVDESS", 'VFHQ']:
|
513 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
|
514 |
+
else:
|
515 |
+
raise NotImplementedError()
|
516 |
+
|
517 |
+
vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
|
518 |
+
if os.path.exists(vid_names_path) and load_names:
|
519 |
+
print(f"loading vid names from {vid_names_path}")
|
520 |
+
vid_names = load_file(vid_names_path)
|
521 |
+
else:
|
522 |
+
vid_names = multiprocess_glob(vid_name_pattern)
|
523 |
+
vid_names = sorted(vid_names)
|
524 |
+
print(f"saving vid names to {vid_names_path}")
|
525 |
+
save_file(vid_names_path, vid_names)
|
526 |
+
out_names = [video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
|
527 |
+
|
528 |
+
print(vid_names[:10])
|
529 |
+
random.seed(args.seed)
|
530 |
+
random.shuffle(vid_names)
|
531 |
+
|
532 |
+
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
|
533 |
+
camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
|
534 |
+
face_model.to(torch.device("cuda:0"))
|
535 |
+
secc_renderer = SECC_Renderer(512)
|
536 |
+
secc_renderer.to("cuda:0")
|
537 |
+
|
538 |
+
process_id = args.process_id
|
539 |
+
total_process = args.total_process
|
540 |
+
if total_process > 1:
|
541 |
+
assert process_id <= total_process -1
|
542 |
+
num_samples_per_process = len(vid_names) // total_process
|
543 |
+
if process_id == total_process:
|
544 |
+
vid_names = vid_names[process_id * num_samples_per_process : ]
|
545 |
+
else:
|
546 |
+
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
547 |
+
|
548 |
+
if not args.reset:
|
549 |
+
vid_names = get_todo_vid_names(vid_names)
|
550 |
+
|
551 |
+
failed_img_names = []
|
552 |
+
for i in tqdm.trange(len(vid_names), desc=f"process {process_id}: fitting 3dmm ..."):
|
553 |
+
img_name = vid_names[i]
|
554 |
+
try:
|
555 |
+
is_person_specific_data = ds_name=='nerf'
|
556 |
+
success = fit_3dmm_for_a_video(img_name, is_person_specific_data, args.id_mode, args.debug, large_yaw_threshold=args.large_yaw_threshold)
|
557 |
+
if not success:
|
558 |
+
failed_img_names.append(img_name)
|
559 |
+
except Exception as e:
|
560 |
+
print(img_name, e)
|
561 |
+
failed_img_names.append(img_name)
|
562 |
+
print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {len(failed_img_names)} / {i + 1} = {len(failed_img_names) / (i + 1):.4f}")
|
563 |
+
sys.stdout.flush()
|
564 |
+
print(f"all failed image names: {failed_img_names}")
|
565 |
+
print(f"All finished!")
|
data_gen/utils/process_video/inpaint_torso_imgs.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
5 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
6 |
+
|
7 |
+
from tasks.eg3ds.loss_utils.segment_loss.mp_segmenter import MediapipeSegmenter
|
8 |
+
seg_model = MediapipeSegmenter()
|
9 |
+
|
10 |
+
def inpaint_torso_job(video_name, idx=None, total=None):
|
11 |
+
raw_img_dir = video_name.replace(".mp4", "").replace("/video/","/gt_imgs/")
|
12 |
+
img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
|
13 |
+
|
14 |
+
for image_path in tqdm.tqdm(img_names):
|
15 |
+
# read ori image
|
16 |
+
ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
17 |
+
segmap = seg_model._cal_seg_map(cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB))
|
18 |
+
head_part = (segmap[1] + segmap[3] + segmap[5]).astype(np.bool)
|
19 |
+
torso_part = (segmap[4]).astype(np.bool)
|
20 |
+
neck_part = (segmap[2]).astype(np.bool)
|
21 |
+
bg_part = segmap[0].astype(np.bool)
|
22 |
+
head_image = cv2.imread(image_path.replace("/gt_imgs/", "/head_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
23 |
+
torso_image = cv2.imread(image_path.replace("/gt_imgs/", "/torso_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
24 |
+
bg_image = cv2.imread(image_path.replace("/gt_imgs/", "/bg_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
25 |
+
|
26 |
+
# head_part = (head_image[...,0] != 0) & (head_image[...,1] != 0) & (head_image[...,2] != 0)
|
27 |
+
# torso_part = (torso_image[...,0] != 0) & (torso_image[...,1] != 0) & (torso_image[...,2] != 0)
|
28 |
+
# bg_part = (bg_image[...,0] != 0) & (bg_image[...,1] != 0) & (bg_image[...,2] != 0)
|
29 |
+
|
30 |
+
# get gt image
|
31 |
+
gt_image = ori_image.copy()
|
32 |
+
gt_image[bg_part] = bg_image[bg_part]
|
33 |
+
cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
|
34 |
+
|
35 |
+
# get torso image
|
36 |
+
torso_image = gt_image.copy() # rgb
|
37 |
+
torso_image[head_part] = 0
|
38 |
+
torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
|
39 |
+
|
40 |
+
# torso part "vertical" in-painting...
|
41 |
+
L = 8 + 1
|
42 |
+
torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
|
43 |
+
# lexsort: sort 2D coords first by y then by x,
|
44 |
+
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
45 |
+
inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
|
46 |
+
torso_coords = torso_coords[inds]
|
47 |
+
# choose the top pixel for each column
|
48 |
+
u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
|
49 |
+
top_torso_coords = torso_coords[uid] # [m, 2]
|
50 |
+
# only keep top-is-head pixels
|
51 |
+
top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
|
52 |
+
mask = head_part[tuple(top_torso_coords_up.T)]
|
53 |
+
if mask.any():
|
54 |
+
top_torso_coords = top_torso_coords[mask]
|
55 |
+
# get the color
|
56 |
+
top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
|
57 |
+
# construct inpaint coords (vertically up, or minus in x)
|
58 |
+
inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
|
59 |
+
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
60 |
+
inpaint_torso_coords += inpaint_offsets
|
61 |
+
inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
|
62 |
+
inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
|
63 |
+
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
64 |
+
inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
65 |
+
# set color
|
66 |
+
torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
|
67 |
+
|
68 |
+
inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
|
69 |
+
inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
|
70 |
+
else:
|
71 |
+
inpaint_torso_mask = None
|
72 |
+
|
73 |
+
# neck part "vertical" in-painting...
|
74 |
+
push_down = 4
|
75 |
+
L = 48 + push_down + 1
|
76 |
+
|
77 |
+
neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
|
78 |
+
|
79 |
+
neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
|
80 |
+
# lexsort: sort 2D coords first by y then by x,
|
81 |
+
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
82 |
+
inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
|
83 |
+
neck_coords = neck_coords[inds]
|
84 |
+
# choose the top pixel for each column
|
85 |
+
u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
|
86 |
+
top_neck_coords = neck_coords[uid] # [m, 2]
|
87 |
+
# only keep top-is-head pixels
|
88 |
+
top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
|
89 |
+
mask = head_part[tuple(top_neck_coords_up.T)]
|
90 |
+
|
91 |
+
top_neck_coords = top_neck_coords[mask]
|
92 |
+
# push these top down for 4 pixels to make the neck inpainting more natural...
|
93 |
+
offset_down = np.minimum(ucnt[mask] - 1, push_down)
|
94 |
+
top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
|
95 |
+
# get the color
|
96 |
+
top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
|
97 |
+
# construct inpaint coords (vertically up, or minus in x)
|
98 |
+
inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
|
99 |
+
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
100 |
+
inpaint_neck_coords += inpaint_offsets
|
101 |
+
inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
|
102 |
+
inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
|
103 |
+
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
104 |
+
inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
105 |
+
# set color
|
106 |
+
torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
|
107 |
+
|
108 |
+
# apply blurring to the inpaint area to avoid vertical-line artifects...
|
109 |
+
inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
|
110 |
+
inpaint_mask[tuple(inpaint_neck_coords.T)] = True
|
111 |
+
|
112 |
+
blur_img = torso_image.copy()
|
113 |
+
blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
|
114 |
+
|
115 |
+
torso_image[inpaint_mask] = blur_img[inpaint_mask]
|
116 |
+
|
117 |
+
# set mask
|
118 |
+
mask = (neck_part | torso_part | inpaint_mask)
|
119 |
+
if inpaint_torso_mask is not None:
|
120 |
+
mask = mask | inpaint_torso_mask
|
121 |
+
torso_image[~mask] = 0
|
122 |
+
torso_alpha[~mask] = 0
|
123 |
+
|
124 |
+
cv2.imwrite("0.png", np.concatenate([torso_image, torso_alpha], axis=-1))
|
125 |
+
|
126 |
+
print(f'[INFO] ===== extracted torso and gt images =====')
|
127 |
+
|
128 |
+
|
129 |
+
def out_exist_job(vid_name):
|
130 |
+
out_dir1 = vid_name.replace("/video/", "/inpaint_torso_imgs/").replace(".mp4","")
|
131 |
+
out_dir2 = vid_name.replace("/video/", "/inpaint_torso_with_bg_imgs/").replace(".mp4","")
|
132 |
+
out_dir3 = vid_name.replace("/video/", "/torso_imgs/").replace(".mp4","")
|
133 |
+
out_dir4 = vid_name.replace("/video/", "/torso_with_bg_imgs/").replace(".mp4","")
|
134 |
+
|
135 |
+
if os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) and os.path.exists(out_dir3) and os.path.exists(out_dir4):
|
136 |
+
num_frames = len(os.listdir(out_dir1))
|
137 |
+
if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames and len(os.listdir(out_dir3)) == num_frames and len(os.listdir(out_dir4)) == num_frames:
|
138 |
+
return None
|
139 |
+
else:
|
140 |
+
return vid_name
|
141 |
+
else:
|
142 |
+
return vid_name
|
143 |
+
|
144 |
+
def get_todo_vid_names(vid_names):
|
145 |
+
todo_vid_names = []
|
146 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
|
147 |
+
if res is not None:
|
148 |
+
todo_vid_names.append(res)
|
149 |
+
return todo_vid_names
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
import argparse, glob, tqdm, random
|
153 |
+
parser = argparse.ArgumentParser()
|
154 |
+
parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
|
155 |
+
parser.add_argument("--ds_name", default='CelebV-HQ')
|
156 |
+
parser.add_argument("--num_workers", default=48, type=int)
|
157 |
+
parser.add_argument("--seed", default=0, type=int)
|
158 |
+
parser.add_argument("--process_id", default=0, type=int)
|
159 |
+
parser.add_argument("--total_process", default=1, type=int)
|
160 |
+
parser.add_argument("--reset", action='store_true')
|
161 |
+
|
162 |
+
inpaint_torso_job('/home/tiger/datasets/raw/CelebV-HQ/video/dgdEr-mXQT4_8.mp4')
|
163 |
+
# args = parser.parse_args()
|
164 |
+
# vid_dir = args.vid_dir
|
165 |
+
# ds_name = args.ds_name
|
166 |
+
# if ds_name in ['lrs3_trainval']:
|
167 |
+
# mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
168 |
+
# if ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
169 |
+
# vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
|
170 |
+
# elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
|
171 |
+
# vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
172 |
+
# vid_names = glob.glob(vid_name_pattern)
|
173 |
+
# vid_names = sorted(vid_names)
|
174 |
+
# random.seed(args.seed)
|
175 |
+
# random.shuffle(vid_names)
|
176 |
+
|
177 |
+
# process_id = args.process_id
|
178 |
+
# total_process = args.total_process
|
179 |
+
# if total_process > 1:
|
180 |
+
# assert process_id <= total_process -1
|
181 |
+
# num_samples_per_process = len(vid_names) // total_process
|
182 |
+
# if process_id == total_process:
|
183 |
+
# vid_names = vid_names[process_id * num_samples_per_process : ]
|
184 |
+
# else:
|
185 |
+
# vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
186 |
+
|
187 |
+
# if not args.reset:
|
188 |
+
# vid_names = get_todo_vid_names(vid_names)
|
189 |
+
# print(f"todo videos number: {len(vid_names)}")
|
190 |
+
|
191 |
+
# fn_args = [(vid_name,i,len(vid_names)) for i, vid_name in enumerate(vid_names)]
|
192 |
+
# for vid_name in multiprocess_run_tqdm(inpaint_torso_job ,fn_args, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
|
193 |
+
# pass
|
data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob
|
2 |
+
import cv2
|
3 |
+
from utils.commons.os_utils import multiprocess_glob
|
4 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
5 |
+
|
6 |
+
def get_video_infos(video_path):
|
7 |
+
vid_cap = cv2.VideoCapture(video_path)
|
8 |
+
height = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
9 |
+
width = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
10 |
+
fps = vid_cap.get(cv2.CAP_PROP_FPS)
|
11 |
+
total_frames = int(vid_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
12 |
+
return {'height': height, 'width': width, 'fps': fps, 'total_frames':total_frames}
|
13 |
+
|
14 |
+
def extract_img_job(video_name:str):
|
15 |
+
out_path = video_name.replace("/video_raw/","/video/",1)
|
16 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
17 |
+
ffmpeg_path = "/usr/bin/ffmpeg"
|
18 |
+
vid_info = get_video_infos(video_name)
|
19 |
+
assert vid_info['width'] == vid_info['height']
|
20 |
+
cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
|
21 |
+
os.system(cmd)
|
22 |
+
|
23 |
+
def extract_img_job_crop(video_name:str):
|
24 |
+
out_path = video_name.replace("/video_raw/","/video/",1)
|
25 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
26 |
+
ffmpeg_path = "/usr/bin/ffmpeg"
|
27 |
+
vid_info = get_video_infos(video_name)
|
28 |
+
wh = min(vid_info['width'], vid_info['height'])
|
29 |
+
cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop={wh}:{wh},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
|
30 |
+
os.system(cmd)
|
31 |
+
|
32 |
+
def extract_img_job_crop_ravdess(video_name:str):
|
33 |
+
out_path = video_name.replace("/video_raw/","/video/",1)
|
34 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
35 |
+
ffmpeg_path = "/usr/bin/ffmpeg"
|
36 |
+
cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop=720:720,scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
|
37 |
+
os.system(cmd)
|
38 |
+
|
39 |
+
if __name__ == '__main__':
|
40 |
+
import argparse, glob, tqdm, random
|
41 |
+
parser = argparse.ArgumentParser()
|
42 |
+
parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video_raw/')
|
43 |
+
parser.add_argument("--ds_name", default='CelebV-HQ')
|
44 |
+
parser.add_argument("--num_workers", default=32, type=int)
|
45 |
+
parser.add_argument("--process_id", default=0, type=int)
|
46 |
+
parser.add_argument("--total_process", default=1, type=int)
|
47 |
+
args = parser.parse_args()
|
48 |
+
print(f"args {args}")
|
49 |
+
|
50 |
+
vid_dir = args.vid_dir
|
51 |
+
ds_name = args.ds_name
|
52 |
+
if ds_name in ['lrs3_trainval']:
|
53 |
+
mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
54 |
+
elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
55 |
+
vid_names = multiprocess_glob(os.path.join(vid_dir, "*.mp4"))
|
56 |
+
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
|
57 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
58 |
+
vid_names = multiprocess_glob(vid_name_pattern)
|
59 |
+
elif ds_name in ["RAVDESS", 'VFHQ']:
|
60 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
|
61 |
+
vid_names = multiprocess_glob(vid_name_pattern)
|
62 |
+
else:
|
63 |
+
raise NotImplementedError()
|
64 |
+
vid_names = sorted(vid_names)
|
65 |
+
print(f"total video number : {len(vid_names)}")
|
66 |
+
print(f"first {vid_names[0]} last {vid_names[-1]}")
|
67 |
+
# exit()
|
68 |
+
process_id = args.process_id
|
69 |
+
total_process = args.total_process
|
70 |
+
if total_process > 1:
|
71 |
+
assert process_id <= total_process -1
|
72 |
+
num_samples_per_process = len(vid_names) // total_process
|
73 |
+
if process_id == total_process:
|
74 |
+
vid_names = vid_names[process_id * num_samples_per_process : ]
|
75 |
+
else:
|
76 |
+
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
77 |
+
|
78 |
+
if ds_name == "RAVDESS":
|
79 |
+
for i, res in multiprocess_run_tqdm(extract_img_job_crop_ravdess, vid_names, num_workers=args.num_workers, desc="resampling videos"):
|
80 |
+
pass
|
81 |
+
elif ds_name == "CMLR":
|
82 |
+
for i, res in multiprocess_run_tqdm(extract_img_job_crop, vid_names, num_workers=args.num_workers, desc="resampling videos"):
|
83 |
+
pass
|
84 |
+
else:
|
85 |
+
for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="resampling videos"):
|
86 |
+
pass
|
87 |
+
|
data_gen/utils/process_video/split_video_to_imgs.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob
|
2 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
3 |
+
|
4 |
+
from data_gen.utils.path_converter import PathConverter, pc
|
5 |
+
|
6 |
+
# mp4_names = glob.glob("/home/tiger/datasets/raw/CelebV-HQ/video/*.mp4")
|
7 |
+
|
8 |
+
def extract_img_job(video_name, raw_img_dir=None):
|
9 |
+
if raw_img_dir is not None:
|
10 |
+
out_path = raw_img_dir
|
11 |
+
else:
|
12 |
+
out_path = pc.to(video_name.replace(".mp4", ""), "vid", "gt")
|
13 |
+
os.makedirs(out_path, exist_ok=True)
|
14 |
+
ffmpeg_path = "/usr/bin/ffmpeg"
|
15 |
+
cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet {os.path.join(out_path, "%8d.jpg")}'
|
16 |
+
os.system(cmd)
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
import argparse, glob, tqdm, random
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
|
22 |
+
parser.add_argument("--ds_name", default='CelebV-HQ')
|
23 |
+
parser.add_argument("--num_workers", default=64, type=int)
|
24 |
+
parser.add_argument("--process_id", default=0, type=int)
|
25 |
+
parser.add_argument("--total_process", default=1, type=int)
|
26 |
+
args = parser.parse_args()
|
27 |
+
vid_dir = args.vid_dir
|
28 |
+
ds_name = args.ds_name
|
29 |
+
if ds_name in ['lrs3_trainval']:
|
30 |
+
mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
31 |
+
elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
32 |
+
vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
|
33 |
+
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
|
34 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
35 |
+
vid_names = glob.glob(vid_name_pattern)
|
36 |
+
elif ds_name in ["RAVDESS", 'VFHQ']:
|
37 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
|
38 |
+
vid_names = glob.glob(vid_name_pattern)
|
39 |
+
vid_names = sorted(vid_names)
|
40 |
+
|
41 |
+
process_id = args.process_id
|
42 |
+
total_process = args.total_process
|
43 |
+
if total_process > 1:
|
44 |
+
assert process_id <= total_process -1
|
45 |
+
num_samples_per_process = len(vid_names) // total_process
|
46 |
+
if process_id == total_process:
|
47 |
+
vid_names = vid_names[process_id * num_samples_per_process : ]
|
48 |
+
else:
|
49 |
+
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
50 |
+
|
51 |
+
for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="extracting images"):
|
52 |
+
pass
|
53 |
+
|
data_util/face3d_helper.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from scipy.io import loadmat
|
6 |
+
|
7 |
+
from deep_3drecon.deep_3drecon_models.bfm import perspective_projection
|
8 |
+
|
9 |
+
|
10 |
+
class Face3DHelper(nn.Module):
|
11 |
+
def __init__(self, bfm_dir='deep_3drecon/BFM', keypoint_mode='lm68', use_gpu=True):
|
12 |
+
super().__init__()
|
13 |
+
self.keypoint_mode = keypoint_mode # lm68 | mediapipe
|
14 |
+
self.bfm_dir = bfm_dir
|
15 |
+
self.load_3dmm()
|
16 |
+
if use_gpu: self.to("cuda")
|
17 |
+
|
18 |
+
def load_3dmm(self):
|
19 |
+
model = loadmat(os.path.join(self.bfm_dir, "BFM_model_front.mat"))
|
20 |
+
self.register_buffer('mean_shape',torch.from_numpy(model['meanshape'].transpose()).float()) # mean face shape. [3*N, 1], N=35709, xyz=3, ==> 3*N=107127
|
21 |
+
mean_shape = self.mean_shape.reshape([-1, 3])
|
22 |
+
# re-center
|
23 |
+
mean_shape = mean_shape - torch.mean(mean_shape, dim=0, keepdims=True)
|
24 |
+
self.mean_shape = mean_shape.reshape([-1, 1])
|
25 |
+
self.register_buffer('id_base',torch.from_numpy(model['idBase']).float()) # identity basis. [3*N,80], we have 80 eigen faces for identity
|
26 |
+
self.register_buffer('exp_base',torch.from_numpy(model['exBase']).float()) # expression basis. [3*N,64], we have 64 eigen faces for expression
|
27 |
+
|
28 |
+
self.register_buffer('mean_texure',torch.from_numpy(model['meantex'].transpose()).float()) # mean face texture. [3*N,1] (0-255)
|
29 |
+
self.register_buffer('tex_base',torch.from_numpy(model['texBase']).float()) # texture basis. [3*N,80], rgb=3
|
30 |
+
|
31 |
+
self.register_buffer('point_buf',torch.from_numpy(model['point_buf']).float()) # triangle indices for each vertex that lies in. starts from 1. [N,8] (1-F)
|
32 |
+
self.register_buffer('face_buf',torch.from_numpy(model['tri']).float()) # vertex indices in each triangle. starts from 1. [F,3] (1-N)
|
33 |
+
if self.keypoint_mode == 'mediapipe':
|
34 |
+
self.register_buffer('key_points', torch.from_numpy(np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)))
|
35 |
+
unmatch_mask = self.key_points < 0
|
36 |
+
self.key_points[unmatch_mask] = 0
|
37 |
+
else:
|
38 |
+
self.register_buffer('key_points',torch.from_numpy(model['keypoints'].squeeze().astype(np.int_)).long()) # vertex indices of 68 facial landmarks. starts from 1. [68,1]
|
39 |
+
|
40 |
+
|
41 |
+
self.register_buffer('key_mean_shape',self.mean_shape.reshape([-1,3])[self.key_points,:])
|
42 |
+
self.register_buffer('key_id_base', self.id_base.reshape([-1,3,80])[self.key_points, :, :].reshape([-1,80]))
|
43 |
+
self.register_buffer('key_exp_base', self.exp_base.reshape([-1,3,64])[self.key_points, :, :].reshape([-1,64]))
|
44 |
+
self.key_id_base_np = self.key_id_base.cpu().numpy()
|
45 |
+
self.key_exp_base_np = self.key_exp_base.cpu().numpy()
|
46 |
+
|
47 |
+
self.register_buffer('persc_proj', torch.tensor(perspective_projection(focal=1015, center=112)))
|
48 |
+
def split_coeff(self, coeff):
|
49 |
+
"""
|
50 |
+
coeff: Tensor[B, T, c=257] or [T, c=257]
|
51 |
+
"""
|
52 |
+
ret_dict = {
|
53 |
+
'identity': coeff[..., :80], # identity, [b, t, c=80]
|
54 |
+
'expression': coeff[..., 80:144], # expression, [b, t, c=80]
|
55 |
+
'texture': coeff[..., 144:224], # texture, [b, t, c=80]
|
56 |
+
'euler': coeff[..., 224:227], # euler euler for pose, [b, t, c=3]
|
57 |
+
'translation': coeff[..., 254:257], # translation, [b, t, c=3]
|
58 |
+
'gamma': coeff[..., 227:254] # lighting, [b, t, c=27]
|
59 |
+
}
|
60 |
+
return ret_dict
|
61 |
+
|
62 |
+
def reconstruct_face_mesh(self, id_coeff, exp_coeff):
|
63 |
+
"""
|
64 |
+
Generate a pose-independent 3D face mesh!
|
65 |
+
id_coeff: Tensor[T, c=80]
|
66 |
+
exp_coeff: Tensor[T, c=64]
|
67 |
+
"""
|
68 |
+
id_coeff = id_coeff.to(self.key_id_base.device)
|
69 |
+
exp_coeff = exp_coeff.to(self.key_id_base.device)
|
70 |
+
mean_face = self.mean_shape.squeeze().reshape([1, -1]) # [3N, 1] ==> [1, 3N]
|
71 |
+
id_base, exp_base = self.id_base, self.exp_base # [3*N, C]
|
72 |
+
identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
|
73 |
+
expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
|
74 |
+
|
75 |
+
face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
|
76 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
77 |
+
# re-centering the face with mean_xyz, so the face will be in [-1, 1]
|
78 |
+
# mean_xyz = self.mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
|
79 |
+
# face_mesh = face - mean_xyz.unsqueeze(0) # [t,N,3]
|
80 |
+
return face
|
81 |
+
|
82 |
+
def reconstruct_cano_lm3d(self, id_coeff, exp_coeff):
|
83 |
+
"""
|
84 |
+
Generate 3D landmark with keypoint base!
|
85 |
+
id_coeff: Tensor[T, c=80]
|
86 |
+
exp_coeff: Tensor[T, c=64]
|
87 |
+
"""
|
88 |
+
id_coeff = id_coeff.to(self.key_id_base.device)
|
89 |
+
exp_coeff = exp_coeff.to(self.key_id_base.device)
|
90 |
+
mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
|
91 |
+
id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
|
92 |
+
identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
93 |
+
expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
94 |
+
|
95 |
+
face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
|
96 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
97 |
+
# re-centering the face with mean_xyz, so the face will be in [-1, 1]
|
98 |
+
# mean_xyz = self.key_mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
|
99 |
+
# lm3d = face - mean_xyz.unsqueeze(0) # [t,N,3]
|
100 |
+
return face
|
101 |
+
|
102 |
+
def reconstruct_lm3d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
|
103 |
+
"""
|
104 |
+
Generate 3D landmark with keypoint base!
|
105 |
+
id_coeff: Tensor[T, c=80]
|
106 |
+
exp_coeff: Tensor[T, c=64]
|
107 |
+
"""
|
108 |
+
id_coeff = id_coeff.to(self.key_id_base.device)
|
109 |
+
exp_coeff = exp_coeff.to(self.key_id_base.device)
|
110 |
+
mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
|
111 |
+
id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
|
112 |
+
identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
113 |
+
expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
114 |
+
|
115 |
+
face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
|
116 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
117 |
+
# re-centering the face with mean_xyz, so the face will be in [-1, 1]
|
118 |
+
rot = self.compute_rotation(euler)
|
119 |
+
# transform
|
120 |
+
lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
|
121 |
+
# to camera
|
122 |
+
if to_camera:
|
123 |
+
lm3d[...,-1] = 10 - lm3d[...,-1]
|
124 |
+
return lm3d
|
125 |
+
|
126 |
+
def reconstruct_lm2d_nerf(self, id_coeff, exp_coeff, euler, trans):
|
127 |
+
lm2d = self.reconstruct_lm2d(id_coeff, exp_coeff, euler, trans, to_camera=False)
|
128 |
+
lm2d[..., 0] = 1 - lm2d[..., 0]
|
129 |
+
lm2d[..., 1] = 1 - lm2d[..., 1]
|
130 |
+
return lm2d
|
131 |
+
|
132 |
+
def reconstruct_lm2d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
|
133 |
+
"""
|
134 |
+
Generate 3D landmark with keypoint base!
|
135 |
+
id_coeff: Tensor[T, c=80]
|
136 |
+
exp_coeff: Tensor[T, c=64]
|
137 |
+
"""
|
138 |
+
is_btc_flag = True if id_coeff.ndim == 3 else False
|
139 |
+
if is_btc_flag:
|
140 |
+
b,t,_ = id_coeff.shape
|
141 |
+
id_coeff = id_coeff.reshape([b*t,-1])
|
142 |
+
exp_coeff = exp_coeff.reshape([b*t,-1])
|
143 |
+
euler = euler.reshape([b*t,-1])
|
144 |
+
trans = trans.reshape([b*t,-1])
|
145 |
+
id_coeff = id_coeff.to(self.key_id_base.device)
|
146 |
+
exp_coeff = exp_coeff.to(self.key_id_base.device)
|
147 |
+
mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
|
148 |
+
id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
|
149 |
+
identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
150 |
+
expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
151 |
+
|
152 |
+
face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
|
153 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
154 |
+
# re-centering the face with mean_xyz, so the face will be in [-1, 1]
|
155 |
+
rot = self.compute_rotation(euler)
|
156 |
+
# transform
|
157 |
+
lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
|
158 |
+
# to camera
|
159 |
+
if to_camera:
|
160 |
+
lm3d[...,-1] = 10 - lm3d[...,-1]
|
161 |
+
# to image_plane
|
162 |
+
lm3d = lm3d @ self.persc_proj
|
163 |
+
lm2d = lm3d[..., :2] / lm3d[..., 2:]
|
164 |
+
# flip
|
165 |
+
lm2d[..., 1] = 224 - lm2d[..., 1]
|
166 |
+
lm2d /= 224
|
167 |
+
if is_btc_flag:
|
168 |
+
return lm2d.reshape([b,t,-1,2])
|
169 |
+
return lm2d
|
170 |
+
|
171 |
+
def compute_rotation(self, euler):
|
172 |
+
"""
|
173 |
+
Return:
|
174 |
+
rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
|
175 |
+
|
176 |
+
Parameters:
|
177 |
+
euler -- torch.tensor, size (B, 3), radian
|
178 |
+
"""
|
179 |
+
|
180 |
+
batch_size = euler.shape[0]
|
181 |
+
euler = euler.to(self.key_id_base.device)
|
182 |
+
ones = torch.ones([batch_size, 1]).to(self.key_id_base.device)
|
183 |
+
zeros = torch.zeros([batch_size, 1]).to(self.key_id_base.device)
|
184 |
+
x, y, z = euler[:, :1], euler[:, 1:2], euler[:, 2:],
|
185 |
+
|
186 |
+
rot_x = torch.cat([
|
187 |
+
ones, zeros, zeros,
|
188 |
+
zeros, torch.cos(x), -torch.sin(x),
|
189 |
+
zeros, torch.sin(x), torch.cos(x)
|
190 |
+
], dim=1).reshape([batch_size, 3, 3])
|
191 |
+
|
192 |
+
rot_y = torch.cat([
|
193 |
+
torch.cos(y), zeros, torch.sin(y),
|
194 |
+
zeros, ones, zeros,
|
195 |
+
-torch.sin(y), zeros, torch.cos(y)
|
196 |
+
], dim=1).reshape([batch_size, 3, 3])
|
197 |
+
|
198 |
+
rot_z = torch.cat([
|
199 |
+
torch.cos(z), -torch.sin(z), zeros,
|
200 |
+
torch.sin(z), torch.cos(z), zeros,
|
201 |
+
zeros, zeros, ones
|
202 |
+
], dim=1).reshape([batch_size, 3, 3])
|
203 |
+
|
204 |
+
rot = rot_z @ rot_y @ rot_x
|
205 |
+
return rot.permute(0, 2, 1)
|
206 |
+
|
207 |
+
def reconstruct_idexp_lm3d(self, id_coeff, exp_coeff):
|
208 |
+
"""
|
209 |
+
Generate 3D landmark with keypoint base!
|
210 |
+
id_coeff: Tensor[T, c=80]
|
211 |
+
exp_coeff: Tensor[T, c=64]
|
212 |
+
"""
|
213 |
+
id_coeff = id_coeff.to(self.key_id_base.device)
|
214 |
+
exp_coeff = exp_coeff.to(self.key_id_base.device)
|
215 |
+
id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
|
216 |
+
identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
217 |
+
expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
218 |
+
|
219 |
+
face = identity_diff_face + expression_diff_face # [t,3N]
|
220 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
221 |
+
lm3d = face * 10
|
222 |
+
return lm3d
|
223 |
+
|
224 |
+
def reconstruct_idexp_lm3d_np(self, id_coeff, exp_coeff):
|
225 |
+
"""
|
226 |
+
Generate 3D landmark with keypoint base!
|
227 |
+
id_coeff: Tensor[T, c=80]
|
228 |
+
exp_coeff: Tensor[T, c=64]
|
229 |
+
"""
|
230 |
+
id_base, exp_base = self.key_id_base_np, self.key_exp_base_np # [3*68, C]
|
231 |
+
identity_diff_face = np.dot(id_coeff, id_base.T) # [t,c],[c,3*68] ==> [t,3*68]
|
232 |
+
expression_diff_face = np.dot(exp_coeff, exp_base.T) # [t,c],[c,3*68] ==> [t,3*68]
|
233 |
+
|
234 |
+
face = identity_diff_face + expression_diff_face # [t,3N]
|
235 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
236 |
+
lm3d = face * 10
|
237 |
+
return lm3d
|
238 |
+
|
239 |
+
def get_eye_mouth_lm_from_lm3d(self, lm3d):
|
240 |
+
eye_lm = lm3d[:, 17:48] # [T, 31, 3]
|
241 |
+
mouth_lm = lm3d[:, 48:68] # [T, 20, 3]
|
242 |
+
return eye_lm, mouth_lm
|
243 |
+
|
244 |
+
def get_eye_mouth_lm_from_lm3d_batch(self, lm3d):
|
245 |
+
eye_lm = lm3d[:, :, 17:48] # [T, 31, 3]
|
246 |
+
mouth_lm = lm3d[:, :, 48:68] # [T, 20, 3]
|
247 |
+
return eye_lm, mouth_lm
|
248 |
+
|
249 |
+
def close_mouth_for_idexp_lm3d(self, idexp_lm3d, freeze_as_first_frame=True):
|
250 |
+
idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
|
251 |
+
num_frames = idexp_lm3d.shape[0]
|
252 |
+
eps = 0.0
|
253 |
+
# [n_landmarks=68,xyz=3], x 代表左右,y代表上下,z代表深度
|
254 |
+
idexp_lm3d[:,49:54, 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 + eps * 2
|
255 |
+
idexp_lm3d[:,range(59,54,-1), 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 - eps * 2
|
256 |
+
|
257 |
+
idexp_lm3d[:,61:64, 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 + eps
|
258 |
+
idexp_lm3d[:,range(67,64,-1), 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 - eps
|
259 |
+
|
260 |
+
idexp_lm3d[:,49:54, 1] += (0.03 - idexp_lm3d[:,49:54, 1].mean(dim=1) + idexp_lm3d[:,61:64, 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
|
261 |
+
idexp_lm3d[:,range(59,54,-1), 1] += (-0.03 - idexp_lm3d[:,range(59,54,-1), 1].mean(dim=1) + idexp_lm3d[:,range(67,64,-1), 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
|
262 |
+
|
263 |
+
if freeze_as_first_frame:
|
264 |
+
idexp_lm3d[:, 48:68,] = idexp_lm3d[0, 48:68].unsqueeze(0).clone().repeat([num_frames, 1,1])*0
|
265 |
+
return idexp_lm3d.cpu()
|
266 |
+
|
267 |
+
def close_eyes_for_idexp_lm3d(self, idexp_lm3d):
|
268 |
+
idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
|
269 |
+
eps = 0.003
|
270 |
+
idexp_lm3d[:,37:39, 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 + eps
|
271 |
+
idexp_lm3d[:,range(41,39,-1), 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 - eps
|
272 |
+
|
273 |
+
idexp_lm3d[:,43:45, 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 + eps
|
274 |
+
idexp_lm3d[:,range(47,45,-1), 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 - eps
|
275 |
+
|
276 |
+
return idexp_lm3d
|
277 |
+
|
278 |
+
if __name__ == '__main__':
|
279 |
+
import cv2
|
280 |
+
|
281 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
282 |
+
|
283 |
+
face_mesh_helper = Face3DHelper('deep_3drecon/BFM')
|
284 |
+
coeff_npy = 'data/coeff_fit_mp/crop_nana_003_coeff_fit_mp.npy'
|
285 |
+
coeff_dict = np.load(coeff_npy, allow_pickle=True).tolist()
|
286 |
+
lm3d = face_mesh_helper.reconstruct_lm2d(torch.tensor(coeff_dict['id']).cuda(), torch.tensor(coeff_dict['exp']).cuda(), torch.tensor(coeff_dict['euler']).cuda(), torch.tensor(coeff_dict['trans']).cuda() )
|
287 |
+
|
288 |
+
WH = 512
|
289 |
+
lm3d = (lm3d * WH).cpu().int().numpy()
|
290 |
+
eye_idx = list(range(36,48))
|
291 |
+
mouth_idx = list(range(48,68))
|
292 |
+
import imageio
|
293 |
+
debug_name = 'debug_lm3d.mp4'
|
294 |
+
writer = imageio.get_writer(debug_name, fps=25)
|
295 |
+
for i_img in range(len(lm3d)):
|
296 |
+
lm2d = lm3d[i_img ,:, :2] # [68, 2]
|
297 |
+
img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
|
298 |
+
for i in range(len(lm2d)):
|
299 |
+
x, y = lm2d[i]
|
300 |
+
if i in eye_idx:
|
301 |
+
color = (0,0,255)
|
302 |
+
elif i in mouth_idx:
|
303 |
+
color = (0,255,0)
|
304 |
+
else:
|
305 |
+
color = (255,0,0)
|
306 |
+
img = cv2.circle(img, center=(x,y), radius=3, color=color, thickness=-1)
|
307 |
+
img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0))
|
308 |
+
writer.append_data(img)
|
309 |
+
writer.close()
|
deep_3drecon/BFM/.gitkeep
ADDED
File without changes
|
deep_3drecon/BFM/basel_53201.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
deep_3drecon/BFM/index_mp468_from_mesh35709_v1.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d238a90df0c55075c9cea43dab76348421379a75c204931e34dbd2c11fb4b65
|
3 |
+
size 3872
|
deep_3drecon/BFM/index_mp468_from_mesh35709_v2.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe95e2bb10ac1e54804006184d7de3c5ccd0eb98a5f1bd28e00b9f3569f6ce5a
|
3 |
+
size 3872
|
deep_3drecon/BFM/index_mp468_from_mesh35709_v3.1.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:053b8cce8424b722db6ec5b068514eb007a23b4c5afd629449eb08746e643211
|
3 |
+
size 3872
|
deep_3drecon/BFM/index_mp468_from_mesh35709_v3.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b007b3619dd02892b38349ba3d4b10e32bc2eff201c265f25d6ed62f67dbd51
|
3 |
+
size 3872
|
deep_3drecon/BFM/select_vertex_id.mat
ADDED
Binary file (62.3 kB). View file
|
|
deep_3drecon/BFM/similarity_Lm3D_all.mat
ADDED
Binary file (994 Bytes). View file
|
|
deep_3drecon/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .reconstructor import *
|
deep_3drecon/bfm_left_eye_faces.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9651756ea2c0fac069a1edf858ed1f125eddc358fa74c529a370c1e7b5730d28
|
3 |
+
size 4680
|
deep_3drecon/bfm_right_eye_faces.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28cb5bbacf578d30a3d5006ec28c617fe5a3ecaeeeb87d9433a884e0f0301a2e
|
3 |
+
size 4648
|
deep_3drecon/data_preparation.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This script is the data preparation script for Deep3DFaceRecon_pytorch
|
2 |
+
"""
|
3 |
+
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import argparse
|
7 |
+
from util.detect_lm68 import detect_68p,load_lm_graph
|
8 |
+
from util.skin_mask import get_skin_mask
|
9 |
+
from util.generate_list import check_list, write_list
|
10 |
+
import warnings
|
11 |
+
warnings.filterwarnings("ignore")
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data')
|
15 |
+
parser.add_argument('--img_folder', nargs="+", required=True, help='folders of training images')
|
16 |
+
parser.add_argument('--mode', type=str, default='train', help='train or val')
|
17 |
+
opt = parser.parse_args()
|
18 |
+
|
19 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
20 |
+
|
21 |
+
def data_prepare(folder_list,mode):
|
22 |
+
|
23 |
+
lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector
|
24 |
+
|
25 |
+
for img_folder in folder_list:
|
26 |
+
detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images
|
27 |
+
get_skin_mask(img_folder) # generate skin attention mask for images
|
28 |
+
|
29 |
+
# create files that record path to all training data
|
30 |
+
msks_list = []
|
31 |
+
for img_folder in folder_list:
|
32 |
+
path = os.path.join(img_folder, 'mask')
|
33 |
+
msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or
|
34 |
+
'png' in i or 'jpeg' in i or 'PNG' in i]
|
35 |
+
|
36 |
+
imgs_list = [i.replace('mask/', '') for i in msks_list]
|
37 |
+
lms_list = [i.replace('mask', 'landmarks') for i in msks_list]
|
38 |
+
lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list]
|
39 |
+
|
40 |
+
lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid
|
41 |
+
write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
print('Datasets:',opt.img_folder)
|
45 |
+
data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode)
|
deep_3drecon/deep_3drecon_models/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import importlib
|
22 |
+
from .base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "deep_3drecon_models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
model = None
|
35 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
36 |
+
for name, cls in modellib.__dict__.items():
|
37 |
+
if name.lower() == target_model_name.lower() \
|
38 |
+
and issubclass(cls, BaseModel):
|
39 |
+
model = cls
|
40 |
+
|
41 |
+
if model is None:
|
42 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
43 |
+
exit(0)
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def get_option_setter(model_name):
|
49 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
50 |
+
model_class = find_model_using_name(model_name)
|
51 |
+
return model_class.modify_commandline_options
|
52 |
+
|
53 |
+
|
54 |
+
def create_model(opt):
|
55 |
+
"""Create a model given the option.
|
56 |
+
|
57 |
+
This function warps the class CustomDatasetDataLoader.
|
58 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
59 |
+
|
60 |
+
Example:
|
61 |
+
>>> from models import create_model
|
62 |
+
>>> model = create_model(opt)
|
63 |
+
"""
|
64 |
+
model = find_model_using_name(opt.model)
|
65 |
+
instance = model(opt)
|
66 |
+
print("model [%s] was created" % type(instance).__name__)
|
67 |
+
return instance
|
deep_3drecon/deep_3drecon_models/arcface_torch/README.md
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Distributed Arcface Training in Pytorch
|
2 |
+
|
3 |
+
The "arcface_torch" repository is the official implementation of the ArcFace algorithm. It supports distributed and sparse training with multiple distributed training examples, including several memory-saving techniques such as mixed precision training and gradient checkpointing. It also supports training for ViT models and datasets including WebFace42M and Glint360K, two of the largest open-source datasets. Additionally, the repository comes with a built-in tool for converting to ONNX format, making it easy to submit to MFR evaluation systems.
|
4 |
+
|
5 |
+
[](https://paperswithcode.com/sota/face-verification-on-ijb-c?p=killing-two-birds-with-one-stone-efficient)
|
6 |
+
[](https://paperswithcode.com/sota/face-verification-on-ijb-b?p=killing-two-birds-with-one-stone-efficient)
|
7 |
+
[](https://paperswithcode.com/sota/face-verification-on-agedb-30?p=killing-two-birds-with-one-stone-efficient)
|
8 |
+
[](https://paperswithcode.com/sota/face-verification-on-cfp-fp?p=killing-two-birds-with-one-stone-efficient)
|
9 |
+
|
10 |
+
## Requirements
|
11 |
+
|
12 |
+
To avail the latest features of PyTorch, we have upgraded to version 1.12.0.
|
13 |
+
|
14 |
+
- Install [PyTorch](https://pytorch.org/get-started/previous-versions/) (torch>=1.12.0).
|
15 |
+
- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md).
|
16 |
+
- `pip install -r requirement.txt`.
|
17 |
+
|
18 |
+
## How to Training
|
19 |
+
|
20 |
+
To train a model, execute the `train.py` script with the path to the configuration files. The sample commands provided below demonstrate the process of conducting distributed training.
|
21 |
+
|
22 |
+
### 1. To run on one GPU:
|
23 |
+
|
24 |
+
```shell
|
25 |
+
python train_v2.py configs/ms1mv3_r50_onegpu
|
26 |
+
```
|
27 |
+
|
28 |
+
Note:
|
29 |
+
It is not recommended to use a single GPU for training, as this may result in longer training times and suboptimal performance. For best results, we suggest using multiple GPUs or a GPU cluster.
|
30 |
+
|
31 |
+
|
32 |
+
### 2. To run on a machine with 8 GPUs:
|
33 |
+
|
34 |
+
```shell
|
35 |
+
torchrun --nproc_per_node=8 train.py configs/ms1mv3_r50
|
36 |
+
```
|
37 |
+
|
38 |
+
### 3. To run on 2 machines with 8 GPUs each:
|
39 |
+
|
40 |
+
Node 0:
|
41 |
+
|
42 |
+
```shell
|
43 |
+
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
|
44 |
+
```
|
45 |
+
|
46 |
+
Node 1:
|
47 |
+
|
48 |
+
```shell
|
49 |
+
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
|
50 |
+
```
|
51 |
+
|
52 |
+
### 4. Run ViT-B on a machine with 24k batchsize:
|
53 |
+
|
54 |
+
```shell
|
55 |
+
torchrun --nproc_per_node=8 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b
|
56 |
+
```
|
57 |
+
|
58 |
+
|
59 |
+
## Download Datasets or Prepare Datasets
|
60 |
+
- [MS1MV2](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-arcface-85k-ids58m-images-57) (87k IDs, 5.8M images)
|
61 |
+
- [MS1MV3](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) (93k IDs, 5.2M images)
|
62 |
+
- [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images)
|
63 |
+
- [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images)
|
64 |
+
- [Your Dataset, Click Here!](docs/prepare_custom_dataset.md)
|
65 |
+
|
66 |
+
Note:
|
67 |
+
If you want to use DALI for data reading, please use the script 'scripts/shuffle_rec.py' to shuffle the InsightFace style rec before using it.
|
68 |
+
Example:
|
69 |
+
|
70 |
+
`python scripts/shuffle_rec.py ms1m-retinaface-t1`
|
71 |
+
|
72 |
+
You will get the "shuffled_ms1m-retinaface-t1" folder, where the samples in the "train.rec" file are shuffled.
|
73 |
+
|
74 |
+
|
75 |
+
## Model Zoo
|
76 |
+
|
77 |
+
- The models are available for non-commercial research purposes only.
|
78 |
+
- All models can be found in here.
|
79 |
+
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
|
80 |
+
- [OneDrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
|
81 |
+
|
82 |
+
### Performance on IJB-C and [**ICCV2021-MFR**](https://github.com/deepinsight/insightface/blob/master/challenges/mfr/README.md)
|
83 |
+
|
84 |
+
ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
|
85 |
+
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
|
86 |
+
As the result, we can evaluate the FAIR performance for different algorithms.
|
87 |
+
|
88 |
+
For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
|
89 |
+
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
|
90 |
+
|
91 |
+
|
92 |
+
#### 1. Training on Single-Host GPU
|
93 |
+
|
94 |
+
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
|
95 |
+
|:---------------|:--------------------|:------------|:------------|:------------|:------------------------------------------------------------------------------------------------------------------------------------|
|
96 |
+
| MS1MV2 | mobilefacenet-0.45G | 62.07 | 93.61 | 90.28 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_mbf/training.log) |
|
97 |
+
| MS1MV2 | r50 | 75.13 | 95.97 | 94.07 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r50/training.log) |
|
98 |
+
| MS1MV2 | r100 | 78.12 | 96.37 | 94.27 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r100/training.log) |
|
99 |
+
| MS1MV3 | mobilefacenet-0.45G | 63.78 | 94.23 | 91.33 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mbf/training.log) |
|
100 |
+
| MS1MV3 | r50 | 79.14 | 96.37 | 94.47 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r50/training.log) |
|
101 |
+
| MS1MV3 | r100 | 81.97 | 96.85 | 95.02 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100/training.log) |
|
102 |
+
| Glint360K | mobilefacenet-0.45G | 70.18 | 95.04 | 92.62 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mbf/training.log) |
|
103 |
+
| Glint360K | r50 | 86.34 | 97.16 | 95.81 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r50/training.log) |
|
104 |
+
| Glint360k | r100 | 89.52 | 97.55 | 96.38 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100/training.log) |
|
105 |
+
| WF4M | r100 | 89.87 | 97.19 | 95.48 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf4m_r100/training.log) |
|
106 |
+
| WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc02_r100/training.log) |
|
107 |
+
| WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc03_r100/training.log) |
|
108 |
+
| WF12M | r100 | 94.69 | 97.59 | 95.97 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_r100/training.log) |
|
109 |
+
| WF42M-PFC-0.2 | r100 | 96.27 | 97.70 | 96.31 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_r100/training.log) |
|
110 |
+
| WF42M-PFC-0.2 | ViT-T-1.5G | 92.04 | 97.27 | 95.68 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_40epoch_8gpu_vit_t/training.log) |
|
111 |
+
| WF42M-PFC-0.3 | ViT-B-11G | 97.16 | 97.91 | 97.05 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_8gpu/training.log) |
|
112 |
+
|
113 |
+
#### 2. Training on Multi-Host GPU
|
114 |
+
|
115 |
+
| Datasets | Backbone(bs*gpus) | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
|
116 |
+
|:-----------------|:------------------|:------------|:------------|:------------|:-----------|:-------------------------------------------------------------------------------------------------------------------------------------------|
|
117 |
+
| WF42M-PFC-0.2 | r50(512*8) | 93.83 | 97.53 | 96.16 | ~5900 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log) |
|
118 |
+
| WF42M-PFC-0.2 | r50(512*16) | 93.96 | 97.46 | 96.12 | ~11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log) |
|
119 |
+
| WF42M-PFC-0.2 | r50(128*32) | 94.04 | 97.48 | 95.94 | ~17000 | click me |
|
120 |
+
| WF42M-PFC-0.2 | r100(128*16) | 96.28 | 97.80 | 96.57 | ~5200 | click me |
|
121 |
+
| WF42M-PFC-0.2 | r100(256*16) | 96.69 | 97.85 | 96.63 | ~5200 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log) |
|
122 |
+
| WF42M-PFC-0.0018 | r100(512*32) | 93.08 | 97.51 | 95.88 | ~10000 | click me |
|
123 |
+
| WF42M-PFC-0.2 | r100(128*32) | 96.57 | 97.83 | 96.50 | ~9800 | click me |
|
124 |
+
|
125 |
+
`r100(128*32)` means backbone is r100, batchsize per gpu is 128, the number of gpus is 32.
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
#### 3. ViT For Face Recognition
|
130 |
+
|
131 |
+
| Datasets | Backbone(bs) | FLOPs | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
|
132 |
+
|:--------------|:--------------|:------|:------------|:------------|:------------|:-----------|:-----------------------------------------------------------------------------------------------------------------------------|
|
133 |
+
| WF42M-PFC-0.3 | r18(128*32) | 2.6 | 79.13 | 95.77 | 93.36 | - | click me |
|
134 |
+
| WF42M-PFC-0.3 | r50(128*32) | 6.3 | 94.03 | 97.48 | 95.94 | - | click me |
|
135 |
+
| WF42M-PFC-0.3 | r100(128*32) | 12.1 | 96.69 | 97.82 | 96.45 | - | click me |
|
136 |
+
| WF42M-PFC-0.3 | r200(128*32) | 23.5 | 97.70 | 97.97 | 96.93 | - | click me |
|
137 |
+
| WF42M-PFC-0.3 | VIT-T(384*64) | 1.5 | 92.24 | 97.31 | 95.97 | ~35000 | click me |
|
138 |
+
| WF42M-PFC-0.3 | VIT-S(384*64) | 5.7 | 95.87 | 97.73 | 96.57 | ~25000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_s_64gpu/training.log) |
|
139 |
+
| WF42M-PFC-0.3 | VIT-B(384*64) | 11.4 | 97.42 | 97.90 | 97.04 | ~13800 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_64gpu/training.log) |
|
140 |
+
| WF42M-PFC-0.3 | VIT-L(384*64) | 25.3 | 97.85 | 98.00 | 97.23 | ~9406 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_l_64gpu/training.log) |
|
141 |
+
|
142 |
+
`WF42M` means WebFace42M, `PFC-0.3` means negivate class centers sample rate is 0.3.
|
143 |
+
|
144 |
+
#### 4. Noisy Datasets
|
145 |
+
|
146 |
+
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
|
147 |
+
|:-------------------------|:---------|:------------|:------------|:------------|:---------|
|
148 |
+
| WF12M-Flip(40%) | r50 | 43.87 | 88.35 | 80.78 | click me |
|
149 |
+
| WF12M-Flip(40%)-PFC-0.1* | r50 | 80.20 | 96.11 | 93.79 | click me |
|
150 |
+
| WF12M-Conflict | r50 | 79.93 | 95.30 | 91.56 | click me |
|
151 |
+
| WF12M-Conflict-PFC-0.3* | r50 | 91.68 | 97.28 | 95.75 | click me |
|
152 |
+
|
153 |
+
`WF12M` means WebFace12M, `+PFC-0.1*` denotes additional abnormal inter-class filtering.
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
## Speed Benchmark
|
158 |
+
<div><img src="https://github.com/anxiangsir/insightface_arcface_log/blob/master/pfc_exp.png" width = "90%" /></div>
|
159 |
+
|
160 |
+
|
161 |
+
**Arcface-Torch** is an efficient tool for training large-scale face recognition training sets. When the number of classes in the training sets exceeds one million, the partial FC sampling strategy maintains the same accuracy while providing several times faster training performance and lower GPU memory utilization. The partial FC is a sparse variant of the model parallel architecture for large-scale face recognition, utilizing a sparse softmax that dynamically samples a subset of class centers for each training batch. During each iteration, only a sparse portion of the parameters are updated, leading to a significant reduction in GPU memory requirements and computational demands. With the partial FC approach, it is possible to train sets with up to 29 million identities, the largest to date. Furthermore, the partial FC method supports multi-machine distributed training and mixed precision training.
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
More details see
|
166 |
+
[speed_benchmark.md](docs/speed_benchmark.md) in docs.
|
167 |
+
|
168 |
+
> 1. Training Speed of Various Parallel Techniques (Samples per Second) on a Tesla V100 32GB x 8 System (Higher is Optimal)
|
169 |
+
|
170 |
+
`-` means training failed because of gpu memory limitations.
|
171 |
+
|
172 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
173 |
+
|:--------------------------------|:--------------|:---------------|:---------------|
|
174 |
+
| 125000 | 4681 | 4824 | 5004 |
|
175 |
+
| 1400000 | **1672** | 3043 | 4738 |
|
176 |
+
| 5500000 | **-** | **1389** | 3975 |
|
177 |
+
| 8000000 | **-** | **-** | 3565 |
|
178 |
+
| 16000000 | **-** | **-** | 2679 |
|
179 |
+
| 29000000 | **-** | **-** | **1855** |
|
180 |
+
|
181 |
+
> 2. GPU Memory Utilization of Various Parallel Techniques (MB per GPU) on a Tesla V100 32GB x 8 System (Lower is Optimal)
|
182 |
+
|
183 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
184 |
+
|:--------------------------------|:--------------|:---------------|:---------------|
|
185 |
+
| 125000 | 7358 | 5306 | 4868 |
|
186 |
+
| 1400000 | 32252 | 11178 | 6056 |
|
187 |
+
| 5500000 | **-** | 32188 | 9854 |
|
188 |
+
| 8000000 | **-** | **-** | 12310 |
|
189 |
+
| 16000000 | **-** | **-** | 19950 |
|
190 |
+
| 29000000 | **-** | **-** | 32324 |
|
191 |
+
|
192 |
+
|
193 |
+
## Citations
|
194 |
+
|
195 |
+
```
|
196 |
+
@inproceedings{deng2019arcface,
|
197 |
+
title={Arcface: Additive angular margin loss for deep face recognition},
|
198 |
+
author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
|
199 |
+
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
200 |
+
pages={4690--4699},
|
201 |
+
year={2019}
|
202 |
+
}
|
203 |
+
@inproceedings{An_2022_CVPR,
|
204 |
+
author={An, Xiang and Deng, Jiankang and Guo, Jia and Feng, Ziyong and Zhu, XuHan and Yang, Jing and Liu, Tongliang},
|
205 |
+
title={Killing Two Birds With One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC},
|
206 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
207 |
+
month={June},
|
208 |
+
year={2022},
|
209 |
+
pages={4042-4051}
|
210 |
+
}
|
211 |
+
@inproceedings{zhu2021webface260m,
|
212 |
+
title={Webface260m: A benchmark unveiling the power of million-scale deep face recognition},
|
213 |
+
author={Zhu, Zheng and Huang, Guan and Deng, Jiankang and Ye, Yun and Huang, Junjie and Chen, Xinze and Zhu, Jiagang and Yang, Tian and Lu, Jiwen and Du, Dalong and Zhou, Jie},
|
214 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
215 |
+
pages={10492--10502},
|
216 |
+
year={2021}
|
217 |
+
}
|
218 |
+
```
|
deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__init__.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
|
2 |
+
from .mobilefacenet import get_mbf
|
3 |
+
|
4 |
+
|
5 |
+
def get_model(name, **kwargs):
|
6 |
+
# resnet
|
7 |
+
if name == "r18":
|
8 |
+
return iresnet18(False, **kwargs)
|
9 |
+
elif name == "r34":
|
10 |
+
return iresnet34(False, **kwargs)
|
11 |
+
elif name == "r50":
|
12 |
+
return iresnet50(False, **kwargs)
|
13 |
+
elif name == "r100":
|
14 |
+
return iresnet100(False, **kwargs)
|
15 |
+
elif name == "r200":
|
16 |
+
return iresnet200(False, **kwargs)
|
17 |
+
elif name == "r2060":
|
18 |
+
from .iresnet2060 import iresnet2060
|
19 |
+
return iresnet2060(False, **kwargs)
|
20 |
+
|
21 |
+
elif name == "mbf":
|
22 |
+
fp16 = kwargs.get("fp16", False)
|
23 |
+
num_features = kwargs.get("num_features", 512)
|
24 |
+
return get_mbf(fp16=fp16, num_features=num_features)
|
25 |
+
|
26 |
+
elif name == "mbf_large":
|
27 |
+
from .mobilefacenet import get_mbf_large
|
28 |
+
fp16 = kwargs.get("fp16", False)
|
29 |
+
num_features = kwargs.get("num_features", 512)
|
30 |
+
return get_mbf_large(fp16=fp16, num_features=num_features)
|
31 |
+
|
32 |
+
elif name == "vit_t":
|
33 |
+
num_features = kwargs.get("num_features", 512)
|
34 |
+
from .vit import VisionTransformer
|
35 |
+
return VisionTransformer(
|
36 |
+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
|
37 |
+
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
|
38 |
+
|
39 |
+
elif name == "vit_t_dp005_mask0": # For WebFace42M
|
40 |
+
num_features = kwargs.get("num_features", 512)
|
41 |
+
from .vit import VisionTransformer
|
42 |
+
return VisionTransformer(
|
43 |
+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
|
44 |
+
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
|
45 |
+
|
46 |
+
elif name == "vit_s":
|
47 |
+
num_features = kwargs.get("num_features", 512)
|
48 |
+
from .vit import VisionTransformer
|
49 |
+
return VisionTransformer(
|
50 |
+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
|
51 |
+
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
|
52 |
+
|
53 |
+
elif name == "vit_s_dp005_mask_0": # For WebFace42M
|
54 |
+
num_features = kwargs.get("num_features", 512)
|
55 |
+
from .vit import VisionTransformer
|
56 |
+
return VisionTransformer(
|
57 |
+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
|
58 |
+
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
|
59 |
+
|
60 |
+
elif name == "vit_b":
|
61 |
+
# this is a feature
|
62 |
+
num_features = kwargs.get("num_features", 512)
|
63 |
+
from .vit import VisionTransformer
|
64 |
+
return VisionTransformer(
|
65 |
+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
|
66 |
+
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True)
|
67 |
+
|
68 |
+
elif name == "vit_b_dp005_mask_005": # For WebFace42M
|
69 |
+
# this is a feature
|
70 |
+
num_features = kwargs.get("num_features", 512)
|
71 |
+
from .vit import VisionTransformer
|
72 |
+
return VisionTransformer(
|
73 |
+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
|
74 |
+
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
|
75 |
+
|
76 |
+
elif name == "vit_l_dp005_mask_005": # For WebFace42M
|
77 |
+
# this is a feature
|
78 |
+
num_features = kwargs.get("num_features", 512)
|
79 |
+
from .vit import VisionTransformer
|
80 |
+
return VisionTransformer(
|
81 |
+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
|
82 |
+
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
|
83 |
+
|
84 |
+
else:
|
85 |
+
raise ValueError()
|
deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.utils.checkpoint import checkpoint
|
4 |
+
|
5 |
+
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
|
6 |
+
using_ckpt = False
|
7 |
+
|
8 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
9 |
+
"""3x3 convolution with padding"""
|
10 |
+
return nn.Conv2d(in_planes,
|
11 |
+
out_planes,
|
12 |
+
kernel_size=3,
|
13 |
+
stride=stride,
|
14 |
+
padding=dilation,
|
15 |
+
groups=groups,
|
16 |
+
bias=False,
|
17 |
+
dilation=dilation)
|
18 |
+
|
19 |
+
|
20 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
21 |
+
"""1x1 convolution"""
|
22 |
+
return nn.Conv2d(in_planes,
|
23 |
+
out_planes,
|
24 |
+
kernel_size=1,
|
25 |
+
stride=stride,
|
26 |
+
bias=False)
|
27 |
+
|
28 |
+
|
29 |
+
class IBasicBlock(nn.Module):
|
30 |
+
expansion = 1
|
31 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
32 |
+
groups=1, base_width=64, dilation=1):
|
33 |
+
super(IBasicBlock, self).__init__()
|
34 |
+
if groups != 1 or base_width != 64:
|
35 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
36 |
+
if dilation > 1:
|
37 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
38 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
|
39 |
+
self.conv1 = conv3x3(inplanes, planes)
|
40 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
|
41 |
+
self.prelu = nn.PReLU(planes)
|
42 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
43 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
|
44 |
+
self.downsample = downsample
|
45 |
+
self.stride = stride
|
46 |
+
|
47 |
+
def forward_impl(self, x):
|
48 |
+
identity = x
|
49 |
+
out = self.bn1(x)
|
50 |
+
out = self.conv1(out)
|
51 |
+
out = self.bn2(out)
|
52 |
+
out = self.prelu(out)
|
53 |
+
out = self.conv2(out)
|
54 |
+
out = self.bn3(out)
|
55 |
+
if self.downsample is not None:
|
56 |
+
identity = self.downsample(x)
|
57 |
+
out += identity
|
58 |
+
return out
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
if self.training and using_ckpt:
|
62 |
+
return checkpoint(self.forward_impl, x)
|
63 |
+
else:
|
64 |
+
return self.forward_impl(x)
|
65 |
+
|
66 |
+
|
67 |
+
class IResNet(nn.Module):
|
68 |
+
fc_scale = 7 * 7
|
69 |
+
def __init__(self,
|
70 |
+
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
71 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
72 |
+
super(IResNet, self).__init__()
|
73 |
+
self.extra_gflops = 0.0
|
74 |
+
self.fp16 = fp16
|
75 |
+
self.inplanes = 64
|
76 |
+
self.dilation = 1
|
77 |
+
if replace_stride_with_dilation is None:
|
78 |
+
replace_stride_with_dilation = [False, False, False]
|
79 |
+
if len(replace_stride_with_dilation) != 3:
|
80 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
81 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
82 |
+
self.groups = groups
|
83 |
+
self.base_width = width_per_group
|
84 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
85 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
86 |
+
self.prelu = nn.PReLU(self.inplanes)
|
87 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
88 |
+
self.layer2 = self._make_layer(block,
|
89 |
+
128,
|
90 |
+
layers[1],
|
91 |
+
stride=2,
|
92 |
+
dilate=replace_stride_with_dilation[0])
|
93 |
+
self.layer3 = self._make_layer(block,
|
94 |
+
256,
|
95 |
+
layers[2],
|
96 |
+
stride=2,
|
97 |
+
dilate=replace_stride_with_dilation[1])
|
98 |
+
self.layer4 = self._make_layer(block,
|
99 |
+
512,
|
100 |
+
layers[3],
|
101 |
+
stride=2,
|
102 |
+
dilate=replace_stride_with_dilation[2])
|
103 |
+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
|
104 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
105 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
106 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
107 |
+
nn.init.constant_(self.features.weight, 1.0)
|
108 |
+
self.features.weight.requires_grad = False
|
109 |
+
|
110 |
+
for m in self.modules():
|
111 |
+
if isinstance(m, nn.Conv2d):
|
112 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
113 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
114 |
+
nn.init.constant_(m.weight, 1)
|
115 |
+
nn.init.constant_(m.bias, 0)
|
116 |
+
|
117 |
+
if zero_init_residual:
|
118 |
+
for m in self.modules():
|
119 |
+
if isinstance(m, IBasicBlock):
|
120 |
+
nn.init.constant_(m.bn2.weight, 0)
|
121 |
+
|
122 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
123 |
+
downsample = None
|
124 |
+
previous_dilation = self.dilation
|
125 |
+
if dilate:
|
126 |
+
self.dilation *= stride
|
127 |
+
stride = 1
|
128 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
129 |
+
downsample = nn.Sequential(
|
130 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
131 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
132 |
+
)
|
133 |
+
layers = []
|
134 |
+
layers.append(
|
135 |
+
block(self.inplanes, planes, stride, downsample, self.groups,
|
136 |
+
self.base_width, previous_dilation))
|
137 |
+
self.inplanes = planes * block.expansion
|
138 |
+
for _ in range(1, blocks):
|
139 |
+
layers.append(
|
140 |
+
block(self.inplanes,
|
141 |
+
planes,
|
142 |
+
groups=self.groups,
|
143 |
+
base_width=self.base_width,
|
144 |
+
dilation=self.dilation))
|
145 |
+
|
146 |
+
return nn.Sequential(*layers)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
with torch.cuda.amp.autocast(self.fp16):
|
150 |
+
x = self.conv1(x)
|
151 |
+
x = self.bn1(x)
|
152 |
+
x = self.prelu(x)
|
153 |
+
x = self.layer1(x)
|
154 |
+
x = self.layer2(x)
|
155 |
+
x = self.layer3(x)
|
156 |
+
x = self.layer4(x)
|
157 |
+
x = self.bn2(x)
|
158 |
+
x = torch.flatten(x, 1)
|
159 |
+
x = self.dropout(x)
|
160 |
+
x = self.fc(x.float() if self.fp16 else x)
|
161 |
+
x = self.features(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
|
165 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
166 |
+
model = IResNet(block, layers, **kwargs)
|
167 |
+
if pretrained:
|
168 |
+
raise ValueError()
|
169 |
+
return model
|
170 |
+
|
171 |
+
|
172 |
+
def iresnet18(pretrained=False, progress=True, **kwargs):
|
173 |
+
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
|
174 |
+
progress, **kwargs)
|
175 |
+
|
176 |
+
|
177 |
+
def iresnet34(pretrained=False, progress=True, **kwargs):
|
178 |
+
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
|
179 |
+
progress, **kwargs)
|
180 |
+
|
181 |
+
|
182 |
+
def iresnet50(pretrained=False, progress=True, **kwargs):
|
183 |
+
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
|
184 |
+
progress, **kwargs)
|
185 |
+
|
186 |
+
|
187 |
+
def iresnet100(pretrained=False, progress=True, **kwargs):
|
188 |
+
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
|
189 |
+
progress, **kwargs)
|
190 |
+
|
191 |
+
|
192 |
+
def iresnet200(pretrained=False, progress=True, **kwargs):
|
193 |
+
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
|
194 |
+
progress, **kwargs)
|
deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet2060.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
assert torch.__version__ >= "1.8.1"
|
5 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
6 |
+
|
7 |
+
__all__ = ['iresnet2060']
|
8 |
+
|
9 |
+
|
10 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
11 |
+
"""3x3 convolution with padding"""
|
12 |
+
return nn.Conv2d(in_planes,
|
13 |
+
out_planes,
|
14 |
+
kernel_size=3,
|
15 |
+
stride=stride,
|
16 |
+
padding=dilation,
|
17 |
+
groups=groups,
|
18 |
+
bias=False,
|
19 |
+
dilation=dilation)
|
20 |
+
|
21 |
+
|
22 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
23 |
+
"""1x1 convolution"""
|
24 |
+
return nn.Conv2d(in_planes,
|
25 |
+
out_planes,
|
26 |
+
kernel_size=1,
|
27 |
+
stride=stride,
|
28 |
+
bias=False)
|
29 |
+
|
30 |
+
|
31 |
+
class IBasicBlock(nn.Module):
|
32 |
+
expansion = 1
|
33 |
+
|
34 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
35 |
+
groups=1, base_width=64, dilation=1):
|
36 |
+
super(IBasicBlock, self).__init__()
|
37 |
+
if groups != 1 or base_width != 64:
|
38 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
39 |
+
if dilation > 1:
|
40 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
41 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
|
42 |
+
self.conv1 = conv3x3(inplanes, planes)
|
43 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
|
44 |
+
self.prelu = nn.PReLU(planes)
|
45 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
46 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
|
47 |
+
self.downsample = downsample
|
48 |
+
self.stride = stride
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
identity = x
|
52 |
+
out = self.bn1(x)
|
53 |
+
out = self.conv1(out)
|
54 |
+
out = self.bn2(out)
|
55 |
+
out = self.prelu(out)
|
56 |
+
out = self.conv2(out)
|
57 |
+
out = self.bn3(out)
|
58 |
+
if self.downsample is not None:
|
59 |
+
identity = self.downsample(x)
|
60 |
+
out += identity
|
61 |
+
return out
|
62 |
+
|
63 |
+
|
64 |
+
class IResNet(nn.Module):
|
65 |
+
fc_scale = 7 * 7
|
66 |
+
|
67 |
+
def __init__(self,
|
68 |
+
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
69 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
70 |
+
super(IResNet, self).__init__()
|
71 |
+
self.fp16 = fp16
|
72 |
+
self.inplanes = 64
|
73 |
+
self.dilation = 1
|
74 |
+
if replace_stride_with_dilation is None:
|
75 |
+
replace_stride_with_dilation = [False, False, False]
|
76 |
+
if len(replace_stride_with_dilation) != 3:
|
77 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
78 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
79 |
+
self.groups = groups
|
80 |
+
self.base_width = width_per_group
|
81 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
82 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
83 |
+
self.prelu = nn.PReLU(self.inplanes)
|
84 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
85 |
+
self.layer2 = self._make_layer(block,
|
86 |
+
128,
|
87 |
+
layers[1],
|
88 |
+
stride=2,
|
89 |
+
dilate=replace_stride_with_dilation[0])
|
90 |
+
self.layer3 = self._make_layer(block,
|
91 |
+
256,
|
92 |
+
layers[2],
|
93 |
+
stride=2,
|
94 |
+
dilate=replace_stride_with_dilation[1])
|
95 |
+
self.layer4 = self._make_layer(block,
|
96 |
+
512,
|
97 |
+
layers[3],
|
98 |
+
stride=2,
|
99 |
+
dilate=replace_stride_with_dilation[2])
|
100 |
+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
|
101 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
102 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
103 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
104 |
+
nn.init.constant_(self.features.weight, 1.0)
|
105 |
+
self.features.weight.requires_grad = False
|
106 |
+
|
107 |
+
for m in self.modules():
|
108 |
+
if isinstance(m, nn.Conv2d):
|
109 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
110 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
111 |
+
nn.init.constant_(m.weight, 1)
|
112 |
+
nn.init.constant_(m.bias, 0)
|
113 |
+
|
114 |
+
if zero_init_residual:
|
115 |
+
for m in self.modules():
|
116 |
+
if isinstance(m, IBasicBlock):
|
117 |
+
nn.init.constant_(m.bn2.weight, 0)
|
118 |
+
|
119 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
120 |
+
downsample = None
|
121 |
+
previous_dilation = self.dilation
|
122 |
+
if dilate:
|
123 |
+
self.dilation *= stride
|
124 |
+
stride = 1
|
125 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
126 |
+
downsample = nn.Sequential(
|
127 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
128 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
129 |
+
)
|
130 |
+
layers = []
|
131 |
+
layers.append(
|
132 |
+
block(self.inplanes, planes, stride, downsample, self.groups,
|
133 |
+
self.base_width, previous_dilation))
|
134 |
+
self.inplanes = planes * block.expansion
|
135 |
+
for _ in range(1, blocks):
|
136 |
+
layers.append(
|
137 |
+
block(self.inplanes,
|
138 |
+
planes,
|
139 |
+
groups=self.groups,
|
140 |
+
base_width=self.base_width,
|
141 |
+
dilation=self.dilation))
|
142 |
+
|
143 |
+
return nn.Sequential(*layers)
|
144 |
+
|
145 |
+
def checkpoint(self, func, num_seg, x):
|
146 |
+
if self.training:
|
147 |
+
return checkpoint_sequential(func, num_seg, x)
|
148 |
+
else:
|
149 |
+
return func(x)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
with torch.cuda.amp.autocast(self.fp16):
|
153 |
+
x = self.conv1(x)
|
154 |
+
x = self.bn1(x)
|
155 |
+
x = self.prelu(x)
|
156 |
+
x = self.layer1(x)
|
157 |
+
x = self.checkpoint(self.layer2, 20, x)
|
158 |
+
x = self.checkpoint(self.layer3, 100, x)
|
159 |
+
x = self.layer4(x)
|
160 |
+
x = self.bn2(x)
|
161 |
+
x = torch.flatten(x, 1)
|
162 |
+
x = self.dropout(x)
|
163 |
+
x = self.fc(x.float() if self.fp16 else x)
|
164 |
+
x = self.features(x)
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
169 |
+
model = IResNet(block, layers, **kwargs)
|
170 |
+
if pretrained:
|
171 |
+
raise ValueError()
|
172 |
+
return model
|
173 |
+
|
174 |
+
|
175 |
+
def iresnet2060(pretrained=False, progress=True, **kwargs):
|
176 |
+
return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
|
deep_3drecon/deep_3drecon_models/arcface_torch/backbones/mobilefacenet.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
|
3 |
+
Original author cavalleria
|
4 |
+
'''
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class Flatten(Module):
|
12 |
+
def forward(self, x):
|
13 |
+
return x.view(x.size(0), -1)
|
14 |
+
|
15 |
+
|
16 |
+
class ConvBlock(Module):
|
17 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
18 |
+
super(ConvBlock, self).__init__()
|
19 |
+
self.layers = nn.Sequential(
|
20 |
+
Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
|
21 |
+
BatchNorm2d(num_features=out_c),
|
22 |
+
PReLU(num_parameters=out_c)
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.layers(x)
|
27 |
+
|
28 |
+
|
29 |
+
class LinearBlock(Module):
|
30 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
31 |
+
super(LinearBlock, self).__init__()
|
32 |
+
self.layers = nn.Sequential(
|
33 |
+
Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
34 |
+
BatchNorm2d(num_features=out_c)
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.layers(x)
|
39 |
+
|
40 |
+
|
41 |
+
class DepthWise(Module):
|
42 |
+
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
|
43 |
+
super(DepthWise, self).__init__()
|
44 |
+
self.residual = residual
|
45 |
+
self.layers = nn.Sequential(
|
46 |
+
ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
|
47 |
+
ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
|
48 |
+
LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
short_cut = None
|
53 |
+
if self.residual:
|
54 |
+
short_cut = x
|
55 |
+
x = self.layers(x)
|
56 |
+
if self.residual:
|
57 |
+
output = short_cut + x
|
58 |
+
else:
|
59 |
+
output = x
|
60 |
+
return output
|
61 |
+
|
62 |
+
|
63 |
+
class Residual(Module):
|
64 |
+
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
|
65 |
+
super(Residual, self).__init__()
|
66 |
+
modules = []
|
67 |
+
for _ in range(num_block):
|
68 |
+
modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
|
69 |
+
self.layers = Sequential(*modules)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
return self.layers(x)
|
73 |
+
|
74 |
+
|
75 |
+
class GDC(Module):
|
76 |
+
def __init__(self, embedding_size):
|
77 |
+
super(GDC, self).__init__()
|
78 |
+
self.layers = nn.Sequential(
|
79 |
+
LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
|
80 |
+
Flatten(),
|
81 |
+
Linear(512, embedding_size, bias=False),
|
82 |
+
BatchNorm1d(embedding_size))
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
return self.layers(x)
|
86 |
+
|
87 |
+
|
88 |
+
class MobileFaceNet(Module):
|
89 |
+
def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2):
|
90 |
+
super(MobileFaceNet, self).__init__()
|
91 |
+
self.scale = scale
|
92 |
+
self.fp16 = fp16
|
93 |
+
self.layers = nn.ModuleList()
|
94 |
+
self.layers.append(
|
95 |
+
ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
|
96 |
+
)
|
97 |
+
if blocks[0] == 1:
|
98 |
+
self.layers.append(
|
99 |
+
ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
self.layers.append(
|
103 |
+
Residual(64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
104 |
+
)
|
105 |
+
|
106 |
+
self.layers.extend(
|
107 |
+
[
|
108 |
+
DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
|
109 |
+
Residual(64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
110 |
+
DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
|
111 |
+
Residual(128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
112 |
+
DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
|
113 |
+
Residual(128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
114 |
+
])
|
115 |
+
|
116 |
+
self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
|
117 |
+
self.features = GDC(num_features)
|
118 |
+
self._initialize_weights()
|
119 |
+
|
120 |
+
def _initialize_weights(self):
|
121 |
+
for m in self.modules():
|
122 |
+
if isinstance(m, nn.Conv2d):
|
123 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
124 |
+
if m.bias is not None:
|
125 |
+
m.bias.data.zero_()
|
126 |
+
elif isinstance(m, nn.BatchNorm2d):
|
127 |
+
m.weight.data.fill_(1)
|
128 |
+
m.bias.data.zero_()
|
129 |
+
elif isinstance(m, nn.Linear):
|
130 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
131 |
+
if m.bias is not None:
|
132 |
+
m.bias.data.zero_()
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
with torch.cuda.amp.autocast(self.fp16):
|
136 |
+
for func in self.layers:
|
137 |
+
x = func(x)
|
138 |
+
x = self.conv_sep(x.float() if self.fp16 else x)
|
139 |
+
x = self.features(x)
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2):
|
144 |
+
return MobileFaceNet(fp16, num_features, blocks, scale=scale)
|
145 |
+
|
146 |
+
def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4):
|
147 |
+
return MobileFaceNet(fp16, num_features, blocks, scale=scale)
|
deep_3drecon/deep_3drecon_models/arcface_torch/backbones/vit.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
4 |
+
from typing import Optional, Callable
|
5 |
+
|
6 |
+
class Mlp(nn.Module):
|
7 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):
|
8 |
+
super().__init__()
|
9 |
+
out_features = out_features or in_features
|
10 |
+
hidden_features = hidden_features or in_features
|
11 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
12 |
+
self.act = act_layer()
|
13 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
14 |
+
self.drop = nn.Dropout(drop)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
x = self.fc1(x)
|
18 |
+
x = self.act(x)
|
19 |
+
x = self.drop(x)
|
20 |
+
x = self.fc2(x)
|
21 |
+
x = self.drop(x)
|
22 |
+
return x
|
23 |
+
|
24 |
+
|
25 |
+
class VITBatchNorm(nn.Module):
|
26 |
+
def __init__(self, num_features):
|
27 |
+
super().__init__()
|
28 |
+
self.num_features = num_features
|
29 |
+
self.bn = nn.BatchNorm1d(num_features=num_features)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
return self.bn(x)
|
33 |
+
|
34 |
+
|
35 |
+
class Attention(nn.Module):
|
36 |
+
def __init__(self,
|
37 |
+
dim: int,
|
38 |
+
num_heads: int = 8,
|
39 |
+
qkv_bias: bool = False,
|
40 |
+
qk_scale: Optional[None] = None,
|
41 |
+
attn_drop: float = 0.,
|
42 |
+
proj_drop: float = 0.):
|
43 |
+
super().__init__()
|
44 |
+
self.num_heads = num_heads
|
45 |
+
head_dim = dim // num_heads
|
46 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
47 |
+
self.scale = qk_scale or head_dim ** -0.5
|
48 |
+
|
49 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
50 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
51 |
+
self.proj = nn.Linear(dim, dim)
|
52 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
|
56 |
+
with torch.cuda.amp.autocast(True):
|
57 |
+
batch_size, num_token, embed_dim = x.shape
|
58 |
+
#qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads]
|
59 |
+
qkv = self.qkv(x).reshape(
|
60 |
+
batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads).permute(2, 0, 3, 1, 4)
|
61 |
+
with torch.cuda.amp.autocast(False):
|
62 |
+
q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float()
|
63 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
64 |
+
attn = attn.softmax(dim=-1)
|
65 |
+
attn = self.attn_drop(attn)
|
66 |
+
x = (attn @ v).transpose(1, 2).reshape(batch_size, num_token, embed_dim)
|
67 |
+
with torch.cuda.amp.autocast(True):
|
68 |
+
x = self.proj(x)
|
69 |
+
x = self.proj_drop(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
class Block(nn.Module):
|
74 |
+
|
75 |
+
def __init__(self,
|
76 |
+
dim: int,
|
77 |
+
num_heads: int,
|
78 |
+
num_patches: int,
|
79 |
+
mlp_ratio: float = 4.,
|
80 |
+
qkv_bias: bool = False,
|
81 |
+
qk_scale: Optional[None] = None,
|
82 |
+
drop: float = 0.,
|
83 |
+
attn_drop: float = 0.,
|
84 |
+
drop_path: float = 0.,
|
85 |
+
act_layer: Callable = nn.ReLU6,
|
86 |
+
norm_layer: str = "ln",
|
87 |
+
patch_n: int = 144):
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
if norm_layer == "bn":
|
91 |
+
self.norm1 = VITBatchNorm(num_features=num_patches)
|
92 |
+
self.norm2 = VITBatchNorm(num_features=num_patches)
|
93 |
+
elif norm_layer == "ln":
|
94 |
+
self.norm1 = nn.LayerNorm(dim)
|
95 |
+
self.norm2 = nn.LayerNorm(dim)
|
96 |
+
|
97 |
+
self.attn = Attention(
|
98 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
99 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
100 |
+
self.drop_path = DropPath(
|
101 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
102 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
103 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
104 |
+
act_layer=act_layer, drop=drop)
|
105 |
+
self.extra_gflops = (num_heads * patch_n * (dim//num_heads)*patch_n * 2) / (1000**3)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
109 |
+
with torch.cuda.amp.autocast(True):
|
110 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
111 |
+
return x
|
112 |
+
|
113 |
+
|
114 |
+
class PatchEmbed(nn.Module):
|
115 |
+
def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768):
|
116 |
+
super().__init__()
|
117 |
+
img_size = to_2tuple(img_size)
|
118 |
+
patch_size = to_2tuple(patch_size)
|
119 |
+
num_patches = (img_size[1] // patch_size[1]) * \
|
120 |
+
(img_size[0] // patch_size[0])
|
121 |
+
self.img_size = img_size
|
122 |
+
self.patch_size = patch_size
|
123 |
+
self.num_patches = num_patches
|
124 |
+
self.proj = nn.Conv2d(in_channels, embed_dim,
|
125 |
+
kernel_size=patch_size, stride=patch_size)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
batch_size, channels, height, width = x.shape
|
129 |
+
assert height == self.img_size[0] and width == self.img_size[1], \
|
130 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
131 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
132 |
+
return x
|
133 |
+
|
134 |
+
|
135 |
+
class VisionTransformer(nn.Module):
|
136 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
137 |
+
"""
|
138 |
+
|
139 |
+
def __init__(self,
|
140 |
+
img_size: int = 112,
|
141 |
+
patch_size: int = 16,
|
142 |
+
in_channels: int = 3,
|
143 |
+
num_classes: int = 1000,
|
144 |
+
embed_dim: int = 768,
|
145 |
+
depth: int = 12,
|
146 |
+
num_heads: int = 12,
|
147 |
+
mlp_ratio: float = 4.,
|
148 |
+
qkv_bias: bool = False,
|
149 |
+
qk_scale: Optional[None] = None,
|
150 |
+
drop_rate: float = 0.,
|
151 |
+
attn_drop_rate: float = 0.,
|
152 |
+
drop_path_rate: float = 0.,
|
153 |
+
hybrid_backbone: Optional[None] = None,
|
154 |
+
norm_layer: str = "ln",
|
155 |
+
mask_ratio = 0.1,
|
156 |
+
using_checkpoint = False,
|
157 |
+
):
|
158 |
+
super().__init__()
|
159 |
+
self.num_classes = num_classes
|
160 |
+
# num_features for consistency with other models
|
161 |
+
self.num_features = self.embed_dim = embed_dim
|
162 |
+
|
163 |
+
if hybrid_backbone is not None:
|
164 |
+
raise ValueError
|
165 |
+
else:
|
166 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
|
167 |
+
self.mask_ratio = mask_ratio
|
168 |
+
self.using_checkpoint = using_checkpoint
|
169 |
+
num_patches = self.patch_embed.num_patches
|
170 |
+
self.num_patches = num_patches
|
171 |
+
|
172 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
173 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
174 |
+
|
175 |
+
# stochastic depth decay rule
|
176 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
177 |
+
patch_n = (img_size//patch_size)**2
|
178 |
+
self.blocks = nn.ModuleList(
|
179 |
+
[
|
180 |
+
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
181 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
182 |
+
num_patches=num_patches, patch_n=patch_n)
|
183 |
+
for i in range(depth)]
|
184 |
+
)
|
185 |
+
self.extra_gflops = 0.0
|
186 |
+
for _block in self.blocks:
|
187 |
+
self.extra_gflops += _block.extra_gflops
|
188 |
+
|
189 |
+
if norm_layer == "ln":
|
190 |
+
self.norm = nn.LayerNorm(embed_dim)
|
191 |
+
elif norm_layer == "bn":
|
192 |
+
self.norm = VITBatchNorm(self.num_patches)
|
193 |
+
|
194 |
+
# features head
|
195 |
+
self.feature = nn.Sequential(
|
196 |
+
nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False),
|
197 |
+
nn.BatchNorm1d(num_features=embed_dim, eps=2e-5),
|
198 |
+
nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False),
|
199 |
+
nn.BatchNorm1d(num_features=num_classes, eps=2e-5)
|
200 |
+
)
|
201 |
+
|
202 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
203 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
204 |
+
trunc_normal_(self.pos_embed, std=.02)
|
205 |
+
# trunc_normal_(self.cls_token, std=.02)
|
206 |
+
self.apply(self._init_weights)
|
207 |
+
|
208 |
+
def _init_weights(self, m):
|
209 |
+
if isinstance(m, nn.Linear):
|
210 |
+
trunc_normal_(m.weight, std=.02)
|
211 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
212 |
+
nn.init.constant_(m.bias, 0)
|
213 |
+
elif isinstance(m, nn.LayerNorm):
|
214 |
+
nn.init.constant_(m.bias, 0)
|
215 |
+
nn.init.constant_(m.weight, 1.0)
|
216 |
+
|
217 |
+
@torch.jit.ignore
|
218 |
+
def no_weight_decay(self):
|
219 |
+
return {'pos_embed', 'cls_token'}
|
220 |
+
|
221 |
+
def get_classifier(self):
|
222 |
+
return self.head
|
223 |
+
|
224 |
+
def random_masking(self, x, mask_ratio=0.1):
|
225 |
+
"""
|
226 |
+
Perform per-sample random masking by per-sample shuffling.
|
227 |
+
Per-sample shuffling is done by argsort random noise.
|
228 |
+
x: [N, L, D], sequence
|
229 |
+
"""
|
230 |
+
N, L, D = x.size() # batch, length, dim
|
231 |
+
len_keep = int(L * (1 - mask_ratio))
|
232 |
+
|
233 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
234 |
+
|
235 |
+
# sort noise for each sample
|
236 |
+
# ascend: small is keep, large is remove
|
237 |
+
ids_shuffle = torch.argsort(noise, dim=1)
|
238 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
239 |
+
|
240 |
+
# keep the first subset
|
241 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
242 |
+
x_masked = torch.gather(
|
243 |
+
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
244 |
+
|
245 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
246 |
+
mask = torch.ones([N, L], device=x.device)
|
247 |
+
mask[:, :len_keep] = 0
|
248 |
+
# unshuffle to get the binary mask
|
249 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
250 |
+
|
251 |
+
return x_masked, mask, ids_restore
|
252 |
+
|
253 |
+
def forward_features(self, x):
|
254 |
+
B = x.shape[0]
|
255 |
+
x = self.patch_embed(x)
|
256 |
+
x = x + self.pos_embed
|
257 |
+
x = self.pos_drop(x)
|
258 |
+
|
259 |
+
if self.training and self.mask_ratio > 0:
|
260 |
+
x, _, ids_restore = self.random_masking(x)
|
261 |
+
|
262 |
+
for func in self.blocks:
|
263 |
+
if self.using_checkpoint and self.training:
|
264 |
+
from torch.utils.checkpoint import checkpoint
|
265 |
+
x = checkpoint(func, x)
|
266 |
+
else:
|
267 |
+
x = func(x)
|
268 |
+
x = self.norm(x.float())
|
269 |
+
|
270 |
+
if self.training and self.mask_ratio > 0:
|
271 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
|
272 |
+
x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
|
273 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
274 |
+
x = x_
|
275 |
+
return torch.reshape(x, (B, self.num_patches * self.embed_dim))
|
276 |
+
|
277 |
+
def forward(self, x):
|
278 |
+
x = self.forward_features(x)
|
279 |
+
x = self.feature(x)
|
280 |
+
return x
|