Initial commit
Browse files- .gitignore +13 -0
- .python-version +1 -0
- CONDITIONING_README.md +96 -0
- Dockerfile +11 -0
- LICENSE +202 -0
- README.md +140 -11
- assets/ArchitectureDiagram.png +0 -0
- assets/ZonosHeader.png +0 -0
- assets/exampleaudio.mp3 +0 -0
- assets/silence_100ms.wav +0 -0
- docker-compose.yml +16 -0
- gradio_interface.py +372 -0
- pyproject.toml +40 -0
- sample.py +20 -0
- uv.lock +0 -0
- zonos/autoencoder.py +26 -0
- zonos/backbone.py +50 -0
- zonos/codebook_pattern.py +12 -0
- zonos/conditioning.py +373 -0
- zonos/config.py +38 -0
- zonos/model.py +270 -0
- zonos/sampling.py +141 -0
- zonos/speaker_cloning.py +406 -0
.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python-generated files
|
2 |
+
__pycache__/
|
3 |
+
*.py[oc]
|
4 |
+
build/
|
5 |
+
dist/
|
6 |
+
wheels/
|
7 |
+
*.egg-info
|
8 |
+
|
9 |
+
# Virtual environments
|
10 |
+
.venv
|
11 |
+
|
12 |
+
# Misc.
|
13 |
+
.ipynb_checkpoints/
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.12
|
CONDITIONING_README.md
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Conditioning explanations
|
2 |
+
Here we will list out all the conditionings the model accepts as well as a short description and some tips for optimal use. For conditionings with a learned unconditional, they can be set to that to allow the model to infer an appropriate setting.
|
3 |
+
### espeak
|
4 |
+
- **Type:** `EspeakPhonemeConditioner`
|
5 |
+
- **Description:**
|
6 |
+
Responsible for cleaning, phonemicizing, tokenizing, and embedding the text provided to the model. This is the text pre-processing pipeline. If you would like to change how a word is pronounced or enter raw phonemes you can do that here.
|
7 |
+
---
|
8 |
+
### speaker
|
9 |
+
- **Type:** `PassthroughConditioner`
|
10 |
+
- **Attributes:**
|
11 |
+
- **cond_dim:** `128`
|
12 |
+
- **uncond_type:** `learned`
|
13 |
+
- **projection:** `linear`
|
14 |
+
- **Description:**
|
15 |
+
An embedded representation of the speakers voice. We use [these](https://huggingface.co/Zyphra/Zonos-v0.1-speaker-embedding) speaker embedding models. It can capture a surprising amount of detail from the reference clip and supports arbitrary length input. Try to input clean reference clips containing only speech. It can be valid to concatenate multiple clean samples from the same speaker into one long sample and may lead to better cloning. If the speaker clip is very long, it is advisable to cut out long speech-free background music segments if they exist. If the reference clip is yielding noisy outputs with denoising enabled we recommend doing source separation before cloning.
|
16 |
+
---
|
17 |
+
### emotion
|
18 |
+
- **Type:** `FourierConditioner`
|
19 |
+
- **Attributes:**
|
20 |
+
- **input_dim:** `8`
|
21 |
+
- **uncond_type:** `learned`
|
22 |
+
- **Description:**
|
23 |
+
Encodes emotion in an 8D vector. Included emotions are Happiness, Sadness, Disgust, Fear, Surprise, Anger, Other, Neutral in that order. This vector tends to be entangled with various other conditioning inputs. More notably, it's entangled with text based on the text sentiment (eg. Angry texts will be more effectively conditioned to be angry, but if you try to make it sound sad it will be a lot less effective). It's also entangled with pitch standard deviation since larger values there tend to correlate to more emotional utterances. It's also heavily correlated with VQScore and DNSMOS as these conditionings favor neutral speech. It's also possible to do a form of "negative prompting" by doing CFG where the unconditional branch is set to a highly neutral emotion vector instead of the true unconditional value, doing this will exaggerate the emotions as it pushes the model away from being neutral.
|
24 |
+
---
|
25 |
+
### fmax
|
26 |
+
- **Type:** `FourierConditioner`
|
27 |
+
- **Attributes:**
|
28 |
+
- **min_val:** `0`
|
29 |
+
- **max_val:** `24000`
|
30 |
+
- **uncond_type:** `learned`
|
31 |
+
- **Description:**
|
32 |
+
Specifies the max frequency of the audio. For best results select 22050 or 24000 as these correspond to 44.1 and 48KHz audio respectively. They should not be any different in terms of actual max frequency since the model's sampling rate is 44.1KHz but they represent different slices of data which lead to slightly different voicing. Selecting a lower value generally produces lower-quality results both in terms of acoustics and voicing.
|
33 |
+
---
|
34 |
+
### pitch_std
|
35 |
+
- **Type:** `FourierConditioner`
|
36 |
+
- **Attributes:**
|
37 |
+
- **min_val:** `0`
|
38 |
+
- **max_val:** `400`
|
39 |
+
- **uncond_type:** `learned`
|
40 |
+
- **Description:**
|
41 |
+
Specifies the standard deviation of the pitch of the output audio. Wider variations of pitch tend to be more correlated with expressive speech. Good values are from 20-45 for normal speech and 60-150 for expressive speech. Higher than that generally tend to be crazier samples.
|
42 |
+
---
|
43 |
+
### speaking_rate
|
44 |
+
- **Type:** `FourierConditioner`
|
45 |
+
- **Attributes:**
|
46 |
+
- **min_val:** `0`
|
47 |
+
- **max_val:** `40`
|
48 |
+
- **uncond_type:** `learned`
|
49 |
+
- **Description:**
|
50 |
+
Specifies the number of phonemes to be read per second. When entering a long text, it is advisable to adjust the speaking rate such that the number of phonemes is readable within the generation length. For example, if your generation length is 10 seconds, and your input is 300 phonemes, you would want either 30 phonemes per second (which is very very fast) or to generate a longer sample. The model's maximum is 30 seconds. Please note that unrealistic speaking rates can be OOD for the model and create undesirable effects, so at the 30-second limit, it can be better to cut the text short and do multiple generations than to feed the model the entire prompt and have an unrealistically low speaking rate.
|
51 |
+
---
|
52 |
+
### language_id
|
53 |
+
- **Type:** `IntegerConditioner`
|
54 |
+
- **Attributes:**
|
55 |
+
- **min_val:** `-1`
|
56 |
+
- **max_val:** `126`
|
57 |
+
- **uncond_type:** `learned`
|
58 |
+
- **Description:**
|
59 |
+
Indicates which language the output should be in. A mapping for these values can be found in the [conditioning section](https://github.com/Zyphra/Zonos/blob/3807c8e04bd4beaadb9502b3df1ffa4b0350e3f7/zonos/conditioning.py#L308C1-L376C21) of Zonos.
|
60 |
+
---
|
61 |
+
### vqscore_8
|
62 |
+
- **Type:** `FourierConditioner`
|
63 |
+
- **Attributes:**
|
64 |
+
- **input_dim:** `8`
|
65 |
+
- **min_val:** `0.5`
|
66 |
+
- **max_val:** `0.8`
|
67 |
+
- **uncond_type:** `learned`
|
68 |
+
- **Description:**
|
69 |
+
Encodes the desired [VQScore](https://github.com/JasonSWFu/VQscore) value for the output audio. VQScore is an unsupervised speech quality (cleanliness) estimation method that we found has superior generalization and reduced biases compared to supervised methods like DNSMOS. A good value for our model is 0.78 for high-quality speech. The eight dimensions correspond to consecutive 1/8th chunks of the audio. (eg. for an 8-second output, the first dimension represents the quality of the first second only). For inference, we generally set all 8 dimensions to the same value. This has an unfortunately strong correlation with expressiveness, so for expressive speech, we recommend setting it to unconditional.
|
70 |
+
---
|
71 |
+
### ctc_loss
|
72 |
+
- **Type:** `FourierConditioner`
|
73 |
+
- **Attributes:**
|
74 |
+
- **min_val:** `-1.0`
|
75 |
+
- **max_val:** `1000`
|
76 |
+
- **uncond_type:** `learned`
|
77 |
+
- **Description:**
|
78 |
+
Encodes loss values from a [CTC](https://en.wikipedia.org/wiki/Connectionist_temporal_classification) (Connectionist Temporal Classification) setup, this indicates how well the training-time transcription matched with the audio according to a CTC model. For inference always use low values (eg. 0.0 or 1.0)
|
79 |
+
---
|
80 |
+
### dnsmos_ovrl
|
81 |
+
- **Type:** `FourierConditioner`
|
82 |
+
- **Attributes:**
|
83 |
+
- **min_val:** `1`
|
84 |
+
- **max_val:** `5`
|
85 |
+
- **uncond_type:** `learned`
|
86 |
+
- **Description:**
|
87 |
+
A [MOS](https://arxiv.org/abs/2110.01763) score for the output audio. This is similar to VQScore and tends to have a stronger entanglement with emotions. It additionally has a strong entanglement with languages. Set to 4.0 for very clean and neutral English speech, else we recommend setting it to unconditional.
|
88 |
+
---
|
89 |
+
### speaker_noised
|
90 |
+
- **Type:** `IntegerConditioner`
|
91 |
+
- **Attributes:**
|
92 |
+
- **min_val:** `0`
|
93 |
+
- **max_val:** `1`
|
94 |
+
- **uncond_type:** `learned`
|
95 |
+
- **Description:**
|
96 |
+
Indicates if the speaker embedding is noisy or not. If checked this lets the model clean (denoise) the input speaker embedding. When this is set to True, VQScore and DNSMOS will have a lot more power to clean the speaker embedding, so for very noisy input samples we recommend setting this to True and specifying a high VQScore value. If your speaker cloning outputs sound echo-y or do weird things, setting this to True will help.
|
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
|
2 |
+
RUN pip install uv
|
3 |
+
|
4 |
+
RUN apt update && \
|
5 |
+
apt install -y espeak-ng && \
|
6 |
+
rm -rf /var/lib/apt/lists/*
|
7 |
+
|
8 |
+
WORKDIR /app
|
9 |
+
COPY . ./
|
10 |
+
|
11 |
+
RUN uv pip install --system -e . && uv pip install --system -e .[compile]
|
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,14 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Zonos-v0.1
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
<img src="assets/ZonosHeader.png"
|
5 |
+
alt="Alt text"
|
6 |
+
style="width: 500px;
|
7 |
+
height: auto;
|
8 |
+
object-position: center top;">
|
9 |
+
</div>
|
10 |
+
|
11 |
+
---
|
12 |
+
|
13 |
+
Zonos-v0.1 is a leading open-weight text-to-speech model trained on more than 200k hours of varied multilingual speech, delivering expressiveness and quality on par with—or even surpassing—top TTS providers.
|
14 |
+
|
15 |
+
Our model enables highly natural speech generation from text prompts when given a speaker embedding or audio prefix, and can accurately perform speech cloning when given a reference clip spanning just a few seconds. The conditioning setup also allows for fine control over speaking rate, pitch variation, audio quality, and emotions such as happiness, fear, sadness, and anger. The model outputs speech natively at 44kHz.
|
16 |
+
|
17 |
+
##### For more details and speech samples, check out our blog [here](https://www.zyphra.com/post/beta-release-of-zonos-v0-1)
|
18 |
+
|
19 |
+
##### We also have a hosted version available at [maia.zyphra.com/audio](https://maia.zyphra.com/audio)
|
20 |
+
|
21 |
---
|
22 |
+
|
23 |
+
Zonos follows a straightforward architecture: text normalization and phonemization via eSpeak, followed by DAC token prediction through a transformer or hybrid backbone. An overview of the architecture can be seen below.
|
24 |
+
|
25 |
+
<div align="center">
|
26 |
+
<img src="assets/ArchitectureDiagram.png"
|
27 |
+
alt="Alt text"
|
28 |
+
style="width: 1000px;
|
29 |
+
height: auto;
|
30 |
+
object-position: center top;">
|
31 |
+
</div>
|
32 |
+
|
33 |
---
|
34 |
|
35 |
+
## Usage
|
36 |
+
|
37 |
+
### Python
|
38 |
+
|
39 |
+
```python
|
40 |
+
import torch
|
41 |
+
import torchaudio
|
42 |
+
from zonos.model import Zonos
|
43 |
+
from zonos.conditioning import make_cond_dict
|
44 |
+
|
45 |
+
# model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device="cuda")
|
46 |
+
model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device="cuda")
|
47 |
+
|
48 |
+
wav, sampling_rate = torchaudio.load("assets/exampleaudio.mp3")
|
49 |
+
speaker = model.make_speaker_embedding(wav, sampling_rate)
|
50 |
+
|
51 |
+
cond_dict = make_cond_dict(text="Hello, world!", speaker=speaker, language="en-us")
|
52 |
+
conditioning = model.prepare_conditioning(cond_dict)
|
53 |
+
|
54 |
+
codes = model.generate(conditioning)
|
55 |
+
|
56 |
+
wavs = model.autoencoder.decode(codes).cpu()
|
57 |
+
torchaudio.save("sample.wav", wavs[0], model.autoencoder.sampling_rate)
|
58 |
+
```
|
59 |
+
|
60 |
+
### Gradio interface (recommended)
|
61 |
+
|
62 |
+
```bash
|
63 |
+
uv run gradio_interface.py
|
64 |
+
# python gradio_interface.py
|
65 |
+
```
|
66 |
+
|
67 |
+
This should produce a `sample.wav` file in your project root directory.
|
68 |
+
|
69 |
+
_For repeated sampling we highly recommend using the gradio interface instead, as the minimal example needs to load the model every time it is run._
|
70 |
+
|
71 |
+
## Features
|
72 |
+
|
73 |
+
- Zero-shot TTS with voice cloning: Input desired text and a 10-30s speaker sample to generate high quality TTS output
|
74 |
+
- Audio prefix inputs: Add text plus an audio prefix for even richer speaker matching. Audio prefixes can be used to elicit behaviours such as whispering which can otherwise be challenging to replicate when cloning from speaker embeddings
|
75 |
+
- Multilingual support: Zonos-v0.1 supports English, Japanese, Chinese, French, and German
|
76 |
+
- Audio quality and emotion control: Zonos offers fine-grained control of many aspects of the generated audio. These include speaking rate, pitch, maximum frequency, audio quality, and various emotions such as happiness, anger, sadness, and fear.
|
77 |
+
- Fast: our model runs with a real-time factor of ~2x on an RTX 4090
|
78 |
+
- Gradio WebUI: Zonos comes packaged with an easy to use gradio interface to generate speech
|
79 |
+
- Simple installation and deployment: Zonos can be installed and deployed simply using the docker file packaged with our repository.
|
80 |
+
|
81 |
+
## Installation
|
82 |
+
|
83 |
+
**At the moment this repository only supports Linux systems (preferably Ubuntu 22.04/24.04) with recent NVIDIA GPUs (3000-series or newer, 6GB+ VRAM).**
|
84 |
+
|
85 |
+
See also [Docker Installation](#docker-installation)
|
86 |
+
|
87 |
+
#### System dependencies
|
88 |
+
|
89 |
+
Zonos depends on the eSpeak library phonemization. You can install it on Ubuntu with the following command:
|
90 |
+
|
91 |
+
```bash
|
92 |
+
apt install -y espeak-ng
|
93 |
+
```
|
94 |
+
|
95 |
+
#### Python dependencies
|
96 |
+
|
97 |
+
We highly recommend using a recent version of [uv](https://docs.astral.sh/uv/#installation) for installation. If you don't have uv installed, you can install it via pip: `pip install -U uv`.
|
98 |
+
|
99 |
+
##### Installing into a new uv virtual environment (recommended)
|
100 |
+
|
101 |
+
```bash
|
102 |
+
uv sync
|
103 |
+
uv sync --extra compile
|
104 |
+
```
|
105 |
+
|
106 |
+
##### Installing into the system/actived environment using uv
|
107 |
+
|
108 |
+
```bash
|
109 |
+
uv pip install -e .
|
110 |
+
uv pip install -e .[compile]
|
111 |
+
```
|
112 |
+
|
113 |
+
##### Installing into the system/actived environment using pip
|
114 |
+
|
115 |
+
```bash
|
116 |
+
pip install -e .
|
117 |
+
pip install --no-build-isolation -e .[compile]
|
118 |
+
```
|
119 |
+
|
120 |
+
##### Confirm that it's working
|
121 |
+
|
122 |
+
For convenience we provide a minimal example to check that the installation works:
|
123 |
+
|
124 |
+
```bash
|
125 |
+
uv run sample.py
|
126 |
+
# python sample.py
|
127 |
+
```
|
128 |
+
|
129 |
+
## Docker installation
|
130 |
+
|
131 |
+
```bash
|
132 |
+
git clone https://github.com/Zyphra/Zonos.git
|
133 |
+
cd Zonos
|
134 |
+
|
135 |
+
# For gradio
|
136 |
+
docker compose up
|
137 |
+
|
138 |
+
# Or for development you can do
|
139 |
+
docker build -t Zonos .
|
140 |
+
docker run -it --gpus=all --net=host -v /path/to/Zonos:/Zonos -t Zonos
|
141 |
+
cd /Zonos
|
142 |
+
python sample.py # this will generate a sample.wav in /Zonos
|
143 |
+
```
|
assets/ArchitectureDiagram.png
ADDED
![]() |
assets/ZonosHeader.png
ADDED
![]() |
assets/exampleaudio.mp3
ADDED
Binary file (820 kB). View file
|
|
assets/silence_100ms.wav
ADDED
Binary file (9.43 kB). View file
|
|
docker-compose.yml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.8'
|
2 |
+
|
3 |
+
services:
|
4 |
+
zonos:
|
5 |
+
build:
|
6 |
+
context: .
|
7 |
+
dockerfile: Dockerfile
|
8 |
+
container_name: zonos_container
|
9 |
+
runtime: nvidia
|
10 |
+
network_mode: "host"
|
11 |
+
stdin_open: true
|
12 |
+
tty: true
|
13 |
+
command: ["python3", "gradio_interface.py"]
|
14 |
+
environment:
|
15 |
+
- NVIDIA_VISIBLE_DEVICES=0
|
16 |
+
- GRADIO_SHARE=False
|
gradio_interface.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import gradio as gr
|
4 |
+
from os import getenv
|
5 |
+
|
6 |
+
from zonos.model import Zonos
|
7 |
+
from zonos.conditioning import make_cond_dict, supported_language_codes
|
8 |
+
|
9 |
+
device = "cuda"
|
10 |
+
CURRENT_MODEL_TYPE = None
|
11 |
+
CURRENT_MODEL = None
|
12 |
+
|
13 |
+
|
14 |
+
def load_model_if_needed(model_choice: str):
|
15 |
+
global CURRENT_MODEL_TYPE, CURRENT_MODEL
|
16 |
+
if CURRENT_MODEL_TYPE != model_choice:
|
17 |
+
if CURRENT_MODEL is not None:
|
18 |
+
del CURRENT_MODEL
|
19 |
+
torch.cuda.empty_cache()
|
20 |
+
print(f"Loading {model_choice} model...")
|
21 |
+
CURRENT_MODEL = Zonos.from_pretrained(model_choice, device=device)
|
22 |
+
CURRENT_MODEL.requires_grad_(False).eval()
|
23 |
+
CURRENT_MODEL_TYPE = model_choice
|
24 |
+
print(f"{model_choice} model loaded successfully!")
|
25 |
+
return CURRENT_MODEL
|
26 |
+
|
27 |
+
|
28 |
+
def update_ui(model_choice):
|
29 |
+
"""
|
30 |
+
Dynamically show/hide UI elements based on the model's conditioners.
|
31 |
+
We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
|
32 |
+
"""
|
33 |
+
model = load_model_if_needed(model_choice)
|
34 |
+
cond_names = [c.name for c in model.prefix_conditioner.conditioners]
|
35 |
+
print("Conditioners in this model:", cond_names)
|
36 |
+
|
37 |
+
text_update = gr.update(visible=("espeak" in cond_names))
|
38 |
+
language_update = gr.update(visible=("espeak" in cond_names))
|
39 |
+
speaker_audio_update = gr.update(visible=("speaker" in cond_names))
|
40 |
+
prefix_audio_update = gr.update(visible=True)
|
41 |
+
emotion1_update = gr.update(visible=("emotion" in cond_names))
|
42 |
+
emotion2_update = gr.update(visible=("emotion" in cond_names))
|
43 |
+
emotion3_update = gr.update(visible=("emotion" in cond_names))
|
44 |
+
emotion4_update = gr.update(visible=("emotion" in cond_names))
|
45 |
+
emotion5_update = gr.update(visible=("emotion" in cond_names))
|
46 |
+
emotion6_update = gr.update(visible=("emotion" in cond_names))
|
47 |
+
emotion7_update = gr.update(visible=("emotion" in cond_names))
|
48 |
+
emotion8_update = gr.update(visible=("emotion" in cond_names))
|
49 |
+
vq_single_slider_update = gr.update(visible=("vqscore_8" in cond_names))
|
50 |
+
fmax_slider_update = gr.update(visible=("fmax" in cond_names))
|
51 |
+
pitch_std_slider_update = gr.update(visible=("pitch_std" in cond_names))
|
52 |
+
speaking_rate_slider_update = gr.update(visible=("speaking_rate" in cond_names))
|
53 |
+
dnsmos_slider_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
|
54 |
+
speaker_noised_checkbox_update = gr.update(visible=("speaker_noised" in cond_names))
|
55 |
+
unconditional_keys_update = gr.update(
|
56 |
+
choices=[name for name in cond_names if name not in ("espeak", "language_id")]
|
57 |
+
)
|
58 |
+
|
59 |
+
return (
|
60 |
+
text_update,
|
61 |
+
language_update,
|
62 |
+
speaker_audio_update,
|
63 |
+
prefix_audio_update,
|
64 |
+
emotion1_update,
|
65 |
+
emotion2_update,
|
66 |
+
emotion3_update,
|
67 |
+
emotion4_update,
|
68 |
+
emotion5_update,
|
69 |
+
emotion6_update,
|
70 |
+
emotion7_update,
|
71 |
+
emotion8_update,
|
72 |
+
vq_single_slider_update,
|
73 |
+
fmax_slider_update,
|
74 |
+
pitch_std_slider_update,
|
75 |
+
speaking_rate_slider_update,
|
76 |
+
dnsmos_slider_update,
|
77 |
+
speaker_noised_checkbox_update,
|
78 |
+
unconditional_keys_update,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def generate_audio(
|
83 |
+
model_choice,
|
84 |
+
text,
|
85 |
+
language,
|
86 |
+
speaker_audio,
|
87 |
+
prefix_audio,
|
88 |
+
e1,
|
89 |
+
e2,
|
90 |
+
e3,
|
91 |
+
e4,
|
92 |
+
e5,
|
93 |
+
e6,
|
94 |
+
e7,
|
95 |
+
e8,
|
96 |
+
vq_single,
|
97 |
+
fmax,
|
98 |
+
pitch_std,
|
99 |
+
speaking_rate,
|
100 |
+
dnsmos_ovrl,
|
101 |
+
speaker_noised,
|
102 |
+
cfg_scale,
|
103 |
+
min_p,
|
104 |
+
seed,
|
105 |
+
randomize_seed,
|
106 |
+
unconditional_keys,
|
107 |
+
progress=gr.Progress(),
|
108 |
+
):
|
109 |
+
"""
|
110 |
+
Generates audio based on the provided UI parameters.
|
111 |
+
We do NOT use language_id or ctc_loss even if the model has them.
|
112 |
+
"""
|
113 |
+
selected_model = load_model_if_needed(model_choice)
|
114 |
+
|
115 |
+
speaker_noised_bool = bool(speaker_noised)
|
116 |
+
fmax = float(fmax)
|
117 |
+
pitch_std = float(pitch_std)
|
118 |
+
speaking_rate = float(speaking_rate)
|
119 |
+
dnsmos_ovrl = float(dnsmos_ovrl)
|
120 |
+
cfg_scale = float(cfg_scale)
|
121 |
+
min_p = float(min_p)
|
122 |
+
seed = int(seed)
|
123 |
+
max_new_tokens = 86 * 30
|
124 |
+
|
125 |
+
if randomize_seed:
|
126 |
+
seed = torch.randint(0, 2**32 - 1, (1,)).item()
|
127 |
+
torch.manual_seed(seed)
|
128 |
+
|
129 |
+
speaker_embedding = None
|
130 |
+
if speaker_audio is not None and "speaker" not in unconditional_keys:
|
131 |
+
wav, sr = torchaudio.load(speaker_audio)
|
132 |
+
speaker_embedding = selected_model.make_speaker_embedding(wav, sr)
|
133 |
+
speaker_embedding = speaker_embedding.to(device, dtype=torch.bfloat16)
|
134 |
+
|
135 |
+
audio_prefix_codes = None
|
136 |
+
if prefix_audio is not None:
|
137 |
+
wav_prefix, sr_prefix = torchaudio.load(prefix_audio)
|
138 |
+
wav_prefix = wav_prefix.mean(0, keepdim=True)
|
139 |
+
wav_prefix = torchaudio.functional.resample(wav_prefix, sr_prefix, selected_model.autoencoder.sampling_rate)
|
140 |
+
wav_prefix = wav_prefix.to(device, dtype=torch.float32)
|
141 |
+
with torch.autocast(device, dtype=torch.float32):
|
142 |
+
audio_prefix_codes = selected_model.autoencoder.encode(wav_prefix.unsqueeze(0))
|
143 |
+
|
144 |
+
emotion_tensor = torch.tensor(list(map(float, [e1, e2, e3, e4, e5, e6, e7, e8])), device=device)
|
145 |
+
|
146 |
+
vq_val = float(vq_single)
|
147 |
+
vq_tensor = torch.tensor([vq_val] * 8, device=device).unsqueeze(0)
|
148 |
+
|
149 |
+
cond_dict = make_cond_dict(
|
150 |
+
text=text,
|
151 |
+
language=language,
|
152 |
+
speaker=speaker_embedding,
|
153 |
+
emotion=emotion_tensor,
|
154 |
+
vqscore_8=vq_tensor,
|
155 |
+
fmax=fmax,
|
156 |
+
pitch_std=pitch_std,
|
157 |
+
speaking_rate=speaking_rate,
|
158 |
+
dnsmos_ovrl=dnsmos_ovrl,
|
159 |
+
speaker_noised=speaker_noised_bool,
|
160 |
+
device=device,
|
161 |
+
unconditional_keys=unconditional_keys,
|
162 |
+
)
|
163 |
+
conditioning = selected_model.prepare_conditioning(cond_dict)
|
164 |
+
|
165 |
+
estimated_generation_duration = 30 * len(text) / 400
|
166 |
+
estimated_total_steps = int(estimated_generation_duration * 86)
|
167 |
+
|
168 |
+
def update_progress(_frame: torch.Tensor, step: int, _total_steps: int) -> bool:
|
169 |
+
progress((step, estimated_total_steps))
|
170 |
+
return True
|
171 |
+
|
172 |
+
codes = selected_model.generate(
|
173 |
+
prefix_conditioning=conditioning,
|
174 |
+
audio_prefix_codes=audio_prefix_codes,
|
175 |
+
max_new_tokens=max_new_tokens,
|
176 |
+
cfg_scale=cfg_scale,
|
177 |
+
batch_size=1,
|
178 |
+
sampling_params=dict(min_p=min_p),
|
179 |
+
callback=update_progress,
|
180 |
+
)
|
181 |
+
|
182 |
+
wav_out = selected_model.autoencoder.decode(codes).cpu().detach()
|
183 |
+
sr_out = selected_model.autoencoder.sampling_rate
|
184 |
+
if wav_out.dim() == 2 and wav_out.size(0) > 1:
|
185 |
+
wav_out = wav_out[0:1, :]
|
186 |
+
return (sr_out, wav_out.squeeze().numpy()), seed
|
187 |
+
|
188 |
+
|
189 |
+
def build_interface():
|
190 |
+
with gr.Blocks() as demo:
|
191 |
+
with gr.Row():
|
192 |
+
with gr.Column():
|
193 |
+
model_choice = gr.Dropdown(
|
194 |
+
choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
|
195 |
+
value="Zyphra/Zonos-v0.1-transformer",
|
196 |
+
label="Zonos Model Type",
|
197 |
+
info="Select the model variant to use.",
|
198 |
+
)
|
199 |
+
text = gr.Textbox(
|
200 |
+
label="Text to Synthesize",
|
201 |
+
value="Zonos uses eSpeak for text to phoneme conversion!",
|
202 |
+
lines=4,
|
203 |
+
max_length=500, # approximately
|
204 |
+
)
|
205 |
+
language = gr.Dropdown(
|
206 |
+
choices=supported_language_codes,
|
207 |
+
value="en-us",
|
208 |
+
label="Language Code",
|
209 |
+
info="Select a language code.",
|
210 |
+
)
|
211 |
+
prefix_audio = gr.Audio(
|
212 |
+
value="assets/silence_100ms.wav",
|
213 |
+
label="Optional Prefix Audio (continue from this audio)",
|
214 |
+
type="filepath",
|
215 |
+
)
|
216 |
+
with gr.Column():
|
217 |
+
speaker_audio = gr.Audio(
|
218 |
+
label="Optional Speaker Audio (for cloning)",
|
219 |
+
type="filepath",
|
220 |
+
)
|
221 |
+
speaker_noised_checkbox = gr.Checkbox(label="Denoise Speaker?", value=False)
|
222 |
+
|
223 |
+
with gr.Row():
|
224 |
+
with gr.Column():
|
225 |
+
gr.Markdown("## Conditioning Parameters")
|
226 |
+
dnsmos_slider = gr.Slider(1.0, 5.0, value=4.0, step=0.1, label="DNSMOS Overall")
|
227 |
+
fmax_slider = gr.Slider(0, 24000, value=24000, step=1, label="Fmax (Hz)")
|
228 |
+
vq_single_slider = gr.Slider(0.5, 0.8, 0.78, 0.01, label="VQ Score")
|
229 |
+
pitch_std_slider = gr.Slider(0.0, 300.0, value=45.0, step=1, label="Pitch Std")
|
230 |
+
speaking_rate_slider = gr.Slider(5.0, 30.0, value=15.0, step=0.5, label="Speaking Rate")
|
231 |
+
|
232 |
+
with gr.Column():
|
233 |
+
gr.Markdown("## Generation Parameters")
|
234 |
+
cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="CFG Scale")
|
235 |
+
min_p_slider = gr.Slider(0.0, 1.0, 0.15, 0.01, label="Min P")
|
236 |
+
seed_number = gr.Number(label="Seed", value=420, precision=0)
|
237 |
+
randomize_seed_toggle = gr.Checkbox(label="Randomize Seed (before generation)", value=True)
|
238 |
+
|
239 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
240 |
+
gr.Markdown(
|
241 |
+
"### Unconditional Toggles\n"
|
242 |
+
"Checking a box will make the model ignore the corresponding conditioning value and make it unconditional.\n"
|
243 |
+
'Practically this means the given conditioning feature will be unconstrained and "filled in automatically".'
|
244 |
+
)
|
245 |
+
with gr.Row():
|
246 |
+
unconditional_keys = gr.CheckboxGroup(
|
247 |
+
[
|
248 |
+
"speaker",
|
249 |
+
"emotion",
|
250 |
+
"vqscore_8",
|
251 |
+
"fmax",
|
252 |
+
"pitch_std",
|
253 |
+
"speaking_rate",
|
254 |
+
"dnsmos_ovrl",
|
255 |
+
"speaker_noised",
|
256 |
+
],
|
257 |
+
value=["emotion"],
|
258 |
+
label="Unconditional Keys",
|
259 |
+
)
|
260 |
+
|
261 |
+
gr.Markdown(
|
262 |
+
"### Emotion Sliders\n"
|
263 |
+
"Warning: The way these sliders work is not intuitive and may require some trial and error to get the desired effect.\n"
|
264 |
+
"Certain configurations can cause the model to become unstable. Setting emotion to unconditional may help."
|
265 |
+
)
|
266 |
+
with gr.Row():
|
267 |
+
emotion1 = gr.Slider(0.0, 1.0, 1.0, 0.05, label="Happiness")
|
268 |
+
emotion2 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Sadness")
|
269 |
+
emotion3 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Disgust")
|
270 |
+
emotion4 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Fear")
|
271 |
+
with gr.Row():
|
272 |
+
emotion5 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Surprise")
|
273 |
+
emotion6 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Anger")
|
274 |
+
emotion7 = gr.Slider(0.0, 1.0, 0.1, 0.05, label="Other")
|
275 |
+
emotion8 = gr.Slider(0.0, 1.0, 0.2, 0.05, label="Neutral")
|
276 |
+
|
277 |
+
with gr.Column():
|
278 |
+
generate_button = gr.Button("Generate Audio")
|
279 |
+
output_audio = gr.Audio(label="Generated Audio", type="numpy", autoplay=True)
|
280 |
+
|
281 |
+
model_choice.change(
|
282 |
+
fn=update_ui,
|
283 |
+
inputs=[model_choice],
|
284 |
+
outputs=[
|
285 |
+
text,
|
286 |
+
language,
|
287 |
+
speaker_audio,
|
288 |
+
prefix_audio,
|
289 |
+
emotion1,
|
290 |
+
emotion2,
|
291 |
+
emotion3,
|
292 |
+
emotion4,
|
293 |
+
emotion5,
|
294 |
+
emotion6,
|
295 |
+
emotion7,
|
296 |
+
emotion8,
|
297 |
+
vq_single_slider,
|
298 |
+
fmax_slider,
|
299 |
+
pitch_std_slider,
|
300 |
+
speaking_rate_slider,
|
301 |
+
dnsmos_slider,
|
302 |
+
speaker_noised_checkbox,
|
303 |
+
unconditional_keys,
|
304 |
+
],
|
305 |
+
)
|
306 |
+
|
307 |
+
# On page load, trigger the same UI refresh
|
308 |
+
demo.load(
|
309 |
+
fn=update_ui,
|
310 |
+
inputs=[model_choice],
|
311 |
+
outputs=[
|
312 |
+
text,
|
313 |
+
language,
|
314 |
+
speaker_audio,
|
315 |
+
prefix_audio,
|
316 |
+
emotion1,
|
317 |
+
emotion2,
|
318 |
+
emotion3,
|
319 |
+
emotion4,
|
320 |
+
emotion5,
|
321 |
+
emotion6,
|
322 |
+
emotion7,
|
323 |
+
emotion8,
|
324 |
+
vq_single_slider,
|
325 |
+
fmax_slider,
|
326 |
+
pitch_std_slider,
|
327 |
+
speaking_rate_slider,
|
328 |
+
dnsmos_slider,
|
329 |
+
speaker_noised_checkbox,
|
330 |
+
unconditional_keys,
|
331 |
+
],
|
332 |
+
)
|
333 |
+
|
334 |
+
# Generate audio on button click
|
335 |
+
generate_button.click(
|
336 |
+
fn=generate_audio,
|
337 |
+
inputs=[
|
338 |
+
model_choice,
|
339 |
+
text,
|
340 |
+
language,
|
341 |
+
speaker_audio,
|
342 |
+
prefix_audio,
|
343 |
+
emotion1,
|
344 |
+
emotion2,
|
345 |
+
emotion3,
|
346 |
+
emotion4,
|
347 |
+
emotion5,
|
348 |
+
emotion6,
|
349 |
+
emotion7,
|
350 |
+
emotion8,
|
351 |
+
vq_single_slider,
|
352 |
+
fmax_slider,
|
353 |
+
pitch_std_slider,
|
354 |
+
speaking_rate_slider,
|
355 |
+
dnsmos_slider,
|
356 |
+
speaker_noised_checkbox,
|
357 |
+
cfg_scale_slider,
|
358 |
+
min_p_slider,
|
359 |
+
seed_number,
|
360 |
+
randomize_seed_toggle,
|
361 |
+
unconditional_keys,
|
362 |
+
],
|
363 |
+
outputs=[output_audio, seed_number],
|
364 |
+
)
|
365 |
+
|
366 |
+
return demo
|
367 |
+
|
368 |
+
|
369 |
+
if __name__ == "__main__":
|
370 |
+
demo = build_interface()
|
371 |
+
share = getenv("GRADIO_SHARE", "False").lower() in ("true", "1", "t")
|
372 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=share)
|
pyproject.toml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "zonos"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Text-to-speech by Zyphra"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.10"
|
7 |
+
dependencies = [
|
8 |
+
"torch>=2.5.1",
|
9 |
+
"setuptools",
|
10 |
+
"packaging",
|
11 |
+
"inflect>=7.5.0",
|
12 |
+
"kanjize>=1.5.0",
|
13 |
+
"numpy>=2.2.2",
|
14 |
+
"phonemizer>=3.3.0",
|
15 |
+
"sudachidict-full>=20241021",
|
16 |
+
"sudachipy>=0.6.10",
|
17 |
+
"torchaudio>=2.5.1",
|
18 |
+
"transformers>=4.48.1",
|
19 |
+
"soundfile>=0.13.1",
|
20 |
+
"huggingface-hub>=0.28.1",
|
21 |
+
"gradio>=5.15.0",
|
22 |
+
]
|
23 |
+
|
24 |
+
# Note: Not actually optional. All the modeling code lives in mamba-ssm.
|
25 |
+
# We put them here to make the two-stage installation process easier.
|
26 |
+
[project.optional-dependencies]
|
27 |
+
compile = [
|
28 |
+
"flash-attn>=2.7.3",
|
29 |
+
"mamba-ssm>=2.2.4",
|
30 |
+
"causal-conv1d>=1.5.0.post8",
|
31 |
+
]
|
32 |
+
|
33 |
+
[tool.setuptools.packages.find]
|
34 |
+
include = ["zonos"]
|
35 |
+
|
36 |
+
[tool.uv]
|
37 |
+
no-build-isolation-package = ["flash-attn", "mamba-ssm", "causal-conv1d"]
|
38 |
+
|
39 |
+
[tool.ruff]
|
40 |
+
line-length = 120
|
sample.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
from zonos.model import Zonos
|
4 |
+
from zonos.conditioning import make_cond_dict
|
5 |
+
|
6 |
+
# Use the hybrid with "Zyphra/Zonos-v0.1-hybrid"
|
7 |
+
model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device="cuda")
|
8 |
+
|
9 |
+
wav, sampling_rate = torchaudio.load("assets/exampleaudio.mp3")
|
10 |
+
speaker = model.make_speaker_embedding(wav, sampling_rate)
|
11 |
+
|
12 |
+
torch.manual_seed(421)
|
13 |
+
|
14 |
+
cond_dict = make_cond_dict(text="Hello, world!", speaker=speaker, language="en-us")
|
15 |
+
conditioning = model.prepare_conditioning(cond_dict)
|
16 |
+
|
17 |
+
codes = model.generate(conditioning)
|
18 |
+
|
19 |
+
wavs = model.autoencoder.decode(codes).cpu()
|
20 |
+
torchaudio.save("sample.wav", wavs[0], model.autoencoder.sampling_rate)
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
zonos/autoencoder.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from transformers.models.dac import DacModel
|
6 |
+
|
7 |
+
|
8 |
+
class DACAutoencoder:
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.dac = DacModel.from_pretrained("descript/dac_44khz")
|
12 |
+
self.dac.eval().requires_grad_(False)
|
13 |
+
self.codebook_size = self.dac.config.codebook_size
|
14 |
+
self.num_codebooks = self.dac.quantizer.n_codebooks
|
15 |
+
self.sampling_rate = self.dac.config.sampling_rate
|
16 |
+
|
17 |
+
def preprocess(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
|
18 |
+
wav = torchaudio.functional.resample(wav, sr, 44_100)
|
19 |
+
right_pad = math.ceil(wav.shape[-1] / 512) * 512 - wav.shape[-1]
|
20 |
+
return torch.nn.functional.pad(wav, (0, right_pad))
|
21 |
+
|
22 |
+
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
23 |
+
return self.dac.encode(wav).audio_codes
|
24 |
+
|
25 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
26 |
+
return self.dac.decode(audio_codes=codes).audio_values.unsqueeze(1)
|
zonos/backbone.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from mamba_ssm.models.mixer_seq_simple import create_block
|
4 |
+
from mamba_ssm.ops.triton.layer_norm import layer_norm_fn
|
5 |
+
from mamba_ssm.utils.generation import InferenceParams
|
6 |
+
|
7 |
+
from zonos.config import BackboneConfig
|
8 |
+
|
9 |
+
|
10 |
+
class ZonosBackbone(nn.Module):
|
11 |
+
def __init__(self, config: BackboneConfig):
|
12 |
+
super().__init__()
|
13 |
+
self.config = config
|
14 |
+
|
15 |
+
self.layers = nn.ModuleList(
|
16 |
+
[
|
17 |
+
create_block(
|
18 |
+
d_model=config.d_model,
|
19 |
+
d_intermediate=config.d_intermediate
|
20 |
+
if (i not in config.attn_layer_idx)
|
21 |
+
else config.attn_mlp_d_intermediate,
|
22 |
+
ssm_cfg=config.ssm_cfg,
|
23 |
+
layer_idx=i,
|
24 |
+
attn_layer_idx=config.attn_layer_idx,
|
25 |
+
attn_cfg=config.attn_cfg,
|
26 |
+
norm_epsilon=config.norm_epsilon,
|
27 |
+
residual_in_fp32=config.residual_in_fp32,
|
28 |
+
fused_add_norm=True,
|
29 |
+
rms_norm=config.rms_norm,
|
30 |
+
)
|
31 |
+
for i in range(config.n_layer)
|
32 |
+
]
|
33 |
+
)
|
34 |
+
|
35 |
+
self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
|
36 |
+
|
37 |
+
def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None):
|
38 |
+
residual = None
|
39 |
+
for layer in self.layers:
|
40 |
+
hidden_states, residual = layer(hidden_states, residual, inference_params)
|
41 |
+
|
42 |
+
return layer_norm_fn(
|
43 |
+
hidden_states,
|
44 |
+
self.norm_f.weight,
|
45 |
+
self.norm_f.bias,
|
46 |
+
residual,
|
47 |
+
eps=self.norm_f.eps,
|
48 |
+
residual_in_fp32=self.config.residual_in_fp32,
|
49 |
+
is_rms_norm=self.config.rms_norm,
|
50 |
+
)
|
zonos/codebook_pattern.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
|
6 |
+
codes = F.pad(codes, (0, codes.shape[1]), value=mask_token)
|
7 |
+
return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
|
8 |
+
|
9 |
+
|
10 |
+
def revert_delay_pattern(codes: torch.Tensor):
|
11 |
+
_, n_q, seq_len = codes.shape
|
12 |
+
return torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)
|
zonos/conditioning.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cache
|
2 |
+
from typing import Any, Literal, Iterable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from zonos.config import PrefixConditionerConfig
|
8 |
+
|
9 |
+
|
10 |
+
class Conditioner(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
output_dim: int,
|
14 |
+
name: str,
|
15 |
+
cond_dim: int | None = None,
|
16 |
+
projection: Literal["none", "linear", "mlp"] = "none",
|
17 |
+
uncond_type: Literal["learned", "none"] = "none",
|
18 |
+
**kwargs,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.name = name
|
22 |
+
self.output_dim = output_dim
|
23 |
+
self.cond_dim = cond_dim = cond_dim or output_dim
|
24 |
+
|
25 |
+
if projection == "linear":
|
26 |
+
self.project = nn.Linear(cond_dim, output_dim)
|
27 |
+
elif projection == "mlp":
|
28 |
+
self.project = nn.Sequential(
|
29 |
+
nn.Linear(cond_dim, output_dim),
|
30 |
+
nn.SiLU(),
|
31 |
+
nn.Linear(output_dim, output_dim),
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
self.project = nn.Identity()
|
35 |
+
|
36 |
+
self.uncond_vector = None
|
37 |
+
if uncond_type == "learned":
|
38 |
+
self.uncond_vector = nn.Parameter(torch.zeros(output_dim))
|
39 |
+
|
40 |
+
def apply_cond(self, *inputs: Any) -> torch.Tensor:
|
41 |
+
raise NotImplementedError()
|
42 |
+
|
43 |
+
def forward(self, inputs: tuple[Any, ...] | None) -> torch.Tensor:
|
44 |
+
if inputs is None:
|
45 |
+
assert self.uncond_vector is not None
|
46 |
+
return self.uncond_vector.data.view(1, 1, -1)
|
47 |
+
|
48 |
+
cond = self.apply_cond(*inputs)
|
49 |
+
cond = self.project(cond)
|
50 |
+
return cond
|
51 |
+
|
52 |
+
|
53 |
+
# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
|
54 |
+
import re
|
55 |
+
import unicodedata
|
56 |
+
|
57 |
+
import inflect
|
58 |
+
import torch
|
59 |
+
import torch.nn as nn
|
60 |
+
from kanjize import number2kanji
|
61 |
+
from phonemizer.backend import EspeakBackend
|
62 |
+
from sudachipy import Dictionary, SplitMode
|
63 |
+
|
64 |
+
# --- Number normalization code from https://github.com/daniilrobnikov/vits2/blob/main/text/normalize_numbers.py ---
|
65 |
+
|
66 |
+
_inflect = inflect.engine()
|
67 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
68 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
69 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
70 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
71 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
72 |
+
_number_re = re.compile(r"[0-9]+")
|
73 |
+
|
74 |
+
|
75 |
+
def _remove_commas(m: re.Match) -> str:
|
76 |
+
return m.group(1).replace(",", "")
|
77 |
+
|
78 |
+
|
79 |
+
def _expand_decimal_point(m: re.Match) -> str:
|
80 |
+
return m.group(1).replace(".", " point ")
|
81 |
+
|
82 |
+
|
83 |
+
def _expand_dollars(m: re.Match) -> str:
|
84 |
+
match = m.group(1)
|
85 |
+
parts = match.split(".")
|
86 |
+
if len(parts) > 2:
|
87 |
+
return match + " dollars" # Unexpected format
|
88 |
+
dollars = int(parts[0]) if parts[0] else 0
|
89 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
90 |
+
if dollars and cents:
|
91 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
92 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
93 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
94 |
+
elif dollars:
|
95 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
96 |
+
return "%s %s" % (dollars, dollar_unit)
|
97 |
+
elif cents:
|
98 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
99 |
+
return "%s %s" % (cents, cent_unit)
|
100 |
+
else:
|
101 |
+
return "zero dollars"
|
102 |
+
|
103 |
+
|
104 |
+
def _expand_ordinal(m: re.Match) -> str:
|
105 |
+
return _inflect.number_to_words(m.group(0))
|
106 |
+
|
107 |
+
|
108 |
+
def _expand_number(m: re.Match) -> str:
|
109 |
+
num = int(m.group(0))
|
110 |
+
if num > 1000 and num < 3000:
|
111 |
+
if num == 2000:
|
112 |
+
return "two thousand"
|
113 |
+
elif num > 2000 and num < 2010:
|
114 |
+
return "two thousand " + _inflect.number_to_words(num % 100)
|
115 |
+
elif num % 100 == 0:
|
116 |
+
return _inflect.number_to_words(num // 100) + " hundred"
|
117 |
+
else:
|
118 |
+
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
119 |
+
else:
|
120 |
+
return _inflect.number_to_words(num, andword="")
|
121 |
+
|
122 |
+
|
123 |
+
def normalize_numbers(text: str) -> str:
|
124 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
125 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
126 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
127 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
128 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
129 |
+
text = re.sub(_number_re, _expand_number, text)
|
130 |
+
return text
|
131 |
+
|
132 |
+
|
133 |
+
# --- Number normalization code end ---
|
134 |
+
|
135 |
+
|
136 |
+
PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3
|
137 |
+
SPECIAL_TOKEN_IDS = [PAD_ID, UNK_ID, BOS_ID, EOS_ID]
|
138 |
+
|
139 |
+
_punctuation = ';:,.!?¡¿—…"«»“”() *~-/\\&'
|
140 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
141 |
+
_letters_ipa = (
|
142 |
+
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
143 |
+
)
|
144 |
+
|
145 |
+
symbols = [*_punctuation, *_letters, *_letters_ipa]
|
146 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols, start=len(SPECIAL_TOKEN_IDS))}
|
147 |
+
|
148 |
+
|
149 |
+
def _get_symbol_id(s: str) -> int:
|
150 |
+
return _symbol_to_id.get(s, 1)
|
151 |
+
|
152 |
+
|
153 |
+
def get_symbol_ids(text: str) -> list[int]:
|
154 |
+
return list(map(_get_symbol_id, text))
|
155 |
+
|
156 |
+
|
157 |
+
def tokenize_phonemes(phonemes: list[str]) -> tuple[torch.Tensor, list[int]]:
|
158 |
+
phoneme_ids = [[BOS_ID, *get_symbol_ids(phonemes), EOS_ID] for phonemes in phonemes]
|
159 |
+
lengths = list(map(len, phoneme_ids))
|
160 |
+
longest = max(lengths)
|
161 |
+
phoneme_ids = [[PAD_ID] * (longest - len(ids)) + ids for ids in phoneme_ids]
|
162 |
+
return torch.tensor(phoneme_ids), lengths
|
163 |
+
|
164 |
+
|
165 |
+
def normalize_jp_text(text: str, tokenizer=Dictionary(dict="full").create()) -> str:
|
166 |
+
text = unicodedata.normalize("NFKC", text)
|
167 |
+
text = re.sub(r"\d+", lambda m: number2kanji(int(m[0])), text)
|
168 |
+
final_text = " ".join([x.reading_form() for x in tokenizer.tokenize(text, SplitMode.A)])
|
169 |
+
return final_text
|
170 |
+
|
171 |
+
|
172 |
+
def clean(texts: list[str], languages: list[str]) -> list[str]:
|
173 |
+
texts_out = []
|
174 |
+
for text, language in zip(texts, languages):
|
175 |
+
if "ja" in language:
|
176 |
+
text = normalize_jp_text(text)
|
177 |
+
else:
|
178 |
+
text = normalize_numbers(text)
|
179 |
+
texts_out.append(text)
|
180 |
+
return texts_out
|
181 |
+
|
182 |
+
|
183 |
+
@cache
|
184 |
+
def get_backend(language: str) -> "EspeakBackend":
|
185 |
+
import logging
|
186 |
+
|
187 |
+
from phonemizer.backend import EspeakBackend
|
188 |
+
|
189 |
+
logger = logging.getLogger("phonemizer")
|
190 |
+
backend = EspeakBackend(
|
191 |
+
language,
|
192 |
+
preserve_punctuation=True,
|
193 |
+
with_stress=True,
|
194 |
+
punctuation_marks=_punctuation,
|
195 |
+
logger=logger,
|
196 |
+
)
|
197 |
+
logger.setLevel(logging.ERROR)
|
198 |
+
return backend
|
199 |
+
|
200 |
+
|
201 |
+
def phonemize(texts: list[str], languages: list[str]) -> list[str]:
|
202 |
+
texts = clean(texts, languages)
|
203 |
+
|
204 |
+
batch_phonemes = []
|
205 |
+
for text, language in zip(texts, languages):
|
206 |
+
backend = get_backend(language)
|
207 |
+
phonemes = backend.phonemize([text], strip=True)
|
208 |
+
batch_phonemes.append(phonemes[0])
|
209 |
+
|
210 |
+
return batch_phonemes
|
211 |
+
|
212 |
+
|
213 |
+
class EspeakPhonemeConditioner(Conditioner):
|
214 |
+
def __init__(self, output_dim: int, **kwargs):
|
215 |
+
super().__init__(output_dim, **kwargs)
|
216 |
+
self.phoneme_embedder = nn.Embedding(len(SPECIAL_TOKEN_IDS) + len(symbols), output_dim)
|
217 |
+
|
218 |
+
def apply_cond(self, texts: list[str], languages: list[str]) -> torch.Tensor:
|
219 |
+
"""
|
220 |
+
Args:
|
221 |
+
texts: list of texts to convert to phonemes
|
222 |
+
languages: ISO 639-1 -or otherwise eSpeak compatible- language code
|
223 |
+
"""
|
224 |
+
device = self.phoneme_embedder.weight.device
|
225 |
+
|
226 |
+
phonemes = phonemize(texts, languages)
|
227 |
+
phoneme_ids, _ = tokenize_phonemes(phonemes)
|
228 |
+
phoneme_embeds = self.phoneme_embedder(phoneme_ids.to(device))
|
229 |
+
|
230 |
+
return phoneme_embeds
|
231 |
+
|
232 |
+
|
233 |
+
# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
|
234 |
+
|
235 |
+
|
236 |
+
class FourierConditioner(Conditioner):
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
output_dim: int,
|
240 |
+
input_dim: int = 1,
|
241 |
+
std: float = 1.0,
|
242 |
+
min_val: float = 0.0,
|
243 |
+
max_val: float = 1.0,
|
244 |
+
**kwargs,
|
245 |
+
):
|
246 |
+
assert output_dim % 2 == 0
|
247 |
+
super().__init__(output_dim, **kwargs)
|
248 |
+
self.register_buffer("weight", torch.randn([output_dim // 2, input_dim]) * std)
|
249 |
+
self.input_dim, self.min_val, self.max_val = input_dim, min_val, max_val
|
250 |
+
|
251 |
+
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
252 |
+
assert x.shape[-1] == self.input_dim
|
253 |
+
x = (x - self.min_val) / (self.max_val - self.min_val) # [batch_size, seq_len, input_dim]
|
254 |
+
f = 2 * torch.pi * x.to(self.weight.dtype) @ self.weight.T # [batch_size, seq_len, output_dim // 2]
|
255 |
+
return torch.cat([f.cos(), f.sin()], dim=-1) # [batch_size, seq_len, output_dim]
|
256 |
+
|
257 |
+
|
258 |
+
class IntegerConditioner(Conditioner):
|
259 |
+
def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512, **kwargs):
|
260 |
+
super().__init__(output_dim, **kwargs)
|
261 |
+
self.min_val = min_val
|
262 |
+
self.max_val = max_val
|
263 |
+
self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim)
|
264 |
+
|
265 |
+
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
266 |
+
assert x.shape[-1] == 1
|
267 |
+
return self.int_embedder(x.squeeze(-1) - self.min_val) # [batch_size, seq_len, output_dim]
|
268 |
+
|
269 |
+
|
270 |
+
class PassthroughConditioner(Conditioner):
|
271 |
+
def __init__(self, output_dim: int, **kwargs):
|
272 |
+
super().__init__(output_dim, **kwargs)
|
273 |
+
|
274 |
+
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
275 |
+
assert x.shape[-1] == self.cond_dim
|
276 |
+
return x
|
277 |
+
|
278 |
+
|
279 |
+
_cond_cls_map = {
|
280 |
+
"PassthroughConditioner": PassthroughConditioner,
|
281 |
+
"EspeakPhonemeConditioner": EspeakPhonemeConditioner,
|
282 |
+
"FourierConditioner": FourierConditioner,
|
283 |
+
"IntegerConditioner": IntegerConditioner,
|
284 |
+
}
|
285 |
+
|
286 |
+
|
287 |
+
def build_conditioners(conditioners: list[dict], output_dim: int) -> list[Conditioner]:
|
288 |
+
return [_cond_cls_map[config["type"]](output_dim, **config) for config in conditioners]
|
289 |
+
|
290 |
+
|
291 |
+
class PrefixConditioner(Conditioner):
|
292 |
+
def __init__(self, config: PrefixConditionerConfig, output_dim: int):
|
293 |
+
super().__init__(output_dim, "prefix", projection=config.projection)
|
294 |
+
self.conditioners = nn.ModuleList(build_conditioners(config.conditioners, output_dim))
|
295 |
+
self.norm = nn.LayerNorm(output_dim)
|
296 |
+
self.required_keys = {c.name for c in self.conditioners if c.uncond_vector is None}
|
297 |
+
|
298 |
+
def forward(self, cond_dict: dict) -> torch.Tensor:
|
299 |
+
if not set(cond_dict).issuperset(self.required_keys):
|
300 |
+
raise ValueError(f"Missing required keys: {self.required_keys - set(cond_dict)}")
|
301 |
+
conds = []
|
302 |
+
for conditioner in self.conditioners:
|
303 |
+
conds.append(conditioner(cond_dict.get(conditioner.name)))
|
304 |
+
max_bsz = max(map(len, conds))
|
305 |
+
assert all(c.shape[0] in (max_bsz, 1) for c in conds)
|
306 |
+
conds = [c.expand(max_bsz, -1, -1) for c in conds]
|
307 |
+
return self.norm(self.project(torch.cat(conds, dim=-2)))
|
308 |
+
|
309 |
+
|
310 |
+
supported_language_codes = [
|
311 |
+
'af', 'am', 'an', 'ar', 'as', 'az', 'ba', 'bg', 'bn', 'bpy', 'bs', 'ca', 'cmn',
|
312 |
+
'cs', 'cy', 'da', 'de', 'el', 'en-029', 'en-gb', 'en-gb-scotland', 'en-gb-x-gbclan',
|
313 |
+
'en-gb-x-gbcwmd', 'en-gb-x-rp', 'en-us', 'eo', 'es', 'es-419', 'et', 'eu', 'fa',
|
314 |
+
'fa-latn', 'fi', 'fr-be', 'fr-ch', 'fr-fr', 'ga', 'gd', 'gn', 'grc', 'gu', 'hak',
|
315 |
+
'hi', 'hr', 'ht', 'hu', 'hy', 'hyw', 'ia', 'id', 'is', 'it', 'ja', 'jbo', 'ka',
|
316 |
+
'kk', 'kl', 'kn', 'ko', 'kok', 'ku', 'ky', 'la', 'lfn', 'lt', 'lv', 'mi', 'mk',
|
317 |
+
'ml', 'mr', 'ms', 'mt', 'my', 'nb', 'nci', 'ne', 'nl', 'om', 'or', 'pa', 'pap',
|
318 |
+
'pl', 'pt', 'pt-br', 'py', 'quc', 'ro', 'ru', 'ru-lv', 'sd', 'shn', 'si', 'sk',
|
319 |
+
'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'tn', 'tr', 'tt', 'ur', 'uz', 'vi',
|
320 |
+
'vi-vn-x-central', 'vi-vn-x-south', 'yue'
|
321 |
+
] # fmt: off
|
322 |
+
|
323 |
+
|
324 |
+
def make_cond_dict(
|
325 |
+
text: str = "It would be nice to have time for testing, indeed.",
|
326 |
+
language: str = "en-us",
|
327 |
+
speaker: torch.Tensor | None = None,
|
328 |
+
emotion: list[float] = [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077],
|
329 |
+
fmax: float = 22050.0,
|
330 |
+
pitch_std: float = 20.0,
|
331 |
+
speaking_rate: float = 15.0,
|
332 |
+
vqscore_8: list[float] = [0.78] * 8,
|
333 |
+
ctc_loss: float = 0.0,
|
334 |
+
dnsmos_ovrl: float = 4.0,
|
335 |
+
speaker_noised: bool = False,
|
336 |
+
unconditional_keys: Iterable[str] = {"vqscore_8", "dnsmos_ovrl"},
|
337 |
+
device: str = "cuda",
|
338 |
+
) -> dict:
|
339 |
+
"""
|
340 |
+
A helper to build the 'cond_dict' that the model expects.
|
341 |
+
By default, it will generate a random speaker embedding
|
342 |
+
"""
|
343 |
+
assert language.lower() in supported_language_codes, "Please pick a supported language"
|
344 |
+
|
345 |
+
language_code_to_id = {lang: i for i, lang in enumerate(supported_language_codes)}
|
346 |
+
|
347 |
+
cond_dict = {
|
348 |
+
"espeak": ([text], [language]),
|
349 |
+
"speaker": speaker,
|
350 |
+
"emotion": emotion,
|
351 |
+
"fmax": fmax,
|
352 |
+
"pitch_std": pitch_std,
|
353 |
+
"speaking_rate": speaking_rate,
|
354 |
+
"language_id": language_code_to_id[language],
|
355 |
+
"vqscore_8": vqscore_8,
|
356 |
+
"ctc_loss": ctc_loss,
|
357 |
+
"dnsmos_ovrl": dnsmos_ovrl,
|
358 |
+
"speaker_noised": int(speaker_noised),
|
359 |
+
}
|
360 |
+
|
361 |
+
for k in unconditional_keys:
|
362 |
+
cond_dict.pop(k, None)
|
363 |
+
|
364 |
+
for k, v in cond_dict.items():
|
365 |
+
if isinstance(v, (float, int, list)):
|
366 |
+
v = torch.tensor(v)
|
367 |
+
if isinstance(v, torch.Tensor):
|
368 |
+
cond_dict[k] = v.view(1, 1, -1).to(device)
|
369 |
+
|
370 |
+
if k == "emotion":
|
371 |
+
cond_dict[k] /= cond_dict[k].sum(dim=-1)
|
372 |
+
|
373 |
+
return cond_dict
|
zonos/config.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Literal
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class BackboneConfig:
|
7 |
+
d_model: int = 1024
|
8 |
+
d_intermediate: int = 0
|
9 |
+
attn_mlp_d_intermediate: int = 0
|
10 |
+
n_layer: int = 16
|
11 |
+
ssm_cfg: dict = field(default_factory=dict)
|
12 |
+
attn_layer_idx: list = field(default_factory=list)
|
13 |
+
attn_cfg: dict = field(default_factory=dict)
|
14 |
+
rms_norm: bool = False
|
15 |
+
residual_in_fp32: bool = False
|
16 |
+
norm_epsilon: float = 1e-5
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class PrefixConditionerConfig:
|
21 |
+
conditioners: list[dict]
|
22 |
+
projection: Literal["none", "linear", "mlp"]
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class ZonosConfig:
|
27 |
+
backbone: BackboneConfig
|
28 |
+
prefix_conditioner: PrefixConditionerConfig
|
29 |
+
eos_token_id: int = 1024
|
30 |
+
masked_token_id: int = 1025
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_dict(cls, d: dict) -> "ZonosConfig":
|
34 |
+
d = d.copy()
|
35 |
+
backbone_config = BackboneConfig(**d.pop("backbone"))
|
36 |
+
prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
|
37 |
+
config = cls(backbone_config, prefix_conditioner_config, **d)
|
38 |
+
return config
|
zonos/model.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
import safetensors
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from mamba_ssm.utils.generation import InferenceParams
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from zonos.autoencoder import DACAutoencoder
|
12 |
+
from zonos.backbone import ZonosBackbone
|
13 |
+
from zonos.codebook_pattern import apply_delay_pattern, revert_delay_pattern
|
14 |
+
from zonos.conditioning import PrefixConditioner
|
15 |
+
from zonos.config import ZonosConfig
|
16 |
+
from zonos.sampling import sample_from_logits
|
17 |
+
from zonos.speaker_cloning import SpeakerEmbeddingLDA
|
18 |
+
|
19 |
+
|
20 |
+
class Zonos(nn.Module):
|
21 |
+
def __init__(self, config: ZonosConfig):
|
22 |
+
super().__init__()
|
23 |
+
self.config = config
|
24 |
+
dim = config.backbone.d_model
|
25 |
+
self.eos_token_id = config.eos_token_id
|
26 |
+
self.masked_token_id = config.masked_token_id
|
27 |
+
|
28 |
+
self.autoencoder = DACAutoencoder()
|
29 |
+
self.backbone = ZonosBackbone(config.backbone)
|
30 |
+
self.prefix_conditioner = PrefixConditioner(config.prefix_conditioner, dim)
|
31 |
+
self.spk_clone_model = None
|
32 |
+
|
33 |
+
# TODO: pad to multiple of at least 8
|
34 |
+
self.embeddings = nn.ModuleList([nn.Embedding(1026, dim) for _ in range(self.autoencoder.num_codebooks)])
|
35 |
+
self.heads = nn.ModuleList([nn.Linear(dim, 1025, bias=False) for _ in range(self.autoencoder.num_codebooks)])
|
36 |
+
|
37 |
+
self._cg_graph = None
|
38 |
+
self._cg_batch_size = None
|
39 |
+
self._cg_input_ids = None
|
40 |
+
self._cg_logits = None
|
41 |
+
self._cg_inference_params = None
|
42 |
+
self._cg_scale = None
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def from_pretrained(cls, repo_id: str, revision: str | None = None, device: str = "cuda") -> "Zonos":
|
46 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
47 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
48 |
+
return cls.from_local(config_path, model_path, device)
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def from_local(cls, config_path: str, model_path: str, device: str = "cuda") -> "Zonos":
|
52 |
+
config = ZonosConfig.from_dict(json.load(open(config_path)))
|
53 |
+
model = cls(config).to(device, torch.bfloat16)
|
54 |
+
model.autoencoder.dac.to(device)
|
55 |
+
|
56 |
+
sd = model.state_dict()
|
57 |
+
with safetensors.safe_open(model_path, framework="pt") as f:
|
58 |
+
for k in f.keys():
|
59 |
+
sd[k] = f.get_tensor(k)
|
60 |
+
model.load_state_dict(sd)
|
61 |
+
|
62 |
+
return model
|
63 |
+
|
64 |
+
def make_speaker_embedding(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
|
65 |
+
"""Generate a speaker embedding from an audio clip."""
|
66 |
+
if self.spk_clone_model is None:
|
67 |
+
self.spk_clone_model = SpeakerEmbeddingLDA()
|
68 |
+
_, spk_embedding = self.spk_clone_model(wav.to(self.spk_clone_model.device), sr)
|
69 |
+
return spk_embedding.unsqueeze(0).bfloat16()
|
70 |
+
|
71 |
+
def embed_codes(self, codes: torch.Tensor) -> torch.Tensor:
|
72 |
+
return sum(emb(codes[:, i]) for i, emb in enumerate(self.embeddings))
|
73 |
+
|
74 |
+
def apply_heads(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
75 |
+
return torch.stack([head(hidden_states) for head in self.heads], dim=1)
|
76 |
+
|
77 |
+
def _compute_logits(
|
78 |
+
self, hidden_states: torch.Tensor, inference_params: InferenceParams, cfg_scale: float
|
79 |
+
) -> torch.Tensor:
|
80 |
+
"""
|
81 |
+
Pass `hidden_states` into `backbone` and `multi_head`, applying
|
82 |
+
classifier-free guidance if `cfg_scale != 1.0`.
|
83 |
+
"""
|
84 |
+
last_hidden_states = self.backbone(hidden_states, inference_params)[:, -1, :].unsqueeze(1)
|
85 |
+
logits = self.apply_heads(last_hidden_states).squeeze(2).float()
|
86 |
+
if cfg_scale != 1.0:
|
87 |
+
cond_logits, uncond_logits = logits.chunk(2)
|
88 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
89 |
+
return logits
|
90 |
+
|
91 |
+
def _decode_one_token(
|
92 |
+
self,
|
93 |
+
input_ids: torch.Tensor,
|
94 |
+
inference_params: InferenceParams,
|
95 |
+
cfg_scale: float,
|
96 |
+
) -> torch.Tensor:
|
97 |
+
"""
|
98 |
+
Single-step decode. Prepares the hidden states, possibly replicates them
|
99 |
+
for CFG, and then delegates to `_compute_logits`.
|
100 |
+
|
101 |
+
Below we wrap this function with a simple CUDA Graph capturing mechanism,
|
102 |
+
doing 3 warmup steps if needed and then capturing or replaying the graph.
|
103 |
+
We only recapture if the batch size changes.
|
104 |
+
"""
|
105 |
+
# TODO: support cfg_scale==1
|
106 |
+
if cfg_scale == 1.0:
|
107 |
+
hidden_states = self.embed_codes(input_ids)
|
108 |
+
return self._compute_logits(hidden_states, inference_params, cfg_scale)
|
109 |
+
|
110 |
+
bsz = input_ids.size(0)
|
111 |
+
|
112 |
+
need_capture = (self._cg_graph is None) or (self._cg_batch_size != bsz)
|
113 |
+
|
114 |
+
if need_capture:
|
115 |
+
self._cg_graph = None
|
116 |
+
|
117 |
+
self._cg_batch_size = bsz
|
118 |
+
self._cg_inference_params = inference_params
|
119 |
+
self._cg_scale = cfg_scale
|
120 |
+
|
121 |
+
for _ in range(3):
|
122 |
+
hidden_states = self.embed_codes(input_ids)
|
123 |
+
hidden_states = hidden_states.repeat(2, 1, 1) # because cfg != 1.0
|
124 |
+
logits = self._compute_logits(hidden_states, inference_params, cfg_scale)
|
125 |
+
|
126 |
+
self._cg_input_ids = input_ids.clone()
|
127 |
+
self._cg_logits = torch.empty_like(logits)
|
128 |
+
|
129 |
+
g = torch.cuda.CUDAGraph()
|
130 |
+
|
131 |
+
def capture_region():
|
132 |
+
hidden_states_local = self.embed_codes(self._cg_input_ids)
|
133 |
+
hidden_states_local = hidden_states_local.repeat(2, 1, 1)
|
134 |
+
self._cg_logits = self._compute_logits(hidden_states_local, self._cg_inference_params, self._cg_scale)
|
135 |
+
|
136 |
+
with torch.cuda.graph(g):
|
137 |
+
capture_region()
|
138 |
+
|
139 |
+
self._cg_graph = g
|
140 |
+
|
141 |
+
else:
|
142 |
+
self._cg_input_ids.copy_(input_ids)
|
143 |
+
|
144 |
+
self._cg_graph.replay()
|
145 |
+
|
146 |
+
return self._cg_logits
|
147 |
+
|
148 |
+
def _prefill(
|
149 |
+
self,
|
150 |
+
prefix_hidden_states: torch.Tensor,
|
151 |
+
input_ids: torch.Tensor,
|
152 |
+
inference_params: InferenceParams,
|
153 |
+
cfg_scale: float,
|
154 |
+
) -> torch.Tensor:
|
155 |
+
"""
|
156 |
+
"Prefill" mode: we already have `prefix_hidden_states`, and we want
|
157 |
+
to append new embeddings, then compute the logits.
|
158 |
+
"""
|
159 |
+
# Replicate input_ids if CFG is enabled
|
160 |
+
if cfg_scale != 1.0:
|
161 |
+
input_ids = input_ids.expand(prefix_hidden_states.shape[0], -1, -1)
|
162 |
+
hidden_states = torch.cat([prefix_hidden_states, self.embed_codes(input_ids)], dim=1)
|
163 |
+
return self._compute_logits(hidden_states, inference_params, cfg_scale)
|
164 |
+
|
165 |
+
def setup_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16) -> InferenceParams:
|
166 |
+
key_value_memory_dict = {
|
167 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
|
168 |
+
for i, layer in enumerate(self.backbone.layers)
|
169 |
+
}
|
170 |
+
lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32, device="cuda")
|
171 |
+
return InferenceParams(max_seqlen, batch_size, 0, 0, key_value_memory_dict, lengths_per_sample)
|
172 |
+
|
173 |
+
def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) -> torch.Tensor:
|
174 |
+
if uncond_dict is None:
|
175 |
+
uncond_dict = {k: cond_dict[k] for k in self.prefix_conditioner.required_keys}
|
176 |
+
return torch.cat(
|
177 |
+
[
|
178 |
+
self.prefix_conditioner(cond_dict),
|
179 |
+
self.prefix_conditioner(uncond_dict),
|
180 |
+
]
|
181 |
+
)
|
182 |
+
|
183 |
+
@torch.inference_mode()
|
184 |
+
def generate(
|
185 |
+
self,
|
186 |
+
prefix_conditioning: torch.Tensor, # [bsz, cond_seq_len, d_model]
|
187 |
+
audio_prefix_codes: torch.Tensor | None = None, # [bsz, 9, prefix_audio_seq_len]
|
188 |
+
max_new_tokens: int = 86 * 30,
|
189 |
+
cfg_scale: float = 2.0,
|
190 |
+
batch_size: int = 1,
|
191 |
+
sampling_params: dict = dict(min_p=0.1),
|
192 |
+
progress_bar: bool = True,
|
193 |
+
callback: Callable[[torch.Tensor, int, int], bool] | None = None,
|
194 |
+
):
|
195 |
+
assert cfg_scale != 1, "TODO: add support for cfg_scale=1"
|
196 |
+
prefix_audio_len = 0 if audio_prefix_codes is None else audio_prefix_codes.shape[2]
|
197 |
+
|
198 |
+
unknown_token = -1
|
199 |
+
audio_seq_len = prefix_audio_len + max_new_tokens
|
200 |
+
seq_len = prefix_conditioning.shape[1] + audio_seq_len
|
201 |
+
|
202 |
+
inference_params = self.setup_cache(batch_size=batch_size * 2, max_seqlen=seq_len)
|
203 |
+
|
204 |
+
codes = torch.full((batch_size, 9, audio_seq_len), unknown_token, device="cuda")
|
205 |
+
if audio_prefix_codes is not None:
|
206 |
+
codes[..., :prefix_audio_len] = audio_prefix_codes
|
207 |
+
|
208 |
+
delayed_codes = apply_delay_pattern(codes, self.masked_token_id)
|
209 |
+
|
210 |
+
delayed_prefix_audio_codes = delayed_codes[..., : prefix_audio_len + 1]
|
211 |
+
|
212 |
+
logits = self._prefill(prefix_conditioning, delayed_prefix_audio_codes, inference_params, cfg_scale)
|
213 |
+
next_token = sample_from_logits(logits, **sampling_params)
|
214 |
+
|
215 |
+
offset = delayed_prefix_audio_codes.shape[2]
|
216 |
+
frame = delayed_codes[..., offset : offset + 1]
|
217 |
+
frame.masked_scatter_(frame == unknown_token, next_token)
|
218 |
+
|
219 |
+
prefix_length = prefix_conditioning.shape[1] + prefix_audio_len + 1
|
220 |
+
inference_params.seqlen_offset += prefix_length
|
221 |
+
inference_params.lengths_per_sample[:] += prefix_length
|
222 |
+
|
223 |
+
logit_bias = torch.zeros_like(logits)
|
224 |
+
logit_bias[:, 1:, self.eos_token_id] = -torch.inf # only allow codebook 0 to predict EOS
|
225 |
+
|
226 |
+
stopping = torch.zeros(batch_size, dtype=torch.bool, device="cuda")
|
227 |
+
max_steps = delayed_codes.shape[2] - offset
|
228 |
+
remaining_steps = torch.full((batch_size,), max_steps, device="cuda")
|
229 |
+
progress = tqdm(total=max_steps, desc="Generating", disable=not progress_bar)
|
230 |
+
|
231 |
+
step = 0
|
232 |
+
while torch.max(remaining_steps) > 0:
|
233 |
+
offset += 1
|
234 |
+
input_ids = delayed_codes[..., offset - 1 : offset]
|
235 |
+
logits = self._decode_one_token(input_ids, inference_params, cfg_scale)
|
236 |
+
|
237 |
+
next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)
|
238 |
+
eos_in_cb0 = next_token[:, 0] == self.eos_token_id
|
239 |
+
|
240 |
+
remaining_steps[eos_in_cb0[:, 0]] = torch.minimum(remaining_steps[eos_in_cb0[:, 0]], torch.tensor(9))
|
241 |
+
stopping |= eos_in_cb0[:, 0]
|
242 |
+
|
243 |
+
eos_codebook_idx = 9 - remaining_steps
|
244 |
+
eos_codebook_idx = torch.clamp(eos_codebook_idx, max=9 - 1)
|
245 |
+
for i in range(next_token.shape[0]):
|
246 |
+
if stopping[i]:
|
247 |
+
idx = eos_codebook_idx[i].item()
|
248 |
+
next_token[i, :idx] = self.masked_token_id
|
249 |
+
next_token[i, idx] = self.eos_token_id
|
250 |
+
|
251 |
+
frame = delayed_codes[..., offset : offset + 1]
|
252 |
+
frame.masked_scatter_(frame == unknown_token, next_token)
|
253 |
+
inference_params.seqlen_offset += 1
|
254 |
+
inference_params.lengths_per_sample[:] += 1
|
255 |
+
|
256 |
+
remaining_steps -= 1
|
257 |
+
|
258 |
+
progress.update()
|
259 |
+
step += 1
|
260 |
+
|
261 |
+
if callback is not None and not callback(frame, step, max_steps):
|
262 |
+
break
|
263 |
+
|
264 |
+
out_codes = revert_delay_pattern(delayed_codes)
|
265 |
+
out_codes.masked_fill_(out_codes >= 1024, 0)
|
266 |
+
out_codes = out_codes[..., : offset - 9]
|
267 |
+
|
268 |
+
self._cg_graph = None # reset cuda graph to avoid cache changes
|
269 |
+
|
270 |
+
return out_codes
|
zonos/sampling.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
5 |
+
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
input (torch.Tensor): The input tensor containing probabilities.
|
9 |
+
num_samples (int): Number of samples to draw.
|
10 |
+
replacement (bool): Whether to draw with replacement or not.
|
11 |
+
Keywords args:
|
12 |
+
generator (torch.Generator): A pseudorandom number generator for sampling.
|
13 |
+
Returns:
|
14 |
+
torch.Tensor: Last dimension contains num_samples indices
|
15 |
+
sampled from the multinomial probability distribution
|
16 |
+
located in the last dimension of tensor input.
|
17 |
+
"""
|
18 |
+
|
19 |
+
if num_samples == 1:
|
20 |
+
q = torch.empty_like(input).exponential_(1, generator=generator)
|
21 |
+
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
|
22 |
+
|
23 |
+
input_ = input.reshape(-1, input.shape[-1])
|
24 |
+
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
25 |
+
output = output_.reshape(*list(input.shape[:-1]), -1)
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
def apply_top_k(
|
30 |
+
probs: torch.Tensor,
|
31 |
+
k: int,
|
32 |
+
) -> torch.Tensor:
|
33 |
+
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
37 |
+
k (int): The k in “top-k”.
|
38 |
+
Returns:
|
39 |
+
torch.Tensor: Sampled tokens.
|
40 |
+
"""
|
41 |
+
v, _ = torch.topk(probs, min(k, probs.size(-1)))
|
42 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
43 |
+
probs = torch.where(probs < pivot, 0.0, probs)
|
44 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
45 |
+
return probs
|
46 |
+
|
47 |
+
|
48 |
+
def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
49 |
+
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
53 |
+
p (int): The p in “top-p”.
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: Sampled tokens.
|
56 |
+
"""
|
57 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
58 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
59 |
+
mask = probs_sum - probs_sort > p
|
60 |
+
probs_sort *= (~mask).float()
|
61 |
+
probs = probs.scatter(-1, probs_idx, probs_sort)
|
62 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
63 |
+
return probs
|
64 |
+
|
65 |
+
|
66 |
+
def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
|
67 |
+
"""Sample next token using min-p sampling.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
|
71 |
+
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
|
72 |
+
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
|
73 |
+
Returns:
|
74 |
+
torch.Tensor: Sampled tokens.
|
75 |
+
"""
|
76 |
+
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
77 |
+
tokens_to_remove = probs < (min_p * top_probs)
|
78 |
+
probs = probs.masked_fill(tokens_to_remove, 0.0)
|
79 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
80 |
+
return probs
|
81 |
+
|
82 |
+
|
83 |
+
def modify_logit_for_repetition_penalty(
|
84 |
+
logits: torch.Tensor,
|
85 |
+
generated_tokens: torch.Tensor,
|
86 |
+
repetition_penalty: float,
|
87 |
+
repetition_penalty_window: int,
|
88 |
+
):
|
89 |
+
"""See https://arxiv.org/abs/1909.05858
|
90 |
+
Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
|
91 |
+
logits: (batch_size, n_codebooks, vocab_size)
|
92 |
+
generated_tokens: (batch_size, n_codebooks, seq_len)
|
93 |
+
"""
|
94 |
+
generated_tokens = generated_tokens[..., -repetition_penalty_window:]
|
95 |
+
generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
|
96 |
+
rp = torch.full_like(logits, repetition_penalty)
|
97 |
+
factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
|
98 |
+
return torch.where(logits <= 0, logits * factors, logits / factors)
|
99 |
+
|
100 |
+
|
101 |
+
def sample_from_logits(
|
102 |
+
logits: torch.Tensor,
|
103 |
+
temperature: float = 1.0,
|
104 |
+
top_p: float = 0.0,
|
105 |
+
top_k: int = 0,
|
106 |
+
min_p: float = 0.0,
|
107 |
+
generated_tokens: torch.Tensor | None = None,
|
108 |
+
repetition_penalty: float = 3.0,
|
109 |
+
repetition_penalty_window: float = 2,
|
110 |
+
) -> torch.Tensor:
|
111 |
+
"""Sample next token from logits using temperature, top-p, top-k, or min-p sampling.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
logits (torch.Tensor): Input logits with token candidates on the last dimension.
|
115 |
+
temperature (float): Sampling temperature. Lower temperature results in more deterministic samples.
|
116 |
+
top_p (float): The p in “top-p”.
|
117 |
+
top_k (int): The k in “top-k”.
|
118 |
+
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
|
119 |
+
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
torch.Tensor: Sampled tokens.
|
123 |
+
"""
|
124 |
+
if repetition_penalty != 1.0 and generated_tokens is not None:
|
125 |
+
logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)
|
126 |
+
|
127 |
+
if temperature > 0:
|
128 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
129 |
+
|
130 |
+
if top_p > 0:
|
131 |
+
probs = apply_top_p(probs, top_p)
|
132 |
+
if top_k > 0:
|
133 |
+
probs = apply_top_k(probs, top_k)
|
134 |
+
if min_p > 0:
|
135 |
+
probs = apply_min_p(probs, min_p)
|
136 |
+
|
137 |
+
next_token = multinomial(probs, num_samples=1)
|
138 |
+
else:
|
139 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
140 |
+
|
141 |
+
return next_token # [batch_size, num_codebooks, 1]
|
zonos/speaker_cloning.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
import os
|
10 |
+
|
11 |
+
|
12 |
+
class logFbankCal(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
sample_rate: int = 16_000,
|
16 |
+
n_fft: int = 512,
|
17 |
+
win_length: float = 0.025,
|
18 |
+
hop_length: float = 0.01,
|
19 |
+
n_mels: int = 80,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.fbankCal = torchaudio.transforms.MelSpectrogram(
|
23 |
+
sample_rate=sample_rate,
|
24 |
+
n_fft=n_fft,
|
25 |
+
win_length=int(win_length * sample_rate),
|
26 |
+
hop_length=int(hop_length * sample_rate),
|
27 |
+
n_mels=n_mels,
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
out = self.fbankCal(x)
|
32 |
+
out = torch.log(out + 1e-6)
|
33 |
+
out = out - out.mean(axis=2).unsqueeze(dim=2)
|
34 |
+
return out
|
35 |
+
|
36 |
+
|
37 |
+
class ASP(nn.Module):
|
38 |
+
# Attentive statistics pooling
|
39 |
+
def __init__(self, in_planes, acoustic_dim):
|
40 |
+
super(ASP, self).__init__()
|
41 |
+
outmap_size = int(acoustic_dim / 8)
|
42 |
+
self.out_dim = in_planes * 8 * outmap_size * 2
|
43 |
+
|
44 |
+
self.attention = nn.Sequential(
|
45 |
+
nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.BatchNorm1d(128),
|
48 |
+
nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1),
|
49 |
+
nn.Softmax(dim=2),
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = x.reshape(x.size()[0], -1, x.size()[-1])
|
54 |
+
w = self.attention(x)
|
55 |
+
mu = torch.sum(x * w, dim=2)
|
56 |
+
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
|
57 |
+
x = torch.cat((mu, sg), 1)
|
58 |
+
|
59 |
+
x = x.view(x.size()[0], -1)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class SimAMBasicBlock(nn.Module):
|
64 |
+
expansion = 1
|
65 |
+
|
66 |
+
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
67 |
+
super(SimAMBasicBlock, self).__init__()
|
68 |
+
self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
69 |
+
self.bn1 = NormLayer(planes)
|
70 |
+
self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
71 |
+
self.bn2 = NormLayer(planes)
|
72 |
+
self.relu = nn.ReLU(inplace=True)
|
73 |
+
self.sigmoid = nn.Sigmoid()
|
74 |
+
|
75 |
+
self.downsample = nn.Sequential()
|
76 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
77 |
+
self.downsample = nn.Sequential(
|
78 |
+
ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
79 |
+
NormLayer(self.expansion * planes),
|
80 |
+
)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
84 |
+
out = self.bn2(self.conv2(out))
|
85 |
+
out = self.SimAM(out)
|
86 |
+
out += self.downsample(x)
|
87 |
+
out = self.relu(out)
|
88 |
+
return out
|
89 |
+
|
90 |
+
def SimAM(self, X, lambda_p=1e-4):
|
91 |
+
n = X.shape[2] * X.shape[3] - 1
|
92 |
+
d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
|
93 |
+
v = d.sum(dim=[2, 3], keepdim=True) / n
|
94 |
+
E_inv = d / (4 * (v + lambda_p)) + 0.5
|
95 |
+
return X * self.sigmoid(E_inv)
|
96 |
+
|
97 |
+
|
98 |
+
class BasicBlock(nn.Module):
|
99 |
+
expansion = 1
|
100 |
+
|
101 |
+
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
102 |
+
super(BasicBlock, self).__init__()
|
103 |
+
self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
104 |
+
self.bn1 = NormLayer(planes)
|
105 |
+
self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
106 |
+
self.bn2 = NormLayer(planes)
|
107 |
+
self.relu = nn.ReLU(inplace=True)
|
108 |
+
|
109 |
+
self.downsample = nn.Sequential()
|
110 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
111 |
+
self.downsample = nn.Sequential(
|
112 |
+
ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
113 |
+
NormLayer(self.expansion * planes),
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
118 |
+
out = self.bn2(self.conv2(out))
|
119 |
+
out += self.downsample(x)
|
120 |
+
out = self.relu(out)
|
121 |
+
return out
|
122 |
+
|
123 |
+
|
124 |
+
class Bottleneck(nn.Module):
|
125 |
+
expansion = 4
|
126 |
+
|
127 |
+
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
128 |
+
super(Bottleneck, self).__init__()
|
129 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
130 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
131 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
132 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
133 |
+
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
|
134 |
+
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
135 |
+
|
136 |
+
self.shortcut = nn.Sequential()
|
137 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
138 |
+
self.shortcut = nn.Sequential(
|
139 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
140 |
+
nn.BatchNorm2d(self.expansion * planes),
|
141 |
+
)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
145 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
146 |
+
out = self.bn3(self.conv3(out))
|
147 |
+
out += self.shortcut(x)
|
148 |
+
out = F.relu(out)
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
class ResNet(nn.Module):
|
153 |
+
def __init__(self, in_planes, block, num_blocks, in_ch=1, feat_dim="2d", **kwargs):
|
154 |
+
super(ResNet, self).__init__()
|
155 |
+
if feat_dim == "1d":
|
156 |
+
self.NormLayer = nn.BatchNorm1d
|
157 |
+
self.ConvLayer = nn.Conv1d
|
158 |
+
elif feat_dim == "2d":
|
159 |
+
self.NormLayer = nn.BatchNorm2d
|
160 |
+
self.ConvLayer = nn.Conv2d
|
161 |
+
elif feat_dim == "3d":
|
162 |
+
self.NormLayer = nn.BatchNorm3d
|
163 |
+
self.ConvLayer = nn.Conv3d
|
164 |
+
else:
|
165 |
+
print("error")
|
166 |
+
|
167 |
+
self.in_planes = in_planes
|
168 |
+
|
169 |
+
self.conv1 = self.ConvLayer(in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
|
170 |
+
self.bn1 = self.NormLayer(in_planes)
|
171 |
+
self.relu = nn.ReLU(inplace=True)
|
172 |
+
self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1, block_id=1)
|
173 |
+
self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2, block_id=2)
|
174 |
+
self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2, block_id=3)
|
175 |
+
self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2, block_id=4)
|
176 |
+
|
177 |
+
def _make_layer(self, block, planes, num_blocks, stride, block_id=1):
|
178 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
179 |
+
layers = []
|
180 |
+
for stride in strides:
|
181 |
+
layers.append(block(self.ConvLayer, self.NormLayer, self.in_planes, planes, stride, block_id))
|
182 |
+
self.in_planes = planes * block.expansion
|
183 |
+
return nn.Sequential(*layers)
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
187 |
+
x = self.layer1(x)
|
188 |
+
x = self.layer2(x)
|
189 |
+
x = self.layer3(x)
|
190 |
+
x = self.layer4(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
def ResNet293(in_planes: int, **kwargs):
|
195 |
+
return ResNet(in_planes, SimAMBasicBlock, [10, 20, 64, 3], **kwargs)
|
196 |
+
|
197 |
+
|
198 |
+
class ResNet293_based(nn.Module):
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
in_planes: int = 64,
|
202 |
+
embd_dim: int = 256,
|
203 |
+
acoustic_dim: int = 80,
|
204 |
+
featCal=None,
|
205 |
+
dropout: float = 0,
|
206 |
+
**kwargs,
|
207 |
+
):
|
208 |
+
super(ResNet293_based, self).__init__()
|
209 |
+
self.featCal = featCal
|
210 |
+
self.front = ResNet293(in_planes)
|
211 |
+
block_expansion = SimAMBasicBlock.expansion
|
212 |
+
self.pooling = ASP(in_planes * block_expansion, acoustic_dim)
|
213 |
+
self.bottleneck = nn.Linear(self.pooling.out_dim, embd_dim)
|
214 |
+
self.drop = nn.Dropout(dropout) if dropout else None
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
x = self.featCal(x)
|
218 |
+
x = self.front(x.unsqueeze(dim=1))
|
219 |
+
x = self.pooling(x)
|
220 |
+
if self.drop:
|
221 |
+
x = self.drop(x)
|
222 |
+
x = self.bottleneck(x)
|
223 |
+
return x
|
224 |
+
|
225 |
+
|
226 |
+
class SEModule(nn.Module):
|
227 |
+
def __init__(self, channels, bottleneck=128):
|
228 |
+
super(SEModule, self).__init__()
|
229 |
+
self.se = nn.Sequential(
|
230 |
+
nn.AdaptiveAvgPool1d(1),
|
231 |
+
nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
|
232 |
+
nn.ReLU(),
|
233 |
+
# nn.BatchNorm1d(bottleneck), # Removed
|
234 |
+
nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
|
235 |
+
nn.Sigmoid(),
|
236 |
+
)
|
237 |
+
|
238 |
+
def forward(self, input):
|
239 |
+
x = self.se(input)
|
240 |
+
return input * x
|
241 |
+
|
242 |
+
|
243 |
+
class Bottle2neck(nn.Module):
|
244 |
+
def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
|
245 |
+
super(Bottle2neck, self).__init__()
|
246 |
+
width = int(math.floor(planes / scale))
|
247 |
+
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
|
248 |
+
self.bn1 = nn.BatchNorm1d(width * scale)
|
249 |
+
self.nums = scale - 1
|
250 |
+
convs = []
|
251 |
+
bns = []
|
252 |
+
num_pad = math.floor(kernel_size / 2) * dilation
|
253 |
+
for i in range(self.nums):
|
254 |
+
convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
|
255 |
+
bns.append(nn.BatchNorm1d(width))
|
256 |
+
self.convs = nn.ModuleList(convs)
|
257 |
+
self.bns = nn.ModuleList(bns)
|
258 |
+
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
|
259 |
+
self.bn3 = nn.BatchNorm1d(planes)
|
260 |
+
self.relu = nn.ReLU()
|
261 |
+
self.width = width
|
262 |
+
self.se = SEModule(planes)
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
residual = x
|
266 |
+
out = self.conv1(x)
|
267 |
+
out = self.relu(out)
|
268 |
+
out = self.bn1(out)
|
269 |
+
|
270 |
+
spx = torch.split(out, self.width, 1)
|
271 |
+
for i in range(self.nums):
|
272 |
+
if i == 0:
|
273 |
+
sp = spx[i]
|
274 |
+
else:
|
275 |
+
sp = sp + spx[i]
|
276 |
+
sp = self.convs[i](sp)
|
277 |
+
sp = self.relu(sp)
|
278 |
+
sp = self.bns[i](sp)
|
279 |
+
if i == 0:
|
280 |
+
out = sp
|
281 |
+
else:
|
282 |
+
out = torch.cat((out, sp), 1)
|
283 |
+
out = torch.cat((out, spx[self.nums]), 1)
|
284 |
+
|
285 |
+
out = self.conv3(out)
|
286 |
+
out = self.relu(out)
|
287 |
+
out = self.bn3(out)
|
288 |
+
|
289 |
+
out = self.se(out)
|
290 |
+
out += residual
|
291 |
+
return out
|
292 |
+
|
293 |
+
|
294 |
+
class ECAPA_TDNN(nn.Module):
|
295 |
+
def __init__(self, C, featCal):
|
296 |
+
super(ECAPA_TDNN, self).__init__()
|
297 |
+
self.featCal = featCal
|
298 |
+
self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
|
299 |
+
self.relu = nn.ReLU()
|
300 |
+
self.bn1 = nn.BatchNorm1d(C)
|
301 |
+
self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
|
302 |
+
self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
|
303 |
+
self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
|
304 |
+
# I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
|
305 |
+
self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
|
306 |
+
self.attention = nn.Sequential(
|
307 |
+
nn.Conv1d(4608, 256, kernel_size=1),
|
308 |
+
nn.ReLU(),
|
309 |
+
nn.BatchNorm1d(256),
|
310 |
+
nn.Tanh(), # Added
|
311 |
+
nn.Conv1d(256, 1536, kernel_size=1),
|
312 |
+
nn.Softmax(dim=2),
|
313 |
+
)
|
314 |
+
self.bn5 = nn.BatchNorm1d(3072)
|
315 |
+
self.fc6 = nn.Linear(3072, 192)
|
316 |
+
self.bn6 = nn.BatchNorm1d(192)
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
x = self.featCal(x)
|
320 |
+
x = self.conv1(x)
|
321 |
+
x = self.relu(x)
|
322 |
+
x = self.bn1(x)
|
323 |
+
|
324 |
+
x1 = self.layer1(x)
|
325 |
+
x2 = self.layer2(x + x1)
|
326 |
+
x3 = self.layer3(x + x1 + x2)
|
327 |
+
|
328 |
+
x = self.layer4(torch.cat((x1, x2, x3), dim=1))
|
329 |
+
x = self.relu(x)
|
330 |
+
|
331 |
+
t = x.size()[-1]
|
332 |
+
|
333 |
+
global_x = torch.cat(
|
334 |
+
(
|
335 |
+
x,
|
336 |
+
torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
|
337 |
+
torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t),
|
338 |
+
),
|
339 |
+
dim=1,
|
340 |
+
)
|
341 |
+
|
342 |
+
w = self.attention(global_x)
|
343 |
+
|
344 |
+
mu = torch.sum(x * w, dim=2)
|
345 |
+
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4))
|
346 |
+
|
347 |
+
x = torch.cat((mu, sg), 1)
|
348 |
+
x = self.bn5(x)
|
349 |
+
x = self.fc6(x)
|
350 |
+
x = self.bn6(x)
|
351 |
+
|
352 |
+
return x
|
353 |
+
|
354 |
+
|
355 |
+
class SpeakerEmbedding(nn.Module):
|
356 |
+
def __init__(self, ckpt_path: str = "ResNet293_SimAM_ASP_base.pt", device: str = "cuda"):
|
357 |
+
super().__init__()
|
358 |
+
self.device = device
|
359 |
+
with torch.device(device):
|
360 |
+
self.model = ResNet293_based()
|
361 |
+
self.model.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
|
362 |
+
self.model.featCal = logFbankCal()
|
363 |
+
|
364 |
+
self.requires_grad_(False).eval()
|
365 |
+
|
366 |
+
@property
|
367 |
+
def dtype(self):
|
368 |
+
return next(self.parameters()).dtype
|
369 |
+
|
370 |
+
@cache
|
371 |
+
def _get_resampler(self, orig_sample_rate: int):
|
372 |
+
return torchaudio.transforms.Resample(orig_sample_rate, 16_000).to(self.device)
|
373 |
+
|
374 |
+
def prepare_input(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
375 |
+
assert wav.ndim < 3
|
376 |
+
if wav.ndim == 2:
|
377 |
+
wav = wav.mean(0, keepdim=True)
|
378 |
+
wav = self._get_resampler(sample_rate)(wav)
|
379 |
+
return wav
|
380 |
+
|
381 |
+
def forward(self, wav: torch.Tensor, sample_rate: int):
|
382 |
+
wav = self.prepare_input(wav, sample_rate).to(self.device, self.dtype)
|
383 |
+
return self.model(wav).to(wav.device)
|
384 |
+
|
385 |
+
class SpeakerEmbeddingLDA(nn.Module):
|
386 |
+
def __init__(
|
387 |
+
self,
|
388 |
+
device: str = "cuda",
|
389 |
+
):
|
390 |
+
super().__init__()
|
391 |
+
spk_model_path = hf_hub_download(repo_id="Zyphra/Zonos-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base.pt")
|
392 |
+
lda_spk_model_path = hf_hub_download(repo_id="Zyphra/Zonos-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base_LDA-128.pt")
|
393 |
+
|
394 |
+
self.device = device
|
395 |
+
with torch.device(device):
|
396 |
+
self.model = SpeakerEmbedding(spk_model_path, device)
|
397 |
+
lda_sd = torch.load(lda_spk_model_path, weights_only=True)
|
398 |
+
out_features, in_features = lda_sd["weight"].shape
|
399 |
+
self.lda = nn.Linear(in_features, out_features, bias=True, dtype=torch.float32)
|
400 |
+
self.lda.load_state_dict(lda_sd)
|
401 |
+
|
402 |
+
self.requires_grad_(False).eval()
|
403 |
+
|
404 |
+
def forward(self, wav: torch.Tensor, sample_rate: int):
|
405 |
+
emb = self.model(wav, sample_rate).to(torch.float32)
|
406 |
+
return emb, self.lda(emb)
|