Update app.py
Browse files
app.py
CHANGED
@@ -54,16 +54,21 @@ else:
|
|
54 |
print('noise prediction')
|
55 |
scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers'])
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
noise = torch.
|
61 |
-
timesteps = torch.randint(
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
|
66 |
|
|
|
67 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
68 |
"""
|
69 |
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
@@ -112,6 +117,7 @@ def sample_diffusion(mixture, timbre, ddim_steps=50, eta=0, seed=2023, guidance_
|
|
112 |
|
113 |
@spaces.GPU
|
114 |
def tse(gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale):
|
|
|
115 |
with torch.no_grad():
|
116 |
# mixture, _ = librosa.load(gt_file_input, sr=sample_rate)
|
117 |
mixture, sr = torchaudio.load(gt_file_input)
|
|
|
54 |
print('noise prediction')
|
55 |
scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers'])
|
56 |
|
57 |
+
@spaces.GPU
|
58 |
+
def reset_scheduler_dtype():
|
59 |
+
latents = torch.randn((1, 128, 128), device="cuda")
|
60 |
+
noise = torch.randn_like(latents)
|
61 |
+
timesteps = torch.randint(
|
62 |
+
0,
|
63 |
+
scheduler.config.num_train_timesteps,
|
64 |
+
(latents.shape[0],),
|
65 |
+
device=latents.device
|
66 |
+
)
|
67 |
+
_ = scheduler.add_noise(latents, noise, timesteps)
|
68 |
+
return "Scheduler dtype reset completed."
|
69 |
|
70 |
|
71 |
+
@spaces.GPU
|
72 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
73 |
"""
|
74 |
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
|
|
117 |
|
118 |
@spaces.GPU
|
119 |
def tse(gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale):
|
120 |
+
reset_scheduler_dtype()
|
121 |
with torch.no_grad():
|
122 |
# mixture, _ = librosa.load(gt_file_input, sr=sample_rate)
|
123 |
mixture, sr = torchaudio.load(gt_file_input)
|