ford442 commited on
Commit
447f99a
·
verified ·
1 Parent(s): 2dc1223

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -27
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, SpeechEncoderDecoderModel, AutoProcessor, FastSpeech2Config, FastSpeech2Model
4
  import soundfile as sf
5
  import numpy as np
6
 
@@ -17,26 +17,23 @@ all_special_ids = asr_pipe.tokenizer.all_special_ids
17
  transcribe_token_id = all_special_ids[-5]
18
  translate_token_id = all_special_ids[-6]
19
 
20
-
21
  # --- FastSpeech2 (TTS) Setup ---
22
  TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech"
23
 
24
- # Try loading the processor and config with trust_remote_code
25
- try:
26
- tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
27
- tts_config = FastSpeech2Config.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
28
- tts_model = SpeechEncoderDecoderModel.from_pretrained(TTS_MODEL_NAME, config=tts_config, trust_remote_code=True)
29
- except ValueError as e:
30
- print(f"Error loading with trust_remote_code: {e}")
31
- # Fallback to manual loading (explained below)
32
- exit() # Stop if we can't load
33
-
34
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
35
  tts_model = tts_model.to(tts_device)
 
36
  # --- Vicuna (LLM) Setup ---
37
- VICUNA_MODEL_NAME = "lmsys/vicuna-33b-v1.3"
38
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
39
-
40
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
41
  vicuna_model = AutoModelForCausalLM.from_pretrained(
42
  VICUNA_MODEL_NAME,
@@ -45,9 +42,6 @@ vicuna_model = AutoModelForCausalLM.from_pretrained(
45
  device_map="auto",
46
  )
47
 
48
- # --- ASR and TTS Functions (and Gradio Interface) ---
49
- # (Rest of your code - transcribe_audio, synthesize_speech, Gradio setup)
50
- # ... (same as before, but using tts_model, tts_processor, and tts_config) ...
51
  # --- ASR Function ---
52
  def transcribe_audio(microphone, state, task="transcribe"):
53
  if microphone is None:
@@ -64,11 +58,11 @@ def transcribe_audio(microphone, state, task="transcribe"):
64
 
65
  prompt = f"{system_prompt}\nUser: {text}"
66
 
67
- with torch.no_grad(): # Disable gradient calculation
68
- vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
69
- vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128) # Limit response length
70
- vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
71
- vicuna_response = vicuna_response.replace(prompt, "").strip()
72
 
73
  updated_state = state + "\n" + vicuna_response
74
  return updated_state, updated_state
@@ -79,7 +73,7 @@ def synthesize_speech(text):
79
  inputs = tts_processor(text=text, return_tensors="pt")
80
  inputs = {key: value.to(tts_device) for key, value in inputs.items()}
81
  with torch.no_grad():
82
- output = tts_model.generate(**inputs)
83
  output = output.cpu()
84
  waveform = output.squeeze().numpy()
85
  return (tts_processor.feature_extractor.sampling_rate, waveform)
@@ -88,13 +82,13 @@ def synthesize_speech(text):
88
  return (None, None)
89
 
90
  # --- Gradio Interface ---
91
- with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo: # More descriptive title
92
  gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna")
93
  gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
94
 
95
  with gr.Tab("Transcribe & Synthesize"):
96
  mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
97
- transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response") # Combined output
98
  audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
99
  transcription_state = gr.State(value="")
100
 
@@ -104,8 +98,8 @@ with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo: # More des
104
  outputs=[transcription_output, transcription_state]
105
  ).then(
106
  fn=synthesize_speech,
107
- inputs=transcription_output, # Use the combined output as input for TTS
108
  outputs=audio_output
109
  )
110
 
111
- demo.launch(enable_queue=True)
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig
4
  import soundfile as sf
5
  import numpy as np
6
 
 
17
  transcribe_token_id = all_special_ids[-5]
18
  translate_token_id = all_special_ids[-6]
19
 
 
20
  # --- FastSpeech2 (TTS) Setup ---
21
  TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech"
22
 
23
+ # Load the config (we'll need it for the model class)
24
+ tts_config = AutoConfig.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
25
+
26
+ # Load the processor and model, using trust_remote_code
27
+ tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
28
+ tts_model = AutoModelForTextToSpeech.from_pretrained(TTS_MODEL_NAME, config=tts_config, trust_remote_code=True)
29
+
30
+
 
 
31
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
32
  tts_model = tts_model.to(tts_device)
33
+
34
  # --- Vicuna (LLM) Setup ---
35
+ VICUNA_MODEL_NAME = "lmsys/vicuna-33b-v1.3" # Or a smaller Vicuna model
36
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
 
37
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
38
  vicuna_model = AutoModelForCausalLM.from_pretrained(
39
  VICUNA_MODEL_NAME,
 
42
  device_map="auto",
43
  )
44
 
 
 
 
45
  # --- ASR Function ---
46
  def transcribe_audio(microphone, state, task="transcribe"):
47
  if microphone is None:
 
58
 
59
  prompt = f"{system_prompt}\nUser: {text}"
60
 
61
+ with torch.no_grad():
62
+ vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
63
+ vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
64
+ vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
65
+ vicuna_response = vicuna_response.replace(prompt, "").strip()
66
 
67
  updated_state = state + "\n" + vicuna_response
68
  return updated_state, updated_state
 
73
  inputs = tts_processor(text=text, return_tensors="pt")
74
  inputs = {key: value.to(tts_device) for key, value in inputs.items()}
75
  with torch.no_grad():
76
+ output = tts_model(**inputs).waveform # Use the model directly, it outputs a waveform
77
  output = output.cpu()
78
  waveform = output.squeeze().numpy()
79
  return (tts_processor.feature_extractor.sampling_rate, waveform)
 
82
  return (None, None)
83
 
84
  # --- Gradio Interface ---
85
+ with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo:
86
  gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna")
87
  gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
88
 
89
  with gr.Tab("Transcribe & Synthesize"):
90
  mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
91
+ transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
92
  audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
93
  transcription_state = gr.State(value="")
94
 
 
98
  outputs=[transcription_output, transcription_state]
99
  ).then(
100
  fn=synthesize_speech,
101
+ inputs=transcription_output,
102
  outputs=audio_output
103
  )
104
 
105
+ demo.launch(enable_queue=True, share=False) # share=False is usually better for local development