Spaces:
Running
on
Zero
Running
on
Zero
Upload 116 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +47 -43
- .gitignore +28 -0
- README.md +12 -12
- app_v1v2.py +175 -0
- configs/astral_quantization/default_2048.yml +40 -0
- configs/astral_quantization/default_32.yml +40 -0
- configs/config.json +1 -0
- configs/inuse/.gitignore +0 -0
- configs/inuse/config.json +1 -0
- configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml +98 -0
- configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml +91 -0
- configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml +82 -0
- configs/v2/ar_base.yaml +0 -0
- configs/v2/dit_small.yaml +17 -0
- configs/v2/vc_wrapper.yaml +105 -0
- hf_utils.py +1 -1
- modules/__pycache__/audio.cpython-310.pyc +0 -0
- modules/__pycache__/commons.cpython-310.pyc +0 -0
- modules/__pycache__/commons.cpython-38.pyc +0 -0
- modules/__pycache__/diffusion_transformer.cpython-310.pyc +0 -0
- modules/__pycache__/flow_matching.cpython-310.pyc +0 -0
- modules/__pycache__/length_regulator.cpython-310.pyc +0 -0
- modules/__pycache__/rmvpe.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/bsq.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/convnext.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/default_model.cpython-310.pyc +0 -0
- modules/astral_quantization/bsq.py +569 -0
- modules/astral_quantization/convnext.py +209 -0
- modules/astral_quantization/default_model.py +73 -0
- modules/astral_quantization/transformer.py +254 -0
- modules/audio.py +82 -82
- modules/bigvgan/__pycache__/activations.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/env.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/meldataset.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/utils.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/activation1d.py +2 -2
- modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/.ninja_log +7 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp +0 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib +0 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/build.ninja +38 -0
- modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc +0 -0
.gitattributes
CHANGED
@@ -1,43 +1,47 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
-
examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
-
examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
-
examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
-
examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
|
41 |
-
examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
|
42 |
-
examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
|
43 |
-
examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
+
examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
|
41 |
+
examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
|
42 |
+
examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
|
43 |
+
examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
|
44 |
+
modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps filter=lfs diff=lfs merge=lfs -text
|
45 |
+
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text
|
46 |
+
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd filter=lfs diff=lfs merge=lfs -text
|
47 |
+
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general things to ignore
|
2 |
+
.DS_Store
|
3 |
+
build/
|
4 |
+
build_contrib/
|
5 |
+
dist/
|
6 |
+
.cache/
|
7 |
+
*.egg-info/
|
8 |
+
*.egg
|
9 |
+
*.py[cod]
|
10 |
+
__pycache__/
|
11 |
+
*.so
|
12 |
+
*~
|
13 |
+
|
14 |
+
# IDE
|
15 |
+
.vscode/
|
16 |
+
.idea/
|
17 |
+
|
18 |
+
# misc
|
19 |
+
checkpoints/
|
20 |
+
test_waves/
|
21 |
+
reconstructed/
|
22 |
+
.python-version
|
23 |
+
ruff.log
|
24 |
+
/configs/inuse/
|
25 |
+
runs/
|
26 |
+
/garbages/
|
27 |
+
/flagged/
|
28 |
+
/experimental/
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
---
|
2 |
-
title: Seed Voice Conversion
|
3 |
-
emoji: 🎤🔄
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: green
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file:
|
9 |
-
pinned: false
|
10 |
-
license: gpl-3.0
|
11 |
-
---
|
12 |
-
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: Seed Voice Conversion
|
3 |
+
emoji: 🎤🔄
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.23.0
|
8 |
+
app_file: app_v1v2.py
|
9 |
+
pinned: false
|
10 |
+
license: gpl-3.0
|
11 |
+
---
|
12 |
+
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app_v1v2.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import yaml
|
5 |
+
import argparse
|
6 |
+
from seed_vc_wrapper import SeedVCWrapper
|
7 |
+
|
8 |
+
# Set up device and torch configurations
|
9 |
+
if torch.cuda.is_available():
|
10 |
+
device = torch.device("cuda")
|
11 |
+
elif torch.backends.mps.is_available():
|
12 |
+
device = torch.device("mps")
|
13 |
+
else:
|
14 |
+
device = torch.device("cpu")
|
15 |
+
|
16 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
17 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
18 |
+
|
19 |
+
if hasattr(torch._inductor.config, "fx_graph_cache"):
|
20 |
+
# Experimental feature to reduce compilation times, will be on by default in future
|
21 |
+
torch._inductor.config.fx_graph_cache = True
|
22 |
+
|
23 |
+
dtype = torch.float16
|
24 |
+
|
25 |
+
def load_v2_models(args):
|
26 |
+
from hydra.utils import instantiate
|
27 |
+
from omegaconf import DictConfig
|
28 |
+
cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
|
29 |
+
vc_wrapper = instantiate(cfg)
|
30 |
+
vc_wrapper.load_checkpoints()
|
31 |
+
vc_wrapper.to(device)
|
32 |
+
vc_wrapper.eval()
|
33 |
+
|
34 |
+
vc_wrapper.setup_ar_caches(max_batch_size=1, max_seq_len=4096, dtype=dtype, device=device)
|
35 |
+
|
36 |
+
if args.compile:
|
37 |
+
vc_wrapper.compile_ar()
|
38 |
+
# vc_wrapper.compile_cfm()
|
39 |
+
|
40 |
+
return vc_wrapper
|
41 |
+
|
42 |
+
def create_v1_interface():
|
43 |
+
# Initialize the V1 wrapper
|
44 |
+
vc_wrapper = SeedVCWrapper()
|
45 |
+
|
46 |
+
# Set up Gradio interface
|
47 |
+
description = ("Zero-shot voice conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
|
48 |
+
"for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
|
49 |
+
"If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
|
50 |
+
"无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
|
51 |
+
"请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。")
|
52 |
+
|
53 |
+
inputs = [
|
54 |
+
gr.Audio(type="filepath", label="Source Audio / 源音频"),
|
55 |
+
gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
|
56 |
+
gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Diffusion Steps / 扩散步数",
|
57 |
+
info="10 by default, 50~100 for best quality / 默认为 10,50~100 为最佳质量"),
|
58 |
+
gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整",
|
59 |
+
info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
|
60 |
+
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate",
|
61 |
+
info="has subtle influence / 有微小影响"),
|
62 |
+
gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False,
|
63 |
+
info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
|
64 |
+
gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
|
65 |
+
info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"),
|
66 |
+
gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0,
|
67 |
+
info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"),
|
68 |
+
]
|
69 |
+
|
70 |
+
examples = [
|
71 |
+
["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0],
|
72 |
+
["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, True, True, 0],
|
73 |
+
["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
|
74 |
+
"examples/reference/teio_0.wav", 100, 1.0, 0.7, True, False, 0],
|
75 |
+
["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
|
76 |
+
"examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12],
|
77 |
+
]
|
78 |
+
|
79 |
+
outputs = [
|
80 |
+
gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
|
81 |
+
gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
|
82 |
+
]
|
83 |
+
|
84 |
+
return gr.Interface(
|
85 |
+
fn=vc_wrapper.convert_voice,
|
86 |
+
description=description,
|
87 |
+
inputs=inputs,
|
88 |
+
outputs=outputs,
|
89 |
+
title="Seed Voice Conversion V1 (Voice & Singing Voice Conversion)",
|
90 |
+
examples=examples,
|
91 |
+
cache_examples=False,
|
92 |
+
)
|
93 |
+
|
94 |
+
def create_v2_interface(vc_wrapper):
|
95 |
+
# Set up Gradio interface
|
96 |
+
description = ("Zero-shot voice/style conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
|
97 |
+
"for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
|
98 |
+
"If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
|
99 |
+
"Please click the 'convert style/emotion/accent' checkbox to convert the style, emotion, or accent of the source audio, or else only timbre conversion will be performed.<br> "
|
100 |
+
"Click the 'anonymization only' checkbox will ignore reference audio but convert source to an 'average voice' determined by model itself.<br> "
|
101 |
+
"无需训练的 zero-shot 语音/口音转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
|
102 |
+
"请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。"
|
103 |
+
"<br>请勾选 'convert style/emotion/accent' 以转换源音频的风格、情感或口音,否则仅执行音色转换。<br>"
|
104 |
+
"勾选 'anonymization only' 会无视参考音频而将源音频转换为某种由模型自身决定的 '平均音色'。<br>"
|
105 |
+
|
106 |
+
"Credits to [Vevo](https://github.com/open-mmlab/Amphion/tree/main/models/vc/vevo)"
|
107 |
+
)
|
108 |
+
inputs = [
|
109 |
+
gr.Audio(type="filepath", label="Source Audio / 源音频"),
|
110 |
+
gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
|
111 |
+
gr.Slider(minimum=1, maximum=200, value=30, step=1, label="Diffusion Steps / 扩散步数",
|
112 |
+
info="30 by default, 50~100 for best quality / 默认为 30,50~100 为最佳质量"),
|
113 |
+
gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整",
|
114 |
+
info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
|
115 |
+
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Intelligibility CFG Rate",
|
116 |
+
info="controls pronunciation intelligibility / 控制发音清晰度"),
|
117 |
+
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Similarity CFG Rate",
|
118 |
+
info="controls similarity to reference audio / 控制与参考音频的相似度"),
|
119 |
+
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.9, label="Top-p",
|
120 |
+
info="AR model sampling top P"),
|
121 |
+
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature",
|
122 |
+
info="AR model sampling temperature"),
|
123 |
+
gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty",
|
124 |
+
info="AR model sampling repetition penalty"),
|
125 |
+
gr.Checkbox(label="convert style/emotion/accent", value=False),
|
126 |
+
gr.Checkbox(label="anonymization only", value=False),
|
127 |
+
]
|
128 |
+
|
129 |
+
examples = [
|
130 |
+
["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False],
|
131 |
+
["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False],
|
132 |
+
]
|
133 |
+
|
134 |
+
outputs = [
|
135 |
+
gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
|
136 |
+
gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
|
137 |
+
]
|
138 |
+
|
139 |
+
return gr.Interface(
|
140 |
+
fn=vc_wrapper.convert_voice_with_streaming,
|
141 |
+
description=description,
|
142 |
+
inputs=inputs,
|
143 |
+
outputs=outputs,
|
144 |
+
title="Seed Voice Conversion V2 (Voice & Style Conversion)",
|
145 |
+
examples=examples,
|
146 |
+
cache_examples=False,
|
147 |
+
)
|
148 |
+
|
149 |
+
def main(args):
|
150 |
+
# Load V2 models
|
151 |
+
vc_wrapper_v2 = load_v2_models(args)
|
152 |
+
|
153 |
+
# Create interfaces
|
154 |
+
v1_interface = create_v1_interface()
|
155 |
+
v2_interface = create_v2_interface(vc_wrapper_v2)
|
156 |
+
|
157 |
+
# Create tabs
|
158 |
+
with gr.Blocks(title="Seed Voice Conversion") as demo:
|
159 |
+
gr.Markdown("# Seed Voice Conversion")
|
160 |
+
gr.Markdown("Choose between V1 (Voice & Singing Voice Conversion) or V2 (Voice & Style Conversion)")
|
161 |
+
|
162 |
+
with gr.Tabs():
|
163 |
+
with gr.TabItem("V2 - Voice & Style Conversion"):
|
164 |
+
v2_interface.render()
|
165 |
+
with gr.TabItem("V1 - Voice & Singing Voice Conversion"):
|
166 |
+
v1_interface.render()
|
167 |
+
|
168 |
+
# Launch the combined interface
|
169 |
+
demo.launch()
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
parser = argparse.ArgumentParser()
|
173 |
+
parser.add_argument("--compile", type=bool, default=True)
|
174 |
+
args = parser.parse_args()
|
175 |
+
main(args)
|
configs/astral_quantization/default_2048.yml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: modules.astral_quantization.default_model.AstralQuantizer
|
2 |
+
tokenizer_name: "openai/whisper-small"
|
3 |
+
ssl_model_name: "facebook/hubert-large-ll60k"
|
4 |
+
ssl_output_layer: 18
|
5 |
+
encoder:
|
6 |
+
_target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
|
7 |
+
dim: 512
|
8 |
+
num_blocks: 12
|
9 |
+
intermediate_dim: 1536
|
10 |
+
dilation: 1
|
11 |
+
input_dim: 1024
|
12 |
+
quantizer:
|
13 |
+
_target_: modules.astral_quantization.bsq.BinarySphericalQuantize
|
14 |
+
codebook_size: 2048 # codebook size, must be a power of 2
|
15 |
+
dim: 512
|
16 |
+
entropy_loss_weight: 0.1
|
17 |
+
diversity_gamma: 1.0
|
18 |
+
spherical: True
|
19 |
+
enable_entropy_loss: True
|
20 |
+
soft_entropy_loss: True
|
21 |
+
decoder:
|
22 |
+
_target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
|
23 |
+
dim: 512
|
24 |
+
num_blocks: 12
|
25 |
+
intermediate_dim: 1536
|
26 |
+
dilation: 1
|
27 |
+
output_dim: 1024
|
28 |
+
gin_channels: 192
|
29 |
+
asr_decoder:
|
30 |
+
_target_: modules.astral_quantization.asr_decoder.ASRDecoder
|
31 |
+
hidden_dim: 768
|
32 |
+
num_heads: 12
|
33 |
+
depth: 12
|
34 |
+
block_size: 4096
|
35 |
+
in_channels: 512
|
36 |
+
n_vocab: 51866
|
37 |
+
bos_id: 50528
|
38 |
+
eos_id: 50527
|
39 |
+
dropout_rate: 0.0
|
40 |
+
attn_dropout_rate: 0.0
|
configs/astral_quantization/default_32.yml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: default_model.AstralQuantizer
|
2 |
+
tokenizer_name: "openai/whisper-small"
|
3 |
+
ssl_model_name: "facebook/hubert-large-ll60k"
|
4 |
+
ssl_output_layer: 18
|
5 |
+
encoder:
|
6 |
+
_target_: modules.convnext.ConvNeXtV2Stage
|
7 |
+
dim: 512
|
8 |
+
num_blocks: 12
|
9 |
+
intermediate_dim: 1536
|
10 |
+
dilation: 1
|
11 |
+
input_dim: 1024
|
12 |
+
quantizer:
|
13 |
+
_target_: modules.bsq.BinarySphericalQuantize
|
14 |
+
codebook_size: 32 # codebook size, must be a power of 2
|
15 |
+
dim: 512
|
16 |
+
entropy_loss_weight: 0.1
|
17 |
+
diversity_gamma: 1.0
|
18 |
+
spherical: True
|
19 |
+
enable_entropy_loss: True
|
20 |
+
soft_entropy_loss: True
|
21 |
+
decoder:
|
22 |
+
_target_: modules.convnext.ConvNeXtV2Stage
|
23 |
+
dim: 512
|
24 |
+
num_blocks: 12
|
25 |
+
intermediate_dim: 1536
|
26 |
+
dilation: 1
|
27 |
+
output_dim: 1024
|
28 |
+
gin_channels: 192
|
29 |
+
asr_decoder:
|
30 |
+
_target_: modules.asr_decoder.ASRDecoder
|
31 |
+
hidden_dim: 768
|
32 |
+
num_heads: 12
|
33 |
+
depth: 12
|
34 |
+
block_size: 4096
|
35 |
+
in_channels: 512
|
36 |
+
n_vocab: 51866
|
37 |
+
bos_id: 50528
|
38 |
+
eos_id: 50527
|
39 |
+
dropout_rate: 0.0
|
40 |
+
attn_dropout_rate: 0.0
|
configs/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"reference_audio_path": "D:/FAcodec/test_waves/kobe_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS 2.4", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS 2.4", "sr_type": "sr_model", "diffusion_steps": 10.0, "inference_cfg_rate": 0.0, "max_prompt_length": 3.0, "block_time": 0.7, "crossfade_length": 0.04, "extra_time": 0.5, "extra_time_right": 0.02}
|
configs/inuse/.gitignore
ADDED
File without changes
|
configs/inuse/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"reference_audio_path": "D:/seed-vc/examples/reference/trump_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS USB", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS USB", "sr_type": "sr_model", "diffusion_steps": 8.0, "inference_cfg_rate": 0.7, "max_prompt_length": 3.0, "block_time": 0.58, "crossfade_length": 0.04, "extra_time_ce": 2.5, "extra_time": 0.5, "extra_time_right": 0.02}
|
configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "./runs"
|
2 |
+
save_freq: 1
|
3 |
+
log_interval: 10
|
4 |
+
save_interval: 1000
|
5 |
+
device: "cuda"
|
6 |
+
epochs: 1000 # number of epochs for first stage training (pre-training)
|
7 |
+
batch_size: 1
|
8 |
+
batch_length: 100 # maximum duration of audio in a batch (in seconds)
|
9 |
+
max_len: 80 # maximum number of frames
|
10 |
+
pretrained_model: "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth"
|
11 |
+
pretrained_encoder: ""
|
12 |
+
load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
|
13 |
+
|
14 |
+
preprocess_params:
|
15 |
+
sr: 44100
|
16 |
+
spect_params:
|
17 |
+
n_fft: 2048
|
18 |
+
win_length: 2048
|
19 |
+
hop_length: 512
|
20 |
+
n_mels: 128
|
21 |
+
fmin: 0
|
22 |
+
fmax: "None"
|
23 |
+
|
24 |
+
model_params:
|
25 |
+
dit_type: "DiT" # uDiT or DiT
|
26 |
+
reg_loss_type: "l1" # l1 or l2
|
27 |
+
|
28 |
+
timbre_shifter:
|
29 |
+
se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
|
30 |
+
ckpt_path: './modules/openvoice/checkpoints_v2/converter'
|
31 |
+
|
32 |
+
vocoder:
|
33 |
+
type: "bigvgan"
|
34 |
+
name: "nvidia/bigvgan_v2_44khz_128band_512x"
|
35 |
+
|
36 |
+
speech_tokenizer:
|
37 |
+
type: 'whisper'
|
38 |
+
name: "openai/whisper-small"
|
39 |
+
|
40 |
+
style_encoder:
|
41 |
+
dim: 192
|
42 |
+
campplus_path: "campplus_cn_common.bin"
|
43 |
+
|
44 |
+
DAC:
|
45 |
+
encoder_dim: 64
|
46 |
+
encoder_rates: [2, 5, 5, 6]
|
47 |
+
decoder_dim: 1536
|
48 |
+
decoder_rates: [ 6, 5, 5, 2 ]
|
49 |
+
sr: 24000
|
50 |
+
|
51 |
+
length_regulator:
|
52 |
+
channels: 768
|
53 |
+
is_discrete: false
|
54 |
+
in_channels: 768
|
55 |
+
content_codebook_size: 2048
|
56 |
+
sampling_ratios: [1, 1, 1, 1]
|
57 |
+
vector_quantize: false
|
58 |
+
n_codebooks: 1
|
59 |
+
quantizer_dropout: 0.0
|
60 |
+
f0_condition: true
|
61 |
+
n_f0_bins: 256
|
62 |
+
|
63 |
+
DiT:
|
64 |
+
hidden_dim: 768
|
65 |
+
num_heads: 12
|
66 |
+
depth: 17
|
67 |
+
class_dropout_prob: 0.1
|
68 |
+
block_size: 8192
|
69 |
+
in_channels: 128
|
70 |
+
style_condition: true
|
71 |
+
final_layer_type: 'mlp'
|
72 |
+
target: 'mel' # mel or codec
|
73 |
+
content_dim: 768
|
74 |
+
content_codebook_size: 1024
|
75 |
+
content_type: 'discrete'
|
76 |
+
f0_condition: true
|
77 |
+
n_f0_bins: 256
|
78 |
+
content_codebooks: 1
|
79 |
+
is_causal: false
|
80 |
+
long_skip_connection: false
|
81 |
+
zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
|
82 |
+
time_as_token: false
|
83 |
+
style_as_token: false
|
84 |
+
uvit_skip_connection: true
|
85 |
+
add_resblock_in_transformer: false
|
86 |
+
|
87 |
+
wavenet:
|
88 |
+
hidden_dim: 768
|
89 |
+
num_layers: 8
|
90 |
+
kernel_size: 5
|
91 |
+
dilation_rate: 1
|
92 |
+
p_dropout: 0.2
|
93 |
+
style_condition: true
|
94 |
+
|
95 |
+
loss_params:
|
96 |
+
base_lr: 0.0001
|
97 |
+
lambda_mel: 45
|
98 |
+
lambda_kl: 1.0
|
configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "./runs"
|
2 |
+
save_freq: 1
|
3 |
+
log_interval: 10
|
4 |
+
save_interval: 1000
|
5 |
+
device: "cuda"
|
6 |
+
epochs: 1000 # number of epochs for first stage training (pre-training)
|
7 |
+
batch_size: 2
|
8 |
+
batch_length: 100 # maximum duration of audio in a batch (in seconds)
|
9 |
+
max_len: 80 # maximum number of frames
|
10 |
+
pretrained_model: "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth"
|
11 |
+
pretrained_encoder: ""
|
12 |
+
load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
|
13 |
+
|
14 |
+
preprocess_params:
|
15 |
+
sr: 22050
|
16 |
+
spect_params:
|
17 |
+
n_fft: 1024
|
18 |
+
win_length: 1024
|
19 |
+
hop_length: 256
|
20 |
+
n_mels: 80
|
21 |
+
fmin: 0
|
22 |
+
fmax: "None"
|
23 |
+
|
24 |
+
model_params:
|
25 |
+
dit_type: "DiT" # uDiT or DiT
|
26 |
+
reg_loss_type: "l1" # l1 or l2
|
27 |
+
|
28 |
+
timbre_shifter:
|
29 |
+
se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
|
30 |
+
ckpt_path: './modules/openvoice/checkpoints_v2/converter'
|
31 |
+
|
32 |
+
speech_tokenizer:
|
33 |
+
type: 'whisper'
|
34 |
+
name: "openai/whisper-small"
|
35 |
+
|
36 |
+
style_encoder:
|
37 |
+
dim: 192
|
38 |
+
campplus_path: "campplus_cn_common.bin"
|
39 |
+
|
40 |
+
vocoder:
|
41 |
+
type: "bigvgan"
|
42 |
+
name: "nvidia/bigvgan_v2_22khz_80band_256x"
|
43 |
+
|
44 |
+
length_regulator:
|
45 |
+
channels: 512
|
46 |
+
is_discrete: false
|
47 |
+
in_channels: 768
|
48 |
+
content_codebook_size: 2048
|
49 |
+
sampling_ratios: [1, 1, 1, 1]
|
50 |
+
vector_quantize: false
|
51 |
+
n_codebooks: 1
|
52 |
+
quantizer_dropout: 0.0
|
53 |
+
f0_condition: false
|
54 |
+
n_f0_bins: 512
|
55 |
+
|
56 |
+
DiT:
|
57 |
+
hidden_dim: 512
|
58 |
+
num_heads: 8
|
59 |
+
depth: 13
|
60 |
+
class_dropout_prob: 0.1
|
61 |
+
block_size: 8192
|
62 |
+
in_channels: 80
|
63 |
+
style_condition: true
|
64 |
+
final_layer_type: 'wavenet'
|
65 |
+
target: 'mel' # mel or codec
|
66 |
+
content_dim: 512
|
67 |
+
content_codebook_size: 1024
|
68 |
+
content_type: 'discrete'
|
69 |
+
f0_condition: false
|
70 |
+
n_f0_bins: 512
|
71 |
+
content_codebooks: 1
|
72 |
+
is_causal: false
|
73 |
+
long_skip_connection: true
|
74 |
+
zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
|
75 |
+
time_as_token: false
|
76 |
+
style_as_token: false
|
77 |
+
uvit_skip_connection: true
|
78 |
+
add_resblock_in_transformer: false
|
79 |
+
|
80 |
+
wavenet:
|
81 |
+
hidden_dim: 512
|
82 |
+
num_layers: 8
|
83 |
+
kernel_size: 5
|
84 |
+
dilation_rate: 1
|
85 |
+
p_dropout: 0.2
|
86 |
+
style_condition: true
|
87 |
+
|
88 |
+
loss_params:
|
89 |
+
base_lr: 0.0001
|
90 |
+
lambda_mel: 45
|
91 |
+
lambda_kl: 1.0
|
configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "./runs/"
|
2 |
+
save_freq: 1
|
3 |
+
log_interval: 10
|
4 |
+
save_interval: 500
|
5 |
+
device: "cuda"
|
6 |
+
epochs: 1000 # number of epochs for first stage training (pre-training)
|
7 |
+
batch_size: 2
|
8 |
+
batch_length: 100 # maximum duration of audio in a batch (in seconds)
|
9 |
+
max_len: 80 # maximum number of frames
|
10 |
+
pretrained_model: "DiT_uvit_tat_xlsr_ema.pth"
|
11 |
+
pretrained_encoder: ""
|
12 |
+
load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
|
13 |
+
|
14 |
+
preprocess_params:
|
15 |
+
sr: 22050
|
16 |
+
spect_params:
|
17 |
+
n_fft: 1024
|
18 |
+
win_length: 1024
|
19 |
+
hop_length: 256
|
20 |
+
n_mels: 80
|
21 |
+
fmin: 0
|
22 |
+
fmax: 8000
|
23 |
+
|
24 |
+
model_params:
|
25 |
+
dit_type: "DiT" # uDiT or DiT
|
26 |
+
reg_loss_type: "l1" # l1 or l2
|
27 |
+
diffusion_type: "flow"
|
28 |
+
|
29 |
+
timbre_shifter:
|
30 |
+
se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
|
31 |
+
ckpt_path: './modules/openvoice/checkpoints_v2/converter'
|
32 |
+
|
33 |
+
vocoder:
|
34 |
+
type: "hifigan"
|
35 |
+
|
36 |
+
speech_tokenizer:
|
37 |
+
type: 'xlsr'
|
38 |
+
output_layer: 12
|
39 |
+
name: 'facebook/wav2vec2-xls-r-300m'
|
40 |
+
|
41 |
+
style_encoder:
|
42 |
+
dim: 192
|
43 |
+
campplus_path: "campplus_cn_common.bin"
|
44 |
+
|
45 |
+
length_regulator:
|
46 |
+
channels: 384
|
47 |
+
is_discrete: false
|
48 |
+
in_channels: 1024
|
49 |
+
content_codebook_size: 1024
|
50 |
+
sampling_ratios: [1, 1, 1, 1]
|
51 |
+
vector_quantize: false
|
52 |
+
n_codebooks: 2
|
53 |
+
quantizer_dropout: 0.0
|
54 |
+
f0_condition: false
|
55 |
+
n_f0_bins: 512
|
56 |
+
|
57 |
+
DiT:
|
58 |
+
hidden_dim: 384
|
59 |
+
num_heads: 6
|
60 |
+
depth: 9
|
61 |
+
class_dropout_prob: 0.1
|
62 |
+
block_size: 8192
|
63 |
+
in_channels: 80
|
64 |
+
style_condition: true
|
65 |
+
final_layer_type: 'mlp'
|
66 |
+
target: 'mel' # mel or betavae
|
67 |
+
content_dim: 384
|
68 |
+
content_codebook_size: 1024
|
69 |
+
content_type: 'discrete'
|
70 |
+
f0_condition: false
|
71 |
+
n_f0_bins: 512
|
72 |
+
content_codebooks: 1
|
73 |
+
is_causal: false
|
74 |
+
long_skip_connection: false
|
75 |
+
zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
|
76 |
+
time_as_token: true
|
77 |
+
style_as_token: true
|
78 |
+
uvit_skip_connection: true
|
79 |
+
add_resblock_in_transformer: false
|
80 |
+
|
81 |
+
loss_params:
|
82 |
+
base_lr: 0.0001
|
configs/v2/ar_base.yaml
ADDED
File without changes
|
configs/v2/dit_small.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: modules.v2.cfm.CFM
|
2 |
+
estimator:
|
3 |
+
_target_: modules.v2.dit_wrapper.DiT
|
4 |
+
time_as_token: true
|
5 |
+
style_as_token: true
|
6 |
+
uvit_skip_connection: false
|
7 |
+
block_size: 8192
|
8 |
+
depth: 13
|
9 |
+
num_heads: 8
|
10 |
+
hidden_dim: 512
|
11 |
+
in_channels: 80
|
12 |
+
content_dim: 512
|
13 |
+
style_encoder_dim: 192
|
14 |
+
class_dropout_prob: 0.1
|
15 |
+
dropout_rate: 0.0
|
16 |
+
attn_dropout_rate: 0.0
|
17 |
+
|
configs/v2/vc_wrapper.yaml
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: modules.v2.vc_wrapper.VoiceConversionWrapper
|
2 |
+
sr: 22050
|
3 |
+
hop_size: 256
|
4 |
+
mel_fn:
|
5 |
+
_target_: modules.audio.mel_spectrogram
|
6 |
+
_partial_: true
|
7 |
+
n_fft: 1024
|
8 |
+
win_size: 1024
|
9 |
+
hop_size: 256
|
10 |
+
num_mels: 80
|
11 |
+
sampling_rate: 22050
|
12 |
+
fmin: 0
|
13 |
+
fmax: null
|
14 |
+
center: False
|
15 |
+
cfm:
|
16 |
+
_target_: modules.v2.cfm.CFM
|
17 |
+
estimator:
|
18 |
+
_target_: modules.v2.dit_wrapper.DiT
|
19 |
+
time_as_token: true
|
20 |
+
style_as_token: true
|
21 |
+
uvit_skip_connection: false
|
22 |
+
block_size: 8192
|
23 |
+
depth: 13
|
24 |
+
num_heads: 8
|
25 |
+
hidden_dim: 512
|
26 |
+
in_channels: 80
|
27 |
+
content_dim: 512
|
28 |
+
style_encoder_dim: 192
|
29 |
+
class_dropout_prob: 0.1
|
30 |
+
dropout_rate: 0.0
|
31 |
+
attn_dropout_rate: 0.0
|
32 |
+
cfm_length_regulator:
|
33 |
+
_target_: modules.v2.length_regulator.InterpolateRegulator
|
34 |
+
channels: 512
|
35 |
+
is_discrete: true
|
36 |
+
codebook_size: 2048
|
37 |
+
sampling_ratios: [ 1, 1, 1, 1 ]
|
38 |
+
f0_condition: false
|
39 |
+
ar:
|
40 |
+
_target_: modules.v2.ar.NaiveWrapper
|
41 |
+
model:
|
42 |
+
_target_: modules.v2.ar.NaiveTransformer
|
43 |
+
config:
|
44 |
+
_target_: modules.v2.ar.NaiveModelArgs
|
45 |
+
dropout: 0.0
|
46 |
+
rope_base: 10000.0
|
47 |
+
dim: 768
|
48 |
+
head_dim: 64
|
49 |
+
n_local_heads: 2
|
50 |
+
intermediate_size: 2304
|
51 |
+
n_head: 12
|
52 |
+
n_layer: 12
|
53 |
+
vocab_size: 2049 # 1 + 1 for eos
|
54 |
+
ar_length_regulator:
|
55 |
+
_target_: modules.v2.length_regulator.InterpolateRegulator
|
56 |
+
channels: 768
|
57 |
+
is_discrete: true
|
58 |
+
codebook_size: 32
|
59 |
+
sampling_ratios: [ ]
|
60 |
+
f0_condition: false
|
61 |
+
style_encoder:
|
62 |
+
_target_: modules.campplus.DTDNN.CAMPPlus
|
63 |
+
feat_dim: 80
|
64 |
+
embedding_size: 192
|
65 |
+
content_extractor_narrow:
|
66 |
+
_target_: modules.astral_quantization.default_model.AstralQuantizer
|
67 |
+
tokenizer_name: "openai/whisper-small"
|
68 |
+
ssl_model_name: "facebook/hubert-large-ll60k"
|
69 |
+
ssl_output_layer: 18
|
70 |
+
skip_ssl: true
|
71 |
+
encoder: &bottleneck_encoder
|
72 |
+
_target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
|
73 |
+
dim: 512
|
74 |
+
num_blocks: 12
|
75 |
+
intermediate_dim: 1536
|
76 |
+
dilation: 1
|
77 |
+
input_dim: 1024
|
78 |
+
quantizer:
|
79 |
+
_target_: modules.astral_quantization.bsq.BinarySphericalQuantize
|
80 |
+
codebook_size: 32 # codebook size, must be a power of 2
|
81 |
+
dim: 512
|
82 |
+
entropy_loss_weight: 0.1
|
83 |
+
diversity_gamma: 1.0
|
84 |
+
spherical: True
|
85 |
+
enable_entropy_loss: True
|
86 |
+
soft_entropy_loss: True
|
87 |
+
content_extractor_wide:
|
88 |
+
_target_: modules.astral_quantization.default_model.AstralQuantizer
|
89 |
+
tokenizer_name: "openai/whisper-small"
|
90 |
+
ssl_model_name: "facebook/hubert-large-ll60k"
|
91 |
+
ssl_output_layer: 18
|
92 |
+
encoder: *bottleneck_encoder
|
93 |
+
quantizer:
|
94 |
+
_target_: modules.astral_quantization.bsq.BinarySphericalQuantize
|
95 |
+
codebook_size: 2048 # codebook size, must be a power of 2
|
96 |
+
dim: 512
|
97 |
+
entropy_loss_weight: 0.1
|
98 |
+
diversity_gamma: 1.0
|
99 |
+
spherical: True
|
100 |
+
enable_entropy_loss: True
|
101 |
+
soft_entropy_loss: True
|
102 |
+
vocoder:
|
103 |
+
_target_: modules.bigvgan.bigvgan.BigVGAN.from_pretrained
|
104 |
+
pretrained_model_name_or_path: "nvidia/bigvgan_v2_22khz_80band_256x"
|
105 |
+
use_cuda_kernel: false
|
hf_utils.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
from huggingface_hub import hf_hub_download
|
3 |
|
4 |
|
5 |
-
def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=
|
6 |
os.makedirs("./checkpoints", exist_ok=True)
|
7 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
|
8 |
if config_filename is None:
|
|
|
2 |
from huggingface_hub import hf_hub_download
|
3 |
|
4 |
|
5 |
+
def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=None):
|
6 |
os.makedirs("./checkpoints", exist_ok=True)
|
7 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
|
8 |
if config_filename is None:
|
modules/__pycache__/audio.cpython-310.pyc
CHANGED
Binary files a/modules/__pycache__/audio.cpython-310.pyc and b/modules/__pycache__/audio.cpython-310.pyc differ
|
|
modules/__pycache__/commons.cpython-310.pyc
CHANGED
Binary files a/modules/__pycache__/commons.cpython-310.pyc and b/modules/__pycache__/commons.cpython-310.pyc differ
|
|
modules/__pycache__/commons.cpython-38.pyc
ADDED
Binary file (14.2 kB). View file
|
|
modules/__pycache__/diffusion_transformer.cpython-310.pyc
CHANGED
Binary files a/modules/__pycache__/diffusion_transformer.cpython-310.pyc and b/modules/__pycache__/diffusion_transformer.cpython-310.pyc differ
|
|
modules/__pycache__/flow_matching.cpython-310.pyc
CHANGED
Binary files a/modules/__pycache__/flow_matching.cpython-310.pyc and b/modules/__pycache__/flow_matching.cpython-310.pyc differ
|
|
modules/__pycache__/length_regulator.cpython-310.pyc
CHANGED
Binary files a/modules/__pycache__/length_regulator.cpython-310.pyc and b/modules/__pycache__/length_regulator.cpython-310.pyc differ
|
|
modules/__pycache__/rmvpe.cpython-310.pyc
ADDED
Binary file (17.6 kB). View file
|
|
modules/astral_quantization/__pycache__/bsq.cpython-310.pyc
ADDED
Binary file (12.7 kB). View file
|
|
modules/astral_quantization/__pycache__/convnext.cpython-310.pyc
ADDED
Binary file (6.87 kB). View file
|
|
modules/astral_quantization/__pycache__/default_model.cpython-310.pyc
ADDED
Binary file (2.8 kB). View file
|
|
modules/astral_quantization/bsq.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Lookup Free Quantization
|
3 |
+
Proposed in https://arxiv.org/abs/2310.05737
|
4 |
+
|
5 |
+
In the simplest setup, each dimension is quantized into {-1, 1}.
|
6 |
+
An entropy penalty is used to encourage utilization.
|
7 |
+
"""
|
8 |
+
|
9 |
+
from math import log2, ceil
|
10 |
+
from functools import partial, cache
|
11 |
+
from collections import namedtuple
|
12 |
+
from contextlib import nullcontext
|
13 |
+
|
14 |
+
import torch.distributed as dist
|
15 |
+
from torch.distributed import nn as dist_nn
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn, einsum
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch.nn import Module
|
21 |
+
from torch.amp import autocast
|
22 |
+
|
23 |
+
from einops import rearrange, reduce, pack, unpack
|
24 |
+
|
25 |
+
# constants
|
26 |
+
|
27 |
+
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
|
28 |
+
|
29 |
+
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
|
30 |
+
|
31 |
+
# distributed helpers
|
32 |
+
|
33 |
+
@cache
|
34 |
+
def is_distributed():
|
35 |
+
return dist.is_initialized() and dist.get_world_size() > 1
|
36 |
+
|
37 |
+
def maybe_distributed_mean(t):
|
38 |
+
if not is_distributed():
|
39 |
+
return t
|
40 |
+
|
41 |
+
dist_nn.all_reduce(t)
|
42 |
+
t = t / dist.get_world_size()
|
43 |
+
return t
|
44 |
+
|
45 |
+
# helper functions
|
46 |
+
|
47 |
+
def exists(v):
|
48 |
+
return v is not None
|
49 |
+
|
50 |
+
def identity(t):
|
51 |
+
return t
|
52 |
+
|
53 |
+
def default(*args):
|
54 |
+
for arg in args:
|
55 |
+
if exists(arg):
|
56 |
+
return arg() if callable(arg) else arg
|
57 |
+
return None
|
58 |
+
|
59 |
+
def pack_one(t, pattern):
|
60 |
+
return pack([t], pattern)
|
61 |
+
|
62 |
+
def unpack_one(t, ps, pattern):
|
63 |
+
return unpack(t, ps, pattern)[0]
|
64 |
+
|
65 |
+
def l2norm(t):
|
66 |
+
return F.normalize(t, dim = -1)
|
67 |
+
|
68 |
+
# entropy
|
69 |
+
|
70 |
+
def log(t, eps = 1e-5):
|
71 |
+
return t.clamp(min = eps).log()
|
72 |
+
|
73 |
+
def entropy(prob):
|
74 |
+
return (-prob * log(prob)).sum(dim=-1)
|
75 |
+
|
76 |
+
# cosine sim linear
|
77 |
+
|
78 |
+
class CosineSimLinear(Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
dim_in,
|
82 |
+
dim_out,
|
83 |
+
scale = 1.
|
84 |
+
):
|
85 |
+
super().__init__()
|
86 |
+
self.scale = scale
|
87 |
+
self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x = F.normalize(x, dim = -1)
|
91 |
+
w = F.normalize(self.weight, dim = 0)
|
92 |
+
return (x @ w) * self.scale
|
93 |
+
|
94 |
+
def soft_entropy_loss(u, tau=1.0, gamma=1.0):
|
95 |
+
"""
|
96 |
+
Compute the soft entropy loss for Binary Spherical Quantization (BSQ).
|
97 |
+
|
98 |
+
Args:
|
99 |
+
u (torch.Tensor): Input latent embeddings of shape (batch_size, L).
|
100 |
+
tau (float): Temperature scaling factor.
|
101 |
+
gamma (float): Weight for the second entropy term.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
torch.Tensor: Soft entropy loss.
|
105 |
+
"""
|
106 |
+
# Binary quantization: Generate implicit codebook corners
|
107 |
+
L = u.size(1) # Dimensionality of codebook
|
108 |
+
corners = torch.tensor([-1.0, 1.0], device=u.device) / (L**0.5)
|
109 |
+
|
110 |
+
# Compute soft quantization probabilities for all dimensions
|
111 |
+
# q_hat(c|u) for each dimension
|
112 |
+
prob_matrix = torch.sigmoid(2 * tau * corners.unsqueeze(1) * u.unsqueeze(2)) # Shape: (batch_size, L, 2)
|
113 |
+
|
114 |
+
# Entropy of q_hat(c|u) (independent along each dimension)
|
115 |
+
entropy_per_dim = -torch.sum(prob_matrix * prob_matrix.log(), dim=-1) # Shape: (batch_size, L)
|
116 |
+
entropy_term1 = entropy_per_dim.mean()
|
117 |
+
|
118 |
+
# Expected probabilities for dataset entropy (approximation)
|
119 |
+
expected_probs = prob_matrix.mean(dim=0) # Mean across batch, shape: (L, 2)
|
120 |
+
entropy_term2 = -torch.sum(expected_probs * expected_probs.log(), dim=-1).mean()
|
121 |
+
|
122 |
+
# Final entropy loss
|
123 |
+
loss = entropy_term1 - gamma * entropy_term2
|
124 |
+
return loss
|
125 |
+
|
126 |
+
# class
|
127 |
+
|
128 |
+
class BinarySphericalQuantize(Module):
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
*,
|
132 |
+
dim = None,
|
133 |
+
codebook_size = None,
|
134 |
+
entropy_loss_weight = 0.1,
|
135 |
+
commitment_loss_weight = 0.,
|
136 |
+
diversity_gamma = 1.,
|
137 |
+
straight_through_activation = nn.Identity(),
|
138 |
+
num_codebooks = 1,
|
139 |
+
keep_num_codebooks_dim = None,
|
140 |
+
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
|
141 |
+
frac_per_sample_entropy = 0.25, # make less than 1. to only use a random fraction of the probs for per sample entropy
|
142 |
+
has_projections = None,
|
143 |
+
projection_has_bias = True,
|
144 |
+
soft_clamp_input_value = None,
|
145 |
+
cosine_sim_project_in = False,
|
146 |
+
cosine_sim_project_in_scale = None,
|
147 |
+
channel_first = None,
|
148 |
+
experimental_softplus_entropy_loss = False,
|
149 |
+
entropy_loss_offset = 5., # how much to shift the loss before softplus
|
150 |
+
spherical = True, # from https://arxiv.org/abs/2406.07548
|
151 |
+
force_quantization_f32 = True, # will force the quantization step to be full precision
|
152 |
+
enable_entropy_loss = True,
|
153 |
+
soft_entropy_loss = True,
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
# some assert validations
|
158 |
+
|
159 |
+
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
|
160 |
+
assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
|
161 |
+
|
162 |
+
codebook_size = default(codebook_size, lambda: 2 ** dim)
|
163 |
+
self.codebook_size = codebook_size
|
164 |
+
|
165 |
+
codebook_dim = int(log2(codebook_size))
|
166 |
+
codebook_dims = codebook_dim * num_codebooks
|
167 |
+
dim = default(dim, codebook_dims)
|
168 |
+
|
169 |
+
has_projections = default(has_projections, dim != codebook_dims)
|
170 |
+
|
171 |
+
if cosine_sim_project_in:
|
172 |
+
cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
|
173 |
+
project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
|
174 |
+
else:
|
175 |
+
project_in_klass = partial(nn.Linear, bias = projection_has_bias)
|
176 |
+
|
177 |
+
self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
|
178 |
+
self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
|
179 |
+
self.has_projections = has_projections
|
180 |
+
|
181 |
+
self.dim = dim
|
182 |
+
self.codebook_dim = codebook_dim
|
183 |
+
self.num_codebooks = num_codebooks
|
184 |
+
|
185 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
186 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
187 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
188 |
+
|
189 |
+
# channel first
|
190 |
+
|
191 |
+
self.channel_first = channel_first
|
192 |
+
|
193 |
+
# straight through activation
|
194 |
+
|
195 |
+
self.activation = straight_through_activation
|
196 |
+
|
197 |
+
# whether to use BSQ (binary spherical quantization)
|
198 |
+
|
199 |
+
self.spherical = spherical
|
200 |
+
self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
|
201 |
+
|
202 |
+
# entropy aux loss related weights
|
203 |
+
|
204 |
+
assert 0 < frac_per_sample_entropy <= 1.
|
205 |
+
self.frac_per_sample_entropy = frac_per_sample_entropy
|
206 |
+
|
207 |
+
self.diversity_gamma = diversity_gamma
|
208 |
+
self.entropy_loss_weight = entropy_loss_weight
|
209 |
+
|
210 |
+
# codebook scale
|
211 |
+
|
212 |
+
self.codebook_scale = codebook_scale
|
213 |
+
|
214 |
+
# commitment loss
|
215 |
+
|
216 |
+
self.commitment_loss_weight = commitment_loss_weight
|
217 |
+
|
218 |
+
# whether to soft clamp the input value from -value to value
|
219 |
+
|
220 |
+
self.soft_clamp_input_value = soft_clamp_input_value
|
221 |
+
assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
|
222 |
+
|
223 |
+
# whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
|
224 |
+
|
225 |
+
self.entropy_loss_offset = entropy_loss_offset
|
226 |
+
self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
|
227 |
+
|
228 |
+
# for no auxiliary loss, during inference
|
229 |
+
|
230 |
+
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
|
231 |
+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
232 |
+
|
233 |
+
# whether to force quantization step to be f32
|
234 |
+
|
235 |
+
self.force_quantization_f32 = force_quantization_f32
|
236 |
+
|
237 |
+
# codes
|
238 |
+
self.enable_entropy_loss = enable_entropy_loss
|
239 |
+
self.soft_entropy_loss = soft_entropy_loss
|
240 |
+
if codebook_size <= 100000:
|
241 |
+
all_codes = torch.arange(codebook_size)
|
242 |
+
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
|
243 |
+
codebook = self.bits_to_codes(bits)
|
244 |
+
|
245 |
+
self.register_buffer('codebook', codebook.float(), persistent = False)
|
246 |
+
else:
|
247 |
+
all_codes = torch.arange(pow(2, 16))
|
248 |
+
mask = 2 ** torch.arange(16 - 1, -1, -1)
|
249 |
+
bits = ((all_codes[..., None].int() & mask) != 0).float()
|
250 |
+
codebook = self.bits_to_codes(bits)
|
251 |
+
|
252 |
+
self.register_buffer('codebook', codebook.float(), persistent = False)
|
253 |
+
|
254 |
+
def bits_to_codes(self, bits):
|
255 |
+
return bits * self.codebook_scale * 2 - self.codebook_scale
|
256 |
+
|
257 |
+
@property
|
258 |
+
def dtype(self):
|
259 |
+
return self.codebook.dtype
|
260 |
+
|
261 |
+
def indices_to_codes(
|
262 |
+
self,
|
263 |
+
indices,
|
264 |
+
project_out = True
|
265 |
+
):
|
266 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
267 |
+
should_transpose = default(self.channel_first, is_img_or_video)
|
268 |
+
|
269 |
+
if not self.keep_num_codebooks_dim:
|
270 |
+
indices = rearrange(indices, '... -> ... 1')
|
271 |
+
|
272 |
+
# indices to codes, which are bits of either -1 or 1
|
273 |
+
|
274 |
+
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
|
275 |
+
|
276 |
+
codes = self.bits_to_codes(bits)
|
277 |
+
|
278 |
+
codes = self.maybe_l2norm(codes)
|
279 |
+
|
280 |
+
codes = rearrange(codes, '... c d -> ... (c d)')
|
281 |
+
|
282 |
+
# whether to project codes out to original dimensions
|
283 |
+
# if the input feature dimensions were not log2(codebook size)
|
284 |
+
|
285 |
+
if project_out:
|
286 |
+
codes = self.project_out(codes)
|
287 |
+
|
288 |
+
# rearrange codes back to original shape
|
289 |
+
|
290 |
+
if should_transpose:
|
291 |
+
codes = rearrange(codes, 'b ... d -> b d ...')
|
292 |
+
|
293 |
+
return codes
|
294 |
+
|
295 |
+
def bits_to_z(self, bits):
|
296 |
+
# assert bits must contain only -1 and 1
|
297 |
+
assert torch.all(bits.abs() == 1)
|
298 |
+
quantized = bits.float()
|
299 |
+
quantized = self.maybe_l2norm(quantized)
|
300 |
+
z = self.project_out(quantized)
|
301 |
+
return z
|
302 |
+
|
303 |
+
def forward(
|
304 |
+
self,
|
305 |
+
x,
|
306 |
+
inv_temperature = 100.,
|
307 |
+
return_loss_breakdown = False,
|
308 |
+
mask = None,
|
309 |
+
return_bits = False
|
310 |
+
):
|
311 |
+
"""
|
312 |
+
einstein notation
|
313 |
+
b - batch
|
314 |
+
n - sequence (or flattened spatial dimensions)
|
315 |
+
d - feature dimension, which is also log2(codebook size)
|
316 |
+
c - number of codebook dim
|
317 |
+
"""
|
318 |
+
|
319 |
+
is_img_or_video = x.ndim >= 4
|
320 |
+
should_transpose = default(self.channel_first, is_img_or_video)
|
321 |
+
|
322 |
+
# standardize image or video into (batch, seq, dimension)
|
323 |
+
|
324 |
+
if should_transpose:
|
325 |
+
x = rearrange(x, 'b d ... -> b ... d')
|
326 |
+
x, ps = pack_one(x, 'b * d')
|
327 |
+
|
328 |
+
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
|
329 |
+
|
330 |
+
x = self.project_in(x)
|
331 |
+
|
332 |
+
# maybe soft clamp
|
333 |
+
|
334 |
+
if exists(self.soft_clamp_input_value):
|
335 |
+
clamp_value = self.soft_clamp_input_value
|
336 |
+
x = (x / clamp_value).tanh() * clamp_value
|
337 |
+
|
338 |
+
# split out number of codebooks
|
339 |
+
|
340 |
+
x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
|
341 |
+
|
342 |
+
# maybe l2norm
|
343 |
+
|
344 |
+
x = self.maybe_l2norm(x)
|
345 |
+
|
346 |
+
# whether to force quantization step to be full precision or not
|
347 |
+
|
348 |
+
force_f32 = self.force_quantization_f32
|
349 |
+
|
350 |
+
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
|
351 |
+
|
352 |
+
with quantization_context():
|
353 |
+
|
354 |
+
if force_f32:
|
355 |
+
orig_dtype = x.dtype
|
356 |
+
x = x.float()
|
357 |
+
|
358 |
+
# quantize by eq 3.
|
359 |
+
|
360 |
+
original_input = x
|
361 |
+
|
362 |
+
codebook_value = torch.ones_like(x) * self.codebook_scale
|
363 |
+
quantized = torch.where(x > 0, codebook_value, -codebook_value)
|
364 |
+
if return_bits:
|
365 |
+
return quantized
|
366 |
+
|
367 |
+
# calculate indices
|
368 |
+
|
369 |
+
indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
|
370 |
+
|
371 |
+
# maybe l2norm
|
372 |
+
|
373 |
+
quantized = self.maybe_l2norm(quantized)
|
374 |
+
|
375 |
+
# use straight-through gradients (optionally with custom activation fn) if training
|
376 |
+
|
377 |
+
if self.training:
|
378 |
+
x = self.activation(x)
|
379 |
+
x = x + (quantized - x).detach()
|
380 |
+
else:
|
381 |
+
x = quantized
|
382 |
+
|
383 |
+
# entropy aux loss
|
384 |
+
if self.soft_entropy_loss:
|
385 |
+
entropy_aux_loss = soft_entropy_loss(x, tau=1.0, gamma=1.0)
|
386 |
+
elif self.training and self.enable_entropy_loss:
|
387 |
+
|
388 |
+
if force_f32:
|
389 |
+
codebook = self.codebook.float()
|
390 |
+
|
391 |
+
codebook = self.maybe_l2norm(codebook)
|
392 |
+
|
393 |
+
# whether to only use a fraction of probs, for reducing memory
|
394 |
+
|
395 |
+
if self.frac_per_sample_entropy < 1.:
|
396 |
+
# account for mask
|
397 |
+
if exists(mask):
|
398 |
+
original_input = original_input[mask]
|
399 |
+
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
|
400 |
+
|
401 |
+
rand_mask = torch.randn(self.codebook_dim).argsort(dim = -1) < 16
|
402 |
+
|
403 |
+
sampled_input = original_input[..., rand_mask]
|
404 |
+
|
405 |
+
sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
|
406 |
+
|
407 |
+
sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
|
408 |
+
|
409 |
+
per_sample_probs = sampled_prob
|
410 |
+
else:
|
411 |
+
if exists(mask):
|
412 |
+
original_input = original_input[mask]
|
413 |
+
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
|
414 |
+
# the same as euclidean distance up to a constant
|
415 |
+
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
|
416 |
+
|
417 |
+
prob = (-distance * inv_temperature).softmax(dim = -1)
|
418 |
+
|
419 |
+
per_sample_probs = prob
|
420 |
+
|
421 |
+
# calculate per sample entropy
|
422 |
+
|
423 |
+
per_sample_entropy = entropy(per_sample_probs).mean()
|
424 |
+
|
425 |
+
# distribution over all available tokens in the batch
|
426 |
+
|
427 |
+
avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
|
428 |
+
|
429 |
+
avg_prob = maybe_distributed_mean(avg_prob)
|
430 |
+
|
431 |
+
codebook_entropy = entropy(avg_prob).mean()
|
432 |
+
|
433 |
+
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
|
434 |
+
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
|
435 |
+
|
436 |
+
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
|
437 |
+
else:
|
438 |
+
# if not training, just return dummy 0
|
439 |
+
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
|
440 |
+
|
441 |
+
# whether to make the entropy loss positive or not through a (shifted) softplus
|
442 |
+
|
443 |
+
if self.training and self.experimental_softplus_entropy_loss:
|
444 |
+
entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
|
445 |
+
|
446 |
+
# commit loss
|
447 |
+
|
448 |
+
if self.training and self.commitment_loss_weight > 0.:
|
449 |
+
|
450 |
+
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
|
451 |
+
|
452 |
+
if exists(mask):
|
453 |
+
commit_loss = commit_loss[mask]
|
454 |
+
|
455 |
+
commit_loss = commit_loss.mean()
|
456 |
+
else:
|
457 |
+
commit_loss = self.zero
|
458 |
+
|
459 |
+
# input back to original dtype if needed
|
460 |
+
|
461 |
+
if force_f32:
|
462 |
+
x = x.type(orig_dtype)
|
463 |
+
|
464 |
+
# merge back codebook dim
|
465 |
+
|
466 |
+
x = rearrange(x, 'b n c d -> b n (c d)')
|
467 |
+
|
468 |
+
# project out to feature dimension if needed
|
469 |
+
|
470 |
+
x = self.project_out(x)
|
471 |
+
|
472 |
+
# reconstitute image or video dimensions
|
473 |
+
|
474 |
+
if should_transpose:
|
475 |
+
x = unpack_one(x, ps, 'b * d')
|
476 |
+
x = rearrange(x, 'b ... d -> b d ...')
|
477 |
+
|
478 |
+
indices = unpack_one(indices, ps, 'b * c')
|
479 |
+
|
480 |
+
# whether to remove single codebook dim
|
481 |
+
|
482 |
+
if not self.keep_num_codebooks_dim:
|
483 |
+
indices = rearrange(indices, '... 1 -> ...')
|
484 |
+
|
485 |
+
# complete aux loss
|
486 |
+
|
487 |
+
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
488 |
+
|
489 |
+
# returns
|
490 |
+
|
491 |
+
ret = Return(x, indices, aux_loss)
|
492 |
+
|
493 |
+
if not return_loss_breakdown:
|
494 |
+
return ret
|
495 |
+
|
496 |
+
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
|
497 |
+
|
498 |
+
class GroupedResidualBSQ(Module):
|
499 |
+
def __init__(
|
500 |
+
self,
|
501 |
+
*,
|
502 |
+
dim,
|
503 |
+
groups = 1,
|
504 |
+
accept_image_fmap = False,
|
505 |
+
**kwargs
|
506 |
+
):
|
507 |
+
super().__init__()
|
508 |
+
self.dim = dim
|
509 |
+
self.groups = groups
|
510 |
+
assert (dim % groups) == 0
|
511 |
+
dim_per_group = dim // groups
|
512 |
+
|
513 |
+
self.accept_image_fmap = accept_image_fmap
|
514 |
+
|
515 |
+
self.rvqs = nn.ModuleList([])
|
516 |
+
|
517 |
+
for _ in range(groups):
|
518 |
+
self.rvqs.append(LFQ(
|
519 |
+
dim = dim_per_group,
|
520 |
+
**kwargs
|
521 |
+
))
|
522 |
+
|
523 |
+
self.codebook_size = self.rvqs[0].codebook_size
|
524 |
+
|
525 |
+
@property
|
526 |
+
def codebooks(self):
|
527 |
+
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
|
528 |
+
|
529 |
+
@property
|
530 |
+
def split_dim(self):
|
531 |
+
return 1 if self.accept_image_fmap else -1
|
532 |
+
|
533 |
+
def get_codes_from_indices(self, indices):
|
534 |
+
codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
|
535 |
+
return torch.stack(codes)
|
536 |
+
|
537 |
+
def get_output_from_indices(self, indices):
|
538 |
+
outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
|
539 |
+
return torch.cat(outputs, dim = self.split_dim)
|
540 |
+
|
541 |
+
def forward(
|
542 |
+
self,
|
543 |
+
x,
|
544 |
+
return_all_codes = False
|
545 |
+
):
|
546 |
+
shape, split_dim = x.shape, self.split_dim
|
547 |
+
assert shape[split_dim] == self.dim
|
548 |
+
|
549 |
+
# split the feature dimension into groups
|
550 |
+
|
551 |
+
x = x.chunk(self.groups, dim = split_dim)
|
552 |
+
|
553 |
+
forward_kwargs = dict(
|
554 |
+
)
|
555 |
+
|
556 |
+
# invoke residual vq on each group
|
557 |
+
|
558 |
+
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
|
559 |
+
out = tuple(zip(*out))
|
560 |
+
|
561 |
+
# otherwise, get all the zipped outputs and combine them
|
562 |
+
|
563 |
+
quantized, all_indices, *maybe_aux_loss = out
|
564 |
+
|
565 |
+
quantized = torch.cat(quantized, dim = split_dim)
|
566 |
+
all_indices = torch.stack(all_indices)
|
567 |
+
|
568 |
+
ret = (quantized, all_indices, *maybe_aux_loss)
|
569 |
+
return ret
|
modules/astral_quantization/convnext.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
|
7 |
+
class ConvNextV2LayerNorm(nn.Module):
|
8 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
9 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
|
10 |
+
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
14 |
+
super().__init__()
|
15 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
16 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
17 |
+
self.eps = eps
|
18 |
+
self.data_format = data_format
|
19 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
20 |
+
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
|
21 |
+
self.normalized_shape = (normalized_shape,)
|
22 |
+
|
23 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
24 |
+
if self.data_format == "channels_last":
|
25 |
+
x = torch.nn.functional.layer_norm(
|
26 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
27 |
+
)
|
28 |
+
elif self.data_format == "channels_first":
|
29 |
+
input_dtype = x.dtype
|
30 |
+
x = x.float()
|
31 |
+
u = x.mean(1, keepdim=True)
|
32 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
33 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
34 |
+
x = x.to(dtype=input_dtype)
|
35 |
+
x = self.weight[None, :, None] * x + self.bias[None, :, None]
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class GRN(nn.Module):
|
40 |
+
def __init__(self, dim):
|
41 |
+
super().__init__()
|
42 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
43 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
47 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
48 |
+
return self.gamma * (x * Nx) + self.beta + x
|
49 |
+
|
50 |
+
class InterpolationLayer(nn.Module):
|
51 |
+
def __init__(self, ): # this is a default of 1 / 50 * (44100 / 512) / 4
|
52 |
+
super().__init__()
|
53 |
+
pass
|
54 |
+
|
55 |
+
def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
56 |
+
x = F.interpolate(x, size=target_len, mode='linear')
|
57 |
+
return x
|
58 |
+
|
59 |
+
class ConvNeXtV2Stage(nn.Module):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
dim: int = 512,
|
63 |
+
intermediate_dim: int = 2048,
|
64 |
+
num_blocks: int = 1,
|
65 |
+
dilation: int = 1,
|
66 |
+
downsample_layer_indices: List[int] = None,
|
67 |
+
downsample_factors: List[int] = None,
|
68 |
+
upsample_layer_indices: List[int] = None,
|
69 |
+
upsample_factors: List[int] = None,
|
70 |
+
interpolation_layer_indices: List[int] = None,
|
71 |
+
input_dim: int = None,
|
72 |
+
output_dim: int = None,
|
73 |
+
gin_channels: int = 0,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
# maybe downsample layers
|
77 |
+
if downsample_layer_indices is not None:
|
78 |
+
assert downsample_factors is not None
|
79 |
+
self.downsample_blocks = nn.ModuleList(
|
80 |
+
[
|
81 |
+
nn.Sequential(
|
82 |
+
ConvNextV2LayerNorm(dim, data_format="channels_first"),
|
83 |
+
nn.Conv1d(
|
84 |
+
dim, dim, kernel_size=downsample_factor, stride=downsample_factor
|
85 |
+
),
|
86 |
+
) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors)
|
87 |
+
]
|
88 |
+
)
|
89 |
+
self.downsample_layer_indices = downsample_layer_indices
|
90 |
+
else:
|
91 |
+
self.downsample_blocks = nn.ModuleList()
|
92 |
+
self.downsample_layer_indices = []
|
93 |
+
|
94 |
+
# maybe upsample layers
|
95 |
+
if upsample_layer_indices is not None:
|
96 |
+
assert upsample_factors is not None
|
97 |
+
self.upsample_blocks = nn.ModuleList(
|
98 |
+
[
|
99 |
+
nn.Sequential(
|
100 |
+
ConvNextV2LayerNorm(dim, data_format="channels_first"),
|
101 |
+
nn.ConvTranspose1d(
|
102 |
+
dim, dim, kernel_size=upsample_factor, stride=upsample_factor
|
103 |
+
),
|
104 |
+
) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors)
|
105 |
+
]
|
106 |
+
)
|
107 |
+
self.upsample_layer_indices = upsample_layer_indices
|
108 |
+
else:
|
109 |
+
self.upsample_blocks = nn.ModuleList()
|
110 |
+
self.upsample_layer_indices = []
|
111 |
+
|
112 |
+
# maybe interpolation layers
|
113 |
+
if interpolation_layer_indices is not None:
|
114 |
+
self.interpolation_blocks = nn.ModuleList(
|
115 |
+
[
|
116 |
+
InterpolationLayer()
|
117 |
+
for _ in interpolation_layer_indices
|
118 |
+
]
|
119 |
+
)
|
120 |
+
self.interpolation_layer_indices = interpolation_layer_indices
|
121 |
+
else:
|
122 |
+
self.interpolation_blocks = nn.ModuleList()
|
123 |
+
self.interpolation_layer_indices = []
|
124 |
+
|
125 |
+
# main blocks
|
126 |
+
self.blocks = nn.ModuleList(
|
127 |
+
[
|
128 |
+
ConvNeXtV2Block(
|
129 |
+
dim=dim,
|
130 |
+
intermediate_dim=intermediate_dim,
|
131 |
+
dilation=dilation,
|
132 |
+
)
|
133 |
+
for _ in range(num_blocks)
|
134 |
+
]
|
135 |
+
)
|
136 |
+
# maybe input and output projections
|
137 |
+
if input_dim is not None and input_dim != dim:
|
138 |
+
self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1)
|
139 |
+
else:
|
140 |
+
self.input_projection = nn.Identity()
|
141 |
+
if output_dim is not None and output_dim != dim:
|
142 |
+
self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1)
|
143 |
+
else:
|
144 |
+
self.output_projection = nn.Identity()
|
145 |
+
|
146 |
+
if gin_channels > 0:
|
147 |
+
self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1)
|
148 |
+
|
149 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
150 |
+
x = self.input_projection(x) # B, D, T
|
151 |
+
if hasattr(self, 'gin'):
|
152 |
+
g = kwargs['g']
|
153 |
+
x = x + self.gin(g)
|
154 |
+
# pad to a multiple of cumprod(downsample_factors)
|
155 |
+
if len(self.downsample_blocks) > 0:
|
156 |
+
downsample_factor = 1
|
157 |
+
for factor in self.downsample_blocks:
|
158 |
+
downsample_factor *= factor[1].stride[0]
|
159 |
+
pad_len = downsample_factor - x.size(-1) % downsample_factor
|
160 |
+
if pad_len > 0:
|
161 |
+
x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
|
162 |
+
|
163 |
+
# main blocks
|
164 |
+
for layer_idx, block in enumerate(self.blocks):
|
165 |
+
if layer_idx in self.downsample_layer_indices:
|
166 |
+
x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x)
|
167 |
+
if layer_idx in self.upsample_layer_indices:
|
168 |
+
x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x)
|
169 |
+
if layer_idx in self.interpolation_layer_indices:
|
170 |
+
x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len'])
|
171 |
+
x = block(x)
|
172 |
+
x = self.output_projection(x)
|
173 |
+
return x
|
174 |
+
|
175 |
+
def setup_caches(self, *args, **kwargs):
|
176 |
+
pass
|
177 |
+
|
178 |
+
|
179 |
+
class ConvNeXtV2Block(nn.Module):
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
dim: int,
|
183 |
+
intermediate_dim: int,
|
184 |
+
dilation: int = 1,
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
padding = (dilation * (7 - 1)) // 2
|
188 |
+
self.dwconv = nn.Conv1d(
|
189 |
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
190 |
+
) # depthwise conv
|
191 |
+
self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first")
|
192 |
+
self.pwconv1 = nn.Linear(
|
193 |
+
dim, intermediate_dim
|
194 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
195 |
+
self.act = nn.GELU()
|
196 |
+
self.grn = GRN(intermediate_dim)
|
197 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
198 |
+
|
199 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
200 |
+
residual = x
|
201 |
+
x = self.dwconv(x)
|
202 |
+
x = self.norm(x)
|
203 |
+
x = x.transpose(1, 2) # b d n -> b n d
|
204 |
+
x = self.pwconv1(x)
|
205 |
+
x = self.act(x)
|
206 |
+
x = self.grn(x)
|
207 |
+
x = self.pwconv2(x)
|
208 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
209 |
+
return residual + x
|
modules/astral_quantization/default_model.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor
|
3 |
+
|
4 |
+
class AstralQuantizer(torch.nn.Module):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
tokenizer_name: str,
|
8 |
+
ssl_model_name: str,
|
9 |
+
ssl_output_layer: int,
|
10 |
+
encoder: torch.nn.Module,
|
11 |
+
quantizer: torch.nn.Module,
|
12 |
+
skip_ssl: bool = False,
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
self.encoder = encoder
|
16 |
+
self.quantizer = quantizer
|
17 |
+
self.tokenizer_name = tokenizer_name
|
18 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
19 |
+
|
20 |
+
# Load SSL model from Huggingface
|
21 |
+
self.ssl_model_name = ssl_model_name
|
22 |
+
self.ssl_output_layer = ssl_output_layer
|
23 |
+
self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name)
|
24 |
+
|
25 |
+
if skip_ssl: # in case the same SSL model has been loaded somewhere else
|
26 |
+
self.ssl_model = None
|
27 |
+
else:
|
28 |
+
self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval()
|
29 |
+
self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer]
|
30 |
+
self.ssl_model.encoder.layer_norm = torch.nn.Identity()
|
31 |
+
|
32 |
+
def load_separate_checkpoint(self, checkpoint_path):
|
33 |
+
params = torch.load(checkpoint_path, map_location='cpu')['net']
|
34 |
+
for key in params.keys():
|
35 |
+
for k in list(params[key].keys()):
|
36 |
+
if k.startswith("module."):
|
37 |
+
params[key][k[len("module."):]] = params[key][k]
|
38 |
+
del params[key][k]
|
39 |
+
self.encoder.load_state_dict(params['encoder'])
|
40 |
+
self.quantizer.load_state_dict(params['vq'])
|
41 |
+
if self.decoder is not None:
|
42 |
+
self.decoder.load_state_dict(params['decoder'])
|
43 |
+
if self.asr_decoder is not None:
|
44 |
+
self.asr_decoder.load_state_dict(params['predictor'], strict=False)
|
45 |
+
|
46 |
+
def forward(self, waves_16k, wave_16k_lens, ssl_model=None):
|
47 |
+
ssl_fn = self.ssl_model if self.ssl_model else ssl_model
|
48 |
+
assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided"
|
49 |
+
waves_16k_input_list = [
|
50 |
+
waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy()
|
51 |
+
for bib in range(len(waves_16k))
|
52 |
+
]
|
53 |
+
alt_inputs = self.ssl_feature_extractor(
|
54 |
+
waves_16k_input_list,
|
55 |
+
return_tensors='pt',
|
56 |
+
return_attention_mask=True,
|
57 |
+
padding=True,
|
58 |
+
sampling_rate=16000
|
59 |
+
).to(waves_16k.device)
|
60 |
+
feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320 # frame rate of hubert is 50 Hz
|
61 |
+
|
62 |
+
outputs = ssl_fn(
|
63 |
+
alt_inputs.input_values,
|
64 |
+
attention_mask=alt_inputs.attention_mask,
|
65 |
+
)
|
66 |
+
last_hidden_states = outputs.last_hidden_state
|
67 |
+
last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
|
68 |
+
feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
|
69 |
+
last_hidden_states = last_hidden_states.transpose(1, 2)
|
70 |
+
x_hidden = self.encoder(last_hidden_states, feature_lens)
|
71 |
+
x_hidden = x_hidden.transpose(1, 2)
|
72 |
+
x_quantized, indices = self.quantizer(x_hidden)[:2]
|
73 |
+
return x_quantized, indices, feature_lens
|
modules/astral_quantization/transformer.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch import Tensor
|
12 |
+
from torch.nn import functional as F
|
13 |
+
import time
|
14 |
+
|
15 |
+
def find_multiple(n: int, k: int) -> int:
|
16 |
+
if n % k == 0:
|
17 |
+
return n
|
18 |
+
return n + k - (n % k)
|
19 |
+
|
20 |
+
class AdaptiveLayerNorm(nn.Module):
|
21 |
+
r"""Adaptive Layer Normalization"""
|
22 |
+
|
23 |
+
def __init__(self, d_model, norm) -> None:
|
24 |
+
super(AdaptiveLayerNorm, self).__init__()
|
25 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
26 |
+
self.norm = norm
|
27 |
+
self.d_model = d_model
|
28 |
+
self.eps = self.norm.eps
|
29 |
+
|
30 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
31 |
+
if embedding is None:
|
32 |
+
return self.norm(input)
|
33 |
+
weight, bias = torch.split(
|
34 |
+
self.project_layer(embedding),
|
35 |
+
split_size_or_sections=self.d_model,
|
36 |
+
dim=-1,
|
37 |
+
)
|
38 |
+
return weight * self.norm(input) + bias
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class ModelArgs:
|
43 |
+
block_size: int = 2048
|
44 |
+
vocab_size: int = 32000
|
45 |
+
n_layer: int = 32
|
46 |
+
n_head: int = 32
|
47 |
+
dim: int = 4096
|
48 |
+
intermediate_size: int = None
|
49 |
+
n_local_heads: int = -1
|
50 |
+
head_dim: int = 64
|
51 |
+
rope_base: float = 10000
|
52 |
+
norm_eps: float = 1e-5
|
53 |
+
has_cross_attention: bool = False
|
54 |
+
context_dim: int = 0
|
55 |
+
is_causal: bool = False
|
56 |
+
dropout_rate: float = 0.1
|
57 |
+
attn_dropout_rate: float = 0.1
|
58 |
+
|
59 |
+
def __post_init__(self):
|
60 |
+
if self.n_local_heads == -1:
|
61 |
+
self.n_local_heads = self.n_head
|
62 |
+
if self.intermediate_size is None:
|
63 |
+
hidden_dim = 4 * self.dim
|
64 |
+
n_hidden = int(2 * hidden_dim / 3)
|
65 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
66 |
+
# self.head_dim = self.dim // self.n_head
|
67 |
+
|
68 |
+
class Transformer(nn.Module):
|
69 |
+
def __init__(self, config: ModelArgs) -> None:
|
70 |
+
super().__init__()
|
71 |
+
self.config = config
|
72 |
+
|
73 |
+
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
74 |
+
self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
75 |
+
|
76 |
+
self.max_batch_size = -1
|
77 |
+
self.max_seq_length = config.block_size
|
78 |
+
freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
|
79 |
+
self.config.rope_base)
|
80 |
+
self.register_buffer("freqs_cis", freqs_cis)
|
81 |
+
|
82 |
+
causal_mask = torch.tril(
|
83 |
+
torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
|
84 |
+
)
|
85 |
+
self.register_buffer("causal_mask", causal_mask)
|
86 |
+
|
87 |
+
def forward(self,
|
88 |
+
x: Tensor,
|
89 |
+
c: Tensor,
|
90 |
+
input_pos: Optional[Tensor] = None,
|
91 |
+
mask: Optional[Tensor] = None,
|
92 |
+
context: Optional[Tensor] = None,
|
93 |
+
context_input_pos: Optional[Tensor] = None,
|
94 |
+
cross_attention_mask: Optional[Tensor] = None,
|
95 |
+
) -> Tensor:
|
96 |
+
if mask is None:
|
97 |
+
mask = self.causal_mask[:x.size(1), :x.size(1)]
|
98 |
+
else:
|
99 |
+
mask = mask[..., input_pos]
|
100 |
+
freqs_cis = self.freqs_cis[input_pos]
|
101 |
+
if context is not None:
|
102 |
+
context_freqs_cis = self.freqs_cis[context_input_pos]
|
103 |
+
else:
|
104 |
+
context_freqs_cis = None
|
105 |
+
skip_in_x_list = []
|
106 |
+
for i, layer in enumerate(self.layers):
|
107 |
+
x = layer(x, c, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask)
|
108 |
+
x = self.norm(x, c)
|
109 |
+
return x
|
110 |
+
|
111 |
+
|
112 |
+
class TransformerBlock(nn.Module):
|
113 |
+
def __init__(self, config: ModelArgs) -> None:
|
114 |
+
super().__init__()
|
115 |
+
self.attention = Attention(config)
|
116 |
+
self.feed_forward = FeedForward(config)
|
117 |
+
self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
118 |
+
self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
119 |
+
|
120 |
+
if config.has_cross_attention:
|
121 |
+
self.has_cross_attention = True
|
122 |
+
self.cross_attention = Attention(config, is_cross_attention=True)
|
123 |
+
self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
124 |
+
else:
|
125 |
+
self.has_cross_attention = False
|
126 |
+
|
127 |
+
def forward(self,
|
128 |
+
x: Tensor,
|
129 |
+
c: Tensor,
|
130 |
+
freqs_cis: Tensor,
|
131 |
+
mask: Tensor,
|
132 |
+
context: Optional[Tensor] = None,
|
133 |
+
context_freqs_cis: Optional[Tensor] = None,
|
134 |
+
cross_attention_mask: Optional[Tensor] = None,
|
135 |
+
) -> Tensor:
|
136 |
+
#time_attn_start = time.time()
|
137 |
+
h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask)
|
138 |
+
#print(f"time take for attention of sequence length {x.shape[1]} is {time.time() - time_attn_start}")
|
139 |
+
if self.has_cross_attention:
|
140 |
+
h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, context, context_freqs_cis)
|
141 |
+
out = h + self.feed_forward(self.ffn_norm(h, c))
|
142 |
+
return out
|
143 |
+
|
144 |
+
|
145 |
+
class Attention(nn.Module):
|
146 |
+
def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
|
147 |
+
super().__init__()
|
148 |
+
assert config.dim % config.n_head == 0
|
149 |
+
|
150 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
151 |
+
# key, query, value projections for all heads, but in a batch
|
152 |
+
if is_cross_attention:
|
153 |
+
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
|
154 |
+
self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
|
155 |
+
else:
|
156 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
157 |
+
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
|
158 |
+
self.kv_cache = None
|
159 |
+
|
160 |
+
self.n_head = config.n_head
|
161 |
+
self.head_dim = config.head_dim
|
162 |
+
self.n_local_heads = config.n_local_heads
|
163 |
+
self.dim = config.dim
|
164 |
+
self.attn_dropout_rate = config.attn_dropout_rate
|
165 |
+
|
166 |
+
def forward(self,
|
167 |
+
x: Tensor,
|
168 |
+
freqs_cis: Tensor,
|
169 |
+
mask: Tensor,
|
170 |
+
context: Optional[Tensor] = None,
|
171 |
+
context_freqs_cis: Optional[Tensor] = None,
|
172 |
+
) -> Tensor:
|
173 |
+
bsz, seqlen, _ = x.shape
|
174 |
+
|
175 |
+
kv_size = self.n_local_heads * self.head_dim
|
176 |
+
if context is None:
|
177 |
+
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
178 |
+
context_seqlen = seqlen
|
179 |
+
else:
|
180 |
+
q = self.wq(x)
|
181 |
+
k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
|
182 |
+
context_seqlen = context.shape[1]
|
183 |
+
|
184 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
185 |
+
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
186 |
+
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
187 |
+
|
188 |
+
q = apply_rotary_emb(q, freqs_cis)
|
189 |
+
k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
|
190 |
+
|
191 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
192 |
+
|
193 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
194 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
195 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_dropout_rate if self.training else 0.0)
|
196 |
+
|
197 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
|
198 |
+
|
199 |
+
y = self.wo(y)
|
200 |
+
return y
|
201 |
+
|
202 |
+
|
203 |
+
class FeedForward(nn.Module):
|
204 |
+
def __init__(self, config: ModelArgs) -> None:
|
205 |
+
super().__init__()
|
206 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
207 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
208 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
209 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
210 |
+
|
211 |
+
def forward(self, x: Tensor) -> Tensor:
|
212 |
+
return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
|
213 |
+
|
214 |
+
|
215 |
+
class RMSNorm(nn.Module):
|
216 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
217 |
+
super().__init__()
|
218 |
+
self.eps = eps
|
219 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
220 |
+
|
221 |
+
def _norm(self, x):
|
222 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
223 |
+
|
224 |
+
def forward(self, x: Tensor) -> Tensor:
|
225 |
+
output = self._norm(x.float()).type_as(x)
|
226 |
+
return output * self.weight
|
227 |
+
|
228 |
+
|
229 |
+
def precompute_freqs_cis(
|
230 |
+
seq_len: int, n_elem: int, base: int = 10000,
|
231 |
+
dtype: torch.dtype = torch.bfloat16
|
232 |
+
) -> Tensor:
|
233 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
234 |
+
t = torch.arange(seq_len, device=freqs.device)
|
235 |
+
freqs = torch.outer(t, freqs)
|
236 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
237 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
238 |
+
return cache.to(dtype=dtype)
|
239 |
+
|
240 |
+
|
241 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
242 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
243 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
244 |
+
x_out2 = torch.stack(
|
245 |
+
[
|
246 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
247 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
248 |
+
],
|
249 |
+
-1,
|
250 |
+
)
|
251 |
+
|
252 |
+
x_out2 = x_out2.flatten(3)
|
253 |
+
return x_out2.type_as(x)
|
254 |
+
|
modules/audio.py
CHANGED
@@ -1,82 +1,82 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.utils.data
|
4 |
-
from librosa.filters import mel as librosa_mel_fn
|
5 |
-
from scipy.io.wavfile import read
|
6 |
-
|
7 |
-
MAX_WAV_VALUE = 32768.0
|
8 |
-
|
9 |
-
|
10 |
-
def load_wav(full_path):
|
11 |
-
sampling_rate, data = read(full_path)
|
12 |
-
return data, sampling_rate
|
13 |
-
|
14 |
-
|
15 |
-
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
-
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
-
|
18 |
-
|
19 |
-
def dynamic_range_decompression(x, C=1):
|
20 |
-
return np.exp(x) / C
|
21 |
-
|
22 |
-
|
23 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
24 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
25 |
-
|
26 |
-
|
27 |
-
def dynamic_range_decompression_torch(x, C=1):
|
28 |
-
return torch.exp(x) / C
|
29 |
-
|
30 |
-
|
31 |
-
def spectral_normalize_torch(magnitudes):
|
32 |
-
output = dynamic_range_compression_torch(magnitudes)
|
33 |
-
return output
|
34 |
-
|
35 |
-
|
36 |
-
def spectral_de_normalize_torch(magnitudes):
|
37 |
-
output = dynamic_range_decompression_torch(magnitudes)
|
38 |
-
return output
|
39 |
-
|
40 |
-
|
41 |
-
mel_basis = {}
|
42 |
-
hann_window = {}
|
43 |
-
|
44 |
-
|
45 |
-
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
46 |
-
if torch.min(y) < -1.0:
|
47 |
-
print("min value is ", torch.min(y))
|
48 |
-
if torch.max(y) > 1.0:
|
49 |
-
print("max value is ", torch.max(y))
|
50 |
-
|
51 |
-
global mel_basis, hann_window # pylint: disable=global-statement
|
52 |
-
if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
|
53 |
-
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
54 |
-
mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
55 |
-
hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
|
56 |
-
|
57 |
-
y = torch.nn.functional.pad(
|
58 |
-
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
59 |
-
)
|
60 |
-
y = y.squeeze(1)
|
61 |
-
|
62 |
-
spec = torch.view_as_real(
|
63 |
-
torch.stft(
|
64 |
-
y,
|
65 |
-
n_fft,
|
66 |
-
hop_length=hop_size,
|
67 |
-
win_length=win_size,
|
68 |
-
window=hann_window[str(sampling_rate) + "_" + str(y.device)],
|
69 |
-
center=center,
|
70 |
-
pad_mode="reflect",
|
71 |
-
normalized=False,
|
72 |
-
onesided=True,
|
73 |
-
return_complex=True,
|
74 |
-
)
|
75 |
-
)
|
76 |
-
|
77 |
-
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
78 |
-
|
79 |
-
spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
|
80 |
-
spec = spectral_normalize_torch(spec)
|
81 |
-
|
82 |
-
return spec
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from librosa.filters import mel as librosa_mel_fn
|
5 |
+
from scipy.io.wavfile import read
|
6 |
+
|
7 |
+
MAX_WAV_VALUE = 32768.0
|
8 |
+
|
9 |
+
|
10 |
+
def load_wav(full_path):
|
11 |
+
sampling_rate, data = read(full_path)
|
12 |
+
return data, sampling_rate
|
13 |
+
|
14 |
+
|
15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_decompression(x, C=1):
|
20 |
+
return np.exp(x) / C
|
21 |
+
|
22 |
+
|
23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
25 |
+
|
26 |
+
|
27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
28 |
+
return torch.exp(x) / C
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
def spectral_de_normalize_torch(magnitudes):
|
37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
mel_basis = {}
|
42 |
+
hann_window = {}
|
43 |
+
|
44 |
+
|
45 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
46 |
+
if torch.min(y) < -1.0:
|
47 |
+
print("min value is ", torch.min(y))
|
48 |
+
if torch.max(y) > 1.0:
|
49 |
+
print("max value is ", torch.max(y))
|
50 |
+
|
51 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
52 |
+
if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
|
53 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
54 |
+
mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
55 |
+
hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
|
56 |
+
|
57 |
+
y = torch.nn.functional.pad(
|
58 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
59 |
+
)
|
60 |
+
y = y.squeeze(1)
|
61 |
+
|
62 |
+
spec = torch.view_as_real(
|
63 |
+
torch.stft(
|
64 |
+
y,
|
65 |
+
n_fft,
|
66 |
+
hop_length=hop_size,
|
67 |
+
win_length=win_size,
|
68 |
+
window=hann_window[str(sampling_rate) + "_" + str(y.device)],
|
69 |
+
center=center,
|
70 |
+
pad_mode="reflect",
|
71 |
+
normalized=False,
|
72 |
+
onesided=True,
|
73 |
+
return_complex=True,
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
78 |
+
|
79 |
+
spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
|
80 |
+
spec = spectral_normalize_torch(spec)
|
81 |
+
|
82 |
+
return spec
|
modules/bigvgan/__pycache__/activations.cpython-310.pyc
ADDED
Binary file (4 kB). View file
|
|
modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc
ADDED
Binary file (11.8 kB). View file
|
|
modules/bigvgan/__pycache__/env.cpython-310.pyc
ADDED
Binary file (796 Bytes). View file
|
|
modules/bigvgan/__pycache__/meldataset.cpython-310.pyc
ADDED
Binary file (8.54 kB). View file
|
|
modules/bigvgan/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.84 kB). View file
|
|
modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (158 Bytes). View file
|
|
modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc
ADDED
Binary file (2.34 kB). View file
|
|
modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc
ADDED
Binary file (1.99 kB). View file
|
|
modules/bigvgan/alias_free_activation/cuda/activation1d.py
CHANGED
@@ -3,10 +3,10 @@
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
-
from
|
7 |
|
8 |
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
9 |
-
from
|
10 |
|
11 |
anti_alias_activation_cuda = load.load()
|
12 |
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
+
from ..torch.resample import UpSample1d, DownSample1d
|
7 |
|
8 |
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
9 |
+
from ..cuda import load
|
10 |
|
11 |
anti_alias_activation_cuda = load.load()
|
12 |
|
modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e233713716a5778577f244b0f310944ff26d3079ce0e42491791da7d42e363c1
|
3 |
+
size 522068
|
modules/bigvgan/alias_free_activation/cuda/build/.ninja_log
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ninja log v5
|
2 |
+
9 39554 7516864785377831 anti_alias_activation.o 3a177f31dd72c43c
|
3 |
+
13 152601 7516865914203767 anti_alias_activation_cuda.cuda.o 2d613e7382d803fd
|
4 |
+
152628 153062 7516865920541751 anti_alias_activation_cuda.pyd f6366e9bdfb27f7
|
5 |
+
128 50503 7654004565901584 anti_alias_activation.o 9ed3213f2e0d0858
|
6 |
+
133 176837 7654005827401976 anti_alias_activation_cuda.cuda.o a679b6661c609136
|
7 |
+
176839 177401 7654005835005523 anti_alias_activation_cuda.pyd f6366e9bdfb27f7
|
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74c2824b05582070b69f51ec588aadb268c4fddf18fbb4590f901d1cdf32185c
|
3 |
+
size 3246655
|
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:86c48de557041de7ebaff7926b5f346cc5e4e2dddc6cf5b88409f6cb161db0f4
|
3 |
+
size 4724513
|
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp
ADDED
Binary file (25.1 kB). View file
|
|
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib
ADDED
Binary file (43.7 kB). View file
|
|
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db37ea2dd31dfe67e68ee6019877d14638c41724ff9342c55f638f4d2cda3d03
|
3 |
+
size 2454528
|
modules/bigvgan/alias_free_activation/cuda/build/build.ninja
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ninja_required_version = 1.3
|
2 |
+
cxx = cl
|
3 |
+
nvcc = C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin\nvcc
|
4 |
+
|
5 |
+
cflags = -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include /std:c++17 -O3 /MD /wd4819 /wd4251 /wd4244 /wd4267 /wd4275 /wd4018 /wd4190 /wd4624 /wd4067 /wd4068 /EHsc
|
6 |
+
post_cflags =
|
7 |
+
cuda_cflags = -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcompiler /EHsc -Xcompiler /wd4068 -Xcompiler /wd4067 -Xcompiler /wd4624 -Xcompiler /wd4190 -Xcompiler /wd4018 -Xcompiler /wd4275 -Xcompiler /wd4267 -Xcompiler /wd4244 -Xcompiler /wd4251 -Xcompiler /wd4819 -Xcompiler /MD -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++17 -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda -gencode arch=compute_80,code=sm_80
|
8 |
+
cuda_post_cflags =
|
9 |
+
cuda_dlink_post_cflags =
|
10 |
+
sycl_dlink_post_cflags =
|
11 |
+
ldflags = /DLL c10.lib c10_cuda.lib torch_cpu.lib torch_cuda.lib -INCLUDE:?warp_size@cuda@at@@YAHXZ torch.lib /LIBPATH:D:\Anaconda\envs\vocos\lib\site-packages\torch\lib torch_python.lib /LIBPATH:D:\Anaconda\envs\vocos\libs "/LIBPATH:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\lib\x64" cudart.lib
|
12 |
+
|
13 |
+
rule compile
|
14 |
+
command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags
|
15 |
+
deps = msvc
|
16 |
+
|
17 |
+
rule cuda_compile
|
18 |
+
depfile = $out.d
|
19 |
+
deps = gcc
|
20 |
+
command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
rule link
|
27 |
+
command = "D$:\Visual Studio\VC\Tools\MSVC\14.29.30133\bin\Hostx86\x64/link.exe" $in /nologo $ldflags /out:$out
|
28 |
+
|
29 |
+
build anti_alias_activation.o: compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation.cpp
|
30 |
+
build anti_alias_activation_cuda.cuda.o: cuda_compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation_cuda.cu
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
build anti_alias_activation_cuda.pyd: link anti_alias_activation.o anti_alias_activation_cuda.cuda.o
|
37 |
+
|
38 |
+
default anti_alias_activation_cuda.pyd
|
modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (217 Bytes). View file
|
|
modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc
ADDED
Binary file (1.05 kB). View file
|
|