kevinwang676 commited on
Commit
05e87f6
·
verified ·
1 Parent(s): 6988d86

Create flow.py

Browse files
Files changed (1) hide show
  1. flow.py +217 -0
flow.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import threading
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from matcha.models.components.flow_matching import BASECFM
18
+
19
+
20
+ class ConditionalCFM(BASECFM):
21
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
22
+ super().__init__(
23
+ n_feats=in_channels,
24
+ cfm_params=cfm_params,
25
+ n_spks=n_spks,
26
+ spk_emb_dim=spk_emb_dim,
27
+ )
28
+ self.t_scheduler = cfm_params.t_scheduler
29
+ self.training_cfg_rate = cfm_params.training_cfg_rate
30
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
31
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
32
+ # Just change the architecture of the estimator here
33
+ self.estimator = estimator
34
+ self.lock = threading.Lock()
35
+
36
+ @torch.inference_mode()
37
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
38
+ """Forward diffusion
39
+
40
+ Args:
41
+ mu (torch.Tensor): output of encoder
42
+ shape: (batch_size, n_feats, mel_timesteps)
43
+ mask (torch.Tensor): output_mask
44
+ shape: (batch_size, 1, mel_timesteps)
45
+ n_timesteps (int): number of diffusion steps
46
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
47
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
48
+ shape: (batch_size, spk_emb_dim)
49
+ cond: Not used but kept for future purposes
50
+
51
+ Returns:
52
+ sample: generated mel-spectrogram
53
+ shape: (batch_size, n_feats, mel_timesteps)
54
+ """
55
+
56
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
57
+ cache_size = flow_cache.shape[2]
58
+ # fix prompt and overlap part mu and z
59
+ if cache_size != 0:
60
+ z[:, :, :cache_size] = flow_cache[:, :, :, 0]
61
+ mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
62
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
63
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
64
+ flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
65
+
66
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
67
+ if self.t_scheduler == 'cosine':
68
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
69
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
70
+
71
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
72
+ """
73
+ Fixed euler solver for ODEs.
74
+ Args:
75
+ x (torch.Tensor): random noise
76
+ t_span (torch.Tensor): n_timesteps interpolated
77
+ shape: (n_timesteps + 1,)
78
+ mu (torch.Tensor): output of encoder
79
+ shape: (batch_size, n_feats, mel_timesteps)
80
+ mask (torch.Tensor): output_mask
81
+ shape: (batch_size, 1, mel_timesteps)
82
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
83
+ shape: (batch_size, spk_emb_dim)
84
+ cond: Not used but kept for future purposes
85
+ """
86
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
87
+ t = t.unsqueeze(dim=0)
88
+
89
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
90
+ # Or in future might add like a return_all_steps flag
91
+ sol = []
92
+
93
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
95
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
97
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
98
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
99
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
100
+ for step in range(1, len(t_span)):
101
+ # Classifier-Free Guidance inference introduced in VoiceBox
102
+ x_in[:] = x
103
+ mask_in[:] = mask
104
+ mu_in[0] = mu
105
+ t_in[:] = t.unsqueeze(0)
106
+ spks_in[0] = spks
107
+ cond_in[0] = cond
108
+ dphi_dt = self.forward_estimator(
109
+ x_in, mask_in,
110
+ mu_in, t_in,
111
+ spks_in,
112
+ cond_in
113
+ )
114
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
115
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
116
+ x = x + dt * dphi_dt
117
+ t = t + dt
118
+ sol.append(x)
119
+ if step < len(t_span) - 1:
120
+ dt = t_span[step + 1] - t
121
+
122
+ return sol[-1].float()
123
+
124
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
125
+ if isinstance(self.estimator, torch.nn.Module):
126
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
127
+ else:
128
+ with self.lock:
129
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
130
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
131
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
132
+ self.estimator.set_input_shape('t', (2,))
133
+ self.estimator.set_input_shape('spks', (2, 80))
134
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
135
+ # run trt engine
136
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
137
+ mask.contiguous().data_ptr(),
138
+ mu.contiguous().data_ptr(),
139
+ t.contiguous().data_ptr(),
140
+ spks.contiguous().data_ptr(),
141
+ cond.contiguous().data_ptr(),
142
+ x.data_ptr()])
143
+ return x
144
+
145
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
146
+ """Computes diffusion loss
147
+
148
+ Args:
149
+ x1 (torch.Tensor): Target
150
+ shape: (batch_size, n_feats, mel_timesteps)
151
+ mask (torch.Tensor): target mask
152
+ shape: (batch_size, 1, mel_timesteps)
153
+ mu (torch.Tensor): output of encoder
154
+ shape: (batch_size, n_feats, mel_timesteps)
155
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
156
+ shape: (batch_size, spk_emb_dim)
157
+
158
+ Returns:
159
+ loss: conditional flow matching loss
160
+ y: conditional flow
161
+ shape: (batch_size, n_feats, mel_timesteps)
162
+ """
163
+ b, _, t = mu.shape
164
+
165
+ # random timestep
166
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
167
+ if self.t_scheduler == 'cosine':
168
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
169
+ # sample noise p(x_0)
170
+ z = torch.randn_like(x1)
171
+
172
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
173
+ u = x1 - (1 - self.sigma_min) * z
174
+
175
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
176
+ if self.training_cfg_rate > 0:
177
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
178
+ mu = mu * cfg_mask.view(-1, 1, 1)
179
+ spks = spks * cfg_mask.view(-1, 1)
180
+ cond = cond * cfg_mask.view(-1, 1, 1)
181
+
182
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
183
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
184
+ return loss, y
185
+
186
+
187
+ class CausalConditionalCFM(ConditionalCFM):
188
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
189
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
190
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
191
+
192
+ @torch.inference_mode()
193
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
194
+ """Forward diffusion
195
+
196
+ Args:
197
+ mu (torch.Tensor): output of encoder
198
+ shape: (batch_size, n_feats, mel_timesteps)
199
+ mask (torch.Tensor): output_mask
200
+ shape: (batch_size, 1, mel_timesteps)
201
+ n_timesteps (int): number of diffusion steps
202
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
203
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
204
+ shape: (batch_size, spk_emb_dim)
205
+ cond: Not used but kept for future purposes
206
+
207
+ Returns:
208
+ sample: generated mel-spectrogram
209
+ shape: (batch_size, n_feats, mel_timesteps)
210
+ """
211
+
212
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
213
+ # fix prompt and overlap part mu and z
214
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
215
+ if self.t_scheduler == 'cosine':
216
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
217
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None