Karl El Hajal
commited on
Commit
·
a180d8c
1
Parent(s):
f920a6c
Add kNN-TTS code and demo interface
Browse files- .gitignore +169 -0
- LICENSES/CC0-1.0.txt +121 -0
- LICENSES/MIT.txt +9 -0
- LICENSES/MPL-2.0.txt +373 -0
- REUSE.toml +26 -0
- app.py +240 -0
- assets/diagram.png +0 -0
- knn_tts/__init__.py +0 -0
- knn_tts/ssl/WavLM/WavLM.py +742 -0
- knn_tts/ssl/WavLM/__init__.py +0 -0
- knn_tts/ssl/WavLM/modules.py +763 -0
- knn_tts/ssl/__init__.py +0 -0
- knn_tts/ssl/models.py +49 -0
- knn_tts/synthesizer.py +69 -0
- knn_tts/text_cleaners.py +66 -0
- knn_tts/tts/GlowTTS/__init__.py +0 -0
- knn_tts/tts/GlowTTS/glow_tts_ssl.py +623 -0
- knn_tts/tts/__init__.py +0 -0
- knn_tts/tts/models.py +33 -0
- knn_tts/utils.py +36 -0
- knn_tts/vc/__init__.py +0 -0
- knn_tts/vc/knn.py +52 -0
- knn_tts/vocoder/__init__.py +0 -0
- knn_tts/vocoder/hifigan/__init__.py +0 -0
- knn_tts/vocoder/hifigan/config_v1_wavlm.json +40 -0
- knn_tts/vocoder/hifigan/models.py +407 -0
- knn_tts/vocoder/hifigan/utils.py +19 -0
- knn_tts/vocoder/models.py +44 -0
- requirements.txt +164 -0
.gitignore
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Github
|
2 |
+
#
|
3 |
+
# SPDX-License-Identifier: CC0-1.0
|
4 |
+
|
5 |
+
# Source:
|
6 |
+
# https://github.com/github/gitignore/blob/8779ee73af62c669e7ca371aaab8399d87127693/Python.gitignore
|
7 |
+
|
8 |
+
# Byte-compiled / optimized / DLL files
|
9 |
+
__pycache__/
|
10 |
+
*.py[cod]
|
11 |
+
*$py.class
|
12 |
+
|
13 |
+
# C extensions
|
14 |
+
*.so
|
15 |
+
|
16 |
+
# Distribution / packaging
|
17 |
+
.Python
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib/
|
25 |
+
lib64/
|
26 |
+
parts/
|
27 |
+
sdist/
|
28 |
+
var/
|
29 |
+
wheels/
|
30 |
+
share/python-wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
htmlcov/
|
48 |
+
.tox/
|
49 |
+
.nox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
*.py,cover
|
57 |
+
.hypothesis/
|
58 |
+
.pytest_cache/
|
59 |
+
cover/
|
60 |
+
|
61 |
+
# Translations
|
62 |
+
*.mo
|
63 |
+
*.pot
|
64 |
+
|
65 |
+
# Django stuff:
|
66 |
+
*.log
|
67 |
+
local_settings.py
|
68 |
+
db.sqlite3
|
69 |
+
db.sqlite3-journal
|
70 |
+
|
71 |
+
# Flask stuff:
|
72 |
+
instance/
|
73 |
+
.webassets-cache
|
74 |
+
|
75 |
+
# Scrapy stuff:
|
76 |
+
.scrapy
|
77 |
+
|
78 |
+
# Sphinx documentation
|
79 |
+
docs/_build/
|
80 |
+
|
81 |
+
# PyBuilder
|
82 |
+
.pybuilder/
|
83 |
+
target/
|
84 |
+
|
85 |
+
# Jupyter Notebook
|
86 |
+
.ipynb_checkpoints
|
87 |
+
|
88 |
+
# IPython
|
89 |
+
profile_default/
|
90 |
+
ipython_config.py
|
91 |
+
|
92 |
+
# pyenv
|
93 |
+
# For a library or package, you might want to ignore these files since the code is
|
94 |
+
# intended to run in multiple environments; otherwise, check them in:
|
95 |
+
# .python-version
|
96 |
+
|
97 |
+
# pipenv
|
98 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
99 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
100 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
101 |
+
# install all needed dependencies.
|
102 |
+
#Pipfile.lock
|
103 |
+
|
104 |
+
# poetry
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
106 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
107 |
+
# commonly ignored for libraries.
|
108 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
109 |
+
#poetry.lock
|
110 |
+
|
111 |
+
# pdm
|
112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
113 |
+
#pdm.lock
|
114 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
115 |
+
# in version control.
|
116 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
117 |
+
.pdm.toml
|
118 |
+
.pdm-python
|
119 |
+
.pdm-build/
|
120 |
+
|
121 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
122 |
+
__pypackages__/
|
123 |
+
|
124 |
+
# Celery stuff
|
125 |
+
celerybeat-schedule
|
126 |
+
celerybeat.pid
|
127 |
+
|
128 |
+
# SageMath parsed files
|
129 |
+
*.sage.py
|
130 |
+
|
131 |
+
# Environments
|
132 |
+
.env
|
133 |
+
.venv
|
134 |
+
env/
|
135 |
+
venv/
|
136 |
+
ENV/
|
137 |
+
env.bak/
|
138 |
+
venv.bak/
|
139 |
+
|
140 |
+
# Spyder project settings
|
141 |
+
.spyderproject
|
142 |
+
.spyproject
|
143 |
+
|
144 |
+
# Rope project settings
|
145 |
+
.ropeproject
|
146 |
+
|
147 |
+
# mkdocs documentation
|
148 |
+
/site
|
149 |
+
|
150 |
+
# mypy
|
151 |
+
.mypy_cache/
|
152 |
+
.dmypy.json
|
153 |
+
dmypy.json
|
154 |
+
|
155 |
+
# Pyre type checker
|
156 |
+
.pyre/
|
157 |
+
|
158 |
+
# pytype static type analyzer
|
159 |
+
.pytype/
|
160 |
+
|
161 |
+
# Cython debug symbols
|
162 |
+
cython_debug/
|
163 |
+
|
164 |
+
# PyCharm
|
165 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
166 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
167 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
168 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
169 |
+
#.idea/
|
LICENSES/CC0-1.0.txt
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Creative Commons Legal Code
|
2 |
+
|
3 |
+
CC0 1.0 Universal
|
4 |
+
|
5 |
+
CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
|
6 |
+
LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
|
7 |
+
ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
|
8 |
+
INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
|
9 |
+
REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
|
10 |
+
PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
|
11 |
+
THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
|
12 |
+
HEREUNDER.
|
13 |
+
|
14 |
+
Statement of Purpose
|
15 |
+
|
16 |
+
The laws of most jurisdictions throughout the world automatically confer
|
17 |
+
exclusive Copyright and Related Rights (defined below) upon the creator
|
18 |
+
and subsequent owner(s) (each and all, an "owner") of an original work of
|
19 |
+
authorship and/or a database (each, a "Work").
|
20 |
+
|
21 |
+
Certain owners wish to permanently relinquish those rights to a Work for
|
22 |
+
the purpose of contributing to a commons of creative, cultural and
|
23 |
+
scientific works ("Commons") that the public can reliably and without fear
|
24 |
+
of later claims of infringement build upon, modify, incorporate in other
|
25 |
+
works, reuse and redistribute as freely as possible in any form whatsoever
|
26 |
+
and for any purposes, including without limitation commercial purposes.
|
27 |
+
These owners may contribute to the Commons to promote the ideal of a free
|
28 |
+
culture and the further production of creative, cultural and scientific
|
29 |
+
works, or to gain reputation or greater distribution for their Work in
|
30 |
+
part through the use and efforts of others.
|
31 |
+
|
32 |
+
For these and/or other purposes and motivations, and without any
|
33 |
+
expectation of additional consideration or compensation, the person
|
34 |
+
associating CC0 with a Work (the "Affirmer"), to the extent that he or she
|
35 |
+
is an owner of Copyright and Related Rights in the Work, voluntarily
|
36 |
+
elects to apply CC0 to the Work and publicly distribute the Work under its
|
37 |
+
terms, with knowledge of his or her Copyright and Related Rights in the
|
38 |
+
Work and the meaning and intended legal effect of CC0 on those rights.
|
39 |
+
|
40 |
+
1. Copyright and Related Rights. A Work made available under CC0 may be
|
41 |
+
protected by copyright and related or neighboring rights ("Copyright and
|
42 |
+
Related Rights"). Copyright and Related Rights include, but are not
|
43 |
+
limited to, the following:
|
44 |
+
|
45 |
+
i. the right to reproduce, adapt, distribute, perform, display,
|
46 |
+
communicate, and translate a Work;
|
47 |
+
ii. moral rights retained by the original author(s) and/or performer(s);
|
48 |
+
iii. publicity and privacy rights pertaining to a person's image or
|
49 |
+
likeness depicted in a Work;
|
50 |
+
iv. rights protecting against unfair competition in regards to a Work,
|
51 |
+
subject to the limitations in paragraph 4(a), below;
|
52 |
+
v. rights protecting the extraction, dissemination, use and reuse of data
|
53 |
+
in a Work;
|
54 |
+
vi. database rights (such as those arising under Directive 96/9/EC of the
|
55 |
+
European Parliament and of the Council of 11 March 1996 on the legal
|
56 |
+
protection of databases, and under any national implementation
|
57 |
+
thereof, including any amended or successor version of such
|
58 |
+
directive); and
|
59 |
+
vii. other similar, equivalent or corresponding rights throughout the
|
60 |
+
world based on applicable law or treaty, and any national
|
61 |
+
implementations thereof.
|
62 |
+
|
63 |
+
2. Waiver. To the greatest extent permitted by, but not in contravention
|
64 |
+
of, applicable law, Affirmer hereby overtly, fully, permanently,
|
65 |
+
irrevocably and unconditionally waives, abandons, and surrenders all of
|
66 |
+
Affirmer's Copyright and Related Rights and associated claims and causes
|
67 |
+
of action, whether now known or unknown (including existing as well as
|
68 |
+
future claims and causes of action), in the Work (i) in all territories
|
69 |
+
worldwide, (ii) for the maximum duration provided by applicable law or
|
70 |
+
treaty (including future time extensions), (iii) in any current or future
|
71 |
+
medium and for any number of copies, and (iv) for any purpose whatsoever,
|
72 |
+
including without limitation commercial, advertising or promotional
|
73 |
+
purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
|
74 |
+
member of the public at large and to the detriment of Affirmer's heirs and
|
75 |
+
successors, fully intending that such Waiver shall not be subject to
|
76 |
+
revocation, rescission, cancellation, termination, or any other legal or
|
77 |
+
equitable action to disrupt the quiet enjoyment of the Work by the public
|
78 |
+
as contemplated by Affirmer's express Statement of Purpose.
|
79 |
+
|
80 |
+
3. Public License Fallback. Should any part of the Waiver for any reason
|
81 |
+
be judged legally invalid or ineffective under applicable law, then the
|
82 |
+
Waiver shall be preserved to the maximum extent permitted taking into
|
83 |
+
account Affirmer's express Statement of Purpose. In addition, to the
|
84 |
+
extent the Waiver is so judged Affirmer hereby grants to each affected
|
85 |
+
person a royalty-free, non transferable, non sublicensable, non exclusive,
|
86 |
+
irrevocable and unconditional license to exercise Affirmer's Copyright and
|
87 |
+
Related Rights in the Work (i) in all territories worldwide, (ii) for the
|
88 |
+
maximum duration provided by applicable law or treaty (including future
|
89 |
+
time extensions), (iii) in any current or future medium and for any number
|
90 |
+
of copies, and (iv) for any purpose whatsoever, including without
|
91 |
+
limitation commercial, advertising or promotional purposes (the
|
92 |
+
"License"). The License shall be deemed effective as of the date CC0 was
|
93 |
+
applied by Affirmer to the Work. Should any part of the License for any
|
94 |
+
reason be judged legally invalid or ineffective under applicable law, such
|
95 |
+
partial invalidity or ineffectiveness shall not invalidate the remainder
|
96 |
+
of the License, and in such case Affirmer hereby affirms that he or she
|
97 |
+
will not (i) exercise any of his or her remaining Copyright and Related
|
98 |
+
Rights in the Work or (ii) assert any associated claims and causes of
|
99 |
+
action with respect to the Work, in either case contrary to Affirmer's
|
100 |
+
express Statement of Purpose.
|
101 |
+
|
102 |
+
4. Limitations and Disclaimers.
|
103 |
+
|
104 |
+
a. No trademark or patent rights held by Affirmer are waived, abandoned,
|
105 |
+
surrendered, licensed or otherwise affected by this document.
|
106 |
+
b. Affirmer offers the Work as-is and makes no representations or
|
107 |
+
warranties of any kind concerning the Work, express, implied,
|
108 |
+
statutory or otherwise, including without limitation warranties of
|
109 |
+
title, merchantability, fitness for a particular purpose, non
|
110 |
+
infringement, or the absence of latent or other defects, accuracy, or
|
111 |
+
the present or absence of errors, whether or not discoverable, all to
|
112 |
+
the greatest extent permissible under applicable law.
|
113 |
+
c. Affirmer disclaims responsibility for clearing rights of other persons
|
114 |
+
that may apply to the Work or any use thereof, including without
|
115 |
+
limitation any person's Copyright and Related Rights in the Work.
|
116 |
+
Further, Affirmer disclaims responsibility for obtaining any necessary
|
117 |
+
consents, permissions or other rights required for any use of the
|
118 |
+
Work.
|
119 |
+
d. Affirmer understands and acknowledges that Creative Commons is not a
|
120 |
+
party to this document and has no duty or obligation with respect to
|
121 |
+
this CC0 or use of the Work.
|
LICENSES/MIT.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) <year> <copyright holders>
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
6 |
+
|
7 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
8 |
+
|
9 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
LICENSES/MPL-2.0.txt
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Mozilla Public License Version 2.0
|
2 |
+
==================================
|
3 |
+
|
4 |
+
1. Definitions
|
5 |
+
--------------
|
6 |
+
|
7 |
+
1.1. "Contributor"
|
8 |
+
means each individual or legal entity that creates, contributes to
|
9 |
+
the creation of, or owns Covered Software.
|
10 |
+
|
11 |
+
1.2. "Contributor Version"
|
12 |
+
means the combination of the Contributions of others (if any) used
|
13 |
+
by a Contributor and that particular Contributor's Contribution.
|
14 |
+
|
15 |
+
1.3. "Contribution"
|
16 |
+
means Covered Software of a particular Contributor.
|
17 |
+
|
18 |
+
1.4. "Covered Software"
|
19 |
+
means Source Code Form to which the initial Contributor has attached
|
20 |
+
the notice in Exhibit A, the Executable Form of such Source Code
|
21 |
+
Form, and Modifications of such Source Code Form, in each case
|
22 |
+
including portions thereof.
|
23 |
+
|
24 |
+
1.5. "Incompatible With Secondary Licenses"
|
25 |
+
means
|
26 |
+
|
27 |
+
(a) that the initial Contributor has attached the notice described
|
28 |
+
in Exhibit B to the Covered Software; or
|
29 |
+
|
30 |
+
(b) that the Covered Software was made available under the terms of
|
31 |
+
version 1.1 or earlier of the License, but not also under the
|
32 |
+
terms of a Secondary License.
|
33 |
+
|
34 |
+
1.6. "Executable Form"
|
35 |
+
means any form of the work other than Source Code Form.
|
36 |
+
|
37 |
+
1.7. "Larger Work"
|
38 |
+
means a work that combines Covered Software with other material, in
|
39 |
+
a separate file or files, that is not Covered Software.
|
40 |
+
|
41 |
+
1.8. "License"
|
42 |
+
means this document.
|
43 |
+
|
44 |
+
1.9. "Licensable"
|
45 |
+
means having the right to grant, to the maximum extent possible,
|
46 |
+
whether at the time of the initial grant or subsequently, any and
|
47 |
+
all of the rights conveyed by this License.
|
48 |
+
|
49 |
+
1.10. "Modifications"
|
50 |
+
means any of the following:
|
51 |
+
|
52 |
+
(a) any file in Source Code Form that results from an addition to,
|
53 |
+
deletion from, or modification of the contents of Covered
|
54 |
+
Software; or
|
55 |
+
|
56 |
+
(b) any new file in Source Code Form that contains any Covered
|
57 |
+
Software.
|
58 |
+
|
59 |
+
1.11. "Patent Claims" of a Contributor
|
60 |
+
means any patent claim(s), including without limitation, method,
|
61 |
+
process, and apparatus claims, in any patent Licensable by such
|
62 |
+
Contributor that would be infringed, but for the grant of the
|
63 |
+
License, by the making, using, selling, offering for sale, having
|
64 |
+
made, import, or transfer of either its Contributions or its
|
65 |
+
Contributor Version.
|
66 |
+
|
67 |
+
1.12. "Secondary License"
|
68 |
+
means either the GNU General Public License, Version 2.0, the GNU
|
69 |
+
Lesser General Public License, Version 2.1, the GNU Affero General
|
70 |
+
Public License, Version 3.0, or any later versions of those
|
71 |
+
licenses.
|
72 |
+
|
73 |
+
1.13. "Source Code Form"
|
74 |
+
means the form of the work preferred for making modifications.
|
75 |
+
|
76 |
+
1.14. "You" (or "Your")
|
77 |
+
means an individual or a legal entity exercising rights under this
|
78 |
+
License. For legal entities, "You" includes any entity that
|
79 |
+
controls, is controlled by, or is under common control with You. For
|
80 |
+
purposes of this definition, "control" means (a) the power, direct
|
81 |
+
or indirect, to cause the direction or management of such entity,
|
82 |
+
whether by contract or otherwise, or (b) ownership of more than
|
83 |
+
fifty percent (50%) of the outstanding shares or beneficial
|
84 |
+
ownership of such entity.
|
85 |
+
|
86 |
+
2. License Grants and Conditions
|
87 |
+
--------------------------------
|
88 |
+
|
89 |
+
2.1. Grants
|
90 |
+
|
91 |
+
Each Contributor hereby grants You a world-wide, royalty-free,
|
92 |
+
non-exclusive license:
|
93 |
+
|
94 |
+
(a) under intellectual property rights (other than patent or trademark)
|
95 |
+
Licensable by such Contributor to use, reproduce, make available,
|
96 |
+
modify, display, perform, distribute, and otherwise exploit its
|
97 |
+
Contributions, either on an unmodified basis, with Modifications, or
|
98 |
+
as part of a Larger Work; and
|
99 |
+
|
100 |
+
(b) under Patent Claims of such Contributor to make, use, sell, offer
|
101 |
+
for sale, have made, import, and otherwise transfer either its
|
102 |
+
Contributions or its Contributor Version.
|
103 |
+
|
104 |
+
2.2. Effective Date
|
105 |
+
|
106 |
+
The licenses granted in Section 2.1 with respect to any Contribution
|
107 |
+
become effective for each Contribution on the date the Contributor first
|
108 |
+
distributes such Contribution.
|
109 |
+
|
110 |
+
2.3. Limitations on Grant Scope
|
111 |
+
|
112 |
+
The licenses granted in this Section 2 are the only rights granted under
|
113 |
+
this License. No additional rights or licenses will be implied from the
|
114 |
+
distribution or licensing of Covered Software under this License.
|
115 |
+
Notwithstanding Section 2.1(b) above, no patent license is granted by a
|
116 |
+
Contributor:
|
117 |
+
|
118 |
+
(a) for any code that a Contributor has removed from Covered Software;
|
119 |
+
or
|
120 |
+
|
121 |
+
(b) for infringements caused by: (i) Your and any other third party's
|
122 |
+
modifications of Covered Software, or (ii) the combination of its
|
123 |
+
Contributions with other software (except as part of its Contributor
|
124 |
+
Version); or
|
125 |
+
|
126 |
+
(c) under Patent Claims infringed by Covered Software in the absence of
|
127 |
+
its Contributions.
|
128 |
+
|
129 |
+
This License does not grant any rights in the trademarks, service marks,
|
130 |
+
or logos of any Contributor (except as may be necessary to comply with
|
131 |
+
the notice requirements in Section 3.4).
|
132 |
+
|
133 |
+
2.4. Subsequent Licenses
|
134 |
+
|
135 |
+
No Contributor makes additional grants as a result of Your choice to
|
136 |
+
distribute the Covered Software under a subsequent version of this
|
137 |
+
License (see Section 10.2) or under the terms of a Secondary License (if
|
138 |
+
permitted under the terms of Section 3.3).
|
139 |
+
|
140 |
+
2.5. Representation
|
141 |
+
|
142 |
+
Each Contributor represents that the Contributor believes its
|
143 |
+
Contributions are its original creation(s) or it has sufficient rights
|
144 |
+
to grant the rights to its Contributions conveyed by this License.
|
145 |
+
|
146 |
+
2.6. Fair Use
|
147 |
+
|
148 |
+
This License is not intended to limit any rights You have under
|
149 |
+
applicable copyright doctrines of fair use, fair dealing, or other
|
150 |
+
equivalents.
|
151 |
+
|
152 |
+
2.7. Conditions
|
153 |
+
|
154 |
+
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
|
155 |
+
in Section 2.1.
|
156 |
+
|
157 |
+
3. Responsibilities
|
158 |
+
-------------------
|
159 |
+
|
160 |
+
3.1. Distribution of Source Form
|
161 |
+
|
162 |
+
All distribution of Covered Software in Source Code Form, including any
|
163 |
+
Modifications that You create or to which You contribute, must be under
|
164 |
+
the terms of this License. You must inform recipients that the Source
|
165 |
+
Code Form of the Covered Software is governed by the terms of this
|
166 |
+
License, and how they can obtain a copy of this License. You may not
|
167 |
+
attempt to alter or restrict the recipients' rights in the Source Code
|
168 |
+
Form.
|
169 |
+
|
170 |
+
3.2. Distribution of Executable Form
|
171 |
+
|
172 |
+
If You distribute Covered Software in Executable Form then:
|
173 |
+
|
174 |
+
(a) such Covered Software must also be made available in Source Code
|
175 |
+
Form, as described in Section 3.1, and You must inform recipients of
|
176 |
+
the Executable Form how they can obtain a copy of such Source Code
|
177 |
+
Form by reasonable means in a timely manner, at a charge no more
|
178 |
+
than the cost of distribution to the recipient; and
|
179 |
+
|
180 |
+
(b) You may distribute such Executable Form under the terms of this
|
181 |
+
License, or sublicense it under different terms, provided that the
|
182 |
+
license for the Executable Form does not attempt to limit or alter
|
183 |
+
the recipients' rights in the Source Code Form under this License.
|
184 |
+
|
185 |
+
3.3. Distribution of a Larger Work
|
186 |
+
|
187 |
+
You may create and distribute a Larger Work under terms of Your choice,
|
188 |
+
provided that You also comply with the requirements of this License for
|
189 |
+
the Covered Software. If the Larger Work is a combination of Covered
|
190 |
+
Software with a work governed by one or more Secondary Licenses, and the
|
191 |
+
Covered Software is not Incompatible With Secondary Licenses, this
|
192 |
+
License permits You to additionally distribute such Covered Software
|
193 |
+
under the terms of such Secondary License(s), so that the recipient of
|
194 |
+
the Larger Work may, at their option, further distribute the Covered
|
195 |
+
Software under the terms of either this License or such Secondary
|
196 |
+
License(s).
|
197 |
+
|
198 |
+
3.4. Notices
|
199 |
+
|
200 |
+
You may not remove or alter the substance of any license notices
|
201 |
+
(including copyright notices, patent notices, disclaimers of warranty,
|
202 |
+
or limitations of liability) contained within the Source Code Form of
|
203 |
+
the Covered Software, except that You may alter any license notices to
|
204 |
+
the extent required to remedy known factual inaccuracies.
|
205 |
+
|
206 |
+
3.5. Application of Additional Terms
|
207 |
+
|
208 |
+
You may choose to offer, and to charge a fee for, warranty, support,
|
209 |
+
indemnity or liability obligations to one or more recipients of Covered
|
210 |
+
Software. However, You may do so only on Your own behalf, and not on
|
211 |
+
behalf of any Contributor. You must make it absolutely clear that any
|
212 |
+
such warranty, support, indemnity, or liability obligation is offered by
|
213 |
+
You alone, and You hereby agree to indemnify every Contributor for any
|
214 |
+
liability incurred by such Contributor as a result of warranty, support,
|
215 |
+
indemnity or liability terms You offer. You may include additional
|
216 |
+
disclaimers of warranty and limitations of liability specific to any
|
217 |
+
jurisdiction.
|
218 |
+
|
219 |
+
4. Inability to Comply Due to Statute or Regulation
|
220 |
+
---------------------------------------------------
|
221 |
+
|
222 |
+
If it is impossible for You to comply with any of the terms of this
|
223 |
+
License with respect to some or all of the Covered Software due to
|
224 |
+
statute, judicial order, or regulation then You must: (a) comply with
|
225 |
+
the terms of this License to the maximum extent possible; and (b)
|
226 |
+
describe the limitations and the code they affect. Such description must
|
227 |
+
be placed in a text file included with all distributions of the Covered
|
228 |
+
Software under this License. Except to the extent prohibited by statute
|
229 |
+
or regulation, such description must be sufficiently detailed for a
|
230 |
+
recipient of ordinary skill to be able to understand it.
|
231 |
+
|
232 |
+
5. Termination
|
233 |
+
--------------
|
234 |
+
|
235 |
+
5.1. The rights granted under this License will terminate automatically
|
236 |
+
if You fail to comply with any of its terms. However, if You become
|
237 |
+
compliant, then the rights granted under this License from a particular
|
238 |
+
Contributor are reinstated (a) provisionally, unless and until such
|
239 |
+
Contributor explicitly and finally terminates Your grants, and (b) on an
|
240 |
+
ongoing basis, if such Contributor fails to notify You of the
|
241 |
+
non-compliance by some reasonable means prior to 60 days after You have
|
242 |
+
come back into compliance. Moreover, Your grants from a particular
|
243 |
+
Contributor are reinstated on an ongoing basis if such Contributor
|
244 |
+
notifies You of the non-compliance by some reasonable means, this is the
|
245 |
+
first time You have received notice of non-compliance with this License
|
246 |
+
from such Contributor, and You become compliant prior to 30 days after
|
247 |
+
Your receipt of the notice.
|
248 |
+
|
249 |
+
5.2. If You initiate litigation against any entity by asserting a patent
|
250 |
+
infringement claim (excluding declaratory judgment actions,
|
251 |
+
counter-claims, and cross-claims) alleging that a Contributor Version
|
252 |
+
directly or indirectly infringes any patent, then the rights granted to
|
253 |
+
You by any and all Contributors for the Covered Software under Section
|
254 |
+
2.1 of this License shall terminate.
|
255 |
+
|
256 |
+
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
|
257 |
+
end user license agreements (excluding distributors and resellers) which
|
258 |
+
have been validly granted by You or Your distributors under this License
|
259 |
+
prior to termination shall survive termination.
|
260 |
+
|
261 |
+
************************************************************************
|
262 |
+
* *
|
263 |
+
* 6. Disclaimer of Warranty *
|
264 |
+
* ------------------------- *
|
265 |
+
* *
|
266 |
+
* Covered Software is provided under this License on an "as is" *
|
267 |
+
* basis, without warranty of any kind, either expressed, implied, or *
|
268 |
+
* statutory, including, without limitation, warranties that the *
|
269 |
+
* Covered Software is free of defects, merchantable, fit for a *
|
270 |
+
* particular purpose or non-infringing. The entire risk as to the *
|
271 |
+
* quality and performance of the Covered Software is with You. *
|
272 |
+
* Should any Covered Software prove defective in any respect, You *
|
273 |
+
* (not any Contributor) assume the cost of any necessary servicing, *
|
274 |
+
* repair, or correction. This disclaimer of warranty constitutes an *
|
275 |
+
* essential part of this License. No use of any Covered Software is *
|
276 |
+
* authorized under this License except under this disclaimer. *
|
277 |
+
* *
|
278 |
+
************************************************************************
|
279 |
+
|
280 |
+
************************************************************************
|
281 |
+
* *
|
282 |
+
* 7. Limitation of Liability *
|
283 |
+
* -------------------------- *
|
284 |
+
* *
|
285 |
+
* Under no circumstances and under no legal theory, whether tort *
|
286 |
+
* (including negligence), contract, or otherwise, shall any *
|
287 |
+
* Contributor, or anyone who distributes Covered Software as *
|
288 |
+
* permitted above, be liable to You for any direct, indirect, *
|
289 |
+
* special, incidental, or consequential damages of any character *
|
290 |
+
* including, without limitation, damages for lost profits, loss of *
|
291 |
+
* goodwill, work stoppage, computer failure or malfunction, or any *
|
292 |
+
* and all other commercial damages or losses, even if such party *
|
293 |
+
* shall have been informed of the possibility of such damages. This *
|
294 |
+
* limitation of liability shall not apply to liability for death or *
|
295 |
+
* personal injury resulting from such party's negligence to the *
|
296 |
+
* extent applicable law prohibits such limitation. Some *
|
297 |
+
* jurisdictions do not allow the exclusion or limitation of *
|
298 |
+
* incidental or consequential damages, so this exclusion and *
|
299 |
+
* limitation may not apply to You. *
|
300 |
+
* *
|
301 |
+
************************************************************************
|
302 |
+
|
303 |
+
8. Litigation
|
304 |
+
-------------
|
305 |
+
|
306 |
+
Any litigation relating to this License may be brought only in the
|
307 |
+
courts of a jurisdiction where the defendant maintains its principal
|
308 |
+
place of business and such litigation shall be governed by laws of that
|
309 |
+
jurisdiction, without reference to its conflict-of-law provisions.
|
310 |
+
Nothing in this Section shall prevent a party's ability to bring
|
311 |
+
cross-claims or counter-claims.
|
312 |
+
|
313 |
+
9. Miscellaneous
|
314 |
+
----------------
|
315 |
+
|
316 |
+
This License represents the complete agreement concerning the subject
|
317 |
+
matter hereof. If any provision of this License is held to be
|
318 |
+
unenforceable, such provision shall be reformed only to the extent
|
319 |
+
necessary to make it enforceable. Any law or regulation which provides
|
320 |
+
that the language of a contract shall be construed against the drafter
|
321 |
+
shall not be used to construe this License against a Contributor.
|
322 |
+
|
323 |
+
10. Versions of the License
|
324 |
+
---------------------------
|
325 |
+
|
326 |
+
10.1. New Versions
|
327 |
+
|
328 |
+
Mozilla Foundation is the license steward. Except as provided in Section
|
329 |
+
10.3, no one other than the license steward has the right to modify or
|
330 |
+
publish new versions of this License. Each version will be given a
|
331 |
+
distinguishing version number.
|
332 |
+
|
333 |
+
10.2. Effect of New Versions
|
334 |
+
|
335 |
+
You may distribute the Covered Software under the terms of the version
|
336 |
+
of the License under which You originally received the Covered Software,
|
337 |
+
or under the terms of any subsequent version published by the license
|
338 |
+
steward.
|
339 |
+
|
340 |
+
10.3. Modified Versions
|
341 |
+
|
342 |
+
If you create software not governed by this License, and you want to
|
343 |
+
create a new license for such software, you may create and use a
|
344 |
+
modified version of this License if you rename the license and remove
|
345 |
+
any references to the name of the license steward (except to note that
|
346 |
+
such modified license differs from this License).
|
347 |
+
|
348 |
+
10.4. Distributing Source Code Form that is Incompatible With Secondary
|
349 |
+
Licenses
|
350 |
+
|
351 |
+
If You choose to distribute Source Code Form that is Incompatible With
|
352 |
+
Secondary Licenses under the terms of this version of the License, the
|
353 |
+
notice described in Exhibit B of this License must be attached.
|
354 |
+
|
355 |
+
Exhibit A - Source Code Form License Notice
|
356 |
+
-------------------------------------------
|
357 |
+
|
358 |
+
This Source Code Form is subject to the terms of the Mozilla Public
|
359 |
+
License, v. 2.0. If a copy of the MPL was not distributed with this
|
360 |
+
file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
361 |
+
|
362 |
+
If it is not possible or desirable to put the notice in a particular
|
363 |
+
file, then You may include the notice in a location (such as a LICENSE
|
364 |
+
file in a relevant directory) where a recipient would be likely to look
|
365 |
+
for such a notice.
|
366 |
+
|
367 |
+
You may add additional accurate notices of copyright ownership.
|
368 |
+
|
369 |
+
Exhibit B - "Incompatible With Secondary Licenses" Notice
|
370 |
+
---------------------------------------------------------
|
371 |
+
|
372 |
+
This Source Code Form is "Incompatible With Secondary Licenses", as
|
373 |
+
defined by the Mozilla Public License, v. 2.0.
|
REUSE.toml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
|
2 |
+
# SPDX-FileContributor: Karl El Hajal
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: MIT
|
5 |
+
|
6 |
+
version = 1
|
7 |
+
SPDX-PackageName = "knn_tts"
|
8 |
+
SPDX-PackageSupplier = "Karl El Hajal <[email protected]>"
|
9 |
+
|
10 |
+
[[annotations]]
|
11 |
+
path = "poetry.lock"
|
12 |
+
precedence = "aggregate"
|
13 |
+
SPDX-FileCopyrightText = "2024 Idiap Research Institute"
|
14 |
+
SPDX-License-Identifier = "CC0-1.0"
|
15 |
+
|
16 |
+
[[annotations]]
|
17 |
+
path = "knn_tts/vocoder/hifigan/config_v1_wavlm.json"
|
18 |
+
precedence = "aggregate"
|
19 |
+
SPDX-FileCopyrightText = "2020 Jungil Kong"
|
20 |
+
SPDX-License-Identifier = "MIT"
|
21 |
+
|
22 |
+
[[annotations]]
|
23 |
+
path = "assets/diagram.png"
|
24 |
+
precedence = "aggregate"
|
25 |
+
SPDX-FileCopyrightText = "2024 Idiap Research Institute"
|
26 |
+
SPDX-License-Identifier = "MIT"
|
app.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
|
2 |
+
# SPDX-FileContributor: Karl El Hajal
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: MIT
|
5 |
+
|
6 |
+
import os
|
7 |
+
import zipfile
|
8 |
+
import gradio as gr
|
9 |
+
import spaces
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
+
|
12 |
+
from knn_tts.synthesizer import Synthesizer
|
13 |
+
from knn_tts.utils import get_vocoder_checkpoint_path
|
14 |
+
|
15 |
+
# Check if target_feats directory exists, if not, unzip target_feats.zip
|
16 |
+
if not os.path.exists("target_feats"):
|
17 |
+
if os.path.exists("target_feats.zip"):
|
18 |
+
with zipfile.ZipFile("target_feats.zip", "r") as zip_ref:
|
19 |
+
zip_ref.extractall(".")
|
20 |
+
else:
|
21 |
+
raise FileNotFoundError("target_feats.zip not found.")
|
22 |
+
|
23 |
+
SAMPLE_RATE = 16000
|
24 |
+
|
25 |
+
CHECKPOINTS_DIR = "./checkpoints"
|
26 |
+
|
27 |
+
tts_checkpoints_dir = snapshot_download(repo_id="idiap/kNN-TTS", local_dir=CHECKPOINTS_DIR)
|
28 |
+
vocoder_checkpoint_path = get_vocoder_checkpoint_path(CHECKPOINTS_DIR)
|
29 |
+
|
30 |
+
tts_checkpoint_name = "best_model_646135.pth"
|
31 |
+
synthesizer = Synthesizer(tts_checkpoints_dir, tts_checkpoint_name, vocoder_checkpoint_path, model_name="glowtts")
|
32 |
+
|
33 |
+
target_speakers = {
|
34 |
+
"Libri 7127":{
|
35 |
+
"feats_path": "target_feats/LibriSpeech-test-clean/7127/wavlm",
|
36 |
+
},
|
37 |
+
"Libri 7729":{
|
38 |
+
"feats_path": "target_feats/LibriSpeech-test-clean/7729/wavlm",
|
39 |
+
},
|
40 |
+
"Libri 6829":{
|
41 |
+
"feats_path": "target_feats/LibriSpeech-test-clean/6829/wavlm",
|
42 |
+
},
|
43 |
+
"Libri 8555":{
|
44 |
+
"feats_path": "target_feats/LibriSpeech-test-clean/8555/wavlm",
|
45 |
+
},
|
46 |
+
"Thorsten Neutral": {
|
47 |
+
"feats_path": "target_feats/Thorsten/neutral/wavlm/",
|
48 |
+
},
|
49 |
+
"Thorsten Whisper": {
|
50 |
+
"feats_path": "target_feats/Thorsten/whisper/wavlm/",
|
51 |
+
},
|
52 |
+
"ESD 0018 Neutral":{
|
53 |
+
"feats_path": "target_feats/ESD/0018/neutral/wavlm/",
|
54 |
+
},
|
55 |
+
"ESD 0018 Surprised":{
|
56 |
+
"feats_path": "target_feats/ESD/0018/surprised/wavlm/",
|
57 |
+
},
|
58 |
+
}
|
59 |
+
|
60 |
+
@spaces.GPU
|
61 |
+
def run(text_input, target_speaker, lambda_rate, topk, weighted_average):
|
62 |
+
feats_path = target_speakers[target_speaker]["feats_path"]
|
63 |
+
wav = synthesizer(text_input, feats_path, interpolation_rate=lambda_rate, knnvc_topk=topk, weighted_average=weighted_average, max_target_num_files=500)
|
64 |
+
wav = (SAMPLE_RATE, wav.squeeze().cpu().numpy())
|
65 |
+
return wav
|
66 |
+
|
67 |
+
|
68 |
+
def get_title(text, size=1):
|
69 |
+
return f"""
|
70 |
+
<center>
|
71 |
+
|
72 |
+
<h{size}> {text} </h{size}>
|
73 |
+
|
74 |
+
</center>
|
75 |
+
"""
|
76 |
+
|
77 |
+
def create_gradio_interface():
|
78 |
+
with gr.Blocks(
|
79 |
+
theme=gr.themes.Default(
|
80 |
+
text_size="lg",
|
81 |
+
),
|
82 |
+
title="kNN-TTS"
|
83 |
+
) as iface:
|
84 |
+
|
85 |
+
gr.HTML(get_title("kNN-TTS: kNN Retrieval for Simple and Effective Zero-Shot Multi-speaker Text-to-Speech", size=1))
|
86 |
+
|
87 |
+
with gr.Tabs():
|
88 |
+
with gr.TabItem("Generate Speech"):
|
89 |
+
with gr.Row():
|
90 |
+
# Left column - inputs
|
91 |
+
with gr.Column():
|
92 |
+
gr.Markdown("## Input")
|
93 |
+
text_box = gr.Textbox(
|
94 |
+
lines=3,
|
95 |
+
placeholder="Enter the text to convert to speech...",
|
96 |
+
label="Text",
|
97 |
+
elem_id="text-input"
|
98 |
+
)
|
99 |
+
|
100 |
+
target_speaker_dropdown = gr.Dropdown(
|
101 |
+
choices=list(target_speakers.keys()),
|
102 |
+
value="Libri 7127",
|
103 |
+
label="Target Voice",
|
104 |
+
elem_id="target-voice"
|
105 |
+
)
|
106 |
+
|
107 |
+
rate_slider = gr.Slider(
|
108 |
+
minimum=0.0,
|
109 |
+
maximum=2.0,
|
110 |
+
value=1.0,
|
111 |
+
step=0.01,
|
112 |
+
label="Voice Morphing (λ)",
|
113 |
+
info="Higher values give more weight to target voice characteristics"
|
114 |
+
)
|
115 |
+
|
116 |
+
with gr.Accordion("Advanced Settings", open=False):
|
117 |
+
k_slider = gr.Slider(
|
118 |
+
minimum=1,
|
119 |
+
maximum=50,
|
120 |
+
value=4,
|
121 |
+
step=1,
|
122 |
+
label="Top-k Retrieval",
|
123 |
+
info="k closest neighbors to retrieve"
|
124 |
+
)
|
125 |
+
weighted_toggle = gr.Checkbox(
|
126 |
+
label="Use Weighted Averaging",
|
127 |
+
value=False,
|
128 |
+
info="Weight neighbors by similarity distance"
|
129 |
+
)
|
130 |
+
|
131 |
+
submit_button = gr.Button("Generate Audio", variant="primary", size="lg")
|
132 |
+
|
133 |
+
# Right column - outputs
|
134 |
+
with gr.Column():
|
135 |
+
gr.Markdown("## Generated Audio")
|
136 |
+
with gr.Group():
|
137 |
+
audio_output = gr.Audio(
|
138 |
+
type="numpy",
|
139 |
+
label="Output Speech",
|
140 |
+
elem_id="audio-output"
|
141 |
+
)
|
142 |
+
with gr.Row():
|
143 |
+
clear_btn = gr.ClearButton([text_box, target_speaker_dropdown, rate_slider, audio_output], variant="secondary", size="lg")
|
144 |
+
|
145 |
+
# Example section
|
146 |
+
with gr.Row():
|
147 |
+
gr.Examples(
|
148 |
+
examples=[
|
149 |
+
["I think foosball is a combination of football and shish kebabs.", "Thorsten Whisper", 1.0, 8, True],
|
150 |
+
["I think foosball is a combination of football and shish kebabs.", "Thorsten Neutral", 1.0, 4, False],
|
151 |
+
["If you're traveling in the north country fair.", "Libri 7127", 1.0, 4, False],
|
152 |
+
["Like a vision she dances across the porch as the radio plays.", "Libri 7729", 1.0, 8, True],
|
153 |
+
["There weren't another other way to be.", "Libri 6829", 1.0, 4, False],
|
154 |
+
],
|
155 |
+
inputs=[text_box, target_speaker_dropdown, rate_slider, k_slider, weighted_toggle],
|
156 |
+
outputs=audio_output,
|
157 |
+
fn=run,
|
158 |
+
cache_examples=True
|
159 |
+
)
|
160 |
+
|
161 |
+
# Additional tabs
|
162 |
+
with gr.TabItem("Model Details"):
|
163 |
+
with gr.Row():
|
164 |
+
with gr.Column():
|
165 |
+
gr.Markdown("""
|
166 |
+
## kNN-TTS Technical Details
|
167 |
+
|
168 |
+
kNN-TTS uses self-supervised learning (SSL) features and kNN retrieval to achieve robust zero-shot multi-speaker TTS.
|
169 |
+
|
170 |
+
### Key Components
|
171 |
+
|
172 |
+
1. **Feature Extraction**: We extract discrete representations from target speaker speech using a pre-trained SSL encoder. We use the 6th layer of WavLM Large.
|
173 |
+
2. **Text-to-SSL**: We train a lightweight TTS model to predict the same representations from Text. For simplicity, we train on a single speaker dataset.
|
174 |
+
3. **Retrieval Mechanism**: We use kNN to find for each unit in the generated features its closest matches in the target voice unit database
|
175 |
+
4. **Voice Morphing**: By linearly interpolating the source and selected target speaker features, we can morph the two voices. The interpolation parameter λ controls the balance between source and target characteristics
|
176 |
+
5. **Vocoder**: We use a pre-trained vocoder to convert the converted features to waveform.
|
177 |
+
|
178 |
+
### Performance
|
179 |
+
|
180 |
+
Our simple and efficient model achieves comparable results to sota models while being trained on 100 to 1000× less transcribed data.
|
181 |
+
This framework is therefore particularly well-suited for low-resource domains.
|
182 |
+
|
183 |
+
For more details, please refer to our paper (https://arxiv.org/abs/2408.10771).
|
184 |
+
""")
|
185 |
+
with gr.Column():
|
186 |
+
gr.Image("assets/diagram.png", label="Model Architecture", scale=0.3, show_label=False, show_download_button=False, show_fullscreen_button=False)
|
187 |
+
|
188 |
+
with gr.TabItem("About"):
|
189 |
+
gr.Markdown("""
|
190 |
+
## About the Project
|
191 |
+
|
192 |
+
This demo showcases kNN-TTS, a lightweight zero-shot text-to-speech synthesis model.
|
193 |
+
|
194 |
+
### Authors
|
195 |
+
|
196 |
+
- Karl El Hajal
|
197 |
+
- Ajinkya Kulkarni
|
198 |
+
- Enno Hermann
|
199 |
+
- Mathew Magimai.-Doss
|
200 |
+
|
201 |
+
### Citation
|
202 |
+
|
203 |
+
If you use kNN-TTS in your research, please cite our paper:
|
204 |
+
|
205 |
+
```
|
206 |
+
@misc{hajal2025knntts,
|
207 |
+
title={kNN Retrieval for Simple and Effective Zero-Shot Multi-speaker Text-to-Speech},
|
208 |
+
author={Karl El Hajal and Ajinkya Kulkarni and Enno Hermann and Mathew Magimai.-Doss},
|
209 |
+
year={2025},
|
210 |
+
eprint={2408.10771},
|
211 |
+
archivePrefix={arXiv},
|
212 |
+
primaryClass={eess.AS},
|
213 |
+
url={https://arxiv.org/abs/2408.10771},
|
214 |
+
}
|
215 |
+
```
|
216 |
+
|
217 |
+
### Acknowledgments
|
218 |
+
|
219 |
+
The target voices featured in this demo were sourced from the following datasets:
|
220 |
+
|
221 |
+
- [Thorsten Dataset](https://www.thorsten-voice.de/)
|
222 |
+
- [LibriSpeech Dataset](https://www.openslr.org/12)
|
223 |
+
- [Emotional Speech Dataset (ESD)](https://hltsingapore.github.io/ESD/)
|
224 |
+
|
225 |
+
### License
|
226 |
+
|
227 |
+
This project is licensed under the MIT License.
|
228 |
+
""")
|
229 |
+
|
230 |
+
# Event handlers
|
231 |
+
submit_button.click(
|
232 |
+
fn=run,
|
233 |
+
inputs=[text_box, target_speaker_dropdown, rate_slider, k_slider, weighted_toggle],
|
234 |
+
outputs=[audio_output]
|
235 |
+
)
|
236 |
+
|
237 |
+
return iface
|
238 |
+
|
239 |
+
demo = create_gradio_interface()
|
240 |
+
demo.launch(share=True, debug=False)
|
assets/diagram.png
ADDED
![]() |
knn_tts/__init__.py
ADDED
File without changes
|
knn_tts/ssl/WavLM/WavLM.py
ADDED
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2021 Microsoft
|
2 |
+
#
|
3 |
+
# SPDX-License-Identifier: MIT
|
4 |
+
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
7 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
8 |
+
# Copyright (c) 2021 Microsoft
|
9 |
+
# Licensed under The MIT License [see LICENSE for details]
|
10 |
+
# Based on fairseq code bases
|
11 |
+
# https://github.com/pytorch/fairseq
|
12 |
+
# --------------------------------------------------------
|
13 |
+
|
14 |
+
import logging
|
15 |
+
import math
|
16 |
+
from typing import List, Optional, Tuple
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch import nn
|
22 |
+
from torch.nn import LayerNorm
|
23 |
+
|
24 |
+
from .modules import (
|
25 |
+
Fp32GroupNorm,
|
26 |
+
Fp32LayerNorm,
|
27 |
+
GLU_Linear,
|
28 |
+
GradMultiply,
|
29 |
+
MultiheadAttention,
|
30 |
+
SamePad,
|
31 |
+
TransposeLast,
|
32 |
+
get_activation_fn,
|
33 |
+
init_bert_params,
|
34 |
+
)
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
# pylint: disable=too-many-branches, too-many-statements
|
40 |
+
def compute_mask_indices(
|
41 |
+
shape: Tuple[int, int],
|
42 |
+
padding_mask: Optional[torch.Tensor],
|
43 |
+
mask_prob: float,
|
44 |
+
mask_length: int,
|
45 |
+
mask_type: str = "static",
|
46 |
+
mask_other: float = 0.0,
|
47 |
+
min_masks: int = 0,
|
48 |
+
no_overlap: bool = False,
|
49 |
+
min_space: int = 0,
|
50 |
+
) -> np.ndarray:
|
51 |
+
"""
|
52 |
+
Computes random mask spans for a given shape
|
53 |
+
|
54 |
+
Args:
|
55 |
+
shape: the the shape for which to compute masks.
|
56 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
57 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
58 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
59 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
60 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
61 |
+
mask_type: how to compute mask lengths
|
62 |
+
static = fixed size
|
63 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
64 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
65 |
+
poisson = sample from possion distribution with lambda = mask length
|
66 |
+
min_masks: minimum number of masked spans
|
67 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
68 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
69 |
+
"""
|
70 |
+
|
71 |
+
bsz, all_sz = shape
|
72 |
+
mask = np.full((bsz, all_sz), False)
|
73 |
+
|
74 |
+
all_num_mask = int(
|
75 |
+
# add a random number for probabilistic rounding
|
76 |
+
mask_prob * all_sz / float(mask_length) + np.random.rand()
|
77 |
+
)
|
78 |
+
|
79 |
+
all_num_mask = max(min_masks, all_num_mask)
|
80 |
+
|
81 |
+
mask_idcs = []
|
82 |
+
for i in range(bsz):
|
83 |
+
if padding_mask is not None:
|
84 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
85 |
+
num_mask = int(
|
86 |
+
# add a random number for probabilistic rounding
|
87 |
+
mask_prob * sz / float(mask_length) + np.random.rand()
|
88 |
+
)
|
89 |
+
num_mask = max(min_masks, num_mask)
|
90 |
+
else:
|
91 |
+
sz = all_sz
|
92 |
+
num_mask = all_num_mask
|
93 |
+
|
94 |
+
if mask_type == "static":
|
95 |
+
lengths = np.full(num_mask, mask_length)
|
96 |
+
elif mask_type == "uniform":
|
97 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
98 |
+
elif mask_type == "normal":
|
99 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
100 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
101 |
+
elif mask_type == "poisson":
|
102 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
103 |
+
lengths = [int(round(x)) for x in lengths]
|
104 |
+
else:
|
105 |
+
raise ValueError("unknown mask selection " + mask_type)
|
106 |
+
|
107 |
+
if sum(lengths) == 0:
|
108 |
+
lengths[0] = min(mask_length, sz - 1)
|
109 |
+
|
110 |
+
if no_overlap:
|
111 |
+
mask_idc = []
|
112 |
+
|
113 |
+
def arrange(s, e, length, keep_length):
|
114 |
+
span_start = np.random.randint(s, e - length)
|
115 |
+
mask_idc.extend(span_start + i for i in range(length)) # pylint: disable=cell-var-from-loop
|
116 |
+
|
117 |
+
new_parts = []
|
118 |
+
if span_start - s - min_space >= keep_length:
|
119 |
+
new_parts.append((s, span_start - min_space + 1))
|
120 |
+
if e - span_start - keep_length - min_space > keep_length:
|
121 |
+
new_parts.append((span_start + length + min_space, e))
|
122 |
+
return new_parts
|
123 |
+
|
124 |
+
parts = [(0, sz)]
|
125 |
+
min_length = min(lengths)
|
126 |
+
for length in sorted(lengths, reverse=True):
|
127 |
+
lens = np.fromiter(
|
128 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
129 |
+
np.int,
|
130 |
+
)
|
131 |
+
l_sum = np.sum(lens)
|
132 |
+
if l_sum == 0:
|
133 |
+
break
|
134 |
+
probs = lens / np.sum(lens)
|
135 |
+
c = np.random.choice(len(parts), p=probs)
|
136 |
+
s, e = parts.pop(c)
|
137 |
+
parts.extend(arrange(s, e, length, min_length))
|
138 |
+
mask_idc = np.asarray(mask_idc)
|
139 |
+
else:
|
140 |
+
min_len = min(lengths)
|
141 |
+
if sz - min_len <= num_mask:
|
142 |
+
min_len = sz - num_mask - 1
|
143 |
+
|
144 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
145 |
+
|
146 |
+
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
147 |
+
|
148 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
149 |
+
|
150 |
+
min_len = min(len(m) for m in mask_idcs)
|
151 |
+
for i, mask_idc in enumerate(mask_idcs):
|
152 |
+
if len(mask_idc) > min_len:
|
153 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
154 |
+
mask[i, mask_idc] = True
|
155 |
+
|
156 |
+
return mask
|
157 |
+
|
158 |
+
|
159 |
+
class WavLMConfig:
|
160 |
+
def __init__(self, cfg=None):
|
161 |
+
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
162 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
163 |
+
|
164 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
165 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
166 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
167 |
+
self.activation_fn: str = "gelu" # activation function to use
|
168 |
+
|
169 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
170 |
+
self.conv_feature_layers: str = (
|
171 |
+
"[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
172 |
+
)
|
173 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
174 |
+
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
|
175 |
+
|
176 |
+
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
|
177 |
+
|
178 |
+
# dropouts
|
179 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
180 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
181 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
182 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
183 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
184 |
+
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
|
185 |
+
|
186 |
+
# masking
|
187 |
+
self.mask_length: int = 10 # mask length
|
188 |
+
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
189 |
+
self.mask_selection: str = "static" # how to choose mask length
|
190 |
+
self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
191 |
+
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
192 |
+
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
193 |
+
|
194 |
+
# channel masking
|
195 |
+
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
196 |
+
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
197 |
+
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
|
198 |
+
self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
199 |
+
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
|
200 |
+
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
201 |
+
|
202 |
+
# positional embeddings
|
203 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
204 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
205 |
+
|
206 |
+
# relative position embedding
|
207 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
208 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
209 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
210 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
211 |
+
|
212 |
+
if cfg is not None:
|
213 |
+
self.update(cfg)
|
214 |
+
|
215 |
+
def update(self, cfg: dict):
|
216 |
+
self.__dict__.update(cfg)
|
217 |
+
|
218 |
+
|
219 |
+
class WavLMModel(nn.Module):
|
220 |
+
def __init__(
|
221 |
+
self,
|
222 |
+
cfg: WavLMConfig,
|
223 |
+
) -> None:
|
224 |
+
super().__init__()
|
225 |
+
logger.info("WavLM Config: %s", cfg.__dict__)
|
226 |
+
|
227 |
+
self.cfg = cfg
|
228 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
229 |
+
self.embed = feature_enc_layers[-1][0]
|
230 |
+
|
231 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
232 |
+
conv_layers=feature_enc_layers,
|
233 |
+
dropout=0.0,
|
234 |
+
mode=cfg.extractor_mode,
|
235 |
+
conv_bias=cfg.conv_bias,
|
236 |
+
)
|
237 |
+
|
238 |
+
self.post_extract_proj = nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
|
239 |
+
|
240 |
+
self.mask_prob = cfg.mask_prob
|
241 |
+
self.mask_selection = cfg.mask_selection
|
242 |
+
self.mask_other = cfg.mask_other
|
243 |
+
self.mask_length = cfg.mask_length
|
244 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
245 |
+
self.mask_min_space = cfg.mask_min_space
|
246 |
+
|
247 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
248 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
249 |
+
self.mask_channel_other = cfg.mask_channel_other
|
250 |
+
self.mask_channel_length = cfg.mask_channel_length
|
251 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
252 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
253 |
+
|
254 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
255 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
256 |
+
|
257 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
258 |
+
|
259 |
+
self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
|
260 |
+
|
261 |
+
self.encoder = TransformerEncoder(cfg)
|
262 |
+
self.layer_norm = LayerNorm(self.embed)
|
263 |
+
|
264 |
+
def apply_mask(self, x, padding_mask):
|
265 |
+
B, T, C = x.shape
|
266 |
+
if self.mask_prob > 0:
|
267 |
+
mask_indices = compute_mask_indices(
|
268 |
+
(B, T),
|
269 |
+
padding_mask,
|
270 |
+
self.mask_prob,
|
271 |
+
self.mask_length,
|
272 |
+
self.mask_selection,
|
273 |
+
self.mask_other,
|
274 |
+
min_masks=2,
|
275 |
+
no_overlap=self.no_mask_overlap,
|
276 |
+
min_space=self.mask_min_space,
|
277 |
+
)
|
278 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
279 |
+
x[mask_indices] = self.mask_emb
|
280 |
+
else:
|
281 |
+
mask_indices = None
|
282 |
+
|
283 |
+
if self.mask_channel_prob > 0:
|
284 |
+
mask_channel_indices = compute_mask_indices(
|
285 |
+
(B, C),
|
286 |
+
None,
|
287 |
+
self.mask_channel_prob,
|
288 |
+
self.mask_channel_length,
|
289 |
+
self.mask_channel_selection,
|
290 |
+
self.mask_channel_other,
|
291 |
+
no_overlap=self.no_mask_channel_overlap,
|
292 |
+
min_space=self.mask_channel_min_space,
|
293 |
+
)
|
294 |
+
mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)
|
295 |
+
x[mask_channel_indices] = 0
|
296 |
+
|
297 |
+
return x, mask_indices
|
298 |
+
|
299 |
+
def forward_padding_mask(
|
300 |
+
self,
|
301 |
+
features: torch.Tensor,
|
302 |
+
padding_mask: torch.Tensor,
|
303 |
+
) -> torch.Tensor:
|
304 |
+
extra = padding_mask.size(1) % features.size(1)
|
305 |
+
if extra > 0:
|
306 |
+
padding_mask = padding_mask[:, :-extra]
|
307 |
+
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
308 |
+
padding_mask = padding_mask.all(-1)
|
309 |
+
return padding_mask
|
310 |
+
|
311 |
+
def extract_features(
|
312 |
+
self,
|
313 |
+
source: torch.Tensor,
|
314 |
+
padding_mask: Optional[torch.Tensor] = None,
|
315 |
+
mask: bool = False,
|
316 |
+
ret_conv: bool = False,
|
317 |
+
output_layer: Optional[int] = None,
|
318 |
+
ret_layer_results: bool = False,
|
319 |
+
):
|
320 |
+
if self.feature_grad_mult > 0:
|
321 |
+
features = self.feature_extractor(source)
|
322 |
+
if self.feature_grad_mult != 1.0:
|
323 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
324 |
+
else:
|
325 |
+
with torch.no_grad():
|
326 |
+
features = self.feature_extractor(source)
|
327 |
+
|
328 |
+
features = features.transpose(1, 2)
|
329 |
+
features = self.layer_norm(features)
|
330 |
+
|
331 |
+
if padding_mask is not None:
|
332 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
333 |
+
|
334 |
+
if self.post_extract_proj is not None:
|
335 |
+
features = self.post_extract_proj(features)
|
336 |
+
|
337 |
+
features = self.dropout_input(features)
|
338 |
+
|
339 |
+
if mask:
|
340 |
+
x, _ = self.apply_mask(features, padding_mask)
|
341 |
+
else:
|
342 |
+
x = features
|
343 |
+
|
344 |
+
# feature: (B, T, D), float
|
345 |
+
# target: (B, T), long
|
346 |
+
# x: (B, T, D), float
|
347 |
+
# padding_mask: (B, T), bool
|
348 |
+
# mask_indices: (B, T), bool
|
349 |
+
x, layer_results = self.encoder(
|
350 |
+
x,
|
351 |
+
padding_mask=padding_mask,
|
352 |
+
layer=None if output_layer is None else output_layer - 1,
|
353 |
+
)
|
354 |
+
|
355 |
+
res = {
|
356 |
+
"x": x,
|
357 |
+
"padding_mask": padding_mask,
|
358 |
+
"features": features,
|
359 |
+
"layer_results": layer_results,
|
360 |
+
}
|
361 |
+
|
362 |
+
feature = res["features"] if ret_conv else res["x"]
|
363 |
+
if ret_layer_results:
|
364 |
+
feature = (feature, res["layer_results"])
|
365 |
+
return feature, res["padding_mask"]
|
366 |
+
|
367 |
+
def forward(
|
368 |
+
self,
|
369 |
+
source: torch.Tensor,
|
370 |
+
padding_mask: Optional[torch.Tensor] = None,
|
371 |
+
mask: bool = False,
|
372 |
+
ret_conv: bool = False,
|
373 |
+
output_layer: Optional[int] = None,
|
374 |
+
ret_layer_results: bool = False,
|
375 |
+
):
|
376 |
+
return self.extract_features(
|
377 |
+
source,
|
378 |
+
padding_mask=padding_mask,
|
379 |
+
mask=mask,
|
380 |
+
ret_conv=ret_conv,
|
381 |
+
output_layer=output_layer,
|
382 |
+
ret_layer_results=ret_layer_results,
|
383 |
+
)
|
384 |
+
|
385 |
+
|
386 |
+
class ConvFeatureExtractionModel(nn.Module):
|
387 |
+
def __init__(
|
388 |
+
self,
|
389 |
+
conv_layers: List[Tuple[int, int, int]],
|
390 |
+
dropout: float = 0.0,
|
391 |
+
mode: str = "default",
|
392 |
+
conv_bias: bool = False,
|
393 |
+
conv_type: str = "default",
|
394 |
+
):
|
395 |
+
super().__init__()
|
396 |
+
|
397 |
+
assert mode in {"default", "layer_norm"}
|
398 |
+
|
399 |
+
def block(
|
400 |
+
n_in,
|
401 |
+
n_out,
|
402 |
+
k,
|
403 |
+
stride,
|
404 |
+
is_layer_norm=False,
|
405 |
+
is_group_norm=False,
|
406 |
+
conv_bias=False,
|
407 |
+
):
|
408 |
+
def make_conv():
|
409 |
+
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
410 |
+
nn.init.kaiming_normal_(conv.weight)
|
411 |
+
return conv
|
412 |
+
|
413 |
+
assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive"
|
414 |
+
|
415 |
+
if is_layer_norm:
|
416 |
+
return nn.Sequential(
|
417 |
+
make_conv(),
|
418 |
+
nn.Dropout(p=dropout),
|
419 |
+
nn.Sequential(
|
420 |
+
TransposeLast(),
|
421 |
+
Fp32LayerNorm(dim, elementwise_affine=True),
|
422 |
+
TransposeLast(),
|
423 |
+
),
|
424 |
+
nn.GELU(),
|
425 |
+
)
|
426 |
+
if is_group_norm:
|
427 |
+
return nn.Sequential(
|
428 |
+
make_conv(),
|
429 |
+
nn.Dropout(p=dropout),
|
430 |
+
Fp32GroupNorm(dim, dim, affine=True),
|
431 |
+
nn.GELU(),
|
432 |
+
)
|
433 |
+
|
434 |
+
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
435 |
+
|
436 |
+
self.conv_type = conv_type
|
437 |
+
if self.conv_type == "default":
|
438 |
+
in_d = 1
|
439 |
+
self.conv_layers = nn.ModuleList()
|
440 |
+
for i, cl in enumerate(conv_layers):
|
441 |
+
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
442 |
+
(dim, k, stride) = cl
|
443 |
+
|
444 |
+
self.conv_layers.append(
|
445 |
+
block(
|
446 |
+
in_d,
|
447 |
+
dim,
|
448 |
+
k,
|
449 |
+
stride,
|
450 |
+
is_layer_norm=mode == "layer_norm",
|
451 |
+
is_group_norm=mode == "default" and i == 0,
|
452 |
+
conv_bias=conv_bias,
|
453 |
+
)
|
454 |
+
)
|
455 |
+
in_d = dim
|
456 |
+
elif self.conv_type == "conv2d":
|
457 |
+
in_d = 1
|
458 |
+
self.conv_layers = nn.ModuleList()
|
459 |
+
for i, cl in enumerate(conv_layers):
|
460 |
+
assert len(cl) == 3
|
461 |
+
(dim, k, stride) = cl
|
462 |
+
|
463 |
+
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
|
464 |
+
self.conv_layers.append(torch.nn.ReLU())
|
465 |
+
in_d = dim
|
466 |
+
elif self.conv_type == "custom":
|
467 |
+
in_d = 1
|
468 |
+
idim = 80
|
469 |
+
self.conv_layers = nn.ModuleList()
|
470 |
+
for i, cl in enumerate(conv_layers):
|
471 |
+
assert len(cl) == 3
|
472 |
+
(dim, k, stride) = cl
|
473 |
+
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride, padding=1))
|
474 |
+
self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
|
475 |
+
self.conv_layers.append(torch.nn.ReLU())
|
476 |
+
in_d = dim
|
477 |
+
if (i + 1) % 2 == 0:
|
478 |
+
self.conv_layers.append(torch.nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
479 |
+
idim = int(math.ceil(idim / 2))
|
480 |
+
else:
|
481 |
+
pass
|
482 |
+
|
483 |
+
def forward(self, x, mask=None): # pylint: disable=unused-argument
|
484 |
+
# BxT -> BxCxT
|
485 |
+
x = x.unsqueeze(1)
|
486 |
+
if self.conv_type == "custom":
|
487 |
+
for conv in self.conv_layers:
|
488 |
+
if isinstance(conv, nn.LayerNorm):
|
489 |
+
x = x.transpose(1, 2)
|
490 |
+
x = conv(x).transpose(1, 2)
|
491 |
+
else:
|
492 |
+
x = conv(x)
|
493 |
+
x = x.transpose(2, 3).contiguous()
|
494 |
+
x = x.view(x.size(0), -1, x.size(-1))
|
495 |
+
else:
|
496 |
+
for conv in self.conv_layers:
|
497 |
+
x = conv(x)
|
498 |
+
if self.conv_type == "conv2d":
|
499 |
+
b, c, t, f = x.size()
|
500 |
+
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
501 |
+
return x
|
502 |
+
|
503 |
+
|
504 |
+
class TransformerEncoder(nn.Module):
|
505 |
+
def __init__(self, args):
|
506 |
+
super().__init__()
|
507 |
+
|
508 |
+
self.dropout = args.dropout
|
509 |
+
self.embedding_dim = args.encoder_embed_dim
|
510 |
+
|
511 |
+
self.pos_conv = nn.Conv1d(
|
512 |
+
self.embedding_dim,
|
513 |
+
self.embedding_dim,
|
514 |
+
kernel_size=args.conv_pos,
|
515 |
+
padding=args.conv_pos // 2,
|
516 |
+
groups=args.conv_pos_groups,
|
517 |
+
)
|
518 |
+
dropout = 0
|
519 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
520 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
521 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
522 |
+
|
523 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
524 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
525 |
+
|
526 |
+
if hasattr(args, "relative_position_embedding"):
|
527 |
+
self.relative_position_embedding = args.relative_position_embedding
|
528 |
+
self.num_buckets = args.num_buckets
|
529 |
+
self.max_distance = args.max_distance
|
530 |
+
else:
|
531 |
+
self.relative_position_embedding = False
|
532 |
+
self.num_buckets = 0
|
533 |
+
self.max_distance = 0
|
534 |
+
|
535 |
+
self.layers = nn.ModuleList(
|
536 |
+
[
|
537 |
+
TransformerSentenceEncoderLayer(
|
538 |
+
embedding_dim=self.embedding_dim,
|
539 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
540 |
+
num_attention_heads=args.encoder_attention_heads,
|
541 |
+
dropout=self.dropout,
|
542 |
+
attention_dropout=args.attention_dropout,
|
543 |
+
activation_dropout=args.activation_dropout,
|
544 |
+
activation_fn=args.activation_fn,
|
545 |
+
layer_norm_first=args.layer_norm_first,
|
546 |
+
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
|
547 |
+
num_buckets=self.num_buckets,
|
548 |
+
max_distance=self.max_distance,
|
549 |
+
gru_rel_pos=args.gru_rel_pos,
|
550 |
+
)
|
551 |
+
for i in range(args.encoder_layers)
|
552 |
+
]
|
553 |
+
)
|
554 |
+
|
555 |
+
self.layer_norm_first = args.layer_norm_first
|
556 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
557 |
+
self.layerdrop = args.encoder_layerdrop
|
558 |
+
|
559 |
+
self.apply(init_bert_params)
|
560 |
+
|
561 |
+
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
562 |
+
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
563 |
+
|
564 |
+
if self.layer_norm_first and layer is None:
|
565 |
+
x = self.layer_norm(x)
|
566 |
+
|
567 |
+
return x, layer_results
|
568 |
+
|
569 |
+
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
|
570 |
+
if padding_mask is not None:
|
571 |
+
x[padding_mask] = 0
|
572 |
+
|
573 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
574 |
+
x_conv = x_conv.transpose(1, 2)
|
575 |
+
x += x_conv
|
576 |
+
|
577 |
+
if not self.layer_norm_first:
|
578 |
+
x = self.layer_norm(x)
|
579 |
+
|
580 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
581 |
+
|
582 |
+
# B x T x C -> T x B x C
|
583 |
+
x = x.transpose(0, 1)
|
584 |
+
|
585 |
+
layer_results = []
|
586 |
+
z = None
|
587 |
+
if tgt_layer is not None:
|
588 |
+
layer_results.append((x, z))
|
589 |
+
r = None
|
590 |
+
pos_bias = None
|
591 |
+
for i, layer in enumerate(self.layers):
|
592 |
+
dropout_probability = np.random.random()
|
593 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
594 |
+
x, z, pos_bias = layer(
|
595 |
+
x,
|
596 |
+
self_attn_padding_mask=padding_mask,
|
597 |
+
need_weights=False,
|
598 |
+
self_attn_mask=streaming_mask,
|
599 |
+
pos_bias=pos_bias,
|
600 |
+
)
|
601 |
+
if tgt_layer is not None:
|
602 |
+
layer_results.append((x, z))
|
603 |
+
if i == tgt_layer:
|
604 |
+
r = x
|
605 |
+
break
|
606 |
+
|
607 |
+
if r is not None:
|
608 |
+
x = r
|
609 |
+
|
610 |
+
# T x B x C -> B x T x C
|
611 |
+
x = x.transpose(0, 1)
|
612 |
+
|
613 |
+
return x, layer_results
|
614 |
+
|
615 |
+
|
616 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
617 |
+
"""
|
618 |
+
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
619 |
+
models.
|
620 |
+
"""
|
621 |
+
|
622 |
+
def __init__(
|
623 |
+
self,
|
624 |
+
embedding_dim: float = 768,
|
625 |
+
ffn_embedding_dim: float = 3072,
|
626 |
+
num_attention_heads: float = 8,
|
627 |
+
dropout: float = 0.1,
|
628 |
+
attention_dropout: float = 0.1,
|
629 |
+
activation_dropout: float = 0.1,
|
630 |
+
activation_fn: str = "relu",
|
631 |
+
layer_norm_first: bool = False,
|
632 |
+
has_relative_attention_bias: bool = False,
|
633 |
+
num_buckets: int = 0,
|
634 |
+
max_distance: int = 0,
|
635 |
+
rescale_init: bool = False,
|
636 |
+
gru_rel_pos: bool = False,
|
637 |
+
) -> None:
|
638 |
+
super().__init__()
|
639 |
+
# Initialize parameters
|
640 |
+
self.embedding_dim = embedding_dim
|
641 |
+
self.dropout = dropout
|
642 |
+
self.activation_dropout = activation_dropout
|
643 |
+
|
644 |
+
# Initialize blocks
|
645 |
+
self.activation_name = activation_fn
|
646 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
647 |
+
self.self_attn = MultiheadAttention(
|
648 |
+
self.embedding_dim,
|
649 |
+
num_attention_heads,
|
650 |
+
dropout=attention_dropout,
|
651 |
+
self_attention=True,
|
652 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
653 |
+
num_buckets=num_buckets,
|
654 |
+
max_distance=max_distance,
|
655 |
+
rescale_init=rescale_init,
|
656 |
+
gru_rel_pos=gru_rel_pos,
|
657 |
+
)
|
658 |
+
|
659 |
+
self.dropout1 = nn.Dropout(dropout)
|
660 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
661 |
+
self.dropout3 = nn.Dropout(dropout)
|
662 |
+
|
663 |
+
self.layer_norm_first = layer_norm_first
|
664 |
+
|
665 |
+
# layer norm associated with the self attention layer
|
666 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
667 |
+
|
668 |
+
if self.activation_name == "glu":
|
669 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
670 |
+
else:
|
671 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
672 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
673 |
+
|
674 |
+
# layer norm associated with the position wise feed-forward NN
|
675 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
676 |
+
|
677 |
+
def forward(
|
678 |
+
self,
|
679 |
+
x: torch.Tensor,
|
680 |
+
self_attn_mask: torch.Tensor = None,
|
681 |
+
self_attn_padding_mask: torch.Tensor = None,
|
682 |
+
need_weights: bool = False,
|
683 |
+
pos_bias=None,
|
684 |
+
):
|
685 |
+
"""
|
686 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
687 |
+
modules similar to the original Transformer imlementation.
|
688 |
+
"""
|
689 |
+
residual = x
|
690 |
+
|
691 |
+
if self.layer_norm_first:
|
692 |
+
x = self.self_attn_layer_norm(x)
|
693 |
+
x, attn, pos_bias = self.self_attn(
|
694 |
+
query=x,
|
695 |
+
key=x,
|
696 |
+
value=x,
|
697 |
+
key_padding_mask=self_attn_padding_mask,
|
698 |
+
need_weights=False,
|
699 |
+
attn_mask=self_attn_mask,
|
700 |
+
position_bias=pos_bias,
|
701 |
+
)
|
702 |
+
x = self.dropout1(x)
|
703 |
+
x = residual + x
|
704 |
+
|
705 |
+
residual = x
|
706 |
+
x = self.final_layer_norm(x)
|
707 |
+
if self.activation_name == "glu":
|
708 |
+
x = self.fc1(x)
|
709 |
+
else:
|
710 |
+
x = self.activation_fn(self.fc1(x))
|
711 |
+
x = self.dropout2(x)
|
712 |
+
x = self.fc2(x)
|
713 |
+
x = self.dropout3(x)
|
714 |
+
x = residual + x
|
715 |
+
else:
|
716 |
+
x, attn, pos_bias = self.self_attn(
|
717 |
+
query=x,
|
718 |
+
key=x,
|
719 |
+
value=x,
|
720 |
+
key_padding_mask=self_attn_padding_mask,
|
721 |
+
need_weights=need_weights,
|
722 |
+
attn_mask=self_attn_mask,
|
723 |
+
position_bias=pos_bias,
|
724 |
+
)
|
725 |
+
|
726 |
+
x = self.dropout1(x)
|
727 |
+
x = residual + x
|
728 |
+
|
729 |
+
x = self.self_attn_layer_norm(x)
|
730 |
+
|
731 |
+
residual = x
|
732 |
+
if self.activation_name == "glu":
|
733 |
+
x = self.fc1(x)
|
734 |
+
else:
|
735 |
+
x = self.activation_fn(self.fc1(x))
|
736 |
+
x = self.dropout2(x)
|
737 |
+
x = self.fc2(x)
|
738 |
+
x = self.dropout3(x)
|
739 |
+
x = residual + x
|
740 |
+
x = self.final_layer_norm(x)
|
741 |
+
|
742 |
+
return x, attn, pos_bias
|
knn_tts/ssl/WavLM/__init__.py
ADDED
File without changes
|
knn_tts/ssl/WavLM/modules.py
ADDED
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2021 Microsoft
|
2 |
+
#
|
3 |
+
# SPDX-License-Identifier: MIT
|
4 |
+
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
7 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
8 |
+
# Copyright (c) 2021 Microsoft
|
9 |
+
# Licensed under The MIT License [see LICENSE for details]
|
10 |
+
# Based on fairseq code bases
|
11 |
+
# https://github.com/pytorch/fairseq
|
12 |
+
# --------------------------------------------------------
|
13 |
+
|
14 |
+
import math
|
15 |
+
import warnings
|
16 |
+
from typing import Dict, Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch import Tensor, nn
|
21 |
+
from torch.nn import Parameter
|
22 |
+
|
23 |
+
|
24 |
+
class TransposeLast(nn.Module):
|
25 |
+
def __init__(self, deconstruct_idx=None):
|
26 |
+
super().__init__()
|
27 |
+
self.deconstruct_idx = deconstruct_idx
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
if self.deconstruct_idx is not None:
|
31 |
+
x = x[self.deconstruct_idx]
|
32 |
+
return x.transpose(-2, -1)
|
33 |
+
|
34 |
+
|
35 |
+
class Fp32LayerNorm(nn.LayerNorm):
|
36 |
+
def __init__(self, *args, **kwargs):
|
37 |
+
super().__init__(*args, **kwargs)
|
38 |
+
|
39 |
+
def forward(self, input): # pylint: disable=W0622
|
40 |
+
output = F.layer_norm(
|
41 |
+
input.float(),
|
42 |
+
self.normalized_shape,
|
43 |
+
self.weight.float() if self.weight is not None else None,
|
44 |
+
self.bias.float() if self.bias is not None else None,
|
45 |
+
self.eps,
|
46 |
+
)
|
47 |
+
return output.type_as(input)
|
48 |
+
|
49 |
+
|
50 |
+
class Fp32GroupNorm(nn.GroupNorm):
|
51 |
+
def __init__(self, *args, **kwargs):
|
52 |
+
super().__init__(*args, **kwargs)
|
53 |
+
|
54 |
+
def forward(self, input): # pylint: disable=W0622
|
55 |
+
output = F.group_norm(
|
56 |
+
input.float(),
|
57 |
+
self.num_groups,
|
58 |
+
self.weight.float() if self.weight is not None else None,
|
59 |
+
self.bias.float() if self.bias is not None else None,
|
60 |
+
self.eps,
|
61 |
+
)
|
62 |
+
return output.type_as(input)
|
63 |
+
|
64 |
+
|
65 |
+
# pylint: disable=W0221,W0223
|
66 |
+
class GradMultiply(torch.autograd.Function):
|
67 |
+
@staticmethod
|
68 |
+
def forward(ctx, x, scale):
|
69 |
+
ctx.scale = scale
|
70 |
+
res = x.new(x)
|
71 |
+
return res
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def backward(ctx, grad):
|
75 |
+
return grad * ctx.scale, None
|
76 |
+
|
77 |
+
|
78 |
+
class SamePad(nn.Module):
|
79 |
+
def __init__(self, kernel_size, causal=False):
|
80 |
+
super().__init__()
|
81 |
+
if causal:
|
82 |
+
self.remove = kernel_size - 1
|
83 |
+
else:
|
84 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
if self.remove > 0:
|
88 |
+
x = x[:, :, : -self.remove]
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
class Swish(nn.Module):
|
93 |
+
"""Swish function"""
|
94 |
+
|
95 |
+
def __init__(self):
|
96 |
+
"""Construct an MultiHeadedAttention object."""
|
97 |
+
super().__init__()
|
98 |
+
self.act = torch.nn.Sigmoid()
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
return x * self.act(x)
|
102 |
+
|
103 |
+
|
104 |
+
class GLU_Linear(nn.Module):
|
105 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
self.glu_type = glu_type
|
109 |
+
self.output_dim = output_dim
|
110 |
+
|
111 |
+
if glu_type == "sigmoid":
|
112 |
+
self.glu_act = torch.nn.Sigmoid()
|
113 |
+
elif glu_type == "swish":
|
114 |
+
self.glu_act = Swish()
|
115 |
+
elif glu_type == "relu":
|
116 |
+
self.glu_act = torch.nn.ReLU()
|
117 |
+
elif glu_type == "gelu":
|
118 |
+
self.glu_act = torch.nn.GELU()
|
119 |
+
|
120 |
+
if bias_in_glu:
|
121 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
122 |
+
else:
|
123 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
127 |
+
x = self.linear(x)
|
128 |
+
|
129 |
+
if self.glu_type == "bilinear":
|
130 |
+
x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
|
131 |
+
else:
|
132 |
+
x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
|
133 |
+
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
# pylint: disable=W0212
|
138 |
+
def gelu_accurate(x):
|
139 |
+
if not hasattr(gelu_accurate, "_a"):
|
140 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
141 |
+
return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
142 |
+
|
143 |
+
|
144 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
145 |
+
return F.gelu(x.float()).type_as(x) # pylint: disable=E1102
|
146 |
+
|
147 |
+
|
148 |
+
def get_activation_fn(activation: str):
|
149 |
+
"""Returns the activation function corresponding to `activation`"""
|
150 |
+
|
151 |
+
activation_functions = {
|
152 |
+
"relu": F.relu,
|
153 |
+
"gelu": gelu,
|
154 |
+
"gelu_fast": gelu_accurate,
|
155 |
+
"gelu_accurate": gelu_accurate,
|
156 |
+
"tanh": torch.tanh,
|
157 |
+
"linear": lambda x: x,
|
158 |
+
"glu": lambda x: x,
|
159 |
+
}
|
160 |
+
|
161 |
+
if activation == "gelu_fast":
|
162 |
+
warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
|
163 |
+
|
164 |
+
if activation in activation_functions:
|
165 |
+
return activation_functions[activation]
|
166 |
+
|
167 |
+
raise RuntimeError(f"--activation-fn {activation} not supported")
|
168 |
+
|
169 |
+
|
170 |
+
def init_bert_params(module):
|
171 |
+
"""
|
172 |
+
Initialize the weights specific to the BERT Model.
|
173 |
+
This overrides the default initializations depending on the specified arguments.
|
174 |
+
1. If normal_init_linear_weights is set then weights of linear
|
175 |
+
layer will be initialized using the normal distribution and
|
176 |
+
bais will be set to the specified value.
|
177 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
178 |
+
layer will be initialized using the normal distribution.
|
179 |
+
3. If normal_init_proj_weights is set then weights of
|
180 |
+
in_project_weight for MultiHeadAttention initialized using
|
181 |
+
the normal distribution (to be validated).
|
182 |
+
"""
|
183 |
+
|
184 |
+
def normal_(data):
|
185 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
186 |
+
# so that the RNG is consistent with and without FSDP
|
187 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
188 |
+
|
189 |
+
if isinstance(module, nn.Linear):
|
190 |
+
normal_(module.weight.data)
|
191 |
+
if module.bias is not None:
|
192 |
+
module.bias.data.zero_()
|
193 |
+
if isinstance(module, nn.Embedding):
|
194 |
+
normal_(module.weight.data)
|
195 |
+
if module.padding_idx is not None:
|
196 |
+
module.weight.data[module.padding_idx].zero_()
|
197 |
+
if isinstance(module, MultiheadAttention):
|
198 |
+
normal_(module.q_proj.weight.data)
|
199 |
+
normal_(module.k_proj.weight.data)
|
200 |
+
normal_(module.v_proj.weight.data)
|
201 |
+
|
202 |
+
|
203 |
+
def quant_noise(module, p, block_size):
|
204 |
+
"""
|
205 |
+
Wraps modules and applies quantization noise to the weights for
|
206 |
+
subsequent quantization with Iterative Product Quantization as
|
207 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
208 |
+
|
209 |
+
Args:
|
210 |
+
- module: nn.Module
|
211 |
+
- p: amount of Quantization Noise
|
212 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
213 |
+
|
214 |
+
Remarks:
|
215 |
+
- Module weights must have the right sizes wrt the block size
|
216 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
217 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
218 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
219 |
+
- We implement the simplest form of noise here as stated in the paper
|
220 |
+
which consists in randomly dropping blocks
|
221 |
+
"""
|
222 |
+
|
223 |
+
# if no quantization noise, don't register hook
|
224 |
+
if p <= 0:
|
225 |
+
return module
|
226 |
+
|
227 |
+
# supported modules
|
228 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
229 |
+
|
230 |
+
# test whether module.weight has the right sizes wrt block_size
|
231 |
+
is_conv = module.weight.ndim == 4
|
232 |
+
|
233 |
+
# 2D matrix
|
234 |
+
if not is_conv:
|
235 |
+
assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
|
236 |
+
|
237 |
+
# 4D matrix
|
238 |
+
else:
|
239 |
+
# 1x1 convolutions
|
240 |
+
if module.kernel_size == (1, 1):
|
241 |
+
assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
|
242 |
+
# regular convolutions
|
243 |
+
else:
|
244 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
245 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
246 |
+
|
247 |
+
def _forward_pre_hook(mod, x): # pylint: disable=W0613
|
248 |
+
# no noise for evaluation
|
249 |
+
if mod.training:
|
250 |
+
if not is_conv:
|
251 |
+
# gather weight and sizes
|
252 |
+
weight = mod.weight
|
253 |
+
in_features = weight.size(1)
|
254 |
+
out_features = weight.size(0)
|
255 |
+
|
256 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
257 |
+
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
258 |
+
mask.bernoulli_(p)
|
259 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
260 |
+
|
261 |
+
else:
|
262 |
+
# gather weight and sizes
|
263 |
+
weight = mod.weight
|
264 |
+
in_channels = mod.in_channels
|
265 |
+
out_channels = mod.out_channels
|
266 |
+
|
267 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
268 |
+
if mod.kernel_size == (1, 1):
|
269 |
+
mask = torch.zeros(
|
270 |
+
int(in_channels // block_size * out_channels),
|
271 |
+
device=weight.device,
|
272 |
+
)
|
273 |
+
mask.bernoulli_(p)
|
274 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
275 |
+
else:
|
276 |
+
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
|
277 |
+
mask.bernoulli_(p)
|
278 |
+
mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
279 |
+
|
280 |
+
# scale weights and apply mask
|
281 |
+
mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
|
282 |
+
s = 1 / (1 - p)
|
283 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
284 |
+
|
285 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
286 |
+
return module
|
287 |
+
|
288 |
+
|
289 |
+
class MultiheadAttention(nn.Module):
|
290 |
+
"""Multi-headed attention.
|
291 |
+
|
292 |
+
See "Attention Is All You Need" for more details.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(
|
296 |
+
self,
|
297 |
+
embed_dim,
|
298 |
+
num_heads,
|
299 |
+
kdim=None,
|
300 |
+
vdim=None,
|
301 |
+
dropout=0.0,
|
302 |
+
bias=True,
|
303 |
+
add_bias_kv=False,
|
304 |
+
add_zero_attn=False,
|
305 |
+
self_attention=False,
|
306 |
+
encoder_decoder_attention=False,
|
307 |
+
q_noise=0.0,
|
308 |
+
qn_block_size=8,
|
309 |
+
has_relative_attention_bias=False,
|
310 |
+
num_buckets=32,
|
311 |
+
max_distance=128,
|
312 |
+
gru_rel_pos=False,
|
313 |
+
rescale_init=False,
|
314 |
+
):
|
315 |
+
super().__init__()
|
316 |
+
self.embed_dim = embed_dim
|
317 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
318 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
319 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
320 |
+
|
321 |
+
self.num_heads = num_heads
|
322 |
+
self.dropout_module = nn.Dropout(dropout)
|
323 |
+
|
324 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
325 |
+
self.num_buckets = num_buckets
|
326 |
+
self.max_distance = max_distance
|
327 |
+
if self.has_relative_attention_bias:
|
328 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
329 |
+
|
330 |
+
self.head_dim = embed_dim // num_heads
|
331 |
+
self.q_head_dim = self.head_dim
|
332 |
+
self.k_head_dim = self.head_dim
|
333 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
334 |
+
self.scaling = self.head_dim**-0.5
|
335 |
+
|
336 |
+
self.self_attention = self_attention
|
337 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
338 |
+
|
339 |
+
assert not self.self_attention or self.qkv_same_dim, "Self-attention requires query, key and " "value to be of the same size"
|
340 |
+
|
341 |
+
k_bias = True
|
342 |
+
if rescale_init:
|
343 |
+
k_bias = False
|
344 |
+
|
345 |
+
k_embed_dim = embed_dim
|
346 |
+
q_embed_dim = embed_dim
|
347 |
+
|
348 |
+
self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
|
349 |
+
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
350 |
+
self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
|
351 |
+
|
352 |
+
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
353 |
+
|
354 |
+
if add_bias_kv:
|
355 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
356 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
357 |
+
else:
|
358 |
+
self.bias_k = self.bias_v = None
|
359 |
+
|
360 |
+
self.add_zero_attn = add_zero_attn
|
361 |
+
|
362 |
+
self.gru_rel_pos = gru_rel_pos
|
363 |
+
if self.gru_rel_pos:
|
364 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
365 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
366 |
+
|
367 |
+
self.reset_parameters()
|
368 |
+
|
369 |
+
def reset_parameters(self):
|
370 |
+
if self.qkv_same_dim:
|
371 |
+
# Empirically observed the convergence to be much better with
|
372 |
+
# the scaled initialization
|
373 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
374 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
375 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
376 |
+
else:
|
377 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
378 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
379 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
380 |
+
|
381 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
382 |
+
if self.out_proj.bias is not None:
|
383 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
384 |
+
if self.bias_k is not None:
|
385 |
+
nn.init.xavier_normal_(self.bias_k)
|
386 |
+
if self.bias_v is not None:
|
387 |
+
nn.init.xavier_normal_(self.bias_v)
|
388 |
+
if self.has_relative_attention_bias:
|
389 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
390 |
+
|
391 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
392 |
+
num_buckets = self.num_buckets
|
393 |
+
max_distance = self.max_distance
|
394 |
+
relative_buckets = 0
|
395 |
+
|
396 |
+
if bidirectional:
|
397 |
+
num_buckets = num_buckets // 2
|
398 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
399 |
+
relative_positions = torch.abs(relative_positions)
|
400 |
+
else:
|
401 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
402 |
+
|
403 |
+
max_exact = num_buckets // 2
|
404 |
+
is_small = relative_positions < max_exact
|
405 |
+
|
406 |
+
relative_postion_if_large = max_exact + (torch.log(relative_positions.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).to(torch.long)
|
407 |
+
relative_postion_if_large = torch.min(
|
408 |
+
relative_postion_if_large,
|
409 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
410 |
+
)
|
411 |
+
|
412 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
413 |
+
return relative_buckets
|
414 |
+
|
415 |
+
def compute_bias(self, query_length, key_length):
|
416 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
417 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
418 |
+
relative_position = memory_position - context_position
|
419 |
+
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
|
420 |
+
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
421 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
422 |
+
values = values.permute([2, 0, 1])
|
423 |
+
return values
|
424 |
+
|
425 |
+
# pylint: disable=R0912, R0915
|
426 |
+
def forward(
|
427 |
+
self,
|
428 |
+
query,
|
429 |
+
key: Optional[Tensor],
|
430 |
+
value: Optional[Tensor],
|
431 |
+
key_padding_mask: Optional[Tensor] = None,
|
432 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
433 |
+
need_weights: bool = True,
|
434 |
+
static_kv: bool = False,
|
435 |
+
attn_mask: Optional[Tensor] = None,
|
436 |
+
before_softmax: bool = False,
|
437 |
+
need_head_weights: bool = False,
|
438 |
+
position_bias: Optional[Tensor] = None,
|
439 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
440 |
+
"""Input shape: Time x Batch x Channel
|
441 |
+
|
442 |
+
Args:
|
443 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
444 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
445 |
+
padding elements are indicated by 1s.
|
446 |
+
need_weights (bool, optional): return the attention weights,
|
447 |
+
averaged over heads (default: False).
|
448 |
+
attn_mask (ByteTensor, optional): typically used to
|
449 |
+
implement causal attention, where the mask prevents the
|
450 |
+
attention from looking forward in time (default: None).
|
451 |
+
before_softmax (bool, optional): return the raw attention
|
452 |
+
weights and values before the attention softmax.
|
453 |
+
need_head_weights (bool, optional): return the attention
|
454 |
+
weights for each head. Implies *need_weights*. Default:
|
455 |
+
return the average attention weights over all heads.
|
456 |
+
"""
|
457 |
+
if need_head_weights:
|
458 |
+
need_weights = True
|
459 |
+
|
460 |
+
is_tpu = query.device.type == "xla"
|
461 |
+
|
462 |
+
tgt_len, bsz, embed_dim = query.size()
|
463 |
+
src_len = tgt_len
|
464 |
+
assert embed_dim == self.embed_dim
|
465 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
466 |
+
if key is not None:
|
467 |
+
src_len, key_bsz, _ = key.size()
|
468 |
+
if not torch.jit.is_scripting():
|
469 |
+
assert key_bsz == bsz
|
470 |
+
assert value is not None
|
471 |
+
assert src_len, bsz == value.shape[:2]
|
472 |
+
|
473 |
+
if self.has_relative_attention_bias and position_bias is None:
|
474 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
475 |
+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
476 |
+
|
477 |
+
if (
|
478 |
+
not is_tpu # don't use PyTorch version on TPUs
|
479 |
+
and incremental_state is None
|
480 |
+
and not static_kv
|
481 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
482 |
+
# treats bias in linear module as method.
|
483 |
+
and not torch.jit.is_scripting()
|
484 |
+
and self.q_head_dim == self.head_dim
|
485 |
+
):
|
486 |
+
assert key is not None and value is not None
|
487 |
+
assert attn_mask is None
|
488 |
+
|
489 |
+
attn_mask_rel_pos = None
|
490 |
+
if position_bias is not None:
|
491 |
+
attn_mask_rel_pos = position_bias
|
492 |
+
if self.gru_rel_pos:
|
493 |
+
query_layer = query.transpose(0, 1)
|
494 |
+
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
|
495 |
+
query_layer = query_layer.view(*new_x_shape)
|
496 |
+
query_layer = query_layer.permute(0, 2, 1, 3)
|
497 |
+
_B, _H, _L, __ = query_layer.size()
|
498 |
+
|
499 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
500 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
501 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
502 |
+
|
503 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
504 |
+
k_proj_bias = self.k_proj.bias
|
505 |
+
if k_proj_bias is None:
|
506 |
+
k_proj_bias = torch.zeros_like(self.q_proj.bias)
|
507 |
+
|
508 |
+
x, attn = F.multi_head_attention_forward(
|
509 |
+
query,
|
510 |
+
key,
|
511 |
+
value,
|
512 |
+
self.embed_dim,
|
513 |
+
self.num_heads,
|
514 |
+
torch.empty([0]),
|
515 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
516 |
+
self.bias_k,
|
517 |
+
self.bias_v,
|
518 |
+
self.add_zero_attn,
|
519 |
+
self.dropout_module.p,
|
520 |
+
self.out_proj.weight,
|
521 |
+
self.out_proj.bias,
|
522 |
+
self.training,
|
523 |
+
# self.training or self.dropout_module.apply_during_inference,
|
524 |
+
key_padding_mask,
|
525 |
+
need_weights,
|
526 |
+
attn_mask_rel_pos,
|
527 |
+
use_separate_proj_weight=True,
|
528 |
+
q_proj_weight=self.q_proj.weight,
|
529 |
+
k_proj_weight=self.k_proj.weight,
|
530 |
+
v_proj_weight=self.v_proj.weight,
|
531 |
+
)
|
532 |
+
return x, attn, position_bias
|
533 |
+
|
534 |
+
if incremental_state is not None:
|
535 |
+
saved_state = self._get_input_buffer(incremental_state)
|
536 |
+
if saved_state is not None and "prev_key" in saved_state:
|
537 |
+
# previous time steps are cached - no need to recompute
|
538 |
+
# key and value if they are static
|
539 |
+
if static_kv:
|
540 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
541 |
+
key = value = None
|
542 |
+
else:
|
543 |
+
saved_state = None
|
544 |
+
|
545 |
+
if self.self_attention:
|
546 |
+
q = self.q_proj(query)
|
547 |
+
k = self.k_proj(query)
|
548 |
+
v = self.v_proj(query)
|
549 |
+
elif self.encoder_decoder_attention:
|
550 |
+
# encoder-decoder attention
|
551 |
+
q = self.q_proj(query)
|
552 |
+
if key is None:
|
553 |
+
assert value is None
|
554 |
+
k = v = None
|
555 |
+
else:
|
556 |
+
k = self.k_proj(key)
|
557 |
+
v = self.v_proj(key)
|
558 |
+
|
559 |
+
else:
|
560 |
+
assert key is not None and value is not None
|
561 |
+
q = self.q_proj(query)
|
562 |
+
k = self.k_proj(key)
|
563 |
+
v = self.v_proj(value)
|
564 |
+
q *= self.scaling
|
565 |
+
|
566 |
+
if self.bias_k is not None:
|
567 |
+
assert self.bias_v is not None
|
568 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
569 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
570 |
+
if attn_mask is not None:
|
571 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
572 |
+
if key_padding_mask is not None:
|
573 |
+
key_padding_mask = torch.cat(
|
574 |
+
[
|
575 |
+
key_padding_mask,
|
576 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
577 |
+
],
|
578 |
+
dim=1,
|
579 |
+
)
|
580 |
+
|
581 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
|
582 |
+
if k is not None:
|
583 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
|
584 |
+
if v is not None:
|
585 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
586 |
+
|
587 |
+
if saved_state is not None:
|
588 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
589 |
+
if "prev_key" in saved_state:
|
590 |
+
_prev_key = saved_state["prev_key"]
|
591 |
+
assert _prev_key is not None
|
592 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
593 |
+
if static_kv:
|
594 |
+
k = prev_key
|
595 |
+
else:
|
596 |
+
assert k is not None
|
597 |
+
k = torch.cat([prev_key, k], dim=1)
|
598 |
+
src_len = k.size(1)
|
599 |
+
if "prev_value" in saved_state:
|
600 |
+
_prev_value = saved_state["prev_value"]
|
601 |
+
assert _prev_value is not None
|
602 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
603 |
+
if static_kv:
|
604 |
+
v = prev_value
|
605 |
+
else:
|
606 |
+
assert v is not None
|
607 |
+
v = torch.cat([prev_value, v], dim=1)
|
608 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
609 |
+
if "prev_key_padding_mask" in saved_state:
|
610 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
611 |
+
assert k is not None and v is not None
|
612 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
613 |
+
key_padding_mask=key_padding_mask,
|
614 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
615 |
+
batch_size=bsz,
|
616 |
+
src_len=k.size(1),
|
617 |
+
static_kv=static_kv,
|
618 |
+
)
|
619 |
+
|
620 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
621 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
622 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
623 |
+
# In this branch incremental_state is never None
|
624 |
+
assert incremental_state is not None
|
625 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
626 |
+
assert k is not None
|
627 |
+
assert k.size(1) == src_len
|
628 |
+
|
629 |
+
# This is part of a workaround to get around fork/join parallelism
|
630 |
+
# not supporting Optional types.
|
631 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
632 |
+
key_padding_mask = None
|
633 |
+
|
634 |
+
if key_padding_mask is not None:
|
635 |
+
assert key_padding_mask.size(0) == bsz
|
636 |
+
assert key_padding_mask.size(1) == src_len
|
637 |
+
|
638 |
+
if self.add_zero_attn:
|
639 |
+
assert v is not None
|
640 |
+
src_len += 1
|
641 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
642 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
643 |
+
if attn_mask is not None:
|
644 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
645 |
+
if key_padding_mask is not None:
|
646 |
+
key_padding_mask = torch.cat(
|
647 |
+
[
|
648 |
+
key_padding_mask,
|
649 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
|
650 |
+
],
|
651 |
+
dim=1,
|
652 |
+
)
|
653 |
+
|
654 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
655 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
656 |
+
|
657 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
658 |
+
|
659 |
+
if attn_mask is not None:
|
660 |
+
attn_mask = attn_mask.unsqueeze(0)
|
661 |
+
attn_weights += attn_mask
|
662 |
+
|
663 |
+
if key_padding_mask is not None:
|
664 |
+
# don't attend to padding symbols
|
665 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
666 |
+
if not is_tpu:
|
667 |
+
attn_weights = attn_weights.masked_fill(
|
668 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
669 |
+
float("-inf"),
|
670 |
+
)
|
671 |
+
else:
|
672 |
+
attn_weights = attn_weights.transpose(0, 2)
|
673 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
674 |
+
attn_weights = attn_weights.transpose(0, 2)
|
675 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
676 |
+
|
677 |
+
if before_softmax:
|
678 |
+
return attn_weights, v, position_bias
|
679 |
+
|
680 |
+
if position_bias is not None:
|
681 |
+
if self.gru_rel_pos == 1:
|
682 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
683 |
+
_B, _H, _L, __ = query_layer.size()
|
684 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
685 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
686 |
+
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
687 |
+
|
688 |
+
position_bias = position_bias.view(attn_weights.size())
|
689 |
+
|
690 |
+
attn_weights = attn_weights + position_bias
|
691 |
+
|
692 |
+
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
693 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
694 |
+
attn_probs = self.dropout_module(attn_weights)
|
695 |
+
|
696 |
+
assert v is not None
|
697 |
+
attn = torch.bmm(attn_probs, v)
|
698 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
699 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
700 |
+
attn = self.out_proj(attn)
|
701 |
+
attn_weights: Optional[Tensor] = None
|
702 |
+
if need_weights:
|
703 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
704 |
+
if not need_head_weights:
|
705 |
+
# average attention weights over heads
|
706 |
+
attn_weights = attn_weights.mean(dim=0)
|
707 |
+
|
708 |
+
return attn, attn_weights, position_bias
|
709 |
+
|
710 |
+
@staticmethod
|
711 |
+
def _append_prev_key_padding_mask(
|
712 |
+
key_padding_mask: Optional[Tensor],
|
713 |
+
prev_key_padding_mask: Optional[Tensor],
|
714 |
+
batch_size: int,
|
715 |
+
src_len: int,
|
716 |
+
static_kv: bool,
|
717 |
+
) -> Optional[Tensor]:
|
718 |
+
# saved key padding masks have shape (bsz, seq_len)
|
719 |
+
if prev_key_padding_mask is not None and static_kv:
|
720 |
+
new_key_padding_mask = prev_key_padding_mask
|
721 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
722 |
+
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
723 |
+
# During incremental decoding, as the padding token enters and
|
724 |
+
# leaves the frame, there will be a time when prev or current
|
725 |
+
# is None
|
726 |
+
elif prev_key_padding_mask is not None:
|
727 |
+
if src_len > prev_key_padding_mask.size(1):
|
728 |
+
filler = torch.zeros(
|
729 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
730 |
+
device=prev_key_padding_mask.device,
|
731 |
+
)
|
732 |
+
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
733 |
+
else:
|
734 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
735 |
+
elif key_padding_mask is not None:
|
736 |
+
if src_len > key_padding_mask.size(1):
|
737 |
+
filler = torch.zeros(
|
738 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
739 |
+
device=key_padding_mask.device,
|
740 |
+
)
|
741 |
+
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
742 |
+
else:
|
743 |
+
new_key_padding_mask = key_padding_mask.float()
|
744 |
+
else:
|
745 |
+
new_key_padding_mask = prev_key_padding_mask
|
746 |
+
return new_key_padding_mask
|
747 |
+
|
748 |
+
def _get_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]) -> Dict[str, Optional[Tensor]]:
|
749 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
750 |
+
if result is not None:
|
751 |
+
return result
|
752 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
753 |
+
return empty_result
|
754 |
+
|
755 |
+
def _set_input_buffer(
|
756 |
+
self,
|
757 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
758 |
+
buffer: Dict[str, Optional[Tensor]],
|
759 |
+
):
|
760 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
761 |
+
|
762 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): # pylint: disable=W0613
|
763 |
+
return attn_weights
|
knn_tts/ssl/__init__.py
ADDED
File without changes
|
knn_tts/ssl/models.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
|
2 |
+
# SPDX-FileContributor: Karl El Hajal
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: MIT
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
|
9 |
+
from knn_tts.ssl.WavLM.WavLM import WavLMConfig, WavLMModel
|
10 |
+
|
11 |
+
|
12 |
+
class WavLM:
|
13 |
+
def __init__(self, device=None):
|
14 |
+
"""
|
15 |
+
Load the WavLM-Large model (see: https://github.com/microsoft/unilm/tree/master/wavlm).
|
16 |
+
Torch hub checkpoint from https://github.com/bshall/knn-vc
|
17 |
+
"""
|
18 |
+
if device is None:
|
19 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
else:
|
21 |
+
self.device = torch.device(device)
|
22 |
+
|
23 |
+
checkpoint_url = "https://github.com/bshall/knn-vc/releases/download/v0.1/WavLM-Large.pt"
|
24 |
+
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=self.device, progress=True)
|
25 |
+
|
26 |
+
cfg = WavLMConfig(checkpoint["cfg"])
|
27 |
+
self.model = WavLMModel(cfg)
|
28 |
+
self.model.load_state_dict(checkpoint["model"])
|
29 |
+
self.model = self.model.to(self.device)
|
30 |
+
self.model.eval()
|
31 |
+
|
32 |
+
self.model.sample_rate = 16000
|
33 |
+
self.SPEAKER_INFORMATION_LAYER = 6
|
34 |
+
num_parameters = sum(p.numel() for p in self.model.parameters())
|
35 |
+
print(f"WavLM-Large loaded with {num_parameters:,d} parameters.")
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def extract_framewise_features(self, wav_path, output_layer=None):
|
39 |
+
audio, orig_sr = torchaudio.load(wav_path)
|
40 |
+
audio = audio.to(self.device)
|
41 |
+
if orig_sr != self.model.sample_rate:
|
42 |
+
audio = torchaudio.functional.resample(audio, orig_freq=orig_sr, new_freq=self.model.sample_rate)
|
43 |
+
|
44 |
+
if output_layer is None:
|
45 |
+
output_layer = self.SPEAKER_INFORMATION_LAYER
|
46 |
+
|
47 |
+
embeddings = self.model.extract_features(audio, output_layer=output_layer, ret_layer_results=False)[0]
|
48 |
+
embeddings = embeddings.squeeze(0)
|
49 |
+
return embeddings # [frame_len, 1024]
|
knn_tts/synthesizer.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
|
2 |
+
# SPDX-FileContributor: Karl El Hajal
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: MIT
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
|
9 |
+
from knn_tts.text_cleaners import clean_input_text
|
10 |
+
from knn_tts.tts.models import GlowTTS
|
11 |
+
from knn_tts.vc.knn import knn_vc, load_target_style_feats
|
12 |
+
from knn_tts.vocoder.models import HiFiGANWavLM
|
13 |
+
|
14 |
+
|
15 |
+
class Synthesizer:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
tts_model_base_path,
|
19 |
+
tts_model_checkpoint,
|
20 |
+
vocoder_checkpoint_path,
|
21 |
+
model_name="glowtts",
|
22 |
+
):
|
23 |
+
self.model_name = model_name
|
24 |
+
self.model = GlowTTS(tts_model_base_path, tts_model_checkpoint)
|
25 |
+
self.vocoder = HiFiGANWavLM(checkpoint_path=vocoder_checkpoint_path, device=self.model.device)
|
26 |
+
self.target_style_feats_path = None
|
27 |
+
self.target_style_feats = None
|
28 |
+
|
29 |
+
def __call__(
|
30 |
+
self,
|
31 |
+
text_input,
|
32 |
+
target_style_feats_path,
|
33 |
+
knnvc_topk=4,
|
34 |
+
weighted_average=False,
|
35 |
+
interpolation_rate=1.0,
|
36 |
+
save_path=None,
|
37 |
+
timesteps=10,
|
38 |
+
max_target_num_files=1000,
|
39 |
+
):
|
40 |
+
with torch.no_grad():
|
41 |
+
# Text-to-SSL
|
42 |
+
text_input = clean_input_text(text_input)
|
43 |
+
tts_feats = self.model(text_input, timesteps) # timesteps are used for GradTTS only
|
44 |
+
|
45 |
+
# kNN-VC
|
46 |
+
if interpolation_rate != 0.0:
|
47 |
+
|
48 |
+
if target_style_feats_path != self.target_style_feats_path:
|
49 |
+
self.target_style_feats_path = target_style_feats_path
|
50 |
+
self.target_style_feats = load_target_style_feats(target_style_feats_path, max_target_num_files)
|
51 |
+
|
52 |
+
selected_feats = knn_vc(
|
53 |
+
tts_feats,
|
54 |
+
self.target_style_feats,
|
55 |
+
topk=knnvc_topk,
|
56 |
+
weighted_average=weighted_average,
|
57 |
+
device=self.model.device,
|
58 |
+
)
|
59 |
+
|
60 |
+
converted_feats = interpolation_rate * selected_feats + (1.0 - interpolation_rate) * tts_feats
|
61 |
+
else:
|
62 |
+
converted_feats = tts_feats
|
63 |
+
|
64 |
+
# Vocoder
|
65 |
+
wav = self.vocoder(converted_feats.unsqueeze(0)).unsqueeze(0)
|
66 |
+
|
67 |
+
if save_path is not None:
|
68 |
+
torchaudio.save(save_path, wav, 16000)
|
69 |
+
return wav
|
knn_tts/text_cleaners.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
|
2 |
+
# SPDX-FileContributor: Karl El Hajal
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: MIT
|
5 |
+
|
6 |
+
import re
|
7 |
+
|
8 |
+
from num2words import num2words
|
9 |
+
|
10 |
+
PUNCTUATION_TO_REPLACE_WITH_COMMA = ["(", ")", ":"]
|
11 |
+
PUNCTUATION = ".!?,"
|
12 |
+
SPACE = " "
|
13 |
+
|
14 |
+
|
15 |
+
def replace_some_punctuation_with_comma(text):
|
16 |
+
for punct in PUNCTUATION_TO_REPLACE_WITH_COMMA:
|
17 |
+
text = text.replace(punct, " , ")
|
18 |
+
return text
|
19 |
+
|
20 |
+
|
21 |
+
def split_string_on_punctuation(s):
|
22 |
+
substrings = []
|
23 |
+
current_substring = ""
|
24 |
+
|
25 |
+
for char in s:
|
26 |
+
current_substring += char
|
27 |
+
if char in PUNCTUATION:
|
28 |
+
substrings.append(current_substring.strip())
|
29 |
+
current_substring = ""
|
30 |
+
|
31 |
+
if current_substring:
|
32 |
+
substrings.append(current_substring)
|
33 |
+
|
34 |
+
return substrings
|
35 |
+
|
36 |
+
|
37 |
+
def remove_punctuation_only_substrings(substrings):
|
38 |
+
new_substrings = []
|
39 |
+
for substring in substrings:
|
40 |
+
if not all(c in PUNCTUATION + SPACE for c in substring):
|
41 |
+
new_substrings.append(substring)
|
42 |
+
return new_substrings
|
43 |
+
|
44 |
+
|
45 |
+
def clean_punctuation(text):
|
46 |
+
text = replace_some_punctuation_with_comma(text)
|
47 |
+
substrings = split_string_on_punctuation(text)
|
48 |
+
substrings = remove_punctuation_only_substrings(substrings)
|
49 |
+
text = SPACE.join(substrings)
|
50 |
+
return text
|
51 |
+
|
52 |
+
|
53 |
+
def clean_input_text(text):
|
54 |
+
text = text.lower().strip()
|
55 |
+
text = clean_punctuation(text)
|
56 |
+
|
57 |
+
def replace_numbers(match):
|
58 |
+
return num2words(int(match.group()))
|
59 |
+
|
60 |
+
text = re.sub(r"\d+", replace_numbers, text)
|
61 |
+
text = " ".join(text.split()) # Remove extra spaces
|
62 |
+
|
63 |
+
if text and text[-1] not in PUNCTUATION:
|
64 |
+
text += "."
|
65 |
+
|
66 |
+
return text
|
knn_tts/tts/GlowTTS/__init__.py
ADDED
File without changes
|
knn_tts/tts/GlowTTS/glow_tts_ssl.py
ADDED
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Coqui AI
|
2 |
+
# SPDX-FileContributor: Karl El Hajal
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: MPL-2.0
|
5 |
+
|
6 |
+
import collections
|
7 |
+
import os
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
from torch import nn
|
13 |
+
from torch.cuda.amp.autocast_mode import autocast
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
16 |
+
from TTS.tts.datasets.dataset import TTSDataset, noise_augment_audio
|
17 |
+
from TTS.tts.models.glow_tts import GlowTTS
|
18 |
+
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
19 |
+
from TTS.tts.utils.synthesis import (
|
20 |
+
embedding_to_torch,
|
21 |
+
id_to_torch,
|
22 |
+
numpy_to_torch,
|
23 |
+
run_model_torch,
|
24 |
+
trim_silence,
|
25 |
+
)
|
26 |
+
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
27 |
+
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
28 |
+
from TTS.utils.audio import AudioProcessor
|
29 |
+
|
30 |
+
from knn_tts.vocoder.models import HiFiGANWavLM
|
31 |
+
|
32 |
+
|
33 |
+
def load_glow_tts_ssl(model_base_path, checkpoint, hifigan_checkpoint_path, device):
|
34 |
+
checkpoint_path = os.path.join(model_base_path, checkpoint)
|
35 |
+
config_path = os.path.join(model_base_path, "config.json")
|
36 |
+
|
37 |
+
config = GlowTTSConfig()
|
38 |
+
config.load_json(config_path)
|
39 |
+
ap = AudioProcessor.init_from_config(config)
|
40 |
+
tokenizer, _ = TTSTokenizer.init_from_config(config)
|
41 |
+
|
42 |
+
model = GlowTTSSSL(
|
43 |
+
config=config,
|
44 |
+
ap=ap,
|
45 |
+
tokenizer=tokenizer,
|
46 |
+
speaker_manager=None,
|
47 |
+
hifigan_checkpoint_path=hifigan_checkpoint_path,
|
48 |
+
)
|
49 |
+
model.load_checkpoint(config=config, checkpoint_path=checkpoint_path)
|
50 |
+
model.to(device)
|
51 |
+
model.eval()
|
52 |
+
return model
|
53 |
+
|
54 |
+
|
55 |
+
def ssl_synthesis(
|
56 |
+
model,
|
57 |
+
text,
|
58 |
+
vocoder,
|
59 |
+
return_ssl_feats=False,
|
60 |
+
speaker_id=None,
|
61 |
+
do_trim_silence=False,
|
62 |
+
d_vector=None,
|
63 |
+
language_id=None,
|
64 |
+
):
|
65 |
+
# device
|
66 |
+
device = next(model.parameters()).device
|
67 |
+
|
68 |
+
language_name = None
|
69 |
+
if language_id is not None:
|
70 |
+
language = [k for k, v in model.language_manager.name_to_id.items() if v == language_id]
|
71 |
+
assert len(language) == 1, "language_id must be a valid language"
|
72 |
+
language_name = language[0]
|
73 |
+
|
74 |
+
# convert text to sequence of token IDs
|
75 |
+
text_inputs = np.asarray(
|
76 |
+
model.tokenizer.text_to_ids(text, language=language_name),
|
77 |
+
dtype=np.int32,
|
78 |
+
)
|
79 |
+
|
80 |
+
# pass tensors to backend
|
81 |
+
if speaker_id is not None:
|
82 |
+
speaker_id = id_to_torch(speaker_id, device=device)
|
83 |
+
|
84 |
+
if d_vector is not None:
|
85 |
+
d_vector = embedding_to_torch(d_vector, device=device)
|
86 |
+
|
87 |
+
if language_id is not None:
|
88 |
+
language_id = id_to_torch(language_id, device=device)
|
89 |
+
|
90 |
+
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
|
91 |
+
text_inputs = text_inputs.unsqueeze(0)
|
92 |
+
# synthesize voice
|
93 |
+
outputs = run_model_torch(
|
94 |
+
model,
|
95 |
+
text_inputs,
|
96 |
+
speaker_id,
|
97 |
+
style_mel=None,
|
98 |
+
style_text=None,
|
99 |
+
d_vector=d_vector,
|
100 |
+
language_id=language_id,
|
101 |
+
)
|
102 |
+
model_outputs = outputs["model_outputs"]
|
103 |
+
|
104 |
+
if return_ssl_feats:
|
105 |
+
ssl_feats = model_outputs.squeeze()
|
106 |
+
return ssl_feats
|
107 |
+
|
108 |
+
model_outputs = model_outputs[0].data.cpu().numpy()
|
109 |
+
alignments = outputs["alignments"]
|
110 |
+
|
111 |
+
ssl_feats = model_outputs.squeeze()
|
112 |
+
wav = vocoder(torch.from_numpy(ssl_feats).unsqueeze(0)).unsqueeze(0).cpu().numpy()
|
113 |
+
|
114 |
+
if do_trim_silence:
|
115 |
+
wav = trim_silence(wav, model.ap)
|
116 |
+
|
117 |
+
return_dict = {
|
118 |
+
"wav": wav,
|
119 |
+
"alignments": alignments,
|
120 |
+
"text_inputs": text_inputs,
|
121 |
+
"outputs": outputs,
|
122 |
+
}
|
123 |
+
return return_dict
|
124 |
+
|
125 |
+
|
126 |
+
class GlowTTSSSLDataset(TTSDataset):
|
127 |
+
def __init__(self, *args, **kwargs):
|
128 |
+
super().__init__(*args, **kwargs)
|
129 |
+
self.pad_id = self.tokenizer.characters.pad_id
|
130 |
+
|
131 |
+
def load_data(self, idx):
|
132 |
+
item = self.samples[idx]
|
133 |
+
|
134 |
+
raw_text = item["text"]
|
135 |
+
|
136 |
+
wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
|
137 |
+
|
138 |
+
# apply noise for augmentation
|
139 |
+
if self.use_noise_augment:
|
140 |
+
wav = noise_augment_audio(wav)
|
141 |
+
|
142 |
+
# get token ids
|
143 |
+
token_ids = self.get_token_ids(idx, item["text"])
|
144 |
+
|
145 |
+
# get pre-computed attention maps
|
146 |
+
attn = None
|
147 |
+
if "alignment_file" in item:
|
148 |
+
attn = self.get_attn_mask(item["alignment_file"])
|
149 |
+
|
150 |
+
# after phonemization the text length may change
|
151 |
+
# this is a shareful 🤭 hack to prevent longer phonemes
|
152 |
+
# TODO: find a better fix # pylint: disable=fixme
|
153 |
+
if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len:
|
154 |
+
self.rescue_item_idx += 1
|
155 |
+
return self.load_data(self.rescue_item_idx)
|
156 |
+
|
157 |
+
# get f0 values
|
158 |
+
f0 = None
|
159 |
+
if self.compute_f0:
|
160 |
+
f0 = self.get_f0(idx)["f0"]
|
161 |
+
energy = None
|
162 |
+
if self.compute_energy:
|
163 |
+
energy = self.get_energy(idx)["energy"]
|
164 |
+
|
165 |
+
ssl_feats = torch.load(item["ssl_feats_file"])
|
166 |
+
|
167 |
+
if len(ssl_feats.size()) == 2:
|
168 |
+
ssl_feats = ssl_feats.unsqueeze(0)
|
169 |
+
|
170 |
+
sample = {
|
171 |
+
"raw_text": raw_text,
|
172 |
+
"token_ids": token_ids,
|
173 |
+
"wav": wav,
|
174 |
+
"pitch": f0,
|
175 |
+
"energy": energy,
|
176 |
+
"attn": attn,
|
177 |
+
"item_idx": item["audio_file"],
|
178 |
+
"speaker_name": item["speaker_name"],
|
179 |
+
"language_name": item["language"],
|
180 |
+
"wav_file_name": os.path.basename(item["audio_file"]),
|
181 |
+
"audio_unique_name": item["audio_unique_name"],
|
182 |
+
"ssl_feats": ssl_feats,
|
183 |
+
}
|
184 |
+
return sample
|
185 |
+
|
186 |
+
# pylint: disable=too-many-branches, too-many-statements
|
187 |
+
def collate_fn(self, batch):
|
188 |
+
# Puts each data field into a tensor with outer dimension batch size
|
189 |
+
if isinstance(batch[0], collections.abc.Mapping):
|
190 |
+
B = len(batch)
|
191 |
+
|
192 |
+
token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
|
193 |
+
|
194 |
+
# sort items with text input length for RNN efficiency
|
195 |
+
batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths)
|
196 |
+
|
197 |
+
# convert list of dicts to dict of lists
|
198 |
+
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
199 |
+
|
200 |
+
# get language ids from language names
|
201 |
+
if self.language_id_mapping is not None:
|
202 |
+
language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
|
203 |
+
else:
|
204 |
+
language_ids = None
|
205 |
+
# get pre-computed d-vectors
|
206 |
+
if self.d_vector_mapping is not None:
|
207 |
+
embedding_keys = list(batch["audio_unique_name"])
|
208 |
+
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys]
|
209 |
+
else:
|
210 |
+
d_vectors = None
|
211 |
+
|
212 |
+
# get numerical speaker ids from speaker names
|
213 |
+
if self.speaker_id_mapping:
|
214 |
+
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
|
215 |
+
else:
|
216 |
+
speaker_ids = None
|
217 |
+
# compute features
|
218 |
+
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
|
219 |
+
|
220 |
+
mel_lengths = [m.shape[1] for m in mel]
|
221 |
+
|
222 |
+
# lengths adjusted by the reduction factor
|
223 |
+
mel_lengths_adjusted = [m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) if m.shape[1] % self.outputs_per_step else m.shape[1] for m in mel]
|
224 |
+
|
225 |
+
# PAD sequences with longest instance in the batch
|
226 |
+
token_ids = prepare_data(batch["token_ids"]).astype(np.int32)
|
227 |
+
|
228 |
+
# PAD features with longest instance
|
229 |
+
mel = prepare_tensor(mel, self.outputs_per_step)
|
230 |
+
|
231 |
+
# B x D x T --> B x T x D
|
232 |
+
mel = mel.transpose(0, 2, 1)
|
233 |
+
|
234 |
+
# convert things to pytorch
|
235 |
+
token_ids_lengths = torch.LongTensor(token_ids_lengths)
|
236 |
+
token_ids = torch.LongTensor(token_ids)
|
237 |
+
mel = torch.FloatTensor(mel).contiguous()
|
238 |
+
mel_lengths = torch.LongTensor(mel_lengths)
|
239 |
+
|
240 |
+
# format ssl_feats
|
241 |
+
ssl_feats_dim = batch["ssl_feats"][0].shape[2]
|
242 |
+
ssl_feats_lens = [emb.shape[1] for emb in batch["ssl_feats"]]
|
243 |
+
ssl_feats_lens = torch.LongTensor(ssl_feats_lens)
|
244 |
+
ssl_feats_lens_max = torch.max(ssl_feats_lens)
|
245 |
+
ssl_feats_rel_lens = ssl_feats_lens / ssl_feats_lens_max
|
246 |
+
ssl_feats_padded = torch.FloatTensor(B, ssl_feats_lens_max, ssl_feats_dim)
|
247 |
+
ssl_feats_padded = ssl_feats_padded.zero_() + self.pad_id
|
248 |
+
|
249 |
+
for i, emb in enumerate(batch["ssl_feats"]):
|
250 |
+
ssl_feats_padded[i, : emb.size(1), :] = torch.FloatTensor(emb)
|
251 |
+
|
252 |
+
stop_targets = [np.array([0.0] * (ssl_len - 1) + [1.0]) for ssl_len in ssl_feats_lens]
|
253 |
+
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
254 |
+
stop_targets = torch.FloatTensor(stop_targets)
|
255 |
+
|
256 |
+
# speaker vectors
|
257 |
+
if d_vectors is not None:
|
258 |
+
d_vectors = torch.FloatTensor(d_vectors)
|
259 |
+
|
260 |
+
if speaker_ids is not None:
|
261 |
+
speaker_ids = torch.LongTensor(speaker_ids)
|
262 |
+
|
263 |
+
if language_ids is not None:
|
264 |
+
language_ids = torch.LongTensor(language_ids)
|
265 |
+
|
266 |
+
# compute linear spectrogram
|
267 |
+
linear = None
|
268 |
+
if self.compute_linear_spec:
|
269 |
+
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
270 |
+
linear = prepare_tensor(linear, self.outputs_per_step)
|
271 |
+
linear = linear.transpose(0, 2, 1)
|
272 |
+
assert mel.shape[1] == linear.shape[1]
|
273 |
+
linear = torch.FloatTensor(linear).contiguous()
|
274 |
+
|
275 |
+
# format waveforms
|
276 |
+
wav_padded = None
|
277 |
+
if self.return_wav:
|
278 |
+
wav_lengths = [w.shape[0] for w in batch["wav"]]
|
279 |
+
# max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
280 |
+
ssl_feats_hop_length = self.ap.hop_length
|
281 |
+
max_wav_len = max(ssl_feats_lens) * ssl_feats_hop_length
|
282 |
+
wav_lengths = torch.LongTensor(wav_lengths)
|
283 |
+
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
|
284 |
+
for i, w in enumerate(batch["wav"]):
|
285 |
+
mel_length = mel_lengths_adjusted[i]
|
286 |
+
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
287 |
+
w = w[: mel_length * self.ap.hop_length]
|
288 |
+
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
289 |
+
wav_padded.transpose_(1, 2)
|
290 |
+
|
291 |
+
# format F0
|
292 |
+
if self.compute_f0:
|
293 |
+
pitch = prepare_data(batch["pitch"])
|
294 |
+
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
|
295 |
+
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
296 |
+
else:
|
297 |
+
pitch = None
|
298 |
+
# format energy
|
299 |
+
if self.compute_energy:
|
300 |
+
energy = prepare_data(batch["energy"])
|
301 |
+
assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}"
|
302 |
+
energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT
|
303 |
+
else:
|
304 |
+
energy = None
|
305 |
+
# format attention masks
|
306 |
+
attns = None
|
307 |
+
if batch["attn"][0] is not None:
|
308 |
+
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
|
309 |
+
for idx, attn in enumerate(attns):
|
310 |
+
pad2 = mel.shape[1] - attn.shape[1]
|
311 |
+
pad1 = token_ids.shape[1] - attn.shape[0]
|
312 |
+
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
|
313 |
+
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
314 |
+
attns[idx] = attn
|
315 |
+
attns = prepare_tensor(attns, self.outputs_per_step)
|
316 |
+
attns = torch.FloatTensor(attns).unsqueeze(1)
|
317 |
+
|
318 |
+
return {
|
319 |
+
"token_id": token_ids,
|
320 |
+
"token_id_lengths": token_ids_lengths,
|
321 |
+
"speaker_names": batch["speaker_name"],
|
322 |
+
"linear": linear,
|
323 |
+
"mel": mel,
|
324 |
+
"mel_lengths": mel_lengths,
|
325 |
+
"stop_targets": stop_targets,
|
326 |
+
"item_idxs": batch["item_idx"],
|
327 |
+
"d_vectors": d_vectors,
|
328 |
+
"speaker_ids": speaker_ids,
|
329 |
+
"attns": attns,
|
330 |
+
"waveform": wav_padded,
|
331 |
+
"raw_text": batch["raw_text"],
|
332 |
+
"pitch": pitch,
|
333 |
+
"energy": energy,
|
334 |
+
"language_ids": language_ids,
|
335 |
+
"audio_unique_names": batch["audio_unique_name"],
|
336 |
+
"ssl_feats": ssl_feats_padded,
|
337 |
+
"ssl_feats_lens": ssl_feats_lens,
|
338 |
+
"ssl_feats_rel_lens": ssl_feats_rel_lens,
|
339 |
+
}
|
340 |
+
return None
|
341 |
+
|
342 |
+
|
343 |
+
class GlowTTSSSL(GlowTTS):
|
344 |
+
def __init__(
|
345 |
+
self,
|
346 |
+
config: GlowTTSConfig,
|
347 |
+
ap=None,
|
348 |
+
tokenizer=None,
|
349 |
+
speaker_manager=None,
|
350 |
+
hifigan_checkpoint_path: str = "checkpoints/hifigan-wavlm/prematch_g_02500000.pt",
|
351 |
+
):
|
352 |
+
super().__init__(config, ap, tokenizer, speaker_manager)
|
353 |
+
if hifigan_checkpoint_path is None:
|
354 |
+
self.vocoder = None
|
355 |
+
else:
|
356 |
+
self.vocoder = HiFiGANWavLM(
|
357 |
+
checkpoint_path=hifigan_checkpoint_path,
|
358 |
+
device=str(next(self.parameters()).device),
|
359 |
+
)
|
360 |
+
|
361 |
+
def train_step(self, batch: dict, criterion: nn.Module):
|
362 |
+
text_input = batch["text_input"]
|
363 |
+
text_lengths = batch["text_lengths"]
|
364 |
+
ssl_feats_input = batch["ssl_feats"]
|
365 |
+
ssl_feats_lengths = batch["ssl_feats_lens"]
|
366 |
+
d_vectors = batch["d_vectors"]
|
367 |
+
speaker_ids = batch["speaker_ids"]
|
368 |
+
|
369 |
+
if self.run_data_dep_init and self.training:
|
370 |
+
# compute data-dependent initialization of activation norm layers
|
371 |
+
self.unlock_act_norm_layers()
|
372 |
+
with torch.no_grad():
|
373 |
+
_ = self.forward(
|
374 |
+
text_input,
|
375 |
+
text_lengths,
|
376 |
+
ssl_feats_input,
|
377 |
+
ssl_feats_lengths,
|
378 |
+
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
379 |
+
)
|
380 |
+
outputs = None
|
381 |
+
loss_dict = None
|
382 |
+
self.lock_act_norm_layers()
|
383 |
+
else:
|
384 |
+
# normal training step
|
385 |
+
outputs = self.forward(
|
386 |
+
text_input,
|
387 |
+
text_lengths,
|
388 |
+
ssl_feats_input,
|
389 |
+
ssl_feats_lengths,
|
390 |
+
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
391 |
+
)
|
392 |
+
|
393 |
+
with autocast(enabled=False): # avoid mixed_precision in criterion
|
394 |
+
loss_dict = criterion(
|
395 |
+
outputs["z"].float(),
|
396 |
+
outputs["y_mean"].float(),
|
397 |
+
outputs["y_log_scale"].float(),
|
398 |
+
outputs["logdet"].float(),
|
399 |
+
ssl_feats_lengths,
|
400 |
+
outputs["durations_log"].float(),
|
401 |
+
outputs["total_durations_log"].float(),
|
402 |
+
text_lengths,
|
403 |
+
)
|
404 |
+
|
405 |
+
return outputs, loss_dict
|
406 |
+
|
407 |
+
@torch.no_grad()
|
408 |
+
def _create_logs(self, batch, outputs, ap):
|
409 |
+
alignments = outputs["alignments"]
|
410 |
+
text_input = batch["text_input"][:1] if batch["text_input"] is not None else None
|
411 |
+
text_lengths = batch["text_lengths"]
|
412 |
+
mel_input = batch["mel_input"]
|
413 |
+
d_vectors = batch["d_vectors"][:1] if batch["d_vectors"] is not None else None
|
414 |
+
speaker_ids = batch["speaker_ids"][:1] if batch["speaker_ids"] is not None else None
|
415 |
+
|
416 |
+
gt_spec = mel_input[0].data.cpu().numpy()
|
417 |
+
align_img = alignments[0].data.cpu().numpy()
|
418 |
+
|
419 |
+
# model runs reverse flow to predict spectrograms
|
420 |
+
with torch.no_grad():
|
421 |
+
pred_outputs = self.inference(
|
422 |
+
text_input,
|
423 |
+
aux_input={
|
424 |
+
"x_lengths": text_lengths[:1],
|
425 |
+
"d_vectors": d_vectors,
|
426 |
+
"speaker_ids": speaker_ids,
|
427 |
+
},
|
428 |
+
)
|
429 |
+
|
430 |
+
pred_ssl_feats = pred_outputs["model_outputs"]
|
431 |
+
|
432 |
+
pred_audio = self.vocoder(pred_ssl_feats).unsqueeze(0).cpu().detach().numpy()
|
433 |
+
|
434 |
+
pred_spec = self.ap.melspectrogram(pred_audio).astype("float32")
|
435 |
+
pred_spec = pred_spec.squeeze(axis=1)
|
436 |
+
pred_spec = pred_spec.T
|
437 |
+
|
438 |
+
figures = {
|
439 |
+
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
440 |
+
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
441 |
+
"alignment": plot_alignment(align_img, output_fig=False),
|
442 |
+
}
|
443 |
+
|
444 |
+
return figures, {"audio": pred_audio}
|
445 |
+
|
446 |
+
def get_data_loader(
|
447 |
+
self,
|
448 |
+
config,
|
449 |
+
assets,
|
450 |
+
is_eval,
|
451 |
+
samples,
|
452 |
+
verbose: bool,
|
453 |
+
num_gpus: int,
|
454 |
+
rank: int = None,
|
455 |
+
) -> "DataLoader":
|
456 |
+
if is_eval and not config.run_eval:
|
457 |
+
loader = None
|
458 |
+
else:
|
459 |
+
# setup multi-speaker attributes
|
460 |
+
if self.speaker_manager is not None:
|
461 |
+
if hasattr(config, "model_args"):
|
462 |
+
speaker_id_mapping = self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None
|
463 |
+
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
|
464 |
+
config.use_d_vector_file = config.model_args.use_d_vector_file
|
465 |
+
else:
|
466 |
+
speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None
|
467 |
+
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
|
468 |
+
else:
|
469 |
+
speaker_id_mapping = None
|
470 |
+
d_vector_mapping = None
|
471 |
+
|
472 |
+
# setup multi-lingual attributes
|
473 |
+
if self.language_manager is not None:
|
474 |
+
language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None
|
475 |
+
else:
|
476 |
+
language_id_mapping = None
|
477 |
+
|
478 |
+
# init dataloader
|
479 |
+
dataset = GlowTTSSSLDataset(
|
480 |
+
outputs_per_step=config.r if "r" in config else 1,
|
481 |
+
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
482 |
+
compute_f0=config.get("compute_f0", False),
|
483 |
+
f0_cache_path=config.get("f0_cache_path", None),
|
484 |
+
compute_energy=config.get("compute_energy", False),
|
485 |
+
energy_cache_path=config.get("energy_cache_path", None),
|
486 |
+
samples=samples,
|
487 |
+
ap=self.ap,
|
488 |
+
return_wav=config.return_wav if "return_wav" in config else False,
|
489 |
+
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
490 |
+
min_text_len=config.min_text_len,
|
491 |
+
max_text_len=config.max_text_len,
|
492 |
+
min_audio_len=config.min_audio_len,
|
493 |
+
max_audio_len=config.max_audio_len,
|
494 |
+
phoneme_cache_path=config.phoneme_cache_path,
|
495 |
+
precompute_num_workers=config.precompute_num_workers,
|
496 |
+
use_noise_augment=False if is_eval else config.use_noise_augment,
|
497 |
+
speaker_id_mapping=speaker_id_mapping,
|
498 |
+
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
499 |
+
tokenizer=self.tokenizer,
|
500 |
+
start_by_longest=config.start_by_longest,
|
501 |
+
language_id_mapping=language_id_mapping,
|
502 |
+
)
|
503 |
+
|
504 |
+
# wait all the DDP process to be ready
|
505 |
+
if num_gpus > 1:
|
506 |
+
dist.barrier()
|
507 |
+
|
508 |
+
# sort input sequences from short to long
|
509 |
+
dataset.preprocess_samples()
|
510 |
+
|
511 |
+
# get samplers
|
512 |
+
sampler = self.get_sampler(config, dataset, num_gpus)
|
513 |
+
|
514 |
+
loader = DataLoader(
|
515 |
+
dataset,
|
516 |
+
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
517 |
+
shuffle=config.shuffle if sampler is None else False, # if there is no other sampler
|
518 |
+
collate_fn=dataset.collate_fn,
|
519 |
+
drop_last=config.drop_last, # setting this False might cause issues in AMP training.
|
520 |
+
sampler=sampler,
|
521 |
+
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
522 |
+
pin_memory=False,
|
523 |
+
)
|
524 |
+
return loader
|
525 |
+
|
526 |
+
def format_batch(self, batch):
|
527 |
+
# setup input batch
|
528 |
+
text_input = batch["token_id"]
|
529 |
+
text_lengths = batch["token_id_lengths"]
|
530 |
+
speaker_names = batch["speaker_names"]
|
531 |
+
linear_input = batch["linear"]
|
532 |
+
mel_input = batch["mel"]
|
533 |
+
mel_lengths = batch["mel_lengths"]
|
534 |
+
stop_targets = batch["stop_targets"]
|
535 |
+
item_idx = batch["item_idxs"]
|
536 |
+
d_vectors = batch["d_vectors"]
|
537 |
+
speaker_ids = batch["speaker_ids"]
|
538 |
+
attn_mask = batch["attns"]
|
539 |
+
waveform = batch["waveform"]
|
540 |
+
pitch = batch["pitch"]
|
541 |
+
energy = batch["energy"]
|
542 |
+
language_ids = batch["language_ids"]
|
543 |
+
max_text_length = torch.max(text_lengths.float())
|
544 |
+
max_spec_length = torch.max(mel_lengths.float())
|
545 |
+
|
546 |
+
# compute durations from attention masks
|
547 |
+
durations = None
|
548 |
+
if attn_mask is not None:
|
549 |
+
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
550 |
+
for idx, am in enumerate(attn_mask):
|
551 |
+
# compute raw durations
|
552 |
+
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
553 |
+
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
554 |
+
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
555 |
+
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
556 |
+
dur[c_idxs] = counts
|
557 |
+
# smooth the durations and set any 0 duration to 1
|
558 |
+
# by cutting off from the largest duration indeces.
|
559 |
+
extra_frames = dur.sum() - mel_lengths[idx]
|
560 |
+
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
561 |
+
dur[largest_idxs] -= 1
|
562 |
+
assert dur.sum() == mel_lengths[idx], f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
563 |
+
durations[idx, : text_lengths[idx]] = dur
|
564 |
+
|
565 |
+
# set stop targets wrt reduction factor
|
566 |
+
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
567 |
+
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
568 |
+
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
|
569 |
+
|
570 |
+
return {
|
571 |
+
"text_input": text_input,
|
572 |
+
"text_lengths": text_lengths,
|
573 |
+
"speaker_names": speaker_names,
|
574 |
+
"mel_input": mel_input,
|
575 |
+
"mel_lengths": mel_lengths,
|
576 |
+
"linear_input": linear_input,
|
577 |
+
"stop_targets": stop_targets,
|
578 |
+
"stop_target_lengths": stop_target_lengths,
|
579 |
+
"attn_mask": attn_mask,
|
580 |
+
"durations": durations,
|
581 |
+
"speaker_ids": speaker_ids,
|
582 |
+
"d_vectors": d_vectors,
|
583 |
+
"max_text_length": float(max_text_length),
|
584 |
+
"max_spec_length": float(max_spec_length),
|
585 |
+
"item_idx": item_idx,
|
586 |
+
"waveform": waveform,
|
587 |
+
"pitch": pitch,
|
588 |
+
"energy": energy,
|
589 |
+
"language_ids": language_ids,
|
590 |
+
"audio_unique_names": batch["audio_unique_names"],
|
591 |
+
"ssl_feats": batch["ssl_feats"],
|
592 |
+
"ssl_feats_lens": batch["ssl_feats_lens"],
|
593 |
+
"ssl_feats_rel_lens": batch["ssl_feats_rel_lens"],
|
594 |
+
}
|
595 |
+
|
596 |
+
@torch.no_grad()
|
597 |
+
def test_run(self, assets):
|
598 |
+
print(" | > Synthesizing test sentences.")
|
599 |
+
test_audios = {}
|
600 |
+
test_figures = {}
|
601 |
+
test_sentences = self.config.test_sentences
|
602 |
+
aux_inputs = self._get_test_aux_input()
|
603 |
+
if len(test_sentences) == 0:
|
604 |
+
print(" | [!] No test sentences provided.")
|
605 |
+
else:
|
606 |
+
for idx, sen in enumerate(test_sentences):
|
607 |
+
outputs = ssl_synthesis(
|
608 |
+
self,
|
609 |
+
sen,
|
610 |
+
self.vocoder,
|
611 |
+
speaker_id=aux_inputs["speaker_id"],
|
612 |
+
d_vector=aux_inputs["d_vector"],
|
613 |
+
do_trim_silence=False,
|
614 |
+
)
|
615 |
+
|
616 |
+
test_audios[f"{idx}-audio"] = outputs["wav"]
|
617 |
+
test_figures[f"{idx}-prediction"] = plot_spectrogram(outputs["outputs"]["model_outputs"], self.ap, output_fig=False)
|
618 |
+
test_figures[f"{idx}-alignment"] = plot_alignment(outputs["alignments"], output_fig=False)
|
619 |
+
return {"figures": test_figures, "audios": test_audios}
|
620 |
+
|
621 |
+
def test_log(self, outputs, logger, assets, steps) -> None: # pylint:disable=W0613
|
622 |
+
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
|
623 |
+
logger.test_figures(steps, outputs["figures"])
|
knn_tts/tts/__init__.py
ADDED
File without changes
|
knn_tts/tts/models.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
|
2 |
+
# SPDX-FileContributor: Karl El Hajal
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: MIT
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from knn_tts.tts.GlowTTS.glow_tts_ssl import load_glow_tts_ssl, ssl_synthesis
|
9 |
+
|
10 |
+
|
11 |
+
class GlowTTS:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
tts_model_base_path,
|
15 |
+
tts_model_checkpoint,
|
16 |
+
vocoder_checkpoint_path=None,
|
17 |
+
device=None,
|
18 |
+
):
|
19 |
+
if device is None:
|
20 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
else:
|
22 |
+
self.device = torch.device(device)
|
23 |
+
self.model = load_glow_tts_ssl(
|
24 |
+
tts_model_base_path,
|
25 |
+
tts_model_checkpoint,
|
26 |
+
vocoder_checkpoint_path,
|
27 |
+
self.device,
|
28 |
+
)
|
29 |
+
|
30 |
+
def __call__(self, text, *args, **kwargs):
|
31 |
+
with torch.no_grad():
|
32 |
+
ssl_feats_out = ssl_synthesis(self.model, text, None, return_ssl_feats=True)
|
33 |
+
return ssl_feats_out
|
knn_tts/utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
|
2 |
+
# SPDX-FileContributor: Karl El Hajal
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: MIT
|
5 |
+
|
6 |
+
import os
|
7 |
+
|
8 |
+
import requests
|
9 |
+
|
10 |
+
|
11 |
+
def find_wav_paths(root_dir):
|
12 |
+
wav_paths = []
|
13 |
+
for subdir, _, files in os.walk(root_dir):
|
14 |
+
for file in files:
|
15 |
+
if file.endswith(".wav"):
|
16 |
+
wav_paths.append(os.path.join(subdir, file))
|
17 |
+
return wav_paths
|
18 |
+
|
19 |
+
def get_vocoder_checkpoint_path(checkpoints_dir):
|
20 |
+
os.makedirs(checkpoints_dir, exist_ok=True) # Ensure directory exists
|
21 |
+
|
22 |
+
checkpoint_path = os.path.join(checkpoints_dir, "prematch_g_02500000.pt")
|
23 |
+
url = "https://github.com/bshall/knn-vc/releases/download/v0.1/prematch_g_02500000.pt"
|
24 |
+
|
25 |
+
if not os.path.exists(checkpoint_path):
|
26 |
+
print(f"Downloading checkpoint to {checkpoint_path}...")
|
27 |
+
response = requests.get(url, stream=True)
|
28 |
+
if response.status_code == 200:
|
29 |
+
with open(checkpoint_path, "wb") as f:
|
30 |
+
for chunk in response.iter_content(chunk_size=8192):
|
31 |
+
f.write(chunk)
|
32 |
+
print("Download complete.")
|
33 |
+
else:
|
34 |
+
raise Exception(f"Failed to download checkpoint: {response.status_code}")
|
35 |
+
|
36 |
+
return checkpoint_path
|
knn_tts/vc/__init__.py
ADDED
File without changes
|
knn_tts/vc/knn.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2023 MediaLab, Department of Electrical & Electronic Engineering, Stellenbosch University
|
2 |
+
# SPDX-FileContributor: Karl El Hajal
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: MIT
|
5 |
+
|
6 |
+
import os
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def load_target_style_feats(feats_base_path, max_num_files=1000):
|
12 |
+
feats = []
|
13 |
+
for filepath in os.listdir(feats_base_path)[:max_num_files]:
|
14 |
+
if ".pt" in filepath:
|
15 |
+
filepath = os.path.join(feats_base_path, filepath)
|
16 |
+
feats.append(torch.load(filepath, weights_only=False))
|
17 |
+
feats = torch.concat(feats, dim=0).cpu()
|
18 |
+
return feats
|
19 |
+
|
20 |
+
|
21 |
+
def fast_cosine_dist(source_feats, matching_pool, device):
|
22 |
+
"""Like torch.cdist, but fixed dim=-1 and for cosine distance."""
|
23 |
+
source_norms = torch.norm(source_feats, p=2, dim=-1).to(device)
|
24 |
+
matching_norms = torch.norm(matching_pool, p=2, dim=-1)
|
25 |
+
dotprod = -(torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0] ** 2) + source_norms[:, None] ** 2 + matching_norms[None] ** 2
|
26 |
+
dotprod /= 2
|
27 |
+
|
28 |
+
dists = 1 - (dotprod / (source_norms[:, None] * matching_norms[None]))
|
29 |
+
return dists
|
30 |
+
|
31 |
+
|
32 |
+
@torch.inference_mode()
|
33 |
+
def knn_vc(source_frames, target_style_set, topk=4, weighted_average=False, device=None):
|
34 |
+
if device is None:
|
35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
36 |
+
else:
|
37 |
+
device = torch.device(device)
|
38 |
+
target_style_set = target_style_set.to(device)
|
39 |
+
source_frames = source_frames.to(device)
|
40 |
+
|
41 |
+
dists = fast_cosine_dist(source_frames, target_style_set, device=device)
|
42 |
+
best = dists.topk(k=topk, largest=False, dim=-1)
|
43 |
+
|
44 |
+
if weighted_average:
|
45 |
+
weights = 1 / (best.values + 1e-8) # Adding a small value to avoid division by zero
|
46 |
+
weights /= weights.sum(dim=-1, keepdim=True) # Normalize weights
|
47 |
+
selected_frames = (target_style_set[best.indices] * weights[..., None]).sum(dim=1)
|
48 |
+
else:
|
49 |
+
selected_frames = target_style_set[best.indices].mean(dim=1)
|
50 |
+
|
51 |
+
return selected_frames
|
52 |
+
|
knn_tts/vocoder/__init__.py
ADDED
File without changes
|
knn_tts/vocoder/hifigan/__init__.py
ADDED
File without changes
|
knn_tts/vocoder/hifigan/config_v1_wavlm.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 16,
|
5 |
+
"learning_rate": 0.0002,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.999,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [10,8,2,2],
|
12 |
+
"upsample_kernel_sizes": [20,16,4,4],
|
13 |
+
"upsample_initial_channel": 512,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"hubert_dim": 1024,
|
18 |
+
"hifi_dim": 512,
|
19 |
+
|
20 |
+
"segment_size": 7040,
|
21 |
+
"num_mels": 80,
|
22 |
+
"num_freq": 1025,
|
23 |
+
"n_fft": 1024,
|
24 |
+
"hop_size": 320,
|
25 |
+
"win_size": 1024,
|
26 |
+
|
27 |
+
"sampling_rate": 16000,
|
28 |
+
|
29 |
+
"fmin": 0,
|
30 |
+
"fmax": 8000,
|
31 |
+
"fmax_for_loss": null,
|
32 |
+
|
33 |
+
"num_workers": 4,
|
34 |
+
|
35 |
+
"dist_config": {
|
36 |
+
"dist_backend": "nccl",
|
37 |
+
"dist_url": "tcp://localhost:54321",
|
38 |
+
"world_size": 1
|
39 |
+
}
|
40 |
+
}
|
knn_tts/vocoder/hifigan/models.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2020 Jungil Kong
|
2 |
+
#
|
3 |
+
# SPDX-License-Identifier: MIT
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
9 |
+
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
10 |
+
|
11 |
+
from knn_tts.vocoder.hifigan.utils import get_padding, init_weights
|
12 |
+
|
13 |
+
LRELU_SLOPE = 0.1
|
14 |
+
|
15 |
+
|
16 |
+
class ResBlock1(torch.nn.Module):
|
17 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
18 |
+
super().__init__()
|
19 |
+
self.h = h
|
20 |
+
self.convs1 = nn.ModuleList(
|
21 |
+
[
|
22 |
+
weight_norm(
|
23 |
+
Conv1d(
|
24 |
+
channels,
|
25 |
+
channels,
|
26 |
+
kernel_size,
|
27 |
+
1,
|
28 |
+
dilation=dilation[0],
|
29 |
+
padding=get_padding(kernel_size, dilation[0]),
|
30 |
+
)
|
31 |
+
),
|
32 |
+
weight_norm(
|
33 |
+
Conv1d(
|
34 |
+
channels,
|
35 |
+
channels,
|
36 |
+
kernel_size,
|
37 |
+
1,
|
38 |
+
dilation=dilation[1],
|
39 |
+
padding=get_padding(kernel_size, dilation[1]),
|
40 |
+
)
|
41 |
+
),
|
42 |
+
weight_norm(
|
43 |
+
Conv1d(
|
44 |
+
channels,
|
45 |
+
channels,
|
46 |
+
kernel_size,
|
47 |
+
1,
|
48 |
+
dilation=dilation[2],
|
49 |
+
padding=get_padding(kernel_size, dilation[2]),
|
50 |
+
)
|
51 |
+
),
|
52 |
+
]
|
53 |
+
)
|
54 |
+
self.convs1.apply(init_weights)
|
55 |
+
|
56 |
+
self.convs2 = nn.ModuleList(
|
57 |
+
[
|
58 |
+
weight_norm(
|
59 |
+
Conv1d(
|
60 |
+
channels,
|
61 |
+
channels,
|
62 |
+
kernel_size,
|
63 |
+
1,
|
64 |
+
dilation=1,
|
65 |
+
padding=get_padding(kernel_size, 1),
|
66 |
+
)
|
67 |
+
),
|
68 |
+
weight_norm(
|
69 |
+
Conv1d(
|
70 |
+
channels,
|
71 |
+
channels,
|
72 |
+
kernel_size,
|
73 |
+
1,
|
74 |
+
dilation=1,
|
75 |
+
padding=get_padding(kernel_size, 1),
|
76 |
+
)
|
77 |
+
),
|
78 |
+
weight_norm(
|
79 |
+
Conv1d(
|
80 |
+
channels,
|
81 |
+
channels,
|
82 |
+
kernel_size,
|
83 |
+
1,
|
84 |
+
dilation=1,
|
85 |
+
padding=get_padding(kernel_size, 1),
|
86 |
+
)
|
87 |
+
),
|
88 |
+
]
|
89 |
+
)
|
90 |
+
self.convs2.apply(init_weights)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
94 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
95 |
+
xt = c1(xt)
|
96 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
97 |
+
xt = c2(xt)
|
98 |
+
x = xt + x
|
99 |
+
return x
|
100 |
+
|
101 |
+
def remove_weight_norm(self):
|
102 |
+
for l in self.convs1:
|
103 |
+
remove_weight_norm(l)
|
104 |
+
for l in self.convs2:
|
105 |
+
remove_weight_norm(l)
|
106 |
+
|
107 |
+
|
108 |
+
class ResBlock2(torch.nn.Module):
|
109 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
110 |
+
super().__init__()
|
111 |
+
self.h = h
|
112 |
+
self.convs = nn.ModuleList(
|
113 |
+
[
|
114 |
+
weight_norm(
|
115 |
+
Conv1d(
|
116 |
+
channels,
|
117 |
+
channels,
|
118 |
+
kernel_size,
|
119 |
+
1,
|
120 |
+
dilation=dilation[0],
|
121 |
+
padding=get_padding(kernel_size, dilation[0]),
|
122 |
+
)
|
123 |
+
),
|
124 |
+
weight_norm(
|
125 |
+
Conv1d(
|
126 |
+
channels,
|
127 |
+
channels,
|
128 |
+
kernel_size,
|
129 |
+
1,
|
130 |
+
dilation=dilation[1],
|
131 |
+
padding=get_padding(kernel_size, dilation[1]),
|
132 |
+
)
|
133 |
+
),
|
134 |
+
]
|
135 |
+
)
|
136 |
+
self.convs.apply(init_weights)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
for c in self.convs:
|
140 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
141 |
+
xt = c(xt)
|
142 |
+
x = xt + x
|
143 |
+
return x
|
144 |
+
|
145 |
+
def remove_weight_norm(self):
|
146 |
+
for l in self.convs:
|
147 |
+
remove_weight_norm(l)
|
148 |
+
|
149 |
+
|
150 |
+
class Generator(torch.nn.Module):
|
151 |
+
def __init__(self, h):
|
152 |
+
super().__init__()
|
153 |
+
self.h = h
|
154 |
+
self.lin_pre = nn.Linear(h.hubert_dim, h.hifi_dim)
|
155 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
156 |
+
self.num_upsamples = len(h.upsample_rates)
|
157 |
+
self.conv_pre = weight_norm(Conv1d(h.hifi_dim, h.upsample_initial_channel, 7, 1, padding=3))
|
158 |
+
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
159 |
+
|
160 |
+
self.ups = nn.ModuleList()
|
161 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
162 |
+
self.ups.append(
|
163 |
+
weight_norm(
|
164 |
+
ConvTranspose1d(
|
165 |
+
h.upsample_initial_channel // (2**i),
|
166 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
167 |
+
k,
|
168 |
+
u,
|
169 |
+
padding=(k - u) // 2,
|
170 |
+
)
|
171 |
+
)
|
172 |
+
)
|
173 |
+
|
174 |
+
self.resblocks = nn.ModuleList()
|
175 |
+
for i in range(len(self.ups)):
|
176 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
177 |
+
for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
178 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
179 |
+
|
180 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
181 |
+
self.ups.apply(init_weights)
|
182 |
+
self.conv_post.apply(init_weights)
|
183 |
+
|
184 |
+
def forward(self, x):
|
185 |
+
"""`x` as (bs, seq_len, dim), regular hifi assumes input of shape (bs, n_mels, seq_len)"""
|
186 |
+
x = self.lin_pre(x)
|
187 |
+
x = x.permute(0, 2, 1) # (bs, seq_len, dim) --> (bs, dim, seq_len)
|
188 |
+
|
189 |
+
x = self.conv_pre(x)
|
190 |
+
for i in range(self.num_upsamples):
|
191 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
192 |
+
x = self.ups[i](x)
|
193 |
+
xs = None
|
194 |
+
for j in range(self.num_kernels):
|
195 |
+
if xs is None:
|
196 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
197 |
+
else:
|
198 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
199 |
+
x = xs / self.num_kernels
|
200 |
+
x = F.leaky_relu(x)
|
201 |
+
x = self.conv_post(x)
|
202 |
+
x = torch.tanh(x)
|
203 |
+
|
204 |
+
return x
|
205 |
+
|
206 |
+
def remove_weight_norm(self):
|
207 |
+
print("Removing weight norm...")
|
208 |
+
for l in self.ups:
|
209 |
+
remove_weight_norm(l)
|
210 |
+
for l in self.resblocks:
|
211 |
+
l.remove_weight_norm()
|
212 |
+
remove_weight_norm(self.conv_pre)
|
213 |
+
remove_weight_norm(self.conv_post)
|
214 |
+
|
215 |
+
|
216 |
+
class DiscriminatorP(torch.nn.Module):
|
217 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
218 |
+
super().__init__()
|
219 |
+
self.period = period
|
220 |
+
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
221 |
+
self.convs = nn.ModuleList(
|
222 |
+
[
|
223 |
+
norm_f(
|
224 |
+
Conv2d(
|
225 |
+
1,
|
226 |
+
32,
|
227 |
+
(kernel_size, 1),
|
228 |
+
(stride, 1),
|
229 |
+
padding=(get_padding(5, 1), 0),
|
230 |
+
)
|
231 |
+
),
|
232 |
+
norm_f(
|
233 |
+
Conv2d(
|
234 |
+
32,
|
235 |
+
128,
|
236 |
+
(kernel_size, 1),
|
237 |
+
(stride, 1),
|
238 |
+
padding=(get_padding(5, 1), 0),
|
239 |
+
)
|
240 |
+
),
|
241 |
+
norm_f(
|
242 |
+
Conv2d(
|
243 |
+
128,
|
244 |
+
512,
|
245 |
+
(kernel_size, 1),
|
246 |
+
(stride, 1),
|
247 |
+
padding=(get_padding(5, 1), 0),
|
248 |
+
)
|
249 |
+
),
|
250 |
+
norm_f(
|
251 |
+
Conv2d(
|
252 |
+
512,
|
253 |
+
1024,
|
254 |
+
(kernel_size, 1),
|
255 |
+
(stride, 1),
|
256 |
+
padding=(get_padding(5, 1), 0),
|
257 |
+
)
|
258 |
+
),
|
259 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
260 |
+
]
|
261 |
+
)
|
262 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
fmap = []
|
266 |
+
|
267 |
+
# 1d to 2d
|
268 |
+
b, c, t = x.shape
|
269 |
+
if t % self.period != 0: # pad first
|
270 |
+
n_pad = self.period - (t % self.period)
|
271 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
272 |
+
t = t + n_pad
|
273 |
+
x = x.view(b, c, t // self.period, self.period)
|
274 |
+
|
275 |
+
for l in self.convs:
|
276 |
+
x = l(x)
|
277 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
278 |
+
fmap.append(x)
|
279 |
+
x = self.conv_post(x)
|
280 |
+
fmap.append(x)
|
281 |
+
x = torch.flatten(x, 1, -1)
|
282 |
+
|
283 |
+
return x, fmap
|
284 |
+
|
285 |
+
|
286 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
287 |
+
def __init__(self):
|
288 |
+
super().__init__()
|
289 |
+
self.discriminators = nn.ModuleList(
|
290 |
+
[
|
291 |
+
DiscriminatorP(2),
|
292 |
+
DiscriminatorP(3),
|
293 |
+
DiscriminatorP(5),
|
294 |
+
DiscriminatorP(7),
|
295 |
+
DiscriminatorP(11),
|
296 |
+
]
|
297 |
+
)
|
298 |
+
|
299 |
+
def forward(self, y, y_hat):
|
300 |
+
y_d_rs = []
|
301 |
+
y_d_gs = []
|
302 |
+
fmap_rs = []
|
303 |
+
fmap_gs = []
|
304 |
+
for _, d in enumerate(self.discriminators):
|
305 |
+
y_d_r, fmap_r = d(y)
|
306 |
+
y_d_g, fmap_g = d(y_hat)
|
307 |
+
y_d_rs.append(y_d_r)
|
308 |
+
fmap_rs.append(fmap_r)
|
309 |
+
y_d_gs.append(y_d_g)
|
310 |
+
fmap_gs.append(fmap_g)
|
311 |
+
|
312 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
313 |
+
|
314 |
+
|
315 |
+
class DiscriminatorS(torch.nn.Module):
|
316 |
+
def __init__(self, use_spectral_norm=False):
|
317 |
+
super().__init__()
|
318 |
+
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
319 |
+
self.convs = nn.ModuleList(
|
320 |
+
[
|
321 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
322 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
323 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
324 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
325 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
326 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
327 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
328 |
+
]
|
329 |
+
)
|
330 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
331 |
+
|
332 |
+
def forward(self, x):
|
333 |
+
fmap = []
|
334 |
+
for l in self.convs:
|
335 |
+
x = l(x)
|
336 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
337 |
+
fmap.append(x)
|
338 |
+
x = self.conv_post(x)
|
339 |
+
fmap.append(x)
|
340 |
+
x = torch.flatten(x, 1, -1)
|
341 |
+
|
342 |
+
return x, fmap
|
343 |
+
|
344 |
+
|
345 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
346 |
+
def __init__(self):
|
347 |
+
super().__init__()
|
348 |
+
self.discriminators = nn.ModuleList(
|
349 |
+
[
|
350 |
+
DiscriminatorS(use_spectral_norm=True),
|
351 |
+
DiscriminatorS(),
|
352 |
+
DiscriminatorS(),
|
353 |
+
]
|
354 |
+
)
|
355 |
+
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
|
356 |
+
|
357 |
+
def forward(self, y, y_hat):
|
358 |
+
y_d_rs = []
|
359 |
+
y_d_gs = []
|
360 |
+
fmap_rs = []
|
361 |
+
fmap_gs = []
|
362 |
+
for i, d in enumerate(self.discriminators):
|
363 |
+
if i != 0:
|
364 |
+
y = self.meanpools[i - 1](y)
|
365 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
366 |
+
y_d_r, fmap_r = d(y)
|
367 |
+
y_d_g, fmap_g = d(y_hat)
|
368 |
+
y_d_rs.append(y_d_r)
|
369 |
+
fmap_rs.append(fmap_r)
|
370 |
+
y_d_gs.append(y_d_g)
|
371 |
+
fmap_gs.append(fmap_g)
|
372 |
+
|
373 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
374 |
+
|
375 |
+
|
376 |
+
def feature_loss(fmap_r, fmap_g):
|
377 |
+
loss = 0
|
378 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
379 |
+
for rl, gl in zip(dr, dg):
|
380 |
+
loss += torch.mean(torch.abs(rl - gl))
|
381 |
+
|
382 |
+
return loss * 2
|
383 |
+
|
384 |
+
|
385 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
386 |
+
loss = 0
|
387 |
+
r_losses = []
|
388 |
+
g_losses = []
|
389 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
390 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
391 |
+
g_loss = torch.mean(dg**2)
|
392 |
+
loss += r_loss + g_loss
|
393 |
+
r_losses.append(r_loss.item())
|
394 |
+
g_losses.append(g_loss.item())
|
395 |
+
|
396 |
+
return loss, r_losses, g_losses
|
397 |
+
|
398 |
+
|
399 |
+
def generator_loss(disc_outputs):
|
400 |
+
loss = 0
|
401 |
+
gen_losses = []
|
402 |
+
for dg in disc_outputs:
|
403 |
+
l = torch.mean((1 - dg) ** 2)
|
404 |
+
gen_losses.append(l)
|
405 |
+
loss += l
|
406 |
+
|
407 |
+
return loss, gen_losses
|
knn_tts/vocoder/hifigan/utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2020 Jungil Kong
|
2 |
+
#
|
3 |
+
# SPDX-License-Identifier: MIT
|
4 |
+
|
5 |
+
|
6 |
+
def init_weights(m, mean=0.0, std=0.01):
|
7 |
+
classname = m.__class__.__name__
|
8 |
+
if classname.find("Conv") != -1:
|
9 |
+
m.weight.data.normal_(mean, std)
|
10 |
+
|
11 |
+
|
12 |
+
def get_padding(kernel_size, dilation=1):
|
13 |
+
return int((kernel_size * dilation - dilation) / 2)
|
14 |
+
|
15 |
+
|
16 |
+
class AttrDict(dict):
|
17 |
+
def __init__(self, *args, **kwargs):
|
18 |
+
super().__init__(*args, **kwargs)
|
19 |
+
self.__dict__ = self
|
knn_tts/vocoder/models.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
|
2 |
+
# SPDX-FileContributor: Karl El Hajal
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: MIT
|
5 |
+
|
6 |
+
import json
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from knn_tts.vocoder.hifigan.models import Generator as HiFiGAN
|
11 |
+
from knn_tts.vocoder.hifigan.utils import AttrDict
|
12 |
+
|
13 |
+
|
14 |
+
class HiFiGANWavLM:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
checkpoint_path,
|
18 |
+
config_path="knn_tts/vocoder/hifigan/config_v1_wavlm.json",
|
19 |
+
device=None,
|
20 |
+
):
|
21 |
+
if device is None:
|
22 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
else:
|
24 |
+
self.device = torch.device(device)
|
25 |
+
with open(config_path, encoding="utf-8") as f:
|
26 |
+
data = f.read()
|
27 |
+
json_config = json.loads(data)
|
28 |
+
self.h = AttrDict(json_config)
|
29 |
+
|
30 |
+
self.generator = HiFiGAN(self.h).to(self.device)
|
31 |
+
|
32 |
+
state_dict_g = torch.load(checkpoint_path, weights_only=False, map_location=device)
|
33 |
+
self.generator.load_state_dict(state_dict_g["generator"])
|
34 |
+
self.generator.eval()
|
35 |
+
self.generator.remove_weight_norm()
|
36 |
+
num_parameters = sum(p.numel() for p in self.generator.parameters())
|
37 |
+
print(f"[HiFiGAN] Generator loaded with {num_parameters:,d} parameters.")
|
38 |
+
|
39 |
+
def __call__(self, x):
|
40 |
+
with torch.no_grad():
|
41 |
+
wav = self.generator(x.to(self.device))
|
42 |
+
wav = wav.squeeze(1)
|
43 |
+
wav = wav.cpu().squeeze()
|
44 |
+
return wav
|
requirements.txt
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.2.2
|
2 |
+
aiofiles==24.1.0
|
3 |
+
aiohappyeyeballs==2.6.1
|
4 |
+
aiohttp==3.11.16
|
5 |
+
aiosignal==1.3.2
|
6 |
+
annotated-types==0.7.0
|
7 |
+
anyascii==0.3.2
|
8 |
+
anyio==4.9.0
|
9 |
+
attrs==25.3.0
|
10 |
+
audioread==3.0.1
|
11 |
+
babel==2.17.0
|
12 |
+
blis==0.7.11
|
13 |
+
catalogue==2.0.10
|
14 |
+
certifi==2025.1.31
|
15 |
+
cffi==1.17.1
|
16 |
+
charset-normalizer==3.4.1
|
17 |
+
click==8.1.8
|
18 |
+
cloudpathlib==0.21.0
|
19 |
+
confection==0.1.5
|
20 |
+
contourpy==1.3.1
|
21 |
+
coqpit-config==0.2.0
|
22 |
+
coqui-tts==0.26.0
|
23 |
+
coqui-tts-trainer==0.2.3
|
24 |
+
cycler==0.12.1
|
25 |
+
cymem==2.0.11
|
26 |
+
Cython==3.0.12
|
27 |
+
dateparser==1.1.8
|
28 |
+
decorator==5.2.1
|
29 |
+
docopt==0.6.2
|
30 |
+
einops==0.8.1
|
31 |
+
encodec==0.1.1
|
32 |
+
fastapi==0.115.12
|
33 |
+
ffmpy==0.5.0
|
34 |
+
filelock==3.18.0
|
35 |
+
fonttools==4.57.0
|
36 |
+
frozenlist==1.5.0
|
37 |
+
fsspec==2025.3.2
|
38 |
+
gradio==5.24.0
|
39 |
+
gradio_client==1.8.0
|
40 |
+
groovy==0.1.2
|
41 |
+
grpcio==1.71.0
|
42 |
+
gruut==2.4.0
|
43 |
+
gruut-ipa==0.13.0
|
44 |
+
gruut_lang_de==2.0.1
|
45 |
+
gruut_lang_en==2.0.1
|
46 |
+
gruut_lang_es==2.0.1
|
47 |
+
gruut_lang_fr==2.0.2
|
48 |
+
h11==0.14.0
|
49 |
+
httpcore==1.0.7
|
50 |
+
httpx==0.28.1
|
51 |
+
huggingface-hub==0.30.2
|
52 |
+
idna==3.10
|
53 |
+
inflect==7.5.0
|
54 |
+
Jinja2==3.1.6
|
55 |
+
joblib==1.4.2
|
56 |
+
jsonlines==1.2.0
|
57 |
+
kiwisolver==1.4.8
|
58 |
+
langcodes==3.5.0
|
59 |
+
language_data==1.3.0
|
60 |
+
lazy_loader==0.4
|
61 |
+
librosa==0.11.0
|
62 |
+
llvmlite==0.44.0
|
63 |
+
marisa-trie==1.2.1
|
64 |
+
Markdown==3.7
|
65 |
+
markdown-it-py==3.0.0
|
66 |
+
MarkupSafe==3.0.2
|
67 |
+
matplotlib==3.10.1
|
68 |
+
mdurl==0.1.2
|
69 |
+
monotonic-alignment-search==0.1.1
|
70 |
+
more-itertools==10.6.0
|
71 |
+
mpmath==1.3.0
|
72 |
+
msgpack==1.1.0
|
73 |
+
multidict==6.2.0
|
74 |
+
murmurhash==1.0.12
|
75 |
+
networkx==3.4.2
|
76 |
+
num2words==0.5.14
|
77 |
+
numba==0.61.2
|
78 |
+
numpy==1.26.4
|
79 |
+
nvidia-cublas-cu12==12.4.5.8
|
80 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
81 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
82 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
83 |
+
nvidia-cudnn-cu12==9.1.0.70
|
84 |
+
nvidia-cufft-cu12==11.2.1.3
|
85 |
+
nvidia-curand-cu12==10.3.5.147
|
86 |
+
nvidia-cusolver-cu12==11.6.1.9
|
87 |
+
nvidia-cusparse-cu12==12.3.1.170
|
88 |
+
nvidia-cusparselt-cu12==0.6.2
|
89 |
+
nvidia-nccl-cu12==2.21.5
|
90 |
+
nvidia-nvjitlink-cu12==12.4.127
|
91 |
+
nvidia-nvtx-cu12==12.4.127
|
92 |
+
orjson==3.10.16
|
93 |
+
packaging==24.2
|
94 |
+
pandas==2.2.3
|
95 |
+
pillow==11.1.0
|
96 |
+
pip==25.0.1
|
97 |
+
platformdirs==4.3.7
|
98 |
+
pooch==1.8.2
|
99 |
+
preshed==3.0.9
|
100 |
+
propcache==0.3.1
|
101 |
+
protobuf==6.30.2
|
102 |
+
psutil==5.9.8
|
103 |
+
pycparser==2.22
|
104 |
+
pydantic==2.11.3
|
105 |
+
pydantic_core==2.33.1
|
106 |
+
pydub==0.25.1
|
107 |
+
Pygments==2.19.1
|
108 |
+
pyparsing==3.2.3
|
109 |
+
pysbd==0.3.4
|
110 |
+
python-crfsuite==0.9.11
|
111 |
+
python-dateutil==2.9.0.post0
|
112 |
+
python-multipart==0.0.20
|
113 |
+
pytz==2025.2
|
114 |
+
PyYAML==6.0.2
|
115 |
+
regex==2024.11.6
|
116 |
+
requests==2.32.3
|
117 |
+
rich==14.0.0
|
118 |
+
ruff==0.11.4
|
119 |
+
safehttpx==0.1.6
|
120 |
+
safetensors==0.5.3
|
121 |
+
scikit-learn==1.6.1
|
122 |
+
scipy==1.15.2
|
123 |
+
semantic-version==2.10.0
|
124 |
+
setuptools==78.1.0
|
125 |
+
shellingham==1.5.4
|
126 |
+
six==1.17.0
|
127 |
+
smart-open==7.1.0
|
128 |
+
sniffio==1.3.1
|
129 |
+
soundfile==0.13.1
|
130 |
+
soxr==0.5.0.post1
|
131 |
+
spaces==0.34.2
|
132 |
+
spacy==3.7.5
|
133 |
+
spacy-legacy==3.0.12
|
134 |
+
spacy-loggers==1.0.5
|
135 |
+
srsly==2.5.1
|
136 |
+
starlette==0.46.1
|
137 |
+
SudachiDict-core==20250129
|
138 |
+
SudachiPy==0.6.10
|
139 |
+
sympy==1.13.1
|
140 |
+
tensorboard==2.19.0
|
141 |
+
tensorboard-data-server==0.7.2
|
142 |
+
thinc==8.2.5
|
143 |
+
threadpoolctl==3.6.0
|
144 |
+
tokenizers==0.21.1
|
145 |
+
tomlkit==0.13.2
|
146 |
+
torch==2.6.0
|
147 |
+
torchaudio==2.6.0
|
148 |
+
tqdm==4.67.1
|
149 |
+
transformers==4.51.1
|
150 |
+
triton==3.2.0
|
151 |
+
typeguard==4.4.2
|
152 |
+
typer==0.15.2
|
153 |
+
typing_extensions==4.13.1
|
154 |
+
typing-inspection==0.4.0
|
155 |
+
tzdata==2025.2
|
156 |
+
tzlocal==5.3.1
|
157 |
+
urllib3==2.3.0
|
158 |
+
uvicorn==0.34.0
|
159 |
+
wasabi==1.1.3
|
160 |
+
weasel==0.4.1
|
161 |
+
websockets==15.0.1
|
162 |
+
Werkzeug==3.1.3
|
163 |
+
wrapt==1.17.2
|
164 |
+
yarl==1.19.0
|