lshzhm commited on
Commit
43d27a1
·
1 Parent(s): 791533f

gradio infer

Browse files
Files changed (3) hide show
  1. F5-TTS/src/f5_tts/infer/infer_cli_test.py +27 -22
  2. MMAudio/demo.py +35 -16
  3. app.py +35 -14
F5-TTS/src/f5_tts/infer/infer_cli_test.py CHANGED
@@ -21,7 +21,7 @@ from f5_tts.infer.utils_infer import (
21
  mel_spec_type,
22
  target_rms,
23
  cross_fade_duration,
24
- nfe_step,
25
  cfg_strength,
26
  sway_sampling_coef,
27
  speed,
@@ -68,7 +68,7 @@ parser.add_argument(
68
  "--ckpt_file",
69
  type=str,
70
  help="The path to model checkpoint .pt, leave blank to use default",
71
- default="",
72
  )
73
  parser.add_argument(
74
  "-v",
@@ -143,11 +143,11 @@ parser.add_argument(
143
  type=float,
144
  help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
145
  )
146
- parser.add_argument(
147
- "--nfe_step",
148
- type=int,
149
- help=f"The number of function evaluation (denoising steps), default {nfe_step}",
150
- )
151
  parser.add_argument(
152
  "--cfg_strength",
153
  type=float,
@@ -177,7 +177,7 @@ parser.add_argument(
177
  parser.add_argument(
178
  "--end",
179
  type=int,
180
- default=99999999,
181
  )
182
  parser.add_argument(
183
  "--v2a_path",
@@ -239,7 +239,7 @@ ref_text = (
239
  gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
240
  gen_file = args.gen_file or config.get("gen_file", "")
241
 
242
- output_dir = args.output_dir or config.get("output_dir", "tests")
243
  output_file = args.output_file or config.get(
244
  "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
245
  )
@@ -251,13 +251,13 @@ load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocod
251
  vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
252
  target_rms = args.target_rms or config.get("target_rms", target_rms)
253
  cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
254
- nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
255
  cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
256
  sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
257
  speed = args.speed or config.get("speed", speed)
258
  fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
259
 
260
- print("############nfe_step", nfe_step, vocoder_name)
261
 
262
 
263
  # patches for pip pkg user
@@ -280,12 +280,12 @@ if gen_file:
280
 
281
  # output path
282
 
283
- wave_path = Path(output_dir) / output_file
284
- # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
285
- if save_chunk:
286
- output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
287
- if not os.path.exists(output_chunk_dir):
288
- os.makedirs(output_chunk_dir)
289
 
290
 
291
  # load vocoder
@@ -335,7 +335,7 @@ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_na
335
  # inference process
336
 
337
 
338
- def main(ref_audio, ref_text, gen_text, energy):
339
  main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
340
  if "voices" not in config:
341
  voices = {"main": main_voice}
@@ -431,9 +431,14 @@ def normalize_wav(waveform, waveform_ref):
431
  return waveform
432
 
433
 
434
- if __name__ == "__main__":
435
-
436
- v2a_path = args.v2a_path
 
 
 
 
 
437
 
438
  if args.wav_p == "":
439
  scp = args.infer_list
@@ -493,7 +498,7 @@ if __name__ == "__main__":
493
  ####wav_gen, sr_gen = main(wav_p, txt_p, txt, [torch.zeros_like(energy_p), torch.zeros_like(energy)])
494
  ####wav_gen, sr_gen = main(wav_p, txt_p, txt, None)
495
  ####wav_gen, sr_gen = main(wav, txt, txt, None)
496
- wav_gen, sr_gen = main(wav_p, txt_p, txt, [energy_p, energy])
497
  ####wav_gen, sr_gen = main(wav, txt, txt, [energy.clone(), energy])
498
  wav_gen = torch.from_numpy(wav_gen).unsqueeze(0)
499
  assert(sr_gen == 24000)
 
21
  mel_spec_type,
22
  target_rms,
23
  cross_fade_duration,
24
+ #nfe_step,
25
  cfg_strength,
26
  sway_sampling_coef,
27
  speed,
 
68
  "--ckpt_file",
69
  type=str,
70
  help="The path to model checkpoint .pt, leave blank to use default",
71
+ default="./F5-TTS/ckpts/v2c/v2c_s44.pt",
72
  )
73
  parser.add_argument(
74
  "-v",
 
143
  type=float,
144
  help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
145
  )
146
+ #parser.add_argument(
147
+ # "--nfe_step",
148
+ # type=int,
149
+ # help=f"The number of function evaluation (denoising steps), default {nfe_step}",
150
+ #)
151
  parser.add_argument(
152
  "--cfg_strength",
153
  type=float,
 
177
  parser.add_argument(
178
  "--end",
179
  type=int,
180
+ default=1,
181
  )
182
  parser.add_argument(
183
  "--v2a_path",
 
239
  gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
240
  gen_file = args.gen_file or config.get("gen_file", "")
241
 
242
+ #output_dir = args.output_dir or config.get("output_dir", "tests")
243
  output_file = args.output_file or config.get(
244
  "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
245
  )
 
251
  vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
252
  target_rms = args.target_rms or config.get("target_rms", target_rms)
253
  cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
254
+ #nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
255
  cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
256
  sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
257
  speed = args.speed or config.get("speed", speed)
258
  fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
259
 
260
+ #print("############nfe_step", nfe_step, vocoder_name)
261
 
262
 
263
  # patches for pip pkg user
 
280
 
281
  # output path
282
 
283
+ #wave_path = Path(output_dir) / output_file
284
+ ## spectrogram_path = Path(output_dir) / "infer_cli_out.png"
285
+ #if save_chunk:
286
+ # output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
287
+ # if not os.path.exists(output_chunk_dir):
288
+ # os.makedirs(output_chunk_dir)
289
 
290
 
291
  # load vocoder
 
335
  # inference process
336
 
337
 
338
+ def main(ref_audio, ref_text, gen_text, energy, nfe_step):
339
  main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
340
  if "voices" not in config:
341
  voices = {"main": main_voice}
 
431
  return waveform
432
 
433
 
434
+ #if __name__ == "__main__":
435
+ def v2s_infer(output_dir, v2a_path, wav_p, txt_p, video, v2a_wav, txt, nfe_step):
436
+ #v2a_path = args.v2a_path
437
+ args.wav_p = wav_p
438
+ args.txt_p = txt_p
439
+ args.video = video
440
+ args.v2a_wav = v2a_wav
441
+ args.txt = txt
442
 
443
  if args.wav_p == "":
444
  scp = args.infer_list
 
498
  ####wav_gen, sr_gen = main(wav_p, txt_p, txt, [torch.zeros_like(energy_p), torch.zeros_like(energy)])
499
  ####wav_gen, sr_gen = main(wav_p, txt_p, txt, None)
500
  ####wav_gen, sr_gen = main(wav, txt, txt, None)
501
+ wav_gen, sr_gen = main(wav_p, txt_p, txt, [energy_p, energy], nfe_step)
502
  ####wav_gen, sr_gen = main(wav, txt, txt, [energy.clone(), energy])
503
  wav_gen = torch.from_numpy(wav_gen).unsqueeze(0)
504
  assert(sr_gen == 24000)
MMAudio/demo.py CHANGED
@@ -29,16 +29,16 @@ log = logging.getLogger()
29
 
30
 
31
  @torch.inference_mode()
32
- def main():
33
  setup_eval_logging()
34
 
35
  parser = ArgumentParser()
36
  parser.add_argument('--variant',
37
  type=str,
38
- default='large_44k',
39
  #default='small_16k',
40
  #default='medium_44k',
41
- #default='small_44k',
42
  help='small_16k, small_44k, medium_44k, large_44k, large_44k_v2')
43
  parser.add_argument('--video', type=Path, help='Path to the video file')
44
  parser.add_argument('--prompt', type=str, help='Input prompt', default='')
@@ -64,23 +64,23 @@ def main():
64
  if args.variant not in all_model_cfg:
65
  raise ValueError(f'Unknown model variant: {args.variant}')
66
  model: ModelConfig = all_model_cfg[args.variant]
67
- model.download_if_needed()
68
  seq_cfg = model.seq_cfg
69
 
70
- if args.video:
71
- #video_path: Path = Path(args.video).expanduser()
72
- video_path = args.video
73
- else:
74
- video_path = None
75
- prompt: str = args.prompt
76
- negative_prompt: str = args.negative_prompt
77
- output_dir: str = args.output.expanduser()
78
  seed: int = args.seed
79
- num_steps: int = args.num_steps
80
  duration: float = args.duration
81
  cfg_strength: float = args.cfg_strength
82
  skip_video_composite: bool = args.skip_video_composite
83
- mask_away_clip: bool = args.mask_away_clip
84
 
85
  device = 'cpu'
86
  if torch.cuda.is_available():
@@ -92,19 +92,26 @@ def main():
92
  print("full_precision", args.full_precision)
93
  dtype = torch.float32 if args.full_precision else torch.bfloat16
94
 
95
- output_dir.mkdir(parents=True, exist_ok=True)
96
 
97
  # load a pretrained model
98
  net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
99
  ####model.model_path = "/ailab-train/speech/zhanghaomin/codes3/MMAudio-main/output/exp_1/exp_1_shadow.pth"
 
 
100
  net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
101
  log.info(f'Loaded weights from {model.model_path}')
102
 
103
  # misc setup
104
  rng = torch.Generator(device=device)
105
  rng.manual_seed(seed)
106
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
107
 
 
 
 
 
 
108
  feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
109
  synchformer_ckpt=model.synchformer_ckpt,
110
  enable_conditions=True,
@@ -112,7 +119,19 @@ def main():
112
  bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
113
  need_vae_encoder=False)
114
  feature_utils = feature_utils.to(device, dtype).eval()
 
 
115
 
 
 
 
 
 
 
 
 
 
 
116
  ####test_scp = "/ailab-train/speech/zhanghaomin/animation_dataset_v2a/test.scp"
117
  #test_scp = "/ailab-train/speech/zhanghaomin/datas/v2cdata/tmp.scp"
118
  #test_scp = "/ailab-train/speech/zhanghaomin/datas/v2cdata/test.scp"
 
29
 
30
 
31
  @torch.inference_mode()
32
+ def v2a_load():
33
  setup_eval_logging()
34
 
35
  parser = ArgumentParser()
36
  parser.add_argument('--variant',
37
  type=str,
38
+ #default='large_44k',
39
  #default='small_16k',
40
  #default='medium_44k',
41
+ default='small_44k',
42
  help='small_16k, small_44k, medium_44k, large_44k, large_44k_v2')
43
  parser.add_argument('--video', type=Path, help='Path to the video file')
44
  parser.add_argument('--prompt', type=str, help='Input prompt', default='')
 
64
  if args.variant not in all_model_cfg:
65
  raise ValueError(f'Unknown model variant: {args.variant}')
66
  model: ModelConfig = all_model_cfg[args.variant]
67
+ #model.download_if_needed()
68
  seq_cfg = model.seq_cfg
69
 
70
+ #if args.video:
71
+ # #video_path: Path = Path(args.video).expanduser()
72
+ # video_path = args.video
73
+ #else:
74
+ # video_path = None
75
+ #prompt: str = args.prompt
76
+ #negative_prompt: str = args.negative_prompt
77
+ #output_dir: str = args.output.expanduser()
78
  seed: int = args.seed
79
+ #num_steps: int = args.num_steps
80
  duration: float = args.duration
81
  cfg_strength: float = args.cfg_strength
82
  skip_video_composite: bool = args.skip_video_composite
83
+ #mask_away_clip: bool = args.mask_away_clip
84
 
85
  device = 'cpu'
86
  if torch.cuda.is_available():
 
92
  print("full_precision", args.full_precision)
93
  dtype = torch.float32 if args.full_precision else torch.bfloat16
94
 
95
+ #output_dir.mkdir(parents=True, exist_ok=True)
96
 
97
  # load a pretrained model
98
  net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
99
  ####model.model_path = "/ailab-train/speech/zhanghaomin/codes3/MMAudio-main/output/exp_1/exp_1_shadow.pth"
100
+ model.model_path = "MMAudio" / model.model_path
101
+ print("model.model_path", model.model_path)
102
  net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
103
  log.info(f'Loaded weights from {model.model_path}')
104
 
105
  # misc setup
106
  rng = torch.Generator(device=device)
107
  rng.manual_seed(seed)
108
+ #fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
109
 
110
+ model.vae_path = "MMAudio" / model.vae_path
111
+ model.synchformer_ckpt = "MMAudio" / model.synchformer_ckpt
112
+ print("model.vae_path", model.vae_path)
113
+ print("model.synchformer_ckpt", model.synchformer_ckpt)
114
+ print("model.bigvgan_16k_path", model.bigvgan_16k_path)
115
  feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
116
  synchformer_ckpt=model.synchformer_ckpt,
117
  enable_conditions=True,
 
119
  bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
120
  need_vae_encoder=False)
121
  feature_utils = feature_utils.to(device, dtype).eval()
122
+ return net, seq_cfg, rng, feature_utils, args
123
+
124
 
125
+ @torch.inference_mode()
126
+ def v2a_infer(output_dir, video_path, prompt, num_steps, loaded):
127
+ net, seq_cfg, rng, feature_utils, args = loaded
128
+ negative_prompt = ""
129
+ duration = args.duration
130
+ cfg_strength = args.cfg_strength
131
+ skip_video_composite = args.skip_video_composite
132
+ mask_away_clip = args.mask_away_clip
133
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
134
+
135
  ####test_scp = "/ailab-train/speech/zhanghaomin/animation_dataset_v2a/test.scp"
136
  #test_scp = "/ailab-train/speech/zhanghaomin/datas/v2cdata/tmp.scp"
137
  #test_scp = "/ailab-train/speech/zhanghaomin/datas/v2cdata/test.scp"
app.py CHANGED
@@ -22,18 +22,31 @@ import numpy as np
22
 
23
  from huggingface_hub import hf_hub_download
24
 
25
- model_path = "./F5-TTS/ckpts/v2c/"
 
26
 
27
- if not os.path.exists(model_path):
28
- os.makedirs(model_path)
29
 
30
- file_path = hf_hub_download(repo_id="lshzhm/DeepAudio-V1", filename="v2c_s44.pt", local_dir=model_path)
31
 
32
- print(f"Model saved at: {file_path}")
33
 
34
  log = logging.getLogger()
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  #@spaces.GPU(duration=120)
38
  def video_to_audio_and_speech(video: gr.Video, prompt: str, v2a_num_steps: int, text: str, audio_prompt: gr.Audio, text_prompt: str, v2s_num_steps: int):
39
 
@@ -64,18 +77,26 @@ def video_to_audio_and_speech(video: gr.Video, prompt: str, v2a_num_steps: int,
64
  else:
65
  shutil.copy(audio_prompt, audio_p_path)
66
 
67
- if prompt == "":
68
- command = "cd ./MMAudio; python ./demo.py --variant small_44k --output %s --video %s --calc_energy 1 --num_steps %d" % (output_dir, video_path, v2a_num_steps)
69
- else:
70
- command = "cd ./MMAudio; python ./demo.py --variant small_44k --output %s --video %s --prompt %s --calc_energy 1 --num_steps %d" % (output_dir, video_path, prompt, v2a_num_steps)
71
- print("v2a command", command)
72
- os.system(command)
 
 
 
 
73
 
74
  video_gen = video_save_path[:-4]+".mp4.gen.mp4"
75
 
76
- command = "python ./F5-TTS/src/f5_tts/infer/infer_cli_test.py --output_dir %s --start 0 --end 1 --ckpt_file ./F5-TTS/ckpts/v2c/v2c_s44.pt --v2a_path %s --wav_p %s --txt_p \"%s\" --video %s --v2a_wav %s --txt \"%s\" --nfe_step %d" % (output_dir, output_dir, audio_p_path, text_prompt, video_save_path, video_save_path[:-4]+".flac", text, v2s_num_steps)
77
- print("v2s command", command, video_gen)
78
- os.system(command)
 
 
 
 
79
 
80
  return video_save_path, video_gen
81
 
 
22
 
23
  from huggingface_hub import hf_hub_download
24
 
25
+ if True:
26
+ model_path = "./F5-TTS/ckpts/v2c/"
27
 
28
+ if not os.path.exists(model_path):
29
+ os.makedirs(model_path)
30
 
31
+ file_path = hf_hub_download(repo_id="lshzhm/DeepAudio-V1", filename="v2c_s44.pt", local_dir=model_path)
32
 
33
+ print(f"Model saved at: {file_path}")
34
 
35
  log = logging.getLogger()
36
 
37
 
38
+ import sys
39
+ sys.path.insert(0, "./F5-TTS/src/")
40
+ from f5_tts.infer.infer_cli_test import v2s_infer
41
+
42
+
43
+ import sys
44
+ sys.path.insert(0, "./MMAudio/")
45
+ from demo import v2a_load, v2a_infer
46
+
47
+ v2a_loaded = v2a_load()
48
+
49
+
50
  #@spaces.GPU(duration=120)
51
  def video_to_audio_and_speech(video: gr.Video, prompt: str, v2a_num_steps: int, text: str, audio_prompt: gr.Audio, text_prompt: str, v2s_num_steps: int):
52
 
 
77
  else:
78
  shutil.copy(audio_prompt, audio_p_path)
79
 
80
+ #if prompt == "":
81
+ # command = "cd ./MMAudio; python ./demo.py --variant small_44k --output %s --video %s --calc_energy 1 --num_steps %d" % (output_dir, video_path, v2a_num_steps)
82
+ #else:
83
+ # command = "cd ./MMAudio; python ./demo.py --variant small_44k --output %s --video %s --prompt %s --calc_energy 1 --num_steps %d" % (output_dir, video_path, prompt, v2a_num_steps)
84
+ #print("v2a command", command)
85
+ #os.system(command)
86
+
87
+
88
+ v2a_infer(output_dir, video_path, prompt, v2a_num_steps, v2a_loaded)
89
+
90
 
91
  video_gen = video_save_path[:-4]+".mp4.gen.mp4"
92
 
93
+ #command = "python ./F5-TTS/src/f5_tts/infer/infer_cli_test.py --output_dir %s --start 0 --end 1 --ckpt_file ./F5-TTS/ckpts/v2c/v2c_s44.pt --v2a_path %s --wav_p %s --txt_p \"%s\" --video %s --v2a_wav %s --txt \"%s\" --nfe_step %d" % (output_dir, output_dir, audio_p_path, text_prompt, video_save_path, video_save_path[:-4]+".flac", text, v2s_num_steps)
94
+ #print("v2s command", command, video_gen)
95
+ #os.system(command)
96
+
97
+
98
+ v2s_infer(output_dir, output_dir, audio_p_path, text_prompt, video_save_path, video_save_path[:-4]+".flac", text, v2s_num_steps)
99
+
100
 
101
  return video_save_path, video_gen
102