OpenSound commited on
Commit
9c80a14
·
verified ·
1 Parent(s): e0f527e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -54,16 +54,21 @@ else:
54
  print('noise prediction')
55
  scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers'])
56
 
57
- # these steps reset dtype of noise_scheduler params
58
- latents = torch.randn((1, 128, 128),
59
- device=device)
60
- noise = torch.randn(latents.shape).to(latents.device)
61
- timesteps = torch.randint(0, scheduler.config.num_train_timesteps,
62
- (noise.shape[0],),
63
- device=latents.device).long()
64
- _ = scheduler.add_noise(latents, noise, timesteps)
 
 
 
 
65
 
66
 
 
67
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
68
  """
69
  Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -112,6 +117,7 @@ def sample_diffusion(mixture, timbre, ddim_steps=50, eta=0, seed=2023, guidance_
112
 
113
  @spaces.GPU
114
  def tse(gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale):
 
115
  with torch.no_grad():
116
  # mixture, _ = librosa.load(gt_file_input, sr=sample_rate)
117
  mixture, sr = torchaudio.load(gt_file_input)
 
54
  print('noise prediction')
55
  scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers'])
56
 
57
+ @spaces.GPU
58
+ def reset_scheduler_dtype():
59
+ latents = torch.randn((1, 128, 128), device="cuda")
60
+ noise = torch.randn_like(latents)
61
+ timesteps = torch.randint(
62
+ 0,
63
+ scheduler.config.num_train_timesteps,
64
+ (latents.shape[0],),
65
+ device=latents.device
66
+ )
67
+ _ = scheduler.add_noise(latents, noise, timesteps)
68
+ return "Scheduler dtype reset completed."
69
 
70
 
71
+ @spaces.GPU
72
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
73
  """
74
  Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
 
117
 
118
  @spaces.GPU
119
  def tse(gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale):
120
+ reset_scheduler_dtype()
121
  with torch.no_grad():
122
  # mixture, _ = librosa.load(gt_file_input, sr=sample_rate)
123
  mixture, sr = torchaudio.load(gt_file_input)