ford442 commited on
Commit
8213d9e
·
verified ·
1 Parent(s): b56cef1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -40
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
4
  import soundfile as sf
5
  import numpy as np
6
- import importlib
 
 
7
 
8
  # --- Whisper (ASR) Setup ---
9
  ASR_MODEL_NAME = "openai/whisper-large-v2"
@@ -18,38 +20,26 @@ all_special_ids = asr_pipe.tokenizer.all_special_ids
18
  transcribe_token_id = all_special_ids[-5]
19
  translate_token_id = all_special_ids[-6]
20
 
21
- # --- FastSpeech2 (TTS) Setup ---
22
- TTS_MODEL_NAME = "ford442/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech"
23
-
24
- # 1. Load the processor (we still need trust_remote_code for this)
25
- tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
26
-
27
- # 2. Load the model using the *custom* modeling file. This is the key.
28
- # We CANNOT use AutoConfig or AutoModel here.
29
- model_file_path = f"models--{TTS_MODEL_NAME.replace('/', '--')}/snapshots"
30
-
31
- import os
32
- # Find the commit hash - this is needed because of the way Hugging Face caches models.
33
- for d in os.listdir(os.path.expanduser(f"~/.cache/huggingface/hub/{model_file_path}")):
34
- if os.path.isdir(os.path.expanduser(f"~/.cache/huggingface/hub/{model_file_path}/{d}")) and not d.startswith("."):
35
- commit_hash = d
36
- break
37
- else:
38
- raise ValueError ("Cannot find the model")
39
- model_file_path += f"/{commit_hash}/modeling_fastspeech2.py"
40
 
41
- # Use importlib to import the custom modeling file.
42
- spec = importlib.util.spec_from_file_location("modeling_fastspeech2", os.path.expanduser(f"~/.cache/huggingface/hub/{model_file_path}"))
43
- fastspeech2_module = importlib.util.module_from_spec(spec)
44
- spec.loader.exec_module(fastspeech2_module)
45
- tts_model = fastspeech2_module.FastSpeech2.from_pretrained(TTS_MODEL_NAME) #Use the actual class name!
 
 
 
46
 
47
- tts_device = "cuda" if torch.cuda.is_available() else "cpu"
48
- tts_model = tts_model.to(tts_device)
 
49
 
50
 
51
  # --- Vicuna (LLM) Setup ---
52
- VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Use a smaller model if needed
53
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
54
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
55
  vicuna_model = AutoModelForCausalLM.from_pretrained(
@@ -59,9 +49,6 @@ vicuna_model = AutoModelForCausalLM.from_pretrained(
59
  device_map="auto",
60
  )
61
 
62
- # --- ASR and TTS Functions (and Gradio Interface) ---
63
- # (Same as before, but using tts_model and tts_processor)
64
-
65
  # --- ASR Function ---
66
  def transcribe_audio(microphone, state, task="transcribe"):
67
  if microphone is None:
@@ -87,16 +74,22 @@ def transcribe_audio(microphone, state, task="transcribe"):
87
  updated_state = state + "\n" + vicuna_response
88
  return updated_state, updated_state
89
 
90
- # --- TTS Function ---
91
  def synthesize_speech(text):
92
  try:
93
- inputs = tts_processor(text=text, return_tensors="pt")
94
- inputs = {key: value.to(tts_device) for key, value in inputs.items()}
95
- with torch.no_grad():
96
- output = tts_model(**inputs).waveform # Use .waveform
97
- output = output.cpu()
98
- waveform = output.squeeze().numpy()
99
- return (tts_processor.feature_extractor.sampling_rate, waveform)
 
 
 
 
 
 
100
  except Exception as e:
101
  print(e)
102
  return (None, None)
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
  import soundfile as sf
5
  import numpy as np
6
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
7
+ from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
8
+ import IPython.display as ipd # We still need this if running in a notebook
9
 
10
  # --- Whisper (ASR) Setup ---
11
  ASR_MODEL_NAME = "openai/whisper-large-v2"
 
20
  transcribe_token_id = all_special_ids[-5]
21
  translate_token_id = all_special_ids[-6]
22
 
23
+ # --- FastSpeech2 (TTS) Setup - Using fairseq ---
24
+ TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech"
25
+ tts_device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Load the fairseq model, config, and task.
28
+ tts_models, tts_cfg, tts_task = load_model_ensemble_and_task_from_hf_hub(
29
+ TTS_MODEL_NAME,
30
+ arg_overrides={"vocoder": "hifigan", "fp16": False}
31
+ )
32
+ tts_model = tts_models[0]
33
+ TTSHubInterface.update_cfg_with_data_cfg(tts_cfg, tts_task.data_cfg)
34
+ tts_generator = tts_task.build_generator(tts_model, tts_cfg)
35
 
36
+ # Move the fairseq model to the correct device.
37
+ tts_model.to(tts_device)
38
+ tts_model.eval() # Put the model in evaluation mode
39
 
40
 
41
  # --- Vicuna (LLM) Setup ---
42
+ VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or your preferred Vicuna
43
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
44
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
45
  vicuna_model = AutoModelForCausalLM.from_pretrained(
 
49
  device_map="auto",
50
  )
51
 
 
 
 
52
  # --- ASR Function ---
53
  def transcribe_audio(microphone, state, task="transcribe"):
54
  if microphone is None:
 
74
  updated_state = state + "\n" + vicuna_response
75
  return updated_state, updated_state
76
 
77
+ # --- TTS Function (Modified for fairseq) ---
78
  def synthesize_speech(text):
79
  try:
80
+ sample = TTSHubInterface.get_model_input(tts_task, text)
81
+
82
+ # Move input tensors to the correct device
83
+ if torch.cuda.is_available():
84
+ sample['net_input'] = {k: v.cuda() for k, v in sample['net_input'].items()}
85
+ else:
86
+ sample['net_input'] = {k: v.cpu() for k, v in sample['net_input'].items()}
87
+
88
+ wav, rate = TTSHubInterface.get_prediction(tts_task, tts_model, tts_generator, sample)
89
+ wav_numpy = wav.cpu().numpy() # fairseq returns a tensor, not a numpy array
90
+
91
+ return (rate, wav_numpy) # Return rate and NumPy array
92
+
93
  except Exception as e:
94
  print(e)
95
  return (None, None)