Karl El Hajal commited on
Commit
a180d8c
·
1 Parent(s): f920a6c

Add kNN-TTS code and demo interface

Browse files
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Github
2
+ #
3
+ # SPDX-License-Identifier: CC0-1.0
4
+
5
+ # Source:
6
+ # https://github.com/github/gitignore/blob/8779ee73af62c669e7ca371aaab8399d87127693/Python.gitignore
7
+
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+ cover/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+ db.sqlite3-journal
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ .pybuilder/
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ # For a library or package, you might want to ignore these files since the code is
94
+ # intended to run in multiple environments; otherwise, check them in:
95
+ # .python-version
96
+
97
+ # pipenv
98
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
100
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
101
+ # install all needed dependencies.
102
+ #Pipfile.lock
103
+
104
+ # poetry
105
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109
+ #poetry.lock
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ #pdm.lock
114
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115
+ # in version control.
116
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
117
+ .pdm.toml
118
+ .pdm-python
119
+ .pdm-build/
120
+
121
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122
+ __pypackages__/
123
+
124
+ # Celery stuff
125
+ celerybeat-schedule
126
+ celerybeat.pid
127
+
128
+ # SageMath parsed files
129
+ *.sage.py
130
+
131
+ # Environments
132
+ .env
133
+ .venv
134
+ env/
135
+ venv/
136
+ ENV/
137
+ env.bak/
138
+ venv.bak/
139
+
140
+ # Spyder project settings
141
+ .spyderproject
142
+ .spyproject
143
+
144
+ # Rope project settings
145
+ .ropeproject
146
+
147
+ # mkdocs documentation
148
+ /site
149
+
150
+ # mypy
151
+ .mypy_cache/
152
+ .dmypy.json
153
+ dmypy.json
154
+
155
+ # Pyre type checker
156
+ .pyre/
157
+
158
+ # pytype static type analyzer
159
+ .pytype/
160
+
161
+ # Cython debug symbols
162
+ cython_debug/
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ #.idea/
LICENSES/CC0-1.0.txt ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Creative Commons Legal Code
2
+
3
+ CC0 1.0 Universal
4
+
5
+ CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
6
+ LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
7
+ ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
8
+ INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
9
+ REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
10
+ PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
11
+ THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
12
+ HEREUNDER.
13
+
14
+ Statement of Purpose
15
+
16
+ The laws of most jurisdictions throughout the world automatically confer
17
+ exclusive Copyright and Related Rights (defined below) upon the creator
18
+ and subsequent owner(s) (each and all, an "owner") of an original work of
19
+ authorship and/or a database (each, a "Work").
20
+
21
+ Certain owners wish to permanently relinquish those rights to a Work for
22
+ the purpose of contributing to a commons of creative, cultural and
23
+ scientific works ("Commons") that the public can reliably and without fear
24
+ of later claims of infringement build upon, modify, incorporate in other
25
+ works, reuse and redistribute as freely as possible in any form whatsoever
26
+ and for any purposes, including without limitation commercial purposes.
27
+ These owners may contribute to the Commons to promote the ideal of a free
28
+ culture and the further production of creative, cultural and scientific
29
+ works, or to gain reputation or greater distribution for their Work in
30
+ part through the use and efforts of others.
31
+
32
+ For these and/or other purposes and motivations, and without any
33
+ expectation of additional consideration or compensation, the person
34
+ associating CC0 with a Work (the "Affirmer"), to the extent that he or she
35
+ is an owner of Copyright and Related Rights in the Work, voluntarily
36
+ elects to apply CC0 to the Work and publicly distribute the Work under its
37
+ terms, with knowledge of his or her Copyright and Related Rights in the
38
+ Work and the meaning and intended legal effect of CC0 on those rights.
39
+
40
+ 1. Copyright and Related Rights. A Work made available under CC0 may be
41
+ protected by copyright and related or neighboring rights ("Copyright and
42
+ Related Rights"). Copyright and Related Rights include, but are not
43
+ limited to, the following:
44
+
45
+ i. the right to reproduce, adapt, distribute, perform, display,
46
+ communicate, and translate a Work;
47
+ ii. moral rights retained by the original author(s) and/or performer(s);
48
+ iii. publicity and privacy rights pertaining to a person's image or
49
+ likeness depicted in a Work;
50
+ iv. rights protecting against unfair competition in regards to a Work,
51
+ subject to the limitations in paragraph 4(a), below;
52
+ v. rights protecting the extraction, dissemination, use and reuse of data
53
+ in a Work;
54
+ vi. database rights (such as those arising under Directive 96/9/EC of the
55
+ European Parliament and of the Council of 11 March 1996 on the legal
56
+ protection of databases, and under any national implementation
57
+ thereof, including any amended or successor version of such
58
+ directive); and
59
+ vii. other similar, equivalent or corresponding rights throughout the
60
+ world based on applicable law or treaty, and any national
61
+ implementations thereof.
62
+
63
+ 2. Waiver. To the greatest extent permitted by, but not in contravention
64
+ of, applicable law, Affirmer hereby overtly, fully, permanently,
65
+ irrevocably and unconditionally waives, abandons, and surrenders all of
66
+ Affirmer's Copyright and Related Rights and associated claims and causes
67
+ of action, whether now known or unknown (including existing as well as
68
+ future claims and causes of action), in the Work (i) in all territories
69
+ worldwide, (ii) for the maximum duration provided by applicable law or
70
+ treaty (including future time extensions), (iii) in any current or future
71
+ medium and for any number of copies, and (iv) for any purpose whatsoever,
72
+ including without limitation commercial, advertising or promotional
73
+ purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
74
+ member of the public at large and to the detriment of Affirmer's heirs and
75
+ successors, fully intending that such Waiver shall not be subject to
76
+ revocation, rescission, cancellation, termination, or any other legal or
77
+ equitable action to disrupt the quiet enjoyment of the Work by the public
78
+ as contemplated by Affirmer's express Statement of Purpose.
79
+
80
+ 3. Public License Fallback. Should any part of the Waiver for any reason
81
+ be judged legally invalid or ineffective under applicable law, then the
82
+ Waiver shall be preserved to the maximum extent permitted taking into
83
+ account Affirmer's express Statement of Purpose. In addition, to the
84
+ extent the Waiver is so judged Affirmer hereby grants to each affected
85
+ person a royalty-free, non transferable, non sublicensable, non exclusive,
86
+ irrevocable and unconditional license to exercise Affirmer's Copyright and
87
+ Related Rights in the Work (i) in all territories worldwide, (ii) for the
88
+ maximum duration provided by applicable law or treaty (including future
89
+ time extensions), (iii) in any current or future medium and for any number
90
+ of copies, and (iv) for any purpose whatsoever, including without
91
+ limitation commercial, advertising or promotional purposes (the
92
+ "License"). The License shall be deemed effective as of the date CC0 was
93
+ applied by Affirmer to the Work. Should any part of the License for any
94
+ reason be judged legally invalid or ineffective under applicable law, such
95
+ partial invalidity or ineffectiveness shall not invalidate the remainder
96
+ of the License, and in such case Affirmer hereby affirms that he or she
97
+ will not (i) exercise any of his or her remaining Copyright and Related
98
+ Rights in the Work or (ii) assert any associated claims and causes of
99
+ action with respect to the Work, in either case contrary to Affirmer's
100
+ express Statement of Purpose.
101
+
102
+ 4. Limitations and Disclaimers.
103
+
104
+ a. No trademark or patent rights held by Affirmer are waived, abandoned,
105
+ surrendered, licensed or otherwise affected by this document.
106
+ b. Affirmer offers the Work as-is and makes no representations or
107
+ warranties of any kind concerning the Work, express, implied,
108
+ statutory or otherwise, including without limitation warranties of
109
+ title, merchantability, fitness for a particular purpose, non
110
+ infringement, or the absence of latent or other defects, accuracy, or
111
+ the present or absence of errors, whether or not discoverable, all to
112
+ the greatest extent permissible under applicable law.
113
+ c. Affirmer disclaims responsibility for clearing rights of other persons
114
+ that may apply to the Work or any use thereof, including without
115
+ limitation any person's Copyright and Related Rights in the Work.
116
+ Further, Affirmer disclaims responsibility for obtaining any necessary
117
+ consents, permissions or other rights required for any use of the
118
+ Work.
119
+ d. Affirmer understands and acknowledges that Creative Commons is not a
120
+ party to this document and has no duty or obligation with respect to
121
+ this CC0 or use of the Work.
LICENSES/MIT.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) <year> <copyright holders>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
LICENSES/MPL-2.0.txt ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Mozilla Public License Version 2.0
2
+ ==================================
3
+
4
+ 1. Definitions
5
+ --------------
6
+
7
+ 1.1. "Contributor"
8
+ means each individual or legal entity that creates, contributes to
9
+ the creation of, or owns Covered Software.
10
+
11
+ 1.2. "Contributor Version"
12
+ means the combination of the Contributions of others (if any) used
13
+ by a Contributor and that particular Contributor's Contribution.
14
+
15
+ 1.3. "Contribution"
16
+ means Covered Software of a particular Contributor.
17
+
18
+ 1.4. "Covered Software"
19
+ means Source Code Form to which the initial Contributor has attached
20
+ the notice in Exhibit A, the Executable Form of such Source Code
21
+ Form, and Modifications of such Source Code Form, in each case
22
+ including portions thereof.
23
+
24
+ 1.5. "Incompatible With Secondary Licenses"
25
+ means
26
+
27
+ (a) that the initial Contributor has attached the notice described
28
+ in Exhibit B to the Covered Software; or
29
+
30
+ (b) that the Covered Software was made available under the terms of
31
+ version 1.1 or earlier of the License, but not also under the
32
+ terms of a Secondary License.
33
+
34
+ 1.6. "Executable Form"
35
+ means any form of the work other than Source Code Form.
36
+
37
+ 1.7. "Larger Work"
38
+ means a work that combines Covered Software with other material, in
39
+ a separate file or files, that is not Covered Software.
40
+
41
+ 1.8. "License"
42
+ means this document.
43
+
44
+ 1.9. "Licensable"
45
+ means having the right to grant, to the maximum extent possible,
46
+ whether at the time of the initial grant or subsequently, any and
47
+ all of the rights conveyed by this License.
48
+
49
+ 1.10. "Modifications"
50
+ means any of the following:
51
+
52
+ (a) any file in Source Code Form that results from an addition to,
53
+ deletion from, or modification of the contents of Covered
54
+ Software; or
55
+
56
+ (b) any new file in Source Code Form that contains any Covered
57
+ Software.
58
+
59
+ 1.11. "Patent Claims" of a Contributor
60
+ means any patent claim(s), including without limitation, method,
61
+ process, and apparatus claims, in any patent Licensable by such
62
+ Contributor that would be infringed, but for the grant of the
63
+ License, by the making, using, selling, offering for sale, having
64
+ made, import, or transfer of either its Contributions or its
65
+ Contributor Version.
66
+
67
+ 1.12. "Secondary License"
68
+ means either the GNU General Public License, Version 2.0, the GNU
69
+ Lesser General Public License, Version 2.1, the GNU Affero General
70
+ Public License, Version 3.0, or any later versions of those
71
+ licenses.
72
+
73
+ 1.13. "Source Code Form"
74
+ means the form of the work preferred for making modifications.
75
+
76
+ 1.14. "You" (or "Your")
77
+ means an individual or a legal entity exercising rights under this
78
+ License. For legal entities, "You" includes any entity that
79
+ controls, is controlled by, or is under common control with You. For
80
+ purposes of this definition, "control" means (a) the power, direct
81
+ or indirect, to cause the direction or management of such entity,
82
+ whether by contract or otherwise, or (b) ownership of more than
83
+ fifty percent (50%) of the outstanding shares or beneficial
84
+ ownership of such entity.
85
+
86
+ 2. License Grants and Conditions
87
+ --------------------------------
88
+
89
+ 2.1. Grants
90
+
91
+ Each Contributor hereby grants You a world-wide, royalty-free,
92
+ non-exclusive license:
93
+
94
+ (a) under intellectual property rights (other than patent or trademark)
95
+ Licensable by such Contributor to use, reproduce, make available,
96
+ modify, display, perform, distribute, and otherwise exploit its
97
+ Contributions, either on an unmodified basis, with Modifications, or
98
+ as part of a Larger Work; and
99
+
100
+ (b) under Patent Claims of such Contributor to make, use, sell, offer
101
+ for sale, have made, import, and otherwise transfer either its
102
+ Contributions or its Contributor Version.
103
+
104
+ 2.2. Effective Date
105
+
106
+ The licenses granted in Section 2.1 with respect to any Contribution
107
+ become effective for each Contribution on the date the Contributor first
108
+ distributes such Contribution.
109
+
110
+ 2.3. Limitations on Grant Scope
111
+
112
+ The licenses granted in this Section 2 are the only rights granted under
113
+ this License. No additional rights or licenses will be implied from the
114
+ distribution or licensing of Covered Software under this License.
115
+ Notwithstanding Section 2.1(b) above, no patent license is granted by a
116
+ Contributor:
117
+
118
+ (a) for any code that a Contributor has removed from Covered Software;
119
+ or
120
+
121
+ (b) for infringements caused by: (i) Your and any other third party's
122
+ modifications of Covered Software, or (ii) the combination of its
123
+ Contributions with other software (except as part of its Contributor
124
+ Version); or
125
+
126
+ (c) under Patent Claims infringed by Covered Software in the absence of
127
+ its Contributions.
128
+
129
+ This License does not grant any rights in the trademarks, service marks,
130
+ or logos of any Contributor (except as may be necessary to comply with
131
+ the notice requirements in Section 3.4).
132
+
133
+ 2.4. Subsequent Licenses
134
+
135
+ No Contributor makes additional grants as a result of Your choice to
136
+ distribute the Covered Software under a subsequent version of this
137
+ License (see Section 10.2) or under the terms of a Secondary License (if
138
+ permitted under the terms of Section 3.3).
139
+
140
+ 2.5. Representation
141
+
142
+ Each Contributor represents that the Contributor believes its
143
+ Contributions are its original creation(s) or it has sufficient rights
144
+ to grant the rights to its Contributions conveyed by this License.
145
+
146
+ 2.6. Fair Use
147
+
148
+ This License is not intended to limit any rights You have under
149
+ applicable copyright doctrines of fair use, fair dealing, or other
150
+ equivalents.
151
+
152
+ 2.7. Conditions
153
+
154
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
155
+ in Section 2.1.
156
+
157
+ 3. Responsibilities
158
+ -------------------
159
+
160
+ 3.1. Distribution of Source Form
161
+
162
+ All distribution of Covered Software in Source Code Form, including any
163
+ Modifications that You create or to which You contribute, must be under
164
+ the terms of this License. You must inform recipients that the Source
165
+ Code Form of the Covered Software is governed by the terms of this
166
+ License, and how they can obtain a copy of this License. You may not
167
+ attempt to alter or restrict the recipients' rights in the Source Code
168
+ Form.
169
+
170
+ 3.2. Distribution of Executable Form
171
+
172
+ If You distribute Covered Software in Executable Form then:
173
+
174
+ (a) such Covered Software must also be made available in Source Code
175
+ Form, as described in Section 3.1, and You must inform recipients of
176
+ the Executable Form how they can obtain a copy of such Source Code
177
+ Form by reasonable means in a timely manner, at a charge no more
178
+ than the cost of distribution to the recipient; and
179
+
180
+ (b) You may distribute such Executable Form under the terms of this
181
+ License, or sublicense it under different terms, provided that the
182
+ license for the Executable Form does not attempt to limit or alter
183
+ the recipients' rights in the Source Code Form under this License.
184
+
185
+ 3.3. Distribution of a Larger Work
186
+
187
+ You may create and distribute a Larger Work under terms of Your choice,
188
+ provided that You also comply with the requirements of this License for
189
+ the Covered Software. If the Larger Work is a combination of Covered
190
+ Software with a work governed by one or more Secondary Licenses, and the
191
+ Covered Software is not Incompatible With Secondary Licenses, this
192
+ License permits You to additionally distribute such Covered Software
193
+ under the terms of such Secondary License(s), so that the recipient of
194
+ the Larger Work may, at their option, further distribute the Covered
195
+ Software under the terms of either this License or such Secondary
196
+ License(s).
197
+
198
+ 3.4. Notices
199
+
200
+ You may not remove or alter the substance of any license notices
201
+ (including copyright notices, patent notices, disclaimers of warranty,
202
+ or limitations of liability) contained within the Source Code Form of
203
+ the Covered Software, except that You may alter any license notices to
204
+ the extent required to remedy known factual inaccuracies.
205
+
206
+ 3.5. Application of Additional Terms
207
+
208
+ You may choose to offer, and to charge a fee for, warranty, support,
209
+ indemnity or liability obligations to one or more recipients of Covered
210
+ Software. However, You may do so only on Your own behalf, and not on
211
+ behalf of any Contributor. You must make it absolutely clear that any
212
+ such warranty, support, indemnity, or liability obligation is offered by
213
+ You alone, and You hereby agree to indemnify every Contributor for any
214
+ liability incurred by such Contributor as a result of warranty, support,
215
+ indemnity or liability terms You offer. You may include additional
216
+ disclaimers of warranty and limitations of liability specific to any
217
+ jurisdiction.
218
+
219
+ 4. Inability to Comply Due to Statute or Regulation
220
+ ---------------------------------------------------
221
+
222
+ If it is impossible for You to comply with any of the terms of this
223
+ License with respect to some or all of the Covered Software due to
224
+ statute, judicial order, or regulation then You must: (a) comply with
225
+ the terms of this License to the maximum extent possible; and (b)
226
+ describe the limitations and the code they affect. Such description must
227
+ be placed in a text file included with all distributions of the Covered
228
+ Software under this License. Except to the extent prohibited by statute
229
+ or regulation, such description must be sufficiently detailed for a
230
+ recipient of ordinary skill to be able to understand it.
231
+
232
+ 5. Termination
233
+ --------------
234
+
235
+ 5.1. The rights granted under this License will terminate automatically
236
+ if You fail to comply with any of its terms. However, if You become
237
+ compliant, then the rights granted under this License from a particular
238
+ Contributor are reinstated (a) provisionally, unless and until such
239
+ Contributor explicitly and finally terminates Your grants, and (b) on an
240
+ ongoing basis, if such Contributor fails to notify You of the
241
+ non-compliance by some reasonable means prior to 60 days after You have
242
+ come back into compliance. Moreover, Your grants from a particular
243
+ Contributor are reinstated on an ongoing basis if such Contributor
244
+ notifies You of the non-compliance by some reasonable means, this is the
245
+ first time You have received notice of non-compliance with this License
246
+ from such Contributor, and You become compliant prior to 30 days after
247
+ Your receipt of the notice.
248
+
249
+ 5.2. If You initiate litigation against any entity by asserting a patent
250
+ infringement claim (excluding declaratory judgment actions,
251
+ counter-claims, and cross-claims) alleging that a Contributor Version
252
+ directly or indirectly infringes any patent, then the rights granted to
253
+ You by any and all Contributors for the Covered Software under Section
254
+ 2.1 of this License shall terminate.
255
+
256
+ 5.3. In the event of termination under Sections 5.1 or 5.2 above, all
257
+ end user license agreements (excluding distributors and resellers) which
258
+ have been validly granted by You or Your distributors under this License
259
+ prior to termination shall survive termination.
260
+
261
+ ************************************************************************
262
+ * *
263
+ * 6. Disclaimer of Warranty *
264
+ * ------------------------- *
265
+ * *
266
+ * Covered Software is provided under this License on an "as is" *
267
+ * basis, without warranty of any kind, either expressed, implied, or *
268
+ * statutory, including, without limitation, warranties that the *
269
+ * Covered Software is free of defects, merchantable, fit for a *
270
+ * particular purpose or non-infringing. The entire risk as to the *
271
+ * quality and performance of the Covered Software is with You. *
272
+ * Should any Covered Software prove defective in any respect, You *
273
+ * (not any Contributor) assume the cost of any necessary servicing, *
274
+ * repair, or correction. This disclaimer of warranty constitutes an *
275
+ * essential part of this License. No use of any Covered Software is *
276
+ * authorized under this License except under this disclaimer. *
277
+ * *
278
+ ************************************************************************
279
+
280
+ ************************************************************************
281
+ * *
282
+ * 7. Limitation of Liability *
283
+ * -------------------------- *
284
+ * *
285
+ * Under no circumstances and under no legal theory, whether tort *
286
+ * (including negligence), contract, or otherwise, shall any *
287
+ * Contributor, or anyone who distributes Covered Software as *
288
+ * permitted above, be liable to You for any direct, indirect, *
289
+ * special, incidental, or consequential damages of any character *
290
+ * including, without limitation, damages for lost profits, loss of *
291
+ * goodwill, work stoppage, computer failure or malfunction, or any *
292
+ * and all other commercial damages or losses, even if such party *
293
+ * shall have been informed of the possibility of such damages. This *
294
+ * limitation of liability shall not apply to liability for death or *
295
+ * personal injury resulting from such party's negligence to the *
296
+ * extent applicable law prohibits such limitation. Some *
297
+ * jurisdictions do not allow the exclusion or limitation of *
298
+ * incidental or consequential damages, so this exclusion and *
299
+ * limitation may not apply to You. *
300
+ * *
301
+ ************************************************************************
302
+
303
+ 8. Litigation
304
+ -------------
305
+
306
+ Any litigation relating to this License may be brought only in the
307
+ courts of a jurisdiction where the defendant maintains its principal
308
+ place of business and such litigation shall be governed by laws of that
309
+ jurisdiction, without reference to its conflict-of-law provisions.
310
+ Nothing in this Section shall prevent a party's ability to bring
311
+ cross-claims or counter-claims.
312
+
313
+ 9. Miscellaneous
314
+ ----------------
315
+
316
+ This License represents the complete agreement concerning the subject
317
+ matter hereof. If any provision of this License is held to be
318
+ unenforceable, such provision shall be reformed only to the extent
319
+ necessary to make it enforceable. Any law or regulation which provides
320
+ that the language of a contract shall be construed against the drafter
321
+ shall not be used to construe this License against a Contributor.
322
+
323
+ 10. Versions of the License
324
+ ---------------------------
325
+
326
+ 10.1. New Versions
327
+
328
+ Mozilla Foundation is the license steward. Except as provided in Section
329
+ 10.3, no one other than the license steward has the right to modify or
330
+ publish new versions of this License. Each version will be given a
331
+ distinguishing version number.
332
+
333
+ 10.2. Effect of New Versions
334
+
335
+ You may distribute the Covered Software under the terms of the version
336
+ of the License under which You originally received the Covered Software,
337
+ or under the terms of any subsequent version published by the license
338
+ steward.
339
+
340
+ 10.3. Modified Versions
341
+
342
+ If you create software not governed by this License, and you want to
343
+ create a new license for such software, you may create and use a
344
+ modified version of this License if you rename the license and remove
345
+ any references to the name of the license steward (except to note that
346
+ such modified license differs from this License).
347
+
348
+ 10.4. Distributing Source Code Form that is Incompatible With Secondary
349
+ Licenses
350
+
351
+ If You choose to distribute Source Code Form that is Incompatible With
352
+ Secondary Licenses under the terms of this version of the License, the
353
+ notice described in Exhibit B of this License must be attached.
354
+
355
+ Exhibit A - Source Code Form License Notice
356
+ -------------------------------------------
357
+
358
+ This Source Code Form is subject to the terms of the Mozilla Public
359
+ License, v. 2.0. If a copy of the MPL was not distributed with this
360
+ file, You can obtain one at https://mozilla.org/MPL/2.0/.
361
+
362
+ If it is not possible or desirable to put the notice in a particular
363
+ file, then You may include the notice in a location (such as a LICENSE
364
+ file in a relevant directory) where a recipient would be likely to look
365
+ for such a notice.
366
+
367
+ You may add additional accurate notices of copyright ownership.
368
+
369
+ Exhibit B - "Incompatible With Secondary Licenses" Notice
370
+ ---------------------------------------------------------
371
+
372
+ This Source Code Form is "Incompatible With Secondary Licenses", as
373
+ defined by the Mozilla Public License, v. 2.0.
REUSE.toml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024 Idiap Research Institute
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ #
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ version = 1
7
+ SPDX-PackageName = "knn_tts"
8
+ SPDX-PackageSupplier = "Karl El Hajal <[email protected]>"
9
+
10
+ [[annotations]]
11
+ path = "poetry.lock"
12
+ precedence = "aggregate"
13
+ SPDX-FileCopyrightText = "2024 Idiap Research Institute"
14
+ SPDX-License-Identifier = "CC0-1.0"
15
+
16
+ [[annotations]]
17
+ path = "knn_tts/vocoder/hifigan/config_v1_wavlm.json"
18
+ precedence = "aggregate"
19
+ SPDX-FileCopyrightText = "2020 Jungil Kong"
20
+ SPDX-License-Identifier = "MIT"
21
+
22
+ [[annotations]]
23
+ path = "assets/diagram.png"
24
+ precedence = "aggregate"
25
+ SPDX-FileCopyrightText = "2024 Idiap Research Institute"
26
+ SPDX-License-Identifier = "MIT"
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024 Idiap Research Institute
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ #
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ import os
7
+ import zipfile
8
+ import gradio as gr
9
+ import spaces
10
+ from huggingface_hub import snapshot_download
11
+
12
+ from knn_tts.synthesizer import Synthesizer
13
+ from knn_tts.utils import get_vocoder_checkpoint_path
14
+
15
+ # Check if target_feats directory exists, if not, unzip target_feats.zip
16
+ if not os.path.exists("target_feats"):
17
+ if os.path.exists("target_feats.zip"):
18
+ with zipfile.ZipFile("target_feats.zip", "r") as zip_ref:
19
+ zip_ref.extractall(".")
20
+ else:
21
+ raise FileNotFoundError("target_feats.zip not found.")
22
+
23
+ SAMPLE_RATE = 16000
24
+
25
+ CHECKPOINTS_DIR = "./checkpoints"
26
+
27
+ tts_checkpoints_dir = snapshot_download(repo_id="idiap/kNN-TTS", local_dir=CHECKPOINTS_DIR)
28
+ vocoder_checkpoint_path = get_vocoder_checkpoint_path(CHECKPOINTS_DIR)
29
+
30
+ tts_checkpoint_name = "best_model_646135.pth"
31
+ synthesizer = Synthesizer(tts_checkpoints_dir, tts_checkpoint_name, vocoder_checkpoint_path, model_name="glowtts")
32
+
33
+ target_speakers = {
34
+ "Libri 7127":{
35
+ "feats_path": "target_feats/LibriSpeech-test-clean/7127/wavlm",
36
+ },
37
+ "Libri 7729":{
38
+ "feats_path": "target_feats/LibriSpeech-test-clean/7729/wavlm",
39
+ },
40
+ "Libri 6829":{
41
+ "feats_path": "target_feats/LibriSpeech-test-clean/6829/wavlm",
42
+ },
43
+ "Libri 8555":{
44
+ "feats_path": "target_feats/LibriSpeech-test-clean/8555/wavlm",
45
+ },
46
+ "Thorsten Neutral": {
47
+ "feats_path": "target_feats/Thorsten/neutral/wavlm/",
48
+ },
49
+ "Thorsten Whisper": {
50
+ "feats_path": "target_feats/Thorsten/whisper/wavlm/",
51
+ },
52
+ "ESD 0018 Neutral":{
53
+ "feats_path": "target_feats/ESD/0018/neutral/wavlm/",
54
+ },
55
+ "ESD 0018 Surprised":{
56
+ "feats_path": "target_feats/ESD/0018/surprised/wavlm/",
57
+ },
58
+ }
59
+
60
+ @spaces.GPU
61
+ def run(text_input, target_speaker, lambda_rate, topk, weighted_average):
62
+ feats_path = target_speakers[target_speaker]["feats_path"]
63
+ wav = synthesizer(text_input, feats_path, interpolation_rate=lambda_rate, knnvc_topk=topk, weighted_average=weighted_average, max_target_num_files=500)
64
+ wav = (SAMPLE_RATE, wav.squeeze().cpu().numpy())
65
+ return wav
66
+
67
+
68
+ def get_title(text, size=1):
69
+ return f"""
70
+ <center>
71
+
72
+ <h{size}> {text} </h{size}>
73
+
74
+ </center>
75
+ """
76
+
77
+ def create_gradio_interface():
78
+ with gr.Blocks(
79
+ theme=gr.themes.Default(
80
+ text_size="lg",
81
+ ),
82
+ title="kNN-TTS"
83
+ ) as iface:
84
+
85
+ gr.HTML(get_title("kNN-TTS: kNN Retrieval for Simple and Effective Zero-Shot Multi-speaker Text-to-Speech", size=1))
86
+
87
+ with gr.Tabs():
88
+ with gr.TabItem("Generate Speech"):
89
+ with gr.Row():
90
+ # Left column - inputs
91
+ with gr.Column():
92
+ gr.Markdown("## Input")
93
+ text_box = gr.Textbox(
94
+ lines=3,
95
+ placeholder="Enter the text to convert to speech...",
96
+ label="Text",
97
+ elem_id="text-input"
98
+ )
99
+
100
+ target_speaker_dropdown = gr.Dropdown(
101
+ choices=list(target_speakers.keys()),
102
+ value="Libri 7127",
103
+ label="Target Voice",
104
+ elem_id="target-voice"
105
+ )
106
+
107
+ rate_slider = gr.Slider(
108
+ minimum=0.0,
109
+ maximum=2.0,
110
+ value=1.0,
111
+ step=0.01,
112
+ label="Voice Morphing (λ)",
113
+ info="Higher values give more weight to target voice characteristics"
114
+ )
115
+
116
+ with gr.Accordion("Advanced Settings", open=False):
117
+ k_slider = gr.Slider(
118
+ minimum=1,
119
+ maximum=50,
120
+ value=4,
121
+ step=1,
122
+ label="Top-k Retrieval",
123
+ info="k closest neighbors to retrieve"
124
+ )
125
+ weighted_toggle = gr.Checkbox(
126
+ label="Use Weighted Averaging",
127
+ value=False,
128
+ info="Weight neighbors by similarity distance"
129
+ )
130
+
131
+ submit_button = gr.Button("Generate Audio", variant="primary", size="lg")
132
+
133
+ # Right column - outputs
134
+ with gr.Column():
135
+ gr.Markdown("## Generated Audio")
136
+ with gr.Group():
137
+ audio_output = gr.Audio(
138
+ type="numpy",
139
+ label="Output Speech",
140
+ elem_id="audio-output"
141
+ )
142
+ with gr.Row():
143
+ clear_btn = gr.ClearButton([text_box, target_speaker_dropdown, rate_slider, audio_output], variant="secondary", size="lg")
144
+
145
+ # Example section
146
+ with gr.Row():
147
+ gr.Examples(
148
+ examples=[
149
+ ["I think foosball is a combination of football and shish kebabs.", "Thorsten Whisper", 1.0, 8, True],
150
+ ["I think foosball is a combination of football and shish kebabs.", "Thorsten Neutral", 1.0, 4, False],
151
+ ["If you're traveling in the north country fair.", "Libri 7127", 1.0, 4, False],
152
+ ["Like a vision she dances across the porch as the radio plays.", "Libri 7729", 1.0, 8, True],
153
+ ["There weren't another other way to be.", "Libri 6829", 1.0, 4, False],
154
+ ],
155
+ inputs=[text_box, target_speaker_dropdown, rate_slider, k_slider, weighted_toggle],
156
+ outputs=audio_output,
157
+ fn=run,
158
+ cache_examples=True
159
+ )
160
+
161
+ # Additional tabs
162
+ with gr.TabItem("Model Details"):
163
+ with gr.Row():
164
+ with gr.Column():
165
+ gr.Markdown("""
166
+ ## kNN-TTS Technical Details
167
+
168
+ kNN-TTS uses self-supervised learning (SSL) features and kNN retrieval to achieve robust zero-shot multi-speaker TTS.
169
+
170
+ ### Key Components
171
+
172
+ 1. **Feature Extraction**: We extract discrete representations from target speaker speech using a pre-trained SSL encoder. We use the 6th layer of WavLM Large.
173
+ 2. **Text-to-SSL**: We train a lightweight TTS model to predict the same representations from Text. For simplicity, we train on a single speaker dataset.
174
+ 3. **Retrieval Mechanism**: We use kNN to find for each unit in the generated features its closest matches in the target voice unit database
175
+ 4. **Voice Morphing**: By linearly interpolating the source and selected target speaker features, we can morph the two voices. The interpolation parameter λ controls the balance between source and target characteristics
176
+ 5. **Vocoder**: We use a pre-trained vocoder to convert the converted features to waveform.
177
+
178
+ ### Performance
179
+
180
+ Our simple and efficient model achieves comparable results to sota models while being trained on 100 to 1000× less transcribed data.
181
+ This framework is therefore particularly well-suited for low-resource domains.
182
+
183
+ For more details, please refer to our paper (https://arxiv.org/abs/2408.10771).
184
+ """)
185
+ with gr.Column():
186
+ gr.Image("assets/diagram.png", label="Model Architecture", scale=0.3, show_label=False, show_download_button=False, show_fullscreen_button=False)
187
+
188
+ with gr.TabItem("About"):
189
+ gr.Markdown("""
190
+ ## About the Project
191
+
192
+ This demo showcases kNN-TTS, a lightweight zero-shot text-to-speech synthesis model.
193
+
194
+ ### Authors
195
+
196
+ - Karl El Hajal
197
+ - Ajinkya Kulkarni
198
+ - Enno Hermann
199
+ - Mathew Magimai.-Doss
200
+
201
+ ### Citation
202
+
203
+ If you use kNN-TTS in your research, please cite our paper:
204
+
205
+ ```
206
+ @misc{hajal2025knntts,
207
+ title={kNN Retrieval for Simple and Effective Zero-Shot Multi-speaker Text-to-Speech},
208
+ author={Karl El Hajal and Ajinkya Kulkarni and Enno Hermann and Mathew Magimai.-Doss},
209
+ year={2025},
210
+ eprint={2408.10771},
211
+ archivePrefix={arXiv},
212
+ primaryClass={eess.AS},
213
+ url={https://arxiv.org/abs/2408.10771},
214
+ }
215
+ ```
216
+
217
+ ### Acknowledgments
218
+
219
+ The target voices featured in this demo were sourced from the following datasets:
220
+
221
+ - [Thorsten Dataset](https://www.thorsten-voice.de/)
222
+ - [LibriSpeech Dataset](https://www.openslr.org/12)
223
+ - [Emotional Speech Dataset (ESD)](https://hltsingapore.github.io/ESD/)
224
+
225
+ ### License
226
+
227
+ This project is licensed under the MIT License.
228
+ """)
229
+
230
+ # Event handlers
231
+ submit_button.click(
232
+ fn=run,
233
+ inputs=[text_box, target_speaker_dropdown, rate_slider, k_slider, weighted_toggle],
234
+ outputs=[audio_output]
235
+ )
236
+
237
+ return iface
238
+
239
+ demo = create_gradio_interface()
240
+ demo.launch(share=True, debug=False)
assets/diagram.png ADDED
knn_tts/__init__.py ADDED
File without changes
knn_tts/ssl/WavLM/WavLM.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2021 Microsoft
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ # --------------------------------------------------------
6
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
7
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
8
+ # Copyright (c) 2021 Microsoft
9
+ # Licensed under The MIT License [see LICENSE for details]
10
+ # Based on fairseq code bases
11
+ # https://github.com/pytorch/fairseq
12
+ # --------------------------------------------------------
13
+
14
+ import logging
15
+ import math
16
+ from typing import List, Optional, Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+ from torch.nn import LayerNorm
23
+
24
+ from .modules import (
25
+ Fp32GroupNorm,
26
+ Fp32LayerNorm,
27
+ GLU_Linear,
28
+ GradMultiply,
29
+ MultiheadAttention,
30
+ SamePad,
31
+ TransposeLast,
32
+ get_activation_fn,
33
+ init_bert_params,
34
+ )
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # pylint: disable=too-many-branches, too-many-statements
40
+ def compute_mask_indices(
41
+ shape: Tuple[int, int],
42
+ padding_mask: Optional[torch.Tensor],
43
+ mask_prob: float,
44
+ mask_length: int,
45
+ mask_type: str = "static",
46
+ mask_other: float = 0.0,
47
+ min_masks: int = 0,
48
+ no_overlap: bool = False,
49
+ min_space: int = 0,
50
+ ) -> np.ndarray:
51
+ """
52
+ Computes random mask spans for a given shape
53
+
54
+ Args:
55
+ shape: the the shape for which to compute masks.
56
+ should be of size 2 where first element is batch size and 2nd is timesteps
57
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
58
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
59
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
60
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
61
+ mask_type: how to compute mask lengths
62
+ static = fixed size
63
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
64
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
65
+ poisson = sample from possion distribution with lambda = mask length
66
+ min_masks: minimum number of masked spans
67
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
68
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
69
+ """
70
+
71
+ bsz, all_sz = shape
72
+ mask = np.full((bsz, all_sz), False)
73
+
74
+ all_num_mask = int(
75
+ # add a random number for probabilistic rounding
76
+ mask_prob * all_sz / float(mask_length) + np.random.rand()
77
+ )
78
+
79
+ all_num_mask = max(min_masks, all_num_mask)
80
+
81
+ mask_idcs = []
82
+ for i in range(bsz):
83
+ if padding_mask is not None:
84
+ sz = all_sz - padding_mask[i].long().sum().item()
85
+ num_mask = int(
86
+ # add a random number for probabilistic rounding
87
+ mask_prob * sz / float(mask_length) + np.random.rand()
88
+ )
89
+ num_mask = max(min_masks, num_mask)
90
+ else:
91
+ sz = all_sz
92
+ num_mask = all_num_mask
93
+
94
+ if mask_type == "static":
95
+ lengths = np.full(num_mask, mask_length)
96
+ elif mask_type == "uniform":
97
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
98
+ elif mask_type == "normal":
99
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
100
+ lengths = [max(1, int(round(x))) for x in lengths]
101
+ elif mask_type == "poisson":
102
+ lengths = np.random.poisson(mask_length, size=num_mask)
103
+ lengths = [int(round(x)) for x in lengths]
104
+ else:
105
+ raise ValueError("unknown mask selection " + mask_type)
106
+
107
+ if sum(lengths) == 0:
108
+ lengths[0] = min(mask_length, sz - 1)
109
+
110
+ if no_overlap:
111
+ mask_idc = []
112
+
113
+ def arrange(s, e, length, keep_length):
114
+ span_start = np.random.randint(s, e - length)
115
+ mask_idc.extend(span_start + i for i in range(length)) # pylint: disable=cell-var-from-loop
116
+
117
+ new_parts = []
118
+ if span_start - s - min_space >= keep_length:
119
+ new_parts.append((s, span_start - min_space + 1))
120
+ if e - span_start - keep_length - min_space > keep_length:
121
+ new_parts.append((span_start + length + min_space, e))
122
+ return new_parts
123
+
124
+ parts = [(0, sz)]
125
+ min_length = min(lengths)
126
+ for length in sorted(lengths, reverse=True):
127
+ lens = np.fromiter(
128
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
129
+ np.int,
130
+ )
131
+ l_sum = np.sum(lens)
132
+ if l_sum == 0:
133
+ break
134
+ probs = lens / np.sum(lens)
135
+ c = np.random.choice(len(parts), p=probs)
136
+ s, e = parts.pop(c)
137
+ parts.extend(arrange(s, e, length, min_length))
138
+ mask_idc = np.asarray(mask_idc)
139
+ else:
140
+ min_len = min(lengths)
141
+ if sz - min_len <= num_mask:
142
+ min_len = sz - num_mask - 1
143
+
144
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
145
+
146
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
147
+
148
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
149
+
150
+ min_len = min(len(m) for m in mask_idcs)
151
+ for i, mask_idc in enumerate(mask_idcs):
152
+ if len(mask_idc) > min_len:
153
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
154
+ mask[i, mask_idc] = True
155
+
156
+ return mask
157
+
158
+
159
+ class WavLMConfig:
160
+ def __init__(self, cfg=None):
161
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
162
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
163
+
164
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
165
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
166
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
167
+ self.activation_fn: str = "gelu" # activation function to use
168
+
169
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
170
+ self.conv_feature_layers: str = (
171
+ "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
172
+ )
173
+ self.conv_bias: bool = False # include bias in conv encoder
174
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
175
+
176
+ self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
177
+
178
+ # dropouts
179
+ self.dropout: float = 0.1 # dropout probability for the transformer
180
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
181
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
182
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
183
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
184
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
185
+
186
+ # masking
187
+ self.mask_length: int = 10 # mask length
188
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
189
+ self.mask_selection: str = "static" # how to choose mask length
190
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
191
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
192
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
193
+
194
+ # channel masking
195
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
196
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
197
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
198
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
199
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
200
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
201
+
202
+ # positional embeddings
203
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
204
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
205
+
206
+ # relative position embedding
207
+ self.relative_position_embedding: bool = False # apply relative position embedding
208
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
209
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
210
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
211
+
212
+ if cfg is not None:
213
+ self.update(cfg)
214
+
215
+ def update(self, cfg: dict):
216
+ self.__dict__.update(cfg)
217
+
218
+
219
+ class WavLMModel(nn.Module):
220
+ def __init__(
221
+ self,
222
+ cfg: WavLMConfig,
223
+ ) -> None:
224
+ super().__init__()
225
+ logger.info("WavLM Config: %s", cfg.__dict__)
226
+
227
+ self.cfg = cfg
228
+ feature_enc_layers = eval(cfg.conv_feature_layers)
229
+ self.embed = feature_enc_layers[-1][0]
230
+
231
+ self.feature_extractor = ConvFeatureExtractionModel(
232
+ conv_layers=feature_enc_layers,
233
+ dropout=0.0,
234
+ mode=cfg.extractor_mode,
235
+ conv_bias=cfg.conv_bias,
236
+ )
237
+
238
+ self.post_extract_proj = nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
239
+
240
+ self.mask_prob = cfg.mask_prob
241
+ self.mask_selection = cfg.mask_selection
242
+ self.mask_other = cfg.mask_other
243
+ self.mask_length = cfg.mask_length
244
+ self.no_mask_overlap = cfg.no_mask_overlap
245
+ self.mask_min_space = cfg.mask_min_space
246
+
247
+ self.mask_channel_prob = cfg.mask_channel_prob
248
+ self.mask_channel_selection = cfg.mask_channel_selection
249
+ self.mask_channel_other = cfg.mask_channel_other
250
+ self.mask_channel_length = cfg.mask_channel_length
251
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
252
+ self.mask_channel_min_space = cfg.mask_channel_min_space
253
+
254
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
255
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
256
+
257
+ self.feature_grad_mult = cfg.feature_grad_mult
258
+
259
+ self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
260
+
261
+ self.encoder = TransformerEncoder(cfg)
262
+ self.layer_norm = LayerNorm(self.embed)
263
+
264
+ def apply_mask(self, x, padding_mask):
265
+ B, T, C = x.shape
266
+ if self.mask_prob > 0:
267
+ mask_indices = compute_mask_indices(
268
+ (B, T),
269
+ padding_mask,
270
+ self.mask_prob,
271
+ self.mask_length,
272
+ self.mask_selection,
273
+ self.mask_other,
274
+ min_masks=2,
275
+ no_overlap=self.no_mask_overlap,
276
+ min_space=self.mask_min_space,
277
+ )
278
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
279
+ x[mask_indices] = self.mask_emb
280
+ else:
281
+ mask_indices = None
282
+
283
+ if self.mask_channel_prob > 0:
284
+ mask_channel_indices = compute_mask_indices(
285
+ (B, C),
286
+ None,
287
+ self.mask_channel_prob,
288
+ self.mask_channel_length,
289
+ self.mask_channel_selection,
290
+ self.mask_channel_other,
291
+ no_overlap=self.no_mask_channel_overlap,
292
+ min_space=self.mask_channel_min_space,
293
+ )
294
+ mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)
295
+ x[mask_channel_indices] = 0
296
+
297
+ return x, mask_indices
298
+
299
+ def forward_padding_mask(
300
+ self,
301
+ features: torch.Tensor,
302
+ padding_mask: torch.Tensor,
303
+ ) -> torch.Tensor:
304
+ extra = padding_mask.size(1) % features.size(1)
305
+ if extra > 0:
306
+ padding_mask = padding_mask[:, :-extra]
307
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
308
+ padding_mask = padding_mask.all(-1)
309
+ return padding_mask
310
+
311
+ def extract_features(
312
+ self,
313
+ source: torch.Tensor,
314
+ padding_mask: Optional[torch.Tensor] = None,
315
+ mask: bool = False,
316
+ ret_conv: bool = False,
317
+ output_layer: Optional[int] = None,
318
+ ret_layer_results: bool = False,
319
+ ):
320
+ if self.feature_grad_mult > 0:
321
+ features = self.feature_extractor(source)
322
+ if self.feature_grad_mult != 1.0:
323
+ features = GradMultiply.apply(features, self.feature_grad_mult)
324
+ else:
325
+ with torch.no_grad():
326
+ features = self.feature_extractor(source)
327
+
328
+ features = features.transpose(1, 2)
329
+ features = self.layer_norm(features)
330
+
331
+ if padding_mask is not None:
332
+ padding_mask = self.forward_padding_mask(features, padding_mask)
333
+
334
+ if self.post_extract_proj is not None:
335
+ features = self.post_extract_proj(features)
336
+
337
+ features = self.dropout_input(features)
338
+
339
+ if mask:
340
+ x, _ = self.apply_mask(features, padding_mask)
341
+ else:
342
+ x = features
343
+
344
+ # feature: (B, T, D), float
345
+ # target: (B, T), long
346
+ # x: (B, T, D), float
347
+ # padding_mask: (B, T), bool
348
+ # mask_indices: (B, T), bool
349
+ x, layer_results = self.encoder(
350
+ x,
351
+ padding_mask=padding_mask,
352
+ layer=None if output_layer is None else output_layer - 1,
353
+ )
354
+
355
+ res = {
356
+ "x": x,
357
+ "padding_mask": padding_mask,
358
+ "features": features,
359
+ "layer_results": layer_results,
360
+ }
361
+
362
+ feature = res["features"] if ret_conv else res["x"]
363
+ if ret_layer_results:
364
+ feature = (feature, res["layer_results"])
365
+ return feature, res["padding_mask"]
366
+
367
+ def forward(
368
+ self,
369
+ source: torch.Tensor,
370
+ padding_mask: Optional[torch.Tensor] = None,
371
+ mask: bool = False,
372
+ ret_conv: bool = False,
373
+ output_layer: Optional[int] = None,
374
+ ret_layer_results: bool = False,
375
+ ):
376
+ return self.extract_features(
377
+ source,
378
+ padding_mask=padding_mask,
379
+ mask=mask,
380
+ ret_conv=ret_conv,
381
+ output_layer=output_layer,
382
+ ret_layer_results=ret_layer_results,
383
+ )
384
+
385
+
386
+ class ConvFeatureExtractionModel(nn.Module):
387
+ def __init__(
388
+ self,
389
+ conv_layers: List[Tuple[int, int, int]],
390
+ dropout: float = 0.0,
391
+ mode: str = "default",
392
+ conv_bias: bool = False,
393
+ conv_type: str = "default",
394
+ ):
395
+ super().__init__()
396
+
397
+ assert mode in {"default", "layer_norm"}
398
+
399
+ def block(
400
+ n_in,
401
+ n_out,
402
+ k,
403
+ stride,
404
+ is_layer_norm=False,
405
+ is_group_norm=False,
406
+ conv_bias=False,
407
+ ):
408
+ def make_conv():
409
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
410
+ nn.init.kaiming_normal_(conv.weight)
411
+ return conv
412
+
413
+ assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive"
414
+
415
+ if is_layer_norm:
416
+ return nn.Sequential(
417
+ make_conv(),
418
+ nn.Dropout(p=dropout),
419
+ nn.Sequential(
420
+ TransposeLast(),
421
+ Fp32LayerNorm(dim, elementwise_affine=True),
422
+ TransposeLast(),
423
+ ),
424
+ nn.GELU(),
425
+ )
426
+ if is_group_norm:
427
+ return nn.Sequential(
428
+ make_conv(),
429
+ nn.Dropout(p=dropout),
430
+ Fp32GroupNorm(dim, dim, affine=True),
431
+ nn.GELU(),
432
+ )
433
+
434
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
435
+
436
+ self.conv_type = conv_type
437
+ if self.conv_type == "default":
438
+ in_d = 1
439
+ self.conv_layers = nn.ModuleList()
440
+ for i, cl in enumerate(conv_layers):
441
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
442
+ (dim, k, stride) = cl
443
+
444
+ self.conv_layers.append(
445
+ block(
446
+ in_d,
447
+ dim,
448
+ k,
449
+ stride,
450
+ is_layer_norm=mode == "layer_norm",
451
+ is_group_norm=mode == "default" and i == 0,
452
+ conv_bias=conv_bias,
453
+ )
454
+ )
455
+ in_d = dim
456
+ elif self.conv_type == "conv2d":
457
+ in_d = 1
458
+ self.conv_layers = nn.ModuleList()
459
+ for i, cl in enumerate(conv_layers):
460
+ assert len(cl) == 3
461
+ (dim, k, stride) = cl
462
+
463
+ self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
464
+ self.conv_layers.append(torch.nn.ReLU())
465
+ in_d = dim
466
+ elif self.conv_type == "custom":
467
+ in_d = 1
468
+ idim = 80
469
+ self.conv_layers = nn.ModuleList()
470
+ for i, cl in enumerate(conv_layers):
471
+ assert len(cl) == 3
472
+ (dim, k, stride) = cl
473
+ self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride, padding=1))
474
+ self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
475
+ self.conv_layers.append(torch.nn.ReLU())
476
+ in_d = dim
477
+ if (i + 1) % 2 == 0:
478
+ self.conv_layers.append(torch.nn.MaxPool2d(2, stride=2, ceil_mode=True))
479
+ idim = int(math.ceil(idim / 2))
480
+ else:
481
+ pass
482
+
483
+ def forward(self, x, mask=None): # pylint: disable=unused-argument
484
+ # BxT -> BxCxT
485
+ x = x.unsqueeze(1)
486
+ if self.conv_type == "custom":
487
+ for conv in self.conv_layers:
488
+ if isinstance(conv, nn.LayerNorm):
489
+ x = x.transpose(1, 2)
490
+ x = conv(x).transpose(1, 2)
491
+ else:
492
+ x = conv(x)
493
+ x = x.transpose(2, 3).contiguous()
494
+ x = x.view(x.size(0), -1, x.size(-1))
495
+ else:
496
+ for conv in self.conv_layers:
497
+ x = conv(x)
498
+ if self.conv_type == "conv2d":
499
+ b, c, t, f = x.size()
500
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
501
+ return x
502
+
503
+
504
+ class TransformerEncoder(nn.Module):
505
+ def __init__(self, args):
506
+ super().__init__()
507
+
508
+ self.dropout = args.dropout
509
+ self.embedding_dim = args.encoder_embed_dim
510
+
511
+ self.pos_conv = nn.Conv1d(
512
+ self.embedding_dim,
513
+ self.embedding_dim,
514
+ kernel_size=args.conv_pos,
515
+ padding=args.conv_pos // 2,
516
+ groups=args.conv_pos_groups,
517
+ )
518
+ dropout = 0
519
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
520
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
521
+ nn.init.constant_(self.pos_conv.bias, 0)
522
+
523
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
524
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
525
+
526
+ if hasattr(args, "relative_position_embedding"):
527
+ self.relative_position_embedding = args.relative_position_embedding
528
+ self.num_buckets = args.num_buckets
529
+ self.max_distance = args.max_distance
530
+ else:
531
+ self.relative_position_embedding = False
532
+ self.num_buckets = 0
533
+ self.max_distance = 0
534
+
535
+ self.layers = nn.ModuleList(
536
+ [
537
+ TransformerSentenceEncoderLayer(
538
+ embedding_dim=self.embedding_dim,
539
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
540
+ num_attention_heads=args.encoder_attention_heads,
541
+ dropout=self.dropout,
542
+ attention_dropout=args.attention_dropout,
543
+ activation_dropout=args.activation_dropout,
544
+ activation_fn=args.activation_fn,
545
+ layer_norm_first=args.layer_norm_first,
546
+ has_relative_attention_bias=(self.relative_position_embedding and i == 0),
547
+ num_buckets=self.num_buckets,
548
+ max_distance=self.max_distance,
549
+ gru_rel_pos=args.gru_rel_pos,
550
+ )
551
+ for i in range(args.encoder_layers)
552
+ ]
553
+ )
554
+
555
+ self.layer_norm_first = args.layer_norm_first
556
+ self.layer_norm = LayerNorm(self.embedding_dim)
557
+ self.layerdrop = args.encoder_layerdrop
558
+
559
+ self.apply(init_bert_params)
560
+
561
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
562
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
563
+
564
+ if self.layer_norm_first and layer is None:
565
+ x = self.layer_norm(x)
566
+
567
+ return x, layer_results
568
+
569
+ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
570
+ if padding_mask is not None:
571
+ x[padding_mask] = 0
572
+
573
+ x_conv = self.pos_conv(x.transpose(1, 2))
574
+ x_conv = x_conv.transpose(1, 2)
575
+ x += x_conv
576
+
577
+ if not self.layer_norm_first:
578
+ x = self.layer_norm(x)
579
+
580
+ x = F.dropout(x, p=self.dropout, training=self.training)
581
+
582
+ # B x T x C -> T x B x C
583
+ x = x.transpose(0, 1)
584
+
585
+ layer_results = []
586
+ z = None
587
+ if tgt_layer is not None:
588
+ layer_results.append((x, z))
589
+ r = None
590
+ pos_bias = None
591
+ for i, layer in enumerate(self.layers):
592
+ dropout_probability = np.random.random()
593
+ if not self.training or (dropout_probability > self.layerdrop):
594
+ x, z, pos_bias = layer(
595
+ x,
596
+ self_attn_padding_mask=padding_mask,
597
+ need_weights=False,
598
+ self_attn_mask=streaming_mask,
599
+ pos_bias=pos_bias,
600
+ )
601
+ if tgt_layer is not None:
602
+ layer_results.append((x, z))
603
+ if i == tgt_layer:
604
+ r = x
605
+ break
606
+
607
+ if r is not None:
608
+ x = r
609
+
610
+ # T x B x C -> B x T x C
611
+ x = x.transpose(0, 1)
612
+
613
+ return x, layer_results
614
+
615
+
616
+ class TransformerSentenceEncoderLayer(nn.Module):
617
+ """
618
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
619
+ models.
620
+ """
621
+
622
+ def __init__(
623
+ self,
624
+ embedding_dim: float = 768,
625
+ ffn_embedding_dim: float = 3072,
626
+ num_attention_heads: float = 8,
627
+ dropout: float = 0.1,
628
+ attention_dropout: float = 0.1,
629
+ activation_dropout: float = 0.1,
630
+ activation_fn: str = "relu",
631
+ layer_norm_first: bool = False,
632
+ has_relative_attention_bias: bool = False,
633
+ num_buckets: int = 0,
634
+ max_distance: int = 0,
635
+ rescale_init: bool = False,
636
+ gru_rel_pos: bool = False,
637
+ ) -> None:
638
+ super().__init__()
639
+ # Initialize parameters
640
+ self.embedding_dim = embedding_dim
641
+ self.dropout = dropout
642
+ self.activation_dropout = activation_dropout
643
+
644
+ # Initialize blocks
645
+ self.activation_name = activation_fn
646
+ self.activation_fn = get_activation_fn(activation_fn)
647
+ self.self_attn = MultiheadAttention(
648
+ self.embedding_dim,
649
+ num_attention_heads,
650
+ dropout=attention_dropout,
651
+ self_attention=True,
652
+ has_relative_attention_bias=has_relative_attention_bias,
653
+ num_buckets=num_buckets,
654
+ max_distance=max_distance,
655
+ rescale_init=rescale_init,
656
+ gru_rel_pos=gru_rel_pos,
657
+ )
658
+
659
+ self.dropout1 = nn.Dropout(dropout)
660
+ self.dropout2 = nn.Dropout(self.activation_dropout)
661
+ self.dropout3 = nn.Dropout(dropout)
662
+
663
+ self.layer_norm_first = layer_norm_first
664
+
665
+ # layer norm associated with the self attention layer
666
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
667
+
668
+ if self.activation_name == "glu":
669
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
670
+ else:
671
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
672
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
673
+
674
+ # layer norm associated with the position wise feed-forward NN
675
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
676
+
677
+ def forward(
678
+ self,
679
+ x: torch.Tensor,
680
+ self_attn_mask: torch.Tensor = None,
681
+ self_attn_padding_mask: torch.Tensor = None,
682
+ need_weights: bool = False,
683
+ pos_bias=None,
684
+ ):
685
+ """
686
+ LayerNorm is applied either before or after the self-attention/ffn
687
+ modules similar to the original Transformer imlementation.
688
+ """
689
+ residual = x
690
+
691
+ if self.layer_norm_first:
692
+ x = self.self_attn_layer_norm(x)
693
+ x, attn, pos_bias = self.self_attn(
694
+ query=x,
695
+ key=x,
696
+ value=x,
697
+ key_padding_mask=self_attn_padding_mask,
698
+ need_weights=False,
699
+ attn_mask=self_attn_mask,
700
+ position_bias=pos_bias,
701
+ )
702
+ x = self.dropout1(x)
703
+ x = residual + x
704
+
705
+ residual = x
706
+ x = self.final_layer_norm(x)
707
+ if self.activation_name == "glu":
708
+ x = self.fc1(x)
709
+ else:
710
+ x = self.activation_fn(self.fc1(x))
711
+ x = self.dropout2(x)
712
+ x = self.fc2(x)
713
+ x = self.dropout3(x)
714
+ x = residual + x
715
+ else:
716
+ x, attn, pos_bias = self.self_attn(
717
+ query=x,
718
+ key=x,
719
+ value=x,
720
+ key_padding_mask=self_attn_padding_mask,
721
+ need_weights=need_weights,
722
+ attn_mask=self_attn_mask,
723
+ position_bias=pos_bias,
724
+ )
725
+
726
+ x = self.dropout1(x)
727
+ x = residual + x
728
+
729
+ x = self.self_attn_layer_norm(x)
730
+
731
+ residual = x
732
+ if self.activation_name == "glu":
733
+ x = self.fc1(x)
734
+ else:
735
+ x = self.activation_fn(self.fc1(x))
736
+ x = self.dropout2(x)
737
+ x = self.fc2(x)
738
+ x = self.dropout3(x)
739
+ x = residual + x
740
+ x = self.final_layer_norm(x)
741
+
742
+ return x, attn, pos_bias
knn_tts/ssl/WavLM/__init__.py ADDED
File without changes
knn_tts/ssl/WavLM/modules.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2021 Microsoft
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ # --------------------------------------------------------
6
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
7
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
8
+ # Copyright (c) 2021 Microsoft
9
+ # Licensed under The MIT License [see LICENSE for details]
10
+ # Based on fairseq code bases
11
+ # https://github.com/pytorch/fairseq
12
+ # --------------------------------------------------------
13
+
14
+ import math
15
+ import warnings
16
+ from typing import Dict, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import Tensor, nn
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class TransposeLast(nn.Module):
25
+ def __init__(self, deconstruct_idx=None):
26
+ super().__init__()
27
+ self.deconstruct_idx = deconstruct_idx
28
+
29
+ def forward(self, x):
30
+ if self.deconstruct_idx is not None:
31
+ x = x[self.deconstruct_idx]
32
+ return x.transpose(-2, -1)
33
+
34
+
35
+ class Fp32LayerNorm(nn.LayerNorm):
36
+ def __init__(self, *args, **kwargs):
37
+ super().__init__(*args, **kwargs)
38
+
39
+ def forward(self, input): # pylint: disable=W0622
40
+ output = F.layer_norm(
41
+ input.float(),
42
+ self.normalized_shape,
43
+ self.weight.float() if self.weight is not None else None,
44
+ self.bias.float() if self.bias is not None else None,
45
+ self.eps,
46
+ )
47
+ return output.type_as(input)
48
+
49
+
50
+ class Fp32GroupNorm(nn.GroupNorm):
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+
54
+ def forward(self, input): # pylint: disable=W0622
55
+ output = F.group_norm(
56
+ input.float(),
57
+ self.num_groups,
58
+ self.weight.float() if self.weight is not None else None,
59
+ self.bias.float() if self.bias is not None else None,
60
+ self.eps,
61
+ )
62
+ return output.type_as(input)
63
+
64
+
65
+ # pylint: disable=W0221,W0223
66
+ class GradMultiply(torch.autograd.Function):
67
+ @staticmethod
68
+ def forward(ctx, x, scale):
69
+ ctx.scale = scale
70
+ res = x.new(x)
71
+ return res
72
+
73
+ @staticmethod
74
+ def backward(ctx, grad):
75
+ return grad * ctx.scale, None
76
+
77
+
78
+ class SamePad(nn.Module):
79
+ def __init__(self, kernel_size, causal=False):
80
+ super().__init__()
81
+ if causal:
82
+ self.remove = kernel_size - 1
83
+ else:
84
+ self.remove = 1 if kernel_size % 2 == 0 else 0
85
+
86
+ def forward(self, x):
87
+ if self.remove > 0:
88
+ x = x[:, :, : -self.remove]
89
+ return x
90
+
91
+
92
+ class Swish(nn.Module):
93
+ """Swish function"""
94
+
95
+ def __init__(self):
96
+ """Construct an MultiHeadedAttention object."""
97
+ super().__init__()
98
+ self.act = torch.nn.Sigmoid()
99
+
100
+ def forward(self, x):
101
+ return x * self.act(x)
102
+
103
+
104
+ class GLU_Linear(nn.Module):
105
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
106
+ super().__init__()
107
+
108
+ self.glu_type = glu_type
109
+ self.output_dim = output_dim
110
+
111
+ if glu_type == "sigmoid":
112
+ self.glu_act = torch.nn.Sigmoid()
113
+ elif glu_type == "swish":
114
+ self.glu_act = Swish()
115
+ elif glu_type == "relu":
116
+ self.glu_act = torch.nn.ReLU()
117
+ elif glu_type == "gelu":
118
+ self.glu_act = torch.nn.GELU()
119
+
120
+ if bias_in_glu:
121
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
122
+ else:
123
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
124
+
125
+ def forward(self, x):
126
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
127
+ x = self.linear(x)
128
+
129
+ if self.glu_type == "bilinear":
130
+ x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
131
+ else:
132
+ x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
133
+
134
+ return x
135
+
136
+
137
+ # pylint: disable=W0212
138
+ def gelu_accurate(x):
139
+ if not hasattr(gelu_accurate, "_a"):
140
+ gelu_accurate._a = math.sqrt(2 / math.pi)
141
+ return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
142
+
143
+
144
+ def gelu(x: torch.Tensor) -> torch.Tensor:
145
+ return F.gelu(x.float()).type_as(x) # pylint: disable=E1102
146
+
147
+
148
+ def get_activation_fn(activation: str):
149
+ """Returns the activation function corresponding to `activation`"""
150
+
151
+ activation_functions = {
152
+ "relu": F.relu,
153
+ "gelu": gelu,
154
+ "gelu_fast": gelu_accurate,
155
+ "gelu_accurate": gelu_accurate,
156
+ "tanh": torch.tanh,
157
+ "linear": lambda x: x,
158
+ "glu": lambda x: x,
159
+ }
160
+
161
+ if activation == "gelu_fast":
162
+ warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
163
+
164
+ if activation in activation_functions:
165
+ return activation_functions[activation]
166
+
167
+ raise RuntimeError(f"--activation-fn {activation} not supported")
168
+
169
+
170
+ def init_bert_params(module):
171
+ """
172
+ Initialize the weights specific to the BERT Model.
173
+ This overrides the default initializations depending on the specified arguments.
174
+ 1. If normal_init_linear_weights is set then weights of linear
175
+ layer will be initialized using the normal distribution and
176
+ bais will be set to the specified value.
177
+ 2. If normal_init_embed_weights is set then weights of embedding
178
+ layer will be initialized using the normal distribution.
179
+ 3. If normal_init_proj_weights is set then weights of
180
+ in_project_weight for MultiHeadAttention initialized using
181
+ the normal distribution (to be validated).
182
+ """
183
+
184
+ def normal_(data):
185
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
186
+ # so that the RNG is consistent with and without FSDP
187
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
188
+
189
+ if isinstance(module, nn.Linear):
190
+ normal_(module.weight.data)
191
+ if module.bias is not None:
192
+ module.bias.data.zero_()
193
+ if isinstance(module, nn.Embedding):
194
+ normal_(module.weight.data)
195
+ if module.padding_idx is not None:
196
+ module.weight.data[module.padding_idx].zero_()
197
+ if isinstance(module, MultiheadAttention):
198
+ normal_(module.q_proj.weight.data)
199
+ normal_(module.k_proj.weight.data)
200
+ normal_(module.v_proj.weight.data)
201
+
202
+
203
+ def quant_noise(module, p, block_size):
204
+ """
205
+ Wraps modules and applies quantization noise to the weights for
206
+ subsequent quantization with Iterative Product Quantization as
207
+ described in "Training with Quantization Noise for Extreme Model Compression"
208
+
209
+ Args:
210
+ - module: nn.Module
211
+ - p: amount of Quantization Noise
212
+ - block_size: size of the blocks for subsequent quantization with iPQ
213
+
214
+ Remarks:
215
+ - Module weights must have the right sizes wrt the block size
216
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
217
+ - For more detail on how to quantize by blocks with convolutional weights,
218
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
219
+ - We implement the simplest form of noise here as stated in the paper
220
+ which consists in randomly dropping blocks
221
+ """
222
+
223
+ # if no quantization noise, don't register hook
224
+ if p <= 0:
225
+ return module
226
+
227
+ # supported modules
228
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
229
+
230
+ # test whether module.weight has the right sizes wrt block_size
231
+ is_conv = module.weight.ndim == 4
232
+
233
+ # 2D matrix
234
+ if not is_conv:
235
+ assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
236
+
237
+ # 4D matrix
238
+ else:
239
+ # 1x1 convolutions
240
+ if module.kernel_size == (1, 1):
241
+ assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
242
+ # regular convolutions
243
+ else:
244
+ k = module.kernel_size[0] * module.kernel_size[1]
245
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
246
+
247
+ def _forward_pre_hook(mod, x): # pylint: disable=W0613
248
+ # no noise for evaluation
249
+ if mod.training:
250
+ if not is_conv:
251
+ # gather weight and sizes
252
+ weight = mod.weight
253
+ in_features = weight.size(1)
254
+ out_features = weight.size(0)
255
+
256
+ # split weight matrix into blocks and randomly drop selected blocks
257
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
258
+ mask.bernoulli_(p)
259
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
260
+
261
+ else:
262
+ # gather weight and sizes
263
+ weight = mod.weight
264
+ in_channels = mod.in_channels
265
+ out_channels = mod.out_channels
266
+
267
+ # split weight matrix into blocks and randomly drop selected blocks
268
+ if mod.kernel_size == (1, 1):
269
+ mask = torch.zeros(
270
+ int(in_channels // block_size * out_channels),
271
+ device=weight.device,
272
+ )
273
+ mask.bernoulli_(p)
274
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
275
+ else:
276
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
277
+ mask.bernoulli_(p)
278
+ mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
279
+
280
+ # scale weights and apply mask
281
+ mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
282
+ s = 1 / (1 - p)
283
+ mod.weight.data = s * weight.masked_fill(mask, 0)
284
+
285
+ module.register_forward_pre_hook(_forward_pre_hook)
286
+ return module
287
+
288
+
289
+ class MultiheadAttention(nn.Module):
290
+ """Multi-headed attention.
291
+
292
+ See "Attention Is All You Need" for more details.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ embed_dim,
298
+ num_heads,
299
+ kdim=None,
300
+ vdim=None,
301
+ dropout=0.0,
302
+ bias=True,
303
+ add_bias_kv=False,
304
+ add_zero_attn=False,
305
+ self_attention=False,
306
+ encoder_decoder_attention=False,
307
+ q_noise=0.0,
308
+ qn_block_size=8,
309
+ has_relative_attention_bias=False,
310
+ num_buckets=32,
311
+ max_distance=128,
312
+ gru_rel_pos=False,
313
+ rescale_init=False,
314
+ ):
315
+ super().__init__()
316
+ self.embed_dim = embed_dim
317
+ self.kdim = kdim if kdim is not None else embed_dim
318
+ self.vdim = vdim if vdim is not None else embed_dim
319
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
320
+
321
+ self.num_heads = num_heads
322
+ self.dropout_module = nn.Dropout(dropout)
323
+
324
+ self.has_relative_attention_bias = has_relative_attention_bias
325
+ self.num_buckets = num_buckets
326
+ self.max_distance = max_distance
327
+ if self.has_relative_attention_bias:
328
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
329
+
330
+ self.head_dim = embed_dim // num_heads
331
+ self.q_head_dim = self.head_dim
332
+ self.k_head_dim = self.head_dim
333
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
334
+ self.scaling = self.head_dim**-0.5
335
+
336
+ self.self_attention = self_attention
337
+ self.encoder_decoder_attention = encoder_decoder_attention
338
+
339
+ assert not self.self_attention or self.qkv_same_dim, "Self-attention requires query, key and " "value to be of the same size"
340
+
341
+ k_bias = True
342
+ if rescale_init:
343
+ k_bias = False
344
+
345
+ k_embed_dim = embed_dim
346
+ q_embed_dim = embed_dim
347
+
348
+ self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
349
+ self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
350
+ self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
351
+
352
+ self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
353
+
354
+ if add_bias_kv:
355
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
356
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
357
+ else:
358
+ self.bias_k = self.bias_v = None
359
+
360
+ self.add_zero_attn = add_zero_attn
361
+
362
+ self.gru_rel_pos = gru_rel_pos
363
+ if self.gru_rel_pos:
364
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
365
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
366
+
367
+ self.reset_parameters()
368
+
369
+ def reset_parameters(self):
370
+ if self.qkv_same_dim:
371
+ # Empirically observed the convergence to be much better with
372
+ # the scaled initialization
373
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
374
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
375
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
376
+ else:
377
+ nn.init.xavier_uniform_(self.k_proj.weight)
378
+ nn.init.xavier_uniform_(self.v_proj.weight)
379
+ nn.init.xavier_uniform_(self.q_proj.weight)
380
+
381
+ nn.init.xavier_uniform_(self.out_proj.weight)
382
+ if self.out_proj.bias is not None:
383
+ nn.init.constant_(self.out_proj.bias, 0.0)
384
+ if self.bias_k is not None:
385
+ nn.init.xavier_normal_(self.bias_k)
386
+ if self.bias_v is not None:
387
+ nn.init.xavier_normal_(self.bias_v)
388
+ if self.has_relative_attention_bias:
389
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
390
+
391
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
392
+ num_buckets = self.num_buckets
393
+ max_distance = self.max_distance
394
+ relative_buckets = 0
395
+
396
+ if bidirectional:
397
+ num_buckets = num_buckets // 2
398
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
399
+ relative_positions = torch.abs(relative_positions)
400
+ else:
401
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
402
+
403
+ max_exact = num_buckets // 2
404
+ is_small = relative_positions < max_exact
405
+
406
+ relative_postion_if_large = max_exact + (torch.log(relative_positions.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).to(torch.long)
407
+ relative_postion_if_large = torch.min(
408
+ relative_postion_if_large,
409
+ torch.full_like(relative_postion_if_large, num_buckets - 1),
410
+ )
411
+
412
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
413
+ return relative_buckets
414
+
415
+ def compute_bias(self, query_length, key_length):
416
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
417
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
418
+ relative_position = memory_position - context_position
419
+ relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
420
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
421
+ values = self.relative_attention_bias(relative_position_bucket)
422
+ values = values.permute([2, 0, 1])
423
+ return values
424
+
425
+ # pylint: disable=R0912, R0915
426
+ def forward(
427
+ self,
428
+ query,
429
+ key: Optional[Tensor],
430
+ value: Optional[Tensor],
431
+ key_padding_mask: Optional[Tensor] = None,
432
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
433
+ need_weights: bool = True,
434
+ static_kv: bool = False,
435
+ attn_mask: Optional[Tensor] = None,
436
+ before_softmax: bool = False,
437
+ need_head_weights: bool = False,
438
+ position_bias: Optional[Tensor] = None,
439
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
440
+ """Input shape: Time x Batch x Channel
441
+
442
+ Args:
443
+ key_padding_mask (ByteTensor, optional): mask to exclude
444
+ keys that are pads, of shape `(batch, src_len)`, where
445
+ padding elements are indicated by 1s.
446
+ need_weights (bool, optional): return the attention weights,
447
+ averaged over heads (default: False).
448
+ attn_mask (ByteTensor, optional): typically used to
449
+ implement causal attention, where the mask prevents the
450
+ attention from looking forward in time (default: None).
451
+ before_softmax (bool, optional): return the raw attention
452
+ weights and values before the attention softmax.
453
+ need_head_weights (bool, optional): return the attention
454
+ weights for each head. Implies *need_weights*. Default:
455
+ return the average attention weights over all heads.
456
+ """
457
+ if need_head_weights:
458
+ need_weights = True
459
+
460
+ is_tpu = query.device.type == "xla"
461
+
462
+ tgt_len, bsz, embed_dim = query.size()
463
+ src_len = tgt_len
464
+ assert embed_dim == self.embed_dim
465
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
466
+ if key is not None:
467
+ src_len, key_bsz, _ = key.size()
468
+ if not torch.jit.is_scripting():
469
+ assert key_bsz == bsz
470
+ assert value is not None
471
+ assert src_len, bsz == value.shape[:2]
472
+
473
+ if self.has_relative_attention_bias and position_bias is None:
474
+ position_bias = self.compute_bias(tgt_len, src_len)
475
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
476
+
477
+ if (
478
+ not is_tpu # don't use PyTorch version on TPUs
479
+ and incremental_state is None
480
+ and not static_kv
481
+ # A workaround for quantization to work. Otherwise JIT compilation
482
+ # treats bias in linear module as method.
483
+ and not torch.jit.is_scripting()
484
+ and self.q_head_dim == self.head_dim
485
+ ):
486
+ assert key is not None and value is not None
487
+ assert attn_mask is None
488
+
489
+ attn_mask_rel_pos = None
490
+ if position_bias is not None:
491
+ attn_mask_rel_pos = position_bias
492
+ if self.gru_rel_pos:
493
+ query_layer = query.transpose(0, 1)
494
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
495
+ query_layer = query_layer.view(*new_x_shape)
496
+ query_layer = query_layer.permute(0, 2, 1, 3)
497
+ _B, _H, _L, __ = query_layer.size()
498
+
499
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
500
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
501
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
502
+
503
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
504
+ k_proj_bias = self.k_proj.bias
505
+ if k_proj_bias is None:
506
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
507
+
508
+ x, attn = F.multi_head_attention_forward(
509
+ query,
510
+ key,
511
+ value,
512
+ self.embed_dim,
513
+ self.num_heads,
514
+ torch.empty([0]),
515
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
516
+ self.bias_k,
517
+ self.bias_v,
518
+ self.add_zero_attn,
519
+ self.dropout_module.p,
520
+ self.out_proj.weight,
521
+ self.out_proj.bias,
522
+ self.training,
523
+ # self.training or self.dropout_module.apply_during_inference,
524
+ key_padding_mask,
525
+ need_weights,
526
+ attn_mask_rel_pos,
527
+ use_separate_proj_weight=True,
528
+ q_proj_weight=self.q_proj.weight,
529
+ k_proj_weight=self.k_proj.weight,
530
+ v_proj_weight=self.v_proj.weight,
531
+ )
532
+ return x, attn, position_bias
533
+
534
+ if incremental_state is not None:
535
+ saved_state = self._get_input_buffer(incremental_state)
536
+ if saved_state is not None and "prev_key" in saved_state:
537
+ # previous time steps are cached - no need to recompute
538
+ # key and value if they are static
539
+ if static_kv:
540
+ assert self.encoder_decoder_attention and not self.self_attention
541
+ key = value = None
542
+ else:
543
+ saved_state = None
544
+
545
+ if self.self_attention:
546
+ q = self.q_proj(query)
547
+ k = self.k_proj(query)
548
+ v = self.v_proj(query)
549
+ elif self.encoder_decoder_attention:
550
+ # encoder-decoder attention
551
+ q = self.q_proj(query)
552
+ if key is None:
553
+ assert value is None
554
+ k = v = None
555
+ else:
556
+ k = self.k_proj(key)
557
+ v = self.v_proj(key)
558
+
559
+ else:
560
+ assert key is not None and value is not None
561
+ q = self.q_proj(query)
562
+ k = self.k_proj(key)
563
+ v = self.v_proj(value)
564
+ q *= self.scaling
565
+
566
+ if self.bias_k is not None:
567
+ assert self.bias_v is not None
568
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
569
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
570
+ if attn_mask is not None:
571
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
572
+ if key_padding_mask is not None:
573
+ key_padding_mask = torch.cat(
574
+ [
575
+ key_padding_mask,
576
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
577
+ ],
578
+ dim=1,
579
+ )
580
+
581
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
582
+ if k is not None:
583
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
584
+ if v is not None:
585
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
586
+
587
+ if saved_state is not None:
588
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
589
+ if "prev_key" in saved_state:
590
+ _prev_key = saved_state["prev_key"]
591
+ assert _prev_key is not None
592
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
593
+ if static_kv:
594
+ k = prev_key
595
+ else:
596
+ assert k is not None
597
+ k = torch.cat([prev_key, k], dim=1)
598
+ src_len = k.size(1)
599
+ if "prev_value" in saved_state:
600
+ _prev_value = saved_state["prev_value"]
601
+ assert _prev_value is not None
602
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
603
+ if static_kv:
604
+ v = prev_value
605
+ else:
606
+ assert v is not None
607
+ v = torch.cat([prev_value, v], dim=1)
608
+ prev_key_padding_mask: Optional[Tensor] = None
609
+ if "prev_key_padding_mask" in saved_state:
610
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
611
+ assert k is not None and v is not None
612
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
613
+ key_padding_mask=key_padding_mask,
614
+ prev_key_padding_mask=prev_key_padding_mask,
615
+ batch_size=bsz,
616
+ src_len=k.size(1),
617
+ static_kv=static_kv,
618
+ )
619
+
620
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
621
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
622
+ saved_state["prev_key_padding_mask"] = key_padding_mask
623
+ # In this branch incremental_state is never None
624
+ assert incremental_state is not None
625
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
626
+ assert k is not None
627
+ assert k.size(1) == src_len
628
+
629
+ # This is part of a workaround to get around fork/join parallelism
630
+ # not supporting Optional types.
631
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
632
+ key_padding_mask = None
633
+
634
+ if key_padding_mask is not None:
635
+ assert key_padding_mask.size(0) == bsz
636
+ assert key_padding_mask.size(1) == src_len
637
+
638
+ if self.add_zero_attn:
639
+ assert v is not None
640
+ src_len += 1
641
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
642
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
643
+ if attn_mask is not None:
644
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
645
+ if key_padding_mask is not None:
646
+ key_padding_mask = torch.cat(
647
+ [
648
+ key_padding_mask,
649
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
650
+ ],
651
+ dim=1,
652
+ )
653
+
654
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
655
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
656
+
657
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
658
+
659
+ if attn_mask is not None:
660
+ attn_mask = attn_mask.unsqueeze(0)
661
+ attn_weights += attn_mask
662
+
663
+ if key_padding_mask is not None:
664
+ # don't attend to padding symbols
665
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
666
+ if not is_tpu:
667
+ attn_weights = attn_weights.masked_fill(
668
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
669
+ float("-inf"),
670
+ )
671
+ else:
672
+ attn_weights = attn_weights.transpose(0, 2)
673
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
674
+ attn_weights = attn_weights.transpose(0, 2)
675
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
676
+
677
+ if before_softmax:
678
+ return attn_weights, v, position_bias
679
+
680
+ if position_bias is not None:
681
+ if self.gru_rel_pos == 1:
682
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
683
+ _B, _H, _L, __ = query_layer.size()
684
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
685
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
686
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
687
+
688
+ position_bias = position_bias.view(attn_weights.size())
689
+
690
+ attn_weights = attn_weights + position_bias
691
+
692
+ attn_weights_float = F.softmax(attn_weights, dim=-1)
693
+ attn_weights = attn_weights_float.type_as(attn_weights)
694
+ attn_probs = self.dropout_module(attn_weights)
695
+
696
+ assert v is not None
697
+ attn = torch.bmm(attn_probs, v)
698
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
699
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
700
+ attn = self.out_proj(attn)
701
+ attn_weights: Optional[Tensor] = None
702
+ if need_weights:
703
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
704
+ if not need_head_weights:
705
+ # average attention weights over heads
706
+ attn_weights = attn_weights.mean(dim=0)
707
+
708
+ return attn, attn_weights, position_bias
709
+
710
+ @staticmethod
711
+ def _append_prev_key_padding_mask(
712
+ key_padding_mask: Optional[Tensor],
713
+ prev_key_padding_mask: Optional[Tensor],
714
+ batch_size: int,
715
+ src_len: int,
716
+ static_kv: bool,
717
+ ) -> Optional[Tensor]:
718
+ # saved key padding masks have shape (bsz, seq_len)
719
+ if prev_key_padding_mask is not None and static_kv:
720
+ new_key_padding_mask = prev_key_padding_mask
721
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
722
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
723
+ # During incremental decoding, as the padding token enters and
724
+ # leaves the frame, there will be a time when prev or current
725
+ # is None
726
+ elif prev_key_padding_mask is not None:
727
+ if src_len > prev_key_padding_mask.size(1):
728
+ filler = torch.zeros(
729
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
730
+ device=prev_key_padding_mask.device,
731
+ )
732
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
733
+ else:
734
+ new_key_padding_mask = prev_key_padding_mask.float()
735
+ elif key_padding_mask is not None:
736
+ if src_len > key_padding_mask.size(1):
737
+ filler = torch.zeros(
738
+ (batch_size, src_len - key_padding_mask.size(1)),
739
+ device=key_padding_mask.device,
740
+ )
741
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
742
+ else:
743
+ new_key_padding_mask = key_padding_mask.float()
744
+ else:
745
+ new_key_padding_mask = prev_key_padding_mask
746
+ return new_key_padding_mask
747
+
748
+ def _get_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]) -> Dict[str, Optional[Tensor]]:
749
+ result = self.get_incremental_state(incremental_state, "attn_state")
750
+ if result is not None:
751
+ return result
752
+ empty_result: Dict[str, Optional[Tensor]] = {}
753
+ return empty_result
754
+
755
+ def _set_input_buffer(
756
+ self,
757
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
758
+ buffer: Dict[str, Optional[Tensor]],
759
+ ):
760
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
761
+
762
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): # pylint: disable=W0613
763
+ return attn_weights
knn_tts/ssl/__init__.py ADDED
File without changes
knn_tts/ssl/models.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024 Idiap Research Institute
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ #
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ import torch
7
+ import torchaudio
8
+
9
+ from knn_tts.ssl.WavLM.WavLM import WavLMConfig, WavLMModel
10
+
11
+
12
+ class WavLM:
13
+ def __init__(self, device=None):
14
+ """
15
+ Load the WavLM-Large model (see: https://github.com/microsoft/unilm/tree/master/wavlm).
16
+ Torch hub checkpoint from https://github.com/bshall/knn-vc
17
+ """
18
+ if device is None:
19
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ else:
21
+ self.device = torch.device(device)
22
+
23
+ checkpoint_url = "https://github.com/bshall/knn-vc/releases/download/v0.1/WavLM-Large.pt"
24
+ checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=self.device, progress=True)
25
+
26
+ cfg = WavLMConfig(checkpoint["cfg"])
27
+ self.model = WavLMModel(cfg)
28
+ self.model.load_state_dict(checkpoint["model"])
29
+ self.model = self.model.to(self.device)
30
+ self.model.eval()
31
+
32
+ self.model.sample_rate = 16000
33
+ self.SPEAKER_INFORMATION_LAYER = 6
34
+ num_parameters = sum(p.numel() for p in self.model.parameters())
35
+ print(f"WavLM-Large loaded with {num_parameters:,d} parameters.")
36
+
37
+ @torch.no_grad()
38
+ def extract_framewise_features(self, wav_path, output_layer=None):
39
+ audio, orig_sr = torchaudio.load(wav_path)
40
+ audio = audio.to(self.device)
41
+ if orig_sr != self.model.sample_rate:
42
+ audio = torchaudio.functional.resample(audio, orig_freq=orig_sr, new_freq=self.model.sample_rate)
43
+
44
+ if output_layer is None:
45
+ output_layer = self.SPEAKER_INFORMATION_LAYER
46
+
47
+ embeddings = self.model.extract_features(audio, output_layer=output_layer, ret_layer_results=False)[0]
48
+ embeddings = embeddings.squeeze(0)
49
+ return embeddings # [frame_len, 1024]
knn_tts/synthesizer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024 Idiap Research Institute
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ #
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ import torch
7
+ import torchaudio
8
+
9
+ from knn_tts.text_cleaners import clean_input_text
10
+ from knn_tts.tts.models import GlowTTS
11
+ from knn_tts.vc.knn import knn_vc, load_target_style_feats
12
+ from knn_tts.vocoder.models import HiFiGANWavLM
13
+
14
+
15
+ class Synthesizer:
16
+ def __init__(
17
+ self,
18
+ tts_model_base_path,
19
+ tts_model_checkpoint,
20
+ vocoder_checkpoint_path,
21
+ model_name="glowtts",
22
+ ):
23
+ self.model_name = model_name
24
+ self.model = GlowTTS(tts_model_base_path, tts_model_checkpoint)
25
+ self.vocoder = HiFiGANWavLM(checkpoint_path=vocoder_checkpoint_path, device=self.model.device)
26
+ self.target_style_feats_path = None
27
+ self.target_style_feats = None
28
+
29
+ def __call__(
30
+ self,
31
+ text_input,
32
+ target_style_feats_path,
33
+ knnvc_topk=4,
34
+ weighted_average=False,
35
+ interpolation_rate=1.0,
36
+ save_path=None,
37
+ timesteps=10,
38
+ max_target_num_files=1000,
39
+ ):
40
+ with torch.no_grad():
41
+ # Text-to-SSL
42
+ text_input = clean_input_text(text_input)
43
+ tts_feats = self.model(text_input, timesteps) # timesteps are used for GradTTS only
44
+
45
+ # kNN-VC
46
+ if interpolation_rate != 0.0:
47
+
48
+ if target_style_feats_path != self.target_style_feats_path:
49
+ self.target_style_feats_path = target_style_feats_path
50
+ self.target_style_feats = load_target_style_feats(target_style_feats_path, max_target_num_files)
51
+
52
+ selected_feats = knn_vc(
53
+ tts_feats,
54
+ self.target_style_feats,
55
+ topk=knnvc_topk,
56
+ weighted_average=weighted_average,
57
+ device=self.model.device,
58
+ )
59
+
60
+ converted_feats = interpolation_rate * selected_feats + (1.0 - interpolation_rate) * tts_feats
61
+ else:
62
+ converted_feats = tts_feats
63
+
64
+ # Vocoder
65
+ wav = self.vocoder(converted_feats.unsqueeze(0)).unsqueeze(0)
66
+
67
+ if save_path is not None:
68
+ torchaudio.save(save_path, wav, 16000)
69
+ return wav
knn_tts/text_cleaners.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024 Idiap Research Institute
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ #
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ import re
7
+
8
+ from num2words import num2words
9
+
10
+ PUNCTUATION_TO_REPLACE_WITH_COMMA = ["(", ")", ":"]
11
+ PUNCTUATION = ".!?,"
12
+ SPACE = " "
13
+
14
+
15
+ def replace_some_punctuation_with_comma(text):
16
+ for punct in PUNCTUATION_TO_REPLACE_WITH_COMMA:
17
+ text = text.replace(punct, " , ")
18
+ return text
19
+
20
+
21
+ def split_string_on_punctuation(s):
22
+ substrings = []
23
+ current_substring = ""
24
+
25
+ for char in s:
26
+ current_substring += char
27
+ if char in PUNCTUATION:
28
+ substrings.append(current_substring.strip())
29
+ current_substring = ""
30
+
31
+ if current_substring:
32
+ substrings.append(current_substring)
33
+
34
+ return substrings
35
+
36
+
37
+ def remove_punctuation_only_substrings(substrings):
38
+ new_substrings = []
39
+ for substring in substrings:
40
+ if not all(c in PUNCTUATION + SPACE for c in substring):
41
+ new_substrings.append(substring)
42
+ return new_substrings
43
+
44
+
45
+ def clean_punctuation(text):
46
+ text = replace_some_punctuation_with_comma(text)
47
+ substrings = split_string_on_punctuation(text)
48
+ substrings = remove_punctuation_only_substrings(substrings)
49
+ text = SPACE.join(substrings)
50
+ return text
51
+
52
+
53
+ def clean_input_text(text):
54
+ text = text.lower().strip()
55
+ text = clean_punctuation(text)
56
+
57
+ def replace_numbers(match):
58
+ return num2words(int(match.group()))
59
+
60
+ text = re.sub(r"\d+", replace_numbers, text)
61
+ text = " ".join(text.split()) # Remove extra spaces
62
+
63
+ if text and text[-1] not in PUNCTUATION:
64
+ text += "."
65
+
66
+ return text
knn_tts/tts/GlowTTS/__init__.py ADDED
File without changes
knn_tts/tts/GlowTTS/glow_tts_ssl.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Coqui AI
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ #
4
+ # SPDX-License-Identifier: MPL-2.0
5
+
6
+ import collections
7
+ import os
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch import nn
13
+ from torch.cuda.amp.autocast_mode import autocast
14
+ from torch.utils.data import DataLoader
15
+ from TTS.tts.configs.glow_tts_config import GlowTTSConfig
16
+ from TTS.tts.datasets.dataset import TTSDataset, noise_augment_audio
17
+ from TTS.tts.models.glow_tts import GlowTTS
18
+ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
19
+ from TTS.tts.utils.synthesis import (
20
+ embedding_to_torch,
21
+ id_to_torch,
22
+ numpy_to_torch,
23
+ run_model_torch,
24
+ trim_silence,
25
+ )
26
+ from TTS.tts.utils.text.tokenizer import TTSTokenizer
27
+ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
28
+ from TTS.utils.audio import AudioProcessor
29
+
30
+ from knn_tts.vocoder.models import HiFiGANWavLM
31
+
32
+
33
+ def load_glow_tts_ssl(model_base_path, checkpoint, hifigan_checkpoint_path, device):
34
+ checkpoint_path = os.path.join(model_base_path, checkpoint)
35
+ config_path = os.path.join(model_base_path, "config.json")
36
+
37
+ config = GlowTTSConfig()
38
+ config.load_json(config_path)
39
+ ap = AudioProcessor.init_from_config(config)
40
+ tokenizer, _ = TTSTokenizer.init_from_config(config)
41
+
42
+ model = GlowTTSSSL(
43
+ config=config,
44
+ ap=ap,
45
+ tokenizer=tokenizer,
46
+ speaker_manager=None,
47
+ hifigan_checkpoint_path=hifigan_checkpoint_path,
48
+ )
49
+ model.load_checkpoint(config=config, checkpoint_path=checkpoint_path)
50
+ model.to(device)
51
+ model.eval()
52
+ return model
53
+
54
+
55
+ def ssl_synthesis(
56
+ model,
57
+ text,
58
+ vocoder,
59
+ return_ssl_feats=False,
60
+ speaker_id=None,
61
+ do_trim_silence=False,
62
+ d_vector=None,
63
+ language_id=None,
64
+ ):
65
+ # device
66
+ device = next(model.parameters()).device
67
+
68
+ language_name = None
69
+ if language_id is not None:
70
+ language = [k for k, v in model.language_manager.name_to_id.items() if v == language_id]
71
+ assert len(language) == 1, "language_id must be a valid language"
72
+ language_name = language[0]
73
+
74
+ # convert text to sequence of token IDs
75
+ text_inputs = np.asarray(
76
+ model.tokenizer.text_to_ids(text, language=language_name),
77
+ dtype=np.int32,
78
+ )
79
+
80
+ # pass tensors to backend
81
+ if speaker_id is not None:
82
+ speaker_id = id_to_torch(speaker_id, device=device)
83
+
84
+ if d_vector is not None:
85
+ d_vector = embedding_to_torch(d_vector, device=device)
86
+
87
+ if language_id is not None:
88
+ language_id = id_to_torch(language_id, device=device)
89
+
90
+ text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
91
+ text_inputs = text_inputs.unsqueeze(0)
92
+ # synthesize voice
93
+ outputs = run_model_torch(
94
+ model,
95
+ text_inputs,
96
+ speaker_id,
97
+ style_mel=None,
98
+ style_text=None,
99
+ d_vector=d_vector,
100
+ language_id=language_id,
101
+ )
102
+ model_outputs = outputs["model_outputs"]
103
+
104
+ if return_ssl_feats:
105
+ ssl_feats = model_outputs.squeeze()
106
+ return ssl_feats
107
+
108
+ model_outputs = model_outputs[0].data.cpu().numpy()
109
+ alignments = outputs["alignments"]
110
+
111
+ ssl_feats = model_outputs.squeeze()
112
+ wav = vocoder(torch.from_numpy(ssl_feats).unsqueeze(0)).unsqueeze(0).cpu().numpy()
113
+
114
+ if do_trim_silence:
115
+ wav = trim_silence(wav, model.ap)
116
+
117
+ return_dict = {
118
+ "wav": wav,
119
+ "alignments": alignments,
120
+ "text_inputs": text_inputs,
121
+ "outputs": outputs,
122
+ }
123
+ return return_dict
124
+
125
+
126
+ class GlowTTSSSLDataset(TTSDataset):
127
+ def __init__(self, *args, **kwargs):
128
+ super().__init__(*args, **kwargs)
129
+ self.pad_id = self.tokenizer.characters.pad_id
130
+
131
+ def load_data(self, idx):
132
+ item = self.samples[idx]
133
+
134
+ raw_text = item["text"]
135
+
136
+ wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
137
+
138
+ # apply noise for augmentation
139
+ if self.use_noise_augment:
140
+ wav = noise_augment_audio(wav)
141
+
142
+ # get token ids
143
+ token_ids = self.get_token_ids(idx, item["text"])
144
+
145
+ # get pre-computed attention maps
146
+ attn = None
147
+ if "alignment_file" in item:
148
+ attn = self.get_attn_mask(item["alignment_file"])
149
+
150
+ # after phonemization the text length may change
151
+ # this is a shareful 🤭 hack to prevent longer phonemes
152
+ # TODO: find a better fix # pylint: disable=fixme
153
+ if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len:
154
+ self.rescue_item_idx += 1
155
+ return self.load_data(self.rescue_item_idx)
156
+
157
+ # get f0 values
158
+ f0 = None
159
+ if self.compute_f0:
160
+ f0 = self.get_f0(idx)["f0"]
161
+ energy = None
162
+ if self.compute_energy:
163
+ energy = self.get_energy(idx)["energy"]
164
+
165
+ ssl_feats = torch.load(item["ssl_feats_file"])
166
+
167
+ if len(ssl_feats.size()) == 2:
168
+ ssl_feats = ssl_feats.unsqueeze(0)
169
+
170
+ sample = {
171
+ "raw_text": raw_text,
172
+ "token_ids": token_ids,
173
+ "wav": wav,
174
+ "pitch": f0,
175
+ "energy": energy,
176
+ "attn": attn,
177
+ "item_idx": item["audio_file"],
178
+ "speaker_name": item["speaker_name"],
179
+ "language_name": item["language"],
180
+ "wav_file_name": os.path.basename(item["audio_file"]),
181
+ "audio_unique_name": item["audio_unique_name"],
182
+ "ssl_feats": ssl_feats,
183
+ }
184
+ return sample
185
+
186
+ # pylint: disable=too-many-branches, too-many-statements
187
+ def collate_fn(self, batch):
188
+ # Puts each data field into a tensor with outer dimension batch size
189
+ if isinstance(batch[0], collections.abc.Mapping):
190
+ B = len(batch)
191
+
192
+ token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
193
+
194
+ # sort items with text input length for RNN efficiency
195
+ batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths)
196
+
197
+ # convert list of dicts to dict of lists
198
+ batch = {k: [dic[k] for dic in batch] for k in batch[0]}
199
+
200
+ # get language ids from language names
201
+ if self.language_id_mapping is not None:
202
+ language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
203
+ else:
204
+ language_ids = None
205
+ # get pre-computed d-vectors
206
+ if self.d_vector_mapping is not None:
207
+ embedding_keys = list(batch["audio_unique_name"])
208
+ d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys]
209
+ else:
210
+ d_vectors = None
211
+
212
+ # get numerical speaker ids from speaker names
213
+ if self.speaker_id_mapping:
214
+ speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
215
+ else:
216
+ speaker_ids = None
217
+ # compute features
218
+ mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
219
+
220
+ mel_lengths = [m.shape[1] for m in mel]
221
+
222
+ # lengths adjusted by the reduction factor
223
+ mel_lengths_adjusted = [m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) if m.shape[1] % self.outputs_per_step else m.shape[1] for m in mel]
224
+
225
+ # PAD sequences with longest instance in the batch
226
+ token_ids = prepare_data(batch["token_ids"]).astype(np.int32)
227
+
228
+ # PAD features with longest instance
229
+ mel = prepare_tensor(mel, self.outputs_per_step)
230
+
231
+ # B x D x T --> B x T x D
232
+ mel = mel.transpose(0, 2, 1)
233
+
234
+ # convert things to pytorch
235
+ token_ids_lengths = torch.LongTensor(token_ids_lengths)
236
+ token_ids = torch.LongTensor(token_ids)
237
+ mel = torch.FloatTensor(mel).contiguous()
238
+ mel_lengths = torch.LongTensor(mel_lengths)
239
+
240
+ # format ssl_feats
241
+ ssl_feats_dim = batch["ssl_feats"][0].shape[2]
242
+ ssl_feats_lens = [emb.shape[1] for emb in batch["ssl_feats"]]
243
+ ssl_feats_lens = torch.LongTensor(ssl_feats_lens)
244
+ ssl_feats_lens_max = torch.max(ssl_feats_lens)
245
+ ssl_feats_rel_lens = ssl_feats_lens / ssl_feats_lens_max
246
+ ssl_feats_padded = torch.FloatTensor(B, ssl_feats_lens_max, ssl_feats_dim)
247
+ ssl_feats_padded = ssl_feats_padded.zero_() + self.pad_id
248
+
249
+ for i, emb in enumerate(batch["ssl_feats"]):
250
+ ssl_feats_padded[i, : emb.size(1), :] = torch.FloatTensor(emb)
251
+
252
+ stop_targets = [np.array([0.0] * (ssl_len - 1) + [1.0]) for ssl_len in ssl_feats_lens]
253
+ stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
254
+ stop_targets = torch.FloatTensor(stop_targets)
255
+
256
+ # speaker vectors
257
+ if d_vectors is not None:
258
+ d_vectors = torch.FloatTensor(d_vectors)
259
+
260
+ if speaker_ids is not None:
261
+ speaker_ids = torch.LongTensor(speaker_ids)
262
+
263
+ if language_ids is not None:
264
+ language_ids = torch.LongTensor(language_ids)
265
+
266
+ # compute linear spectrogram
267
+ linear = None
268
+ if self.compute_linear_spec:
269
+ linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
270
+ linear = prepare_tensor(linear, self.outputs_per_step)
271
+ linear = linear.transpose(0, 2, 1)
272
+ assert mel.shape[1] == linear.shape[1]
273
+ linear = torch.FloatTensor(linear).contiguous()
274
+
275
+ # format waveforms
276
+ wav_padded = None
277
+ if self.return_wav:
278
+ wav_lengths = [w.shape[0] for w in batch["wav"]]
279
+ # max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
280
+ ssl_feats_hop_length = self.ap.hop_length
281
+ max_wav_len = max(ssl_feats_lens) * ssl_feats_hop_length
282
+ wav_lengths = torch.LongTensor(wav_lengths)
283
+ wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
284
+ for i, w in enumerate(batch["wav"]):
285
+ mel_length = mel_lengths_adjusted[i]
286
+ w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
287
+ w = w[: mel_length * self.ap.hop_length]
288
+ wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
289
+ wav_padded.transpose_(1, 2)
290
+
291
+ # format F0
292
+ if self.compute_f0:
293
+ pitch = prepare_data(batch["pitch"])
294
+ assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
295
+ pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
296
+ else:
297
+ pitch = None
298
+ # format energy
299
+ if self.compute_energy:
300
+ energy = prepare_data(batch["energy"])
301
+ assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}"
302
+ energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT
303
+ else:
304
+ energy = None
305
+ # format attention masks
306
+ attns = None
307
+ if batch["attn"][0] is not None:
308
+ attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
309
+ for idx, attn in enumerate(attns):
310
+ pad2 = mel.shape[1] - attn.shape[1]
311
+ pad1 = token_ids.shape[1] - attn.shape[0]
312
+ assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
313
+ attn = np.pad(attn, [[0, pad1], [0, pad2]])
314
+ attns[idx] = attn
315
+ attns = prepare_tensor(attns, self.outputs_per_step)
316
+ attns = torch.FloatTensor(attns).unsqueeze(1)
317
+
318
+ return {
319
+ "token_id": token_ids,
320
+ "token_id_lengths": token_ids_lengths,
321
+ "speaker_names": batch["speaker_name"],
322
+ "linear": linear,
323
+ "mel": mel,
324
+ "mel_lengths": mel_lengths,
325
+ "stop_targets": stop_targets,
326
+ "item_idxs": batch["item_idx"],
327
+ "d_vectors": d_vectors,
328
+ "speaker_ids": speaker_ids,
329
+ "attns": attns,
330
+ "waveform": wav_padded,
331
+ "raw_text": batch["raw_text"],
332
+ "pitch": pitch,
333
+ "energy": energy,
334
+ "language_ids": language_ids,
335
+ "audio_unique_names": batch["audio_unique_name"],
336
+ "ssl_feats": ssl_feats_padded,
337
+ "ssl_feats_lens": ssl_feats_lens,
338
+ "ssl_feats_rel_lens": ssl_feats_rel_lens,
339
+ }
340
+ return None
341
+
342
+
343
+ class GlowTTSSSL(GlowTTS):
344
+ def __init__(
345
+ self,
346
+ config: GlowTTSConfig,
347
+ ap=None,
348
+ tokenizer=None,
349
+ speaker_manager=None,
350
+ hifigan_checkpoint_path: str = "checkpoints/hifigan-wavlm/prematch_g_02500000.pt",
351
+ ):
352
+ super().__init__(config, ap, tokenizer, speaker_manager)
353
+ if hifigan_checkpoint_path is None:
354
+ self.vocoder = None
355
+ else:
356
+ self.vocoder = HiFiGANWavLM(
357
+ checkpoint_path=hifigan_checkpoint_path,
358
+ device=str(next(self.parameters()).device),
359
+ )
360
+
361
+ def train_step(self, batch: dict, criterion: nn.Module):
362
+ text_input = batch["text_input"]
363
+ text_lengths = batch["text_lengths"]
364
+ ssl_feats_input = batch["ssl_feats"]
365
+ ssl_feats_lengths = batch["ssl_feats_lens"]
366
+ d_vectors = batch["d_vectors"]
367
+ speaker_ids = batch["speaker_ids"]
368
+
369
+ if self.run_data_dep_init and self.training:
370
+ # compute data-dependent initialization of activation norm layers
371
+ self.unlock_act_norm_layers()
372
+ with torch.no_grad():
373
+ _ = self.forward(
374
+ text_input,
375
+ text_lengths,
376
+ ssl_feats_input,
377
+ ssl_feats_lengths,
378
+ aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
379
+ )
380
+ outputs = None
381
+ loss_dict = None
382
+ self.lock_act_norm_layers()
383
+ else:
384
+ # normal training step
385
+ outputs = self.forward(
386
+ text_input,
387
+ text_lengths,
388
+ ssl_feats_input,
389
+ ssl_feats_lengths,
390
+ aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
391
+ )
392
+
393
+ with autocast(enabled=False): # avoid mixed_precision in criterion
394
+ loss_dict = criterion(
395
+ outputs["z"].float(),
396
+ outputs["y_mean"].float(),
397
+ outputs["y_log_scale"].float(),
398
+ outputs["logdet"].float(),
399
+ ssl_feats_lengths,
400
+ outputs["durations_log"].float(),
401
+ outputs["total_durations_log"].float(),
402
+ text_lengths,
403
+ )
404
+
405
+ return outputs, loss_dict
406
+
407
+ @torch.no_grad()
408
+ def _create_logs(self, batch, outputs, ap):
409
+ alignments = outputs["alignments"]
410
+ text_input = batch["text_input"][:1] if batch["text_input"] is not None else None
411
+ text_lengths = batch["text_lengths"]
412
+ mel_input = batch["mel_input"]
413
+ d_vectors = batch["d_vectors"][:1] if batch["d_vectors"] is not None else None
414
+ speaker_ids = batch["speaker_ids"][:1] if batch["speaker_ids"] is not None else None
415
+
416
+ gt_spec = mel_input[0].data.cpu().numpy()
417
+ align_img = alignments[0].data.cpu().numpy()
418
+
419
+ # model runs reverse flow to predict spectrograms
420
+ with torch.no_grad():
421
+ pred_outputs = self.inference(
422
+ text_input,
423
+ aux_input={
424
+ "x_lengths": text_lengths[:1],
425
+ "d_vectors": d_vectors,
426
+ "speaker_ids": speaker_ids,
427
+ },
428
+ )
429
+
430
+ pred_ssl_feats = pred_outputs["model_outputs"]
431
+
432
+ pred_audio = self.vocoder(pred_ssl_feats).unsqueeze(0).cpu().detach().numpy()
433
+
434
+ pred_spec = self.ap.melspectrogram(pred_audio).astype("float32")
435
+ pred_spec = pred_spec.squeeze(axis=1)
436
+ pred_spec = pred_spec.T
437
+
438
+ figures = {
439
+ "prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
440
+ "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
441
+ "alignment": plot_alignment(align_img, output_fig=False),
442
+ }
443
+
444
+ return figures, {"audio": pred_audio}
445
+
446
+ def get_data_loader(
447
+ self,
448
+ config,
449
+ assets,
450
+ is_eval,
451
+ samples,
452
+ verbose: bool,
453
+ num_gpus: int,
454
+ rank: int = None,
455
+ ) -> "DataLoader":
456
+ if is_eval and not config.run_eval:
457
+ loader = None
458
+ else:
459
+ # setup multi-speaker attributes
460
+ if self.speaker_manager is not None:
461
+ if hasattr(config, "model_args"):
462
+ speaker_id_mapping = self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None
463
+ d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
464
+ config.use_d_vector_file = config.model_args.use_d_vector_file
465
+ else:
466
+ speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None
467
+ d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
468
+ else:
469
+ speaker_id_mapping = None
470
+ d_vector_mapping = None
471
+
472
+ # setup multi-lingual attributes
473
+ if self.language_manager is not None:
474
+ language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None
475
+ else:
476
+ language_id_mapping = None
477
+
478
+ # init dataloader
479
+ dataset = GlowTTSSSLDataset(
480
+ outputs_per_step=config.r if "r" in config else 1,
481
+ compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
482
+ compute_f0=config.get("compute_f0", False),
483
+ f0_cache_path=config.get("f0_cache_path", None),
484
+ compute_energy=config.get("compute_energy", False),
485
+ energy_cache_path=config.get("energy_cache_path", None),
486
+ samples=samples,
487
+ ap=self.ap,
488
+ return_wav=config.return_wav if "return_wav" in config else False,
489
+ batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
490
+ min_text_len=config.min_text_len,
491
+ max_text_len=config.max_text_len,
492
+ min_audio_len=config.min_audio_len,
493
+ max_audio_len=config.max_audio_len,
494
+ phoneme_cache_path=config.phoneme_cache_path,
495
+ precompute_num_workers=config.precompute_num_workers,
496
+ use_noise_augment=False if is_eval else config.use_noise_augment,
497
+ speaker_id_mapping=speaker_id_mapping,
498
+ d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
499
+ tokenizer=self.tokenizer,
500
+ start_by_longest=config.start_by_longest,
501
+ language_id_mapping=language_id_mapping,
502
+ )
503
+
504
+ # wait all the DDP process to be ready
505
+ if num_gpus > 1:
506
+ dist.barrier()
507
+
508
+ # sort input sequences from short to long
509
+ dataset.preprocess_samples()
510
+
511
+ # get samplers
512
+ sampler = self.get_sampler(config, dataset, num_gpus)
513
+
514
+ loader = DataLoader(
515
+ dataset,
516
+ batch_size=config.eval_batch_size if is_eval else config.batch_size,
517
+ shuffle=config.shuffle if sampler is None else False, # if there is no other sampler
518
+ collate_fn=dataset.collate_fn,
519
+ drop_last=config.drop_last, # setting this False might cause issues in AMP training.
520
+ sampler=sampler,
521
+ num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
522
+ pin_memory=False,
523
+ )
524
+ return loader
525
+
526
+ def format_batch(self, batch):
527
+ # setup input batch
528
+ text_input = batch["token_id"]
529
+ text_lengths = batch["token_id_lengths"]
530
+ speaker_names = batch["speaker_names"]
531
+ linear_input = batch["linear"]
532
+ mel_input = batch["mel"]
533
+ mel_lengths = batch["mel_lengths"]
534
+ stop_targets = batch["stop_targets"]
535
+ item_idx = batch["item_idxs"]
536
+ d_vectors = batch["d_vectors"]
537
+ speaker_ids = batch["speaker_ids"]
538
+ attn_mask = batch["attns"]
539
+ waveform = batch["waveform"]
540
+ pitch = batch["pitch"]
541
+ energy = batch["energy"]
542
+ language_ids = batch["language_ids"]
543
+ max_text_length = torch.max(text_lengths.float())
544
+ max_spec_length = torch.max(mel_lengths.float())
545
+
546
+ # compute durations from attention masks
547
+ durations = None
548
+ if attn_mask is not None:
549
+ durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
550
+ for idx, am in enumerate(attn_mask):
551
+ # compute raw durations
552
+ c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
553
+ # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
554
+ c_idxs, counts = torch.unique(c_idxs, return_counts=True)
555
+ dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
556
+ dur[c_idxs] = counts
557
+ # smooth the durations and set any 0 duration to 1
558
+ # by cutting off from the largest duration indeces.
559
+ extra_frames = dur.sum() - mel_lengths[idx]
560
+ largest_idxs = torch.argsort(-dur)[:extra_frames]
561
+ dur[largest_idxs] -= 1
562
+ assert dur.sum() == mel_lengths[idx], f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
563
+ durations[idx, : text_lengths[idx]] = dur
564
+
565
+ # set stop targets wrt reduction factor
566
+ stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
567
+ stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
568
+ stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
569
+
570
+ return {
571
+ "text_input": text_input,
572
+ "text_lengths": text_lengths,
573
+ "speaker_names": speaker_names,
574
+ "mel_input": mel_input,
575
+ "mel_lengths": mel_lengths,
576
+ "linear_input": linear_input,
577
+ "stop_targets": stop_targets,
578
+ "stop_target_lengths": stop_target_lengths,
579
+ "attn_mask": attn_mask,
580
+ "durations": durations,
581
+ "speaker_ids": speaker_ids,
582
+ "d_vectors": d_vectors,
583
+ "max_text_length": float(max_text_length),
584
+ "max_spec_length": float(max_spec_length),
585
+ "item_idx": item_idx,
586
+ "waveform": waveform,
587
+ "pitch": pitch,
588
+ "energy": energy,
589
+ "language_ids": language_ids,
590
+ "audio_unique_names": batch["audio_unique_names"],
591
+ "ssl_feats": batch["ssl_feats"],
592
+ "ssl_feats_lens": batch["ssl_feats_lens"],
593
+ "ssl_feats_rel_lens": batch["ssl_feats_rel_lens"],
594
+ }
595
+
596
+ @torch.no_grad()
597
+ def test_run(self, assets):
598
+ print(" | > Synthesizing test sentences.")
599
+ test_audios = {}
600
+ test_figures = {}
601
+ test_sentences = self.config.test_sentences
602
+ aux_inputs = self._get_test_aux_input()
603
+ if len(test_sentences) == 0:
604
+ print(" | [!] No test sentences provided.")
605
+ else:
606
+ for idx, sen in enumerate(test_sentences):
607
+ outputs = ssl_synthesis(
608
+ self,
609
+ sen,
610
+ self.vocoder,
611
+ speaker_id=aux_inputs["speaker_id"],
612
+ d_vector=aux_inputs["d_vector"],
613
+ do_trim_silence=False,
614
+ )
615
+
616
+ test_audios[f"{idx}-audio"] = outputs["wav"]
617
+ test_figures[f"{idx}-prediction"] = plot_spectrogram(outputs["outputs"]["model_outputs"], self.ap, output_fig=False)
618
+ test_figures[f"{idx}-alignment"] = plot_alignment(outputs["alignments"], output_fig=False)
619
+ return {"figures": test_figures, "audios": test_audios}
620
+
621
+ def test_log(self, outputs, logger, assets, steps) -> None: # pylint:disable=W0613
622
+ logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
623
+ logger.test_figures(steps, outputs["figures"])
knn_tts/tts/__init__.py ADDED
File without changes
knn_tts/tts/models.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024 Idiap Research Institute
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ #
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ import torch
7
+
8
+ from knn_tts.tts.GlowTTS.glow_tts_ssl import load_glow_tts_ssl, ssl_synthesis
9
+
10
+
11
+ class GlowTTS:
12
+ def __init__(
13
+ self,
14
+ tts_model_base_path,
15
+ tts_model_checkpoint,
16
+ vocoder_checkpoint_path=None,
17
+ device=None,
18
+ ):
19
+ if device is None:
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ else:
22
+ self.device = torch.device(device)
23
+ self.model = load_glow_tts_ssl(
24
+ tts_model_base_path,
25
+ tts_model_checkpoint,
26
+ vocoder_checkpoint_path,
27
+ self.device,
28
+ )
29
+
30
+ def __call__(self, text, *args, **kwargs):
31
+ with torch.no_grad():
32
+ ssl_feats_out = ssl_synthesis(self.model, text, None, return_ssl_feats=True)
33
+ return ssl_feats_out
knn_tts/utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024 Idiap Research Institute
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ #
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ import os
7
+
8
+ import requests
9
+
10
+
11
+ def find_wav_paths(root_dir):
12
+ wav_paths = []
13
+ for subdir, _, files in os.walk(root_dir):
14
+ for file in files:
15
+ if file.endswith(".wav"):
16
+ wav_paths.append(os.path.join(subdir, file))
17
+ return wav_paths
18
+
19
+ def get_vocoder_checkpoint_path(checkpoints_dir):
20
+ os.makedirs(checkpoints_dir, exist_ok=True) # Ensure directory exists
21
+
22
+ checkpoint_path = os.path.join(checkpoints_dir, "prematch_g_02500000.pt")
23
+ url = "https://github.com/bshall/knn-vc/releases/download/v0.1/prematch_g_02500000.pt"
24
+
25
+ if not os.path.exists(checkpoint_path):
26
+ print(f"Downloading checkpoint to {checkpoint_path}...")
27
+ response = requests.get(url, stream=True)
28
+ if response.status_code == 200:
29
+ with open(checkpoint_path, "wb") as f:
30
+ for chunk in response.iter_content(chunk_size=8192):
31
+ f.write(chunk)
32
+ print("Download complete.")
33
+ else:
34
+ raise Exception(f"Failed to download checkpoint: {response.status_code}")
35
+
36
+ return checkpoint_path
knn_tts/vc/__init__.py ADDED
File without changes
knn_tts/vc/knn.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2023 MediaLab, Department of Electrical & Electronic Engineering, Stellenbosch University
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ #
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ import os
7
+
8
+ import torch
9
+
10
+
11
+ def load_target_style_feats(feats_base_path, max_num_files=1000):
12
+ feats = []
13
+ for filepath in os.listdir(feats_base_path)[:max_num_files]:
14
+ if ".pt" in filepath:
15
+ filepath = os.path.join(feats_base_path, filepath)
16
+ feats.append(torch.load(filepath, weights_only=False))
17
+ feats = torch.concat(feats, dim=0).cpu()
18
+ return feats
19
+
20
+
21
+ def fast_cosine_dist(source_feats, matching_pool, device):
22
+ """Like torch.cdist, but fixed dim=-1 and for cosine distance."""
23
+ source_norms = torch.norm(source_feats, p=2, dim=-1).to(device)
24
+ matching_norms = torch.norm(matching_pool, p=2, dim=-1)
25
+ dotprod = -(torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0] ** 2) + source_norms[:, None] ** 2 + matching_norms[None] ** 2
26
+ dotprod /= 2
27
+
28
+ dists = 1 - (dotprod / (source_norms[:, None] * matching_norms[None]))
29
+ return dists
30
+
31
+
32
+ @torch.inference_mode()
33
+ def knn_vc(source_frames, target_style_set, topk=4, weighted_average=False, device=None):
34
+ if device is None:
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ else:
37
+ device = torch.device(device)
38
+ target_style_set = target_style_set.to(device)
39
+ source_frames = source_frames.to(device)
40
+
41
+ dists = fast_cosine_dist(source_frames, target_style_set, device=device)
42
+ best = dists.topk(k=topk, largest=False, dim=-1)
43
+
44
+ if weighted_average:
45
+ weights = 1 / (best.values + 1e-8) # Adding a small value to avoid division by zero
46
+ weights /= weights.sum(dim=-1, keepdim=True) # Normalize weights
47
+ selected_frames = (target_style_set[best.indices] * weights[..., None]).sum(dim=1)
48
+ else:
49
+ selected_frames = target_style_set[best.indices].mean(dim=1)
50
+
51
+ return selected_frames
52
+
knn_tts/vocoder/__init__.py ADDED
File without changes
knn_tts/vocoder/hifigan/__init__.py ADDED
File without changes
knn_tts/vocoder/hifigan/config_v1_wavlm.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 16,
5
+ "learning_rate": 0.0002,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.999,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [10,8,2,2],
12
+ "upsample_kernel_sizes": [20,16,4,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "hubert_dim": 1024,
18
+ "hifi_dim": 512,
19
+
20
+ "segment_size": 7040,
21
+ "num_mels": 80,
22
+ "num_freq": 1025,
23
+ "n_fft": 1024,
24
+ "hop_size": 320,
25
+ "win_size": 1024,
26
+
27
+ "sampling_rate": 16000,
28
+
29
+ "fmin": 0,
30
+ "fmax": 8000,
31
+ "fmax_for_loss": null,
32
+
33
+ "num_workers": 4,
34
+
35
+ "dist_config": {
36
+ "dist_backend": "nccl",
37
+ "dist_url": "tcp://localhost:54321",
38
+ "world_size": 1
39
+ }
40
+ }
knn_tts/vocoder/hifigan/models.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2020 Jungil Kong
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
9
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
10
+
11
+ from knn_tts.vocoder.hifigan.utils import get_padding, init_weights
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+
16
+ class ResBlock1(torch.nn.Module):
17
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
18
+ super().__init__()
19
+ self.h = h
20
+ self.convs1 = nn.ModuleList(
21
+ [
22
+ weight_norm(
23
+ Conv1d(
24
+ channels,
25
+ channels,
26
+ kernel_size,
27
+ 1,
28
+ dilation=dilation[0],
29
+ padding=get_padding(kernel_size, dilation[0]),
30
+ )
31
+ ),
32
+ weight_norm(
33
+ Conv1d(
34
+ channels,
35
+ channels,
36
+ kernel_size,
37
+ 1,
38
+ dilation=dilation[1],
39
+ padding=get_padding(kernel_size, dilation[1]),
40
+ )
41
+ ),
42
+ weight_norm(
43
+ Conv1d(
44
+ channels,
45
+ channels,
46
+ kernel_size,
47
+ 1,
48
+ dilation=dilation[2],
49
+ padding=get_padding(kernel_size, dilation[2]),
50
+ )
51
+ ),
52
+ ]
53
+ )
54
+ self.convs1.apply(init_weights)
55
+
56
+ self.convs2 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ 1,
64
+ dilation=1,
65
+ padding=get_padding(kernel_size, 1),
66
+ )
67
+ ),
68
+ weight_norm(
69
+ Conv1d(
70
+ channels,
71
+ channels,
72
+ kernel_size,
73
+ 1,
74
+ dilation=1,
75
+ padding=get_padding(kernel_size, 1),
76
+ )
77
+ ),
78
+ weight_norm(
79
+ Conv1d(
80
+ channels,
81
+ channels,
82
+ kernel_size,
83
+ 1,
84
+ dilation=1,
85
+ padding=get_padding(kernel_size, 1),
86
+ )
87
+ ),
88
+ ]
89
+ )
90
+ self.convs2.apply(init_weights)
91
+
92
+ def forward(self, x):
93
+ for c1, c2 in zip(self.convs1, self.convs2):
94
+ xt = F.leaky_relu(x, LRELU_SLOPE)
95
+ xt = c1(xt)
96
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
97
+ xt = c2(xt)
98
+ x = xt + x
99
+ return x
100
+
101
+ def remove_weight_norm(self):
102
+ for l in self.convs1:
103
+ remove_weight_norm(l)
104
+ for l in self.convs2:
105
+ remove_weight_norm(l)
106
+
107
+
108
+ class ResBlock2(torch.nn.Module):
109
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
110
+ super().__init__()
111
+ self.h = h
112
+ self.convs = nn.ModuleList(
113
+ [
114
+ weight_norm(
115
+ Conv1d(
116
+ channels,
117
+ channels,
118
+ kernel_size,
119
+ 1,
120
+ dilation=dilation[0],
121
+ padding=get_padding(kernel_size, dilation[0]),
122
+ )
123
+ ),
124
+ weight_norm(
125
+ Conv1d(
126
+ channels,
127
+ channels,
128
+ kernel_size,
129
+ 1,
130
+ dilation=dilation[1],
131
+ padding=get_padding(kernel_size, dilation[1]),
132
+ )
133
+ ),
134
+ ]
135
+ )
136
+ self.convs.apply(init_weights)
137
+
138
+ def forward(self, x):
139
+ for c in self.convs:
140
+ xt = F.leaky_relu(x, LRELU_SLOPE)
141
+ xt = c(xt)
142
+ x = xt + x
143
+ return x
144
+
145
+ def remove_weight_norm(self):
146
+ for l in self.convs:
147
+ remove_weight_norm(l)
148
+
149
+
150
+ class Generator(torch.nn.Module):
151
+ def __init__(self, h):
152
+ super().__init__()
153
+ self.h = h
154
+ self.lin_pre = nn.Linear(h.hubert_dim, h.hifi_dim)
155
+ self.num_kernels = len(h.resblock_kernel_sizes)
156
+ self.num_upsamples = len(h.upsample_rates)
157
+ self.conv_pre = weight_norm(Conv1d(h.hifi_dim, h.upsample_initial_channel, 7, 1, padding=3))
158
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
159
+
160
+ self.ups = nn.ModuleList()
161
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
162
+ self.ups.append(
163
+ weight_norm(
164
+ ConvTranspose1d(
165
+ h.upsample_initial_channel // (2**i),
166
+ h.upsample_initial_channel // (2 ** (i + 1)),
167
+ k,
168
+ u,
169
+ padding=(k - u) // 2,
170
+ )
171
+ )
172
+ )
173
+
174
+ self.resblocks = nn.ModuleList()
175
+ for i in range(len(self.ups)):
176
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
177
+ for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
178
+ self.resblocks.append(resblock(h, ch, k, d))
179
+
180
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
181
+ self.ups.apply(init_weights)
182
+ self.conv_post.apply(init_weights)
183
+
184
+ def forward(self, x):
185
+ """`x` as (bs, seq_len, dim), regular hifi assumes input of shape (bs, n_mels, seq_len)"""
186
+ x = self.lin_pre(x)
187
+ x = x.permute(0, 2, 1) # (bs, seq_len, dim) --> (bs, dim, seq_len)
188
+
189
+ x = self.conv_pre(x)
190
+ for i in range(self.num_upsamples):
191
+ x = F.leaky_relu(x, LRELU_SLOPE)
192
+ x = self.ups[i](x)
193
+ xs = None
194
+ for j in range(self.num_kernels):
195
+ if xs is None:
196
+ xs = self.resblocks[i * self.num_kernels + j](x)
197
+ else:
198
+ xs += self.resblocks[i * self.num_kernels + j](x)
199
+ x = xs / self.num_kernels
200
+ x = F.leaky_relu(x)
201
+ x = self.conv_post(x)
202
+ x = torch.tanh(x)
203
+
204
+ return x
205
+
206
+ def remove_weight_norm(self):
207
+ print("Removing weight norm...")
208
+ for l in self.ups:
209
+ remove_weight_norm(l)
210
+ for l in self.resblocks:
211
+ l.remove_weight_norm()
212
+ remove_weight_norm(self.conv_pre)
213
+ remove_weight_norm(self.conv_post)
214
+
215
+
216
+ class DiscriminatorP(torch.nn.Module):
217
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
218
+ super().__init__()
219
+ self.period = period
220
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
221
+ self.convs = nn.ModuleList(
222
+ [
223
+ norm_f(
224
+ Conv2d(
225
+ 1,
226
+ 32,
227
+ (kernel_size, 1),
228
+ (stride, 1),
229
+ padding=(get_padding(5, 1), 0),
230
+ )
231
+ ),
232
+ norm_f(
233
+ Conv2d(
234
+ 32,
235
+ 128,
236
+ (kernel_size, 1),
237
+ (stride, 1),
238
+ padding=(get_padding(5, 1), 0),
239
+ )
240
+ ),
241
+ norm_f(
242
+ Conv2d(
243
+ 128,
244
+ 512,
245
+ (kernel_size, 1),
246
+ (stride, 1),
247
+ padding=(get_padding(5, 1), 0),
248
+ )
249
+ ),
250
+ norm_f(
251
+ Conv2d(
252
+ 512,
253
+ 1024,
254
+ (kernel_size, 1),
255
+ (stride, 1),
256
+ padding=(get_padding(5, 1), 0),
257
+ )
258
+ ),
259
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
260
+ ]
261
+ )
262
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
263
+
264
+ def forward(self, x):
265
+ fmap = []
266
+
267
+ # 1d to 2d
268
+ b, c, t = x.shape
269
+ if t % self.period != 0: # pad first
270
+ n_pad = self.period - (t % self.period)
271
+ x = F.pad(x, (0, n_pad), "reflect")
272
+ t = t + n_pad
273
+ x = x.view(b, c, t // self.period, self.period)
274
+
275
+ for l in self.convs:
276
+ x = l(x)
277
+ x = F.leaky_relu(x, LRELU_SLOPE)
278
+ fmap.append(x)
279
+ x = self.conv_post(x)
280
+ fmap.append(x)
281
+ x = torch.flatten(x, 1, -1)
282
+
283
+ return x, fmap
284
+
285
+
286
+ class MultiPeriodDiscriminator(torch.nn.Module):
287
+ def __init__(self):
288
+ super().__init__()
289
+ self.discriminators = nn.ModuleList(
290
+ [
291
+ DiscriminatorP(2),
292
+ DiscriminatorP(3),
293
+ DiscriminatorP(5),
294
+ DiscriminatorP(7),
295
+ DiscriminatorP(11),
296
+ ]
297
+ )
298
+
299
+ def forward(self, y, y_hat):
300
+ y_d_rs = []
301
+ y_d_gs = []
302
+ fmap_rs = []
303
+ fmap_gs = []
304
+ for _, d in enumerate(self.discriminators):
305
+ y_d_r, fmap_r = d(y)
306
+ y_d_g, fmap_g = d(y_hat)
307
+ y_d_rs.append(y_d_r)
308
+ fmap_rs.append(fmap_r)
309
+ y_d_gs.append(y_d_g)
310
+ fmap_gs.append(fmap_g)
311
+
312
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
313
+
314
+
315
+ class DiscriminatorS(torch.nn.Module):
316
+ def __init__(self, use_spectral_norm=False):
317
+ super().__init__()
318
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
319
+ self.convs = nn.ModuleList(
320
+ [
321
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
322
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
323
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
324
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
325
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
326
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
327
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
328
+ ]
329
+ )
330
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
331
+
332
+ def forward(self, x):
333
+ fmap = []
334
+ for l in self.convs:
335
+ x = l(x)
336
+ x = F.leaky_relu(x, LRELU_SLOPE)
337
+ fmap.append(x)
338
+ x = self.conv_post(x)
339
+ fmap.append(x)
340
+ x = torch.flatten(x, 1, -1)
341
+
342
+ return x, fmap
343
+
344
+
345
+ class MultiScaleDiscriminator(torch.nn.Module):
346
+ def __init__(self):
347
+ super().__init__()
348
+ self.discriminators = nn.ModuleList(
349
+ [
350
+ DiscriminatorS(use_spectral_norm=True),
351
+ DiscriminatorS(),
352
+ DiscriminatorS(),
353
+ ]
354
+ )
355
+ self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
356
+
357
+ def forward(self, y, y_hat):
358
+ y_d_rs = []
359
+ y_d_gs = []
360
+ fmap_rs = []
361
+ fmap_gs = []
362
+ for i, d in enumerate(self.discriminators):
363
+ if i != 0:
364
+ y = self.meanpools[i - 1](y)
365
+ y_hat = self.meanpools[i - 1](y_hat)
366
+ y_d_r, fmap_r = d(y)
367
+ y_d_g, fmap_g = d(y_hat)
368
+ y_d_rs.append(y_d_r)
369
+ fmap_rs.append(fmap_r)
370
+ y_d_gs.append(y_d_g)
371
+ fmap_gs.append(fmap_g)
372
+
373
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
374
+
375
+
376
+ def feature_loss(fmap_r, fmap_g):
377
+ loss = 0
378
+ for dr, dg in zip(fmap_r, fmap_g):
379
+ for rl, gl in zip(dr, dg):
380
+ loss += torch.mean(torch.abs(rl - gl))
381
+
382
+ return loss * 2
383
+
384
+
385
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
386
+ loss = 0
387
+ r_losses = []
388
+ g_losses = []
389
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
390
+ r_loss = torch.mean((1 - dr) ** 2)
391
+ g_loss = torch.mean(dg**2)
392
+ loss += r_loss + g_loss
393
+ r_losses.append(r_loss.item())
394
+ g_losses.append(g_loss.item())
395
+
396
+ return loss, r_losses, g_losses
397
+
398
+
399
+ def generator_loss(disc_outputs):
400
+ loss = 0
401
+ gen_losses = []
402
+ for dg in disc_outputs:
403
+ l = torch.mean((1 - dg) ** 2)
404
+ gen_losses.append(l)
405
+ loss += l
406
+
407
+ return loss, gen_losses
knn_tts/vocoder/hifigan/utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2020 Jungil Kong
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ class AttrDict(dict):
17
+ def __init__(self, *args, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+ self.__dict__ = self
knn_tts/vocoder/models.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024 Idiap Research Institute
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ #
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ import json
7
+
8
+ import torch
9
+
10
+ from knn_tts.vocoder.hifigan.models import Generator as HiFiGAN
11
+ from knn_tts.vocoder.hifigan.utils import AttrDict
12
+
13
+
14
+ class HiFiGANWavLM:
15
+ def __init__(
16
+ self,
17
+ checkpoint_path,
18
+ config_path="knn_tts/vocoder/hifigan/config_v1_wavlm.json",
19
+ device=None,
20
+ ):
21
+ if device is None:
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ else:
24
+ self.device = torch.device(device)
25
+ with open(config_path, encoding="utf-8") as f:
26
+ data = f.read()
27
+ json_config = json.loads(data)
28
+ self.h = AttrDict(json_config)
29
+
30
+ self.generator = HiFiGAN(self.h).to(self.device)
31
+
32
+ state_dict_g = torch.load(checkpoint_path, weights_only=False, map_location=device)
33
+ self.generator.load_state_dict(state_dict_g["generator"])
34
+ self.generator.eval()
35
+ self.generator.remove_weight_norm()
36
+ num_parameters = sum(p.numel() for p in self.generator.parameters())
37
+ print(f"[HiFiGAN] Generator loaded with {num_parameters:,d} parameters.")
38
+
39
+ def __call__(self, x):
40
+ with torch.no_grad():
41
+ wav = self.generator(x.to(self.device))
42
+ wav = wav.squeeze(1)
43
+ wav = wav.cpu().squeeze()
44
+ return wav
requirements.txt ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.2.2
2
+ aiofiles==24.1.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.11.16
5
+ aiosignal==1.3.2
6
+ annotated-types==0.7.0
7
+ anyascii==0.3.2
8
+ anyio==4.9.0
9
+ attrs==25.3.0
10
+ audioread==3.0.1
11
+ babel==2.17.0
12
+ blis==0.7.11
13
+ catalogue==2.0.10
14
+ certifi==2025.1.31
15
+ cffi==1.17.1
16
+ charset-normalizer==3.4.1
17
+ click==8.1.8
18
+ cloudpathlib==0.21.0
19
+ confection==0.1.5
20
+ contourpy==1.3.1
21
+ coqpit-config==0.2.0
22
+ coqui-tts==0.26.0
23
+ coqui-tts-trainer==0.2.3
24
+ cycler==0.12.1
25
+ cymem==2.0.11
26
+ Cython==3.0.12
27
+ dateparser==1.1.8
28
+ decorator==5.2.1
29
+ docopt==0.6.2
30
+ einops==0.8.1
31
+ encodec==0.1.1
32
+ fastapi==0.115.12
33
+ ffmpy==0.5.0
34
+ filelock==3.18.0
35
+ fonttools==4.57.0
36
+ frozenlist==1.5.0
37
+ fsspec==2025.3.2
38
+ gradio==5.24.0
39
+ gradio_client==1.8.0
40
+ groovy==0.1.2
41
+ grpcio==1.71.0
42
+ gruut==2.4.0
43
+ gruut-ipa==0.13.0
44
+ gruut_lang_de==2.0.1
45
+ gruut_lang_en==2.0.1
46
+ gruut_lang_es==2.0.1
47
+ gruut_lang_fr==2.0.2
48
+ h11==0.14.0
49
+ httpcore==1.0.7
50
+ httpx==0.28.1
51
+ huggingface-hub==0.30.2
52
+ idna==3.10
53
+ inflect==7.5.0
54
+ Jinja2==3.1.6
55
+ joblib==1.4.2
56
+ jsonlines==1.2.0
57
+ kiwisolver==1.4.8
58
+ langcodes==3.5.0
59
+ language_data==1.3.0
60
+ lazy_loader==0.4
61
+ librosa==0.11.0
62
+ llvmlite==0.44.0
63
+ marisa-trie==1.2.1
64
+ Markdown==3.7
65
+ markdown-it-py==3.0.0
66
+ MarkupSafe==3.0.2
67
+ matplotlib==3.10.1
68
+ mdurl==0.1.2
69
+ monotonic-alignment-search==0.1.1
70
+ more-itertools==10.6.0
71
+ mpmath==1.3.0
72
+ msgpack==1.1.0
73
+ multidict==6.2.0
74
+ murmurhash==1.0.12
75
+ networkx==3.4.2
76
+ num2words==0.5.14
77
+ numba==0.61.2
78
+ numpy==1.26.4
79
+ nvidia-cublas-cu12==12.4.5.8
80
+ nvidia-cuda-cupti-cu12==12.4.127
81
+ nvidia-cuda-nvrtc-cu12==12.4.127
82
+ nvidia-cuda-runtime-cu12==12.4.127
83
+ nvidia-cudnn-cu12==9.1.0.70
84
+ nvidia-cufft-cu12==11.2.1.3
85
+ nvidia-curand-cu12==10.3.5.147
86
+ nvidia-cusolver-cu12==11.6.1.9
87
+ nvidia-cusparse-cu12==12.3.1.170
88
+ nvidia-cusparselt-cu12==0.6.2
89
+ nvidia-nccl-cu12==2.21.5
90
+ nvidia-nvjitlink-cu12==12.4.127
91
+ nvidia-nvtx-cu12==12.4.127
92
+ orjson==3.10.16
93
+ packaging==24.2
94
+ pandas==2.2.3
95
+ pillow==11.1.0
96
+ pip==25.0.1
97
+ platformdirs==4.3.7
98
+ pooch==1.8.2
99
+ preshed==3.0.9
100
+ propcache==0.3.1
101
+ protobuf==6.30.2
102
+ psutil==5.9.8
103
+ pycparser==2.22
104
+ pydantic==2.11.3
105
+ pydantic_core==2.33.1
106
+ pydub==0.25.1
107
+ Pygments==2.19.1
108
+ pyparsing==3.2.3
109
+ pysbd==0.3.4
110
+ python-crfsuite==0.9.11
111
+ python-dateutil==2.9.0.post0
112
+ python-multipart==0.0.20
113
+ pytz==2025.2
114
+ PyYAML==6.0.2
115
+ regex==2024.11.6
116
+ requests==2.32.3
117
+ rich==14.0.0
118
+ ruff==0.11.4
119
+ safehttpx==0.1.6
120
+ safetensors==0.5.3
121
+ scikit-learn==1.6.1
122
+ scipy==1.15.2
123
+ semantic-version==2.10.0
124
+ setuptools==78.1.0
125
+ shellingham==1.5.4
126
+ six==1.17.0
127
+ smart-open==7.1.0
128
+ sniffio==1.3.1
129
+ soundfile==0.13.1
130
+ soxr==0.5.0.post1
131
+ spaces==0.34.2
132
+ spacy==3.7.5
133
+ spacy-legacy==3.0.12
134
+ spacy-loggers==1.0.5
135
+ srsly==2.5.1
136
+ starlette==0.46.1
137
+ SudachiDict-core==20250129
138
+ SudachiPy==0.6.10
139
+ sympy==1.13.1
140
+ tensorboard==2.19.0
141
+ tensorboard-data-server==0.7.2
142
+ thinc==8.2.5
143
+ threadpoolctl==3.6.0
144
+ tokenizers==0.21.1
145
+ tomlkit==0.13.2
146
+ torch==2.6.0
147
+ torchaudio==2.6.0
148
+ tqdm==4.67.1
149
+ transformers==4.51.1
150
+ triton==3.2.0
151
+ typeguard==4.4.2
152
+ typer==0.15.2
153
+ typing_extensions==4.13.1
154
+ typing-inspection==0.4.0
155
+ tzdata==2025.2
156
+ tzlocal==5.3.1
157
+ urllib3==2.3.0
158
+ uvicorn==0.34.0
159
+ wasabi==1.1.3
160
+ weasel==0.4.1
161
+ websockets==15.0.1
162
+ Werkzeug==3.1.3
163
+ wrapt==1.17.2
164
+ yarl==1.19.0