Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
-
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer,
|
4 |
import soundfile as sf
|
5 |
import numpy as np
|
6 |
|
@@ -17,26 +17,23 @@ all_special_ids = asr_pipe.tokenizer.all_special_ids
|
|
17 |
transcribe_token_id = all_special_ids[-5]
|
18 |
translate_token_id = all_special_ids[-6]
|
19 |
|
20 |
-
|
21 |
# --- FastSpeech2 (TTS) Setup ---
|
22 |
TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech"
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
exit() # Stop if we can't load
|
33 |
-
|
34 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
|
35 |
tts_model = tts_model.to(tts_device)
|
|
|
36 |
# --- Vicuna (LLM) Setup ---
|
37 |
-
VICUNA_MODEL_NAME = "lmsys/vicuna-33b-v1.3"
|
38 |
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
|
39 |
-
|
40 |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
41 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
42 |
VICUNA_MODEL_NAME,
|
@@ -45,9 +42,6 @@ vicuna_model = AutoModelForCausalLM.from_pretrained(
|
|
45 |
device_map="auto",
|
46 |
)
|
47 |
|
48 |
-
# --- ASR and TTS Functions (and Gradio Interface) ---
|
49 |
-
# (Rest of your code - transcribe_audio, synthesize_speech, Gradio setup)
|
50 |
-
# ... (same as before, but using tts_model, tts_processor, and tts_config) ...
|
51 |
# --- ASR Function ---
|
52 |
def transcribe_audio(microphone, state, task="transcribe"):
|
53 |
if microphone is None:
|
@@ -64,11 +58,11 @@ def transcribe_audio(microphone, state, task="transcribe"):
|
|
64 |
|
65 |
prompt = f"{system_prompt}\nUser: {text}"
|
66 |
|
67 |
-
with torch.no_grad():
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
|
73 |
updated_state = state + "\n" + vicuna_response
|
74 |
return updated_state, updated_state
|
@@ -79,7 +73,7 @@ def synthesize_speech(text):
|
|
79 |
inputs = tts_processor(text=text, return_tensors="pt")
|
80 |
inputs = {key: value.to(tts_device) for key, value in inputs.items()}
|
81 |
with torch.no_grad():
|
82 |
-
output = tts_model
|
83 |
output = output.cpu()
|
84 |
waveform = output.squeeze().numpy()
|
85 |
return (tts_processor.feature_extractor.sampling_rate, waveform)
|
@@ -88,13 +82,13 @@ def synthesize_speech(text):
|
|
88 |
return (None, None)
|
89 |
|
90 |
# --- Gradio Interface ---
|
91 |
-
with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo:
|
92 |
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna")
|
93 |
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
|
94 |
|
95 |
with gr.Tab("Transcribe & Synthesize"):
|
96 |
mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
|
97 |
-
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
|
98 |
audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
|
99 |
transcription_state = gr.State(value="")
|
100 |
|
@@ -104,8 +98,8 @@ with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo: # More des
|
|
104 |
outputs=[transcription_output, transcription_state]
|
105 |
).then(
|
106 |
fn=synthesize_speech,
|
107 |
-
inputs=transcription_output,
|
108 |
outputs=audio_output
|
109 |
)
|
110 |
|
111 |
-
demo.launch(enable_queue=True)
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
4 |
import soundfile as sf
|
5 |
import numpy as np
|
6 |
|
|
|
17 |
transcribe_token_id = all_special_ids[-5]
|
18 |
translate_token_id = all_special_ids[-6]
|
19 |
|
|
|
20 |
# --- FastSpeech2 (TTS) Setup ---
|
21 |
TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech"
|
22 |
|
23 |
+
# Load the config (we'll need it for the model class)
|
24 |
+
tts_config = AutoConfig.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
|
25 |
+
|
26 |
+
# Load the processor and model, using trust_remote_code
|
27 |
+
tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
|
28 |
+
tts_model = AutoModelForTextToSpeech.from_pretrained(TTS_MODEL_NAME, config=tts_config, trust_remote_code=True)
|
29 |
+
|
30 |
+
|
|
|
|
|
31 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
tts_model = tts_model.to(tts_device)
|
33 |
+
|
34 |
# --- Vicuna (LLM) Setup ---
|
35 |
+
VICUNA_MODEL_NAME = "lmsys/vicuna-33b-v1.3" # Or a smaller Vicuna model
|
36 |
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
37 |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
38 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
39 |
VICUNA_MODEL_NAME,
|
|
|
42 |
device_map="auto",
|
43 |
)
|
44 |
|
|
|
|
|
|
|
45 |
# --- ASR Function ---
|
46 |
def transcribe_audio(microphone, state, task="transcribe"):
|
47 |
if microphone is None:
|
|
|
58 |
|
59 |
prompt = f"{system_prompt}\nUser: {text}"
|
60 |
|
61 |
+
with torch.no_grad():
|
62 |
+
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
|
63 |
+
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
|
64 |
+
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
|
65 |
+
vicuna_response = vicuna_response.replace(prompt, "").strip()
|
66 |
|
67 |
updated_state = state + "\n" + vicuna_response
|
68 |
return updated_state, updated_state
|
|
|
73 |
inputs = tts_processor(text=text, return_tensors="pt")
|
74 |
inputs = {key: value.to(tts_device) for key, value in inputs.items()}
|
75 |
with torch.no_grad():
|
76 |
+
output = tts_model(**inputs).waveform # Use the model directly, it outputs a waveform
|
77 |
output = output.cpu()
|
78 |
waveform = output.squeeze().numpy()
|
79 |
return (tts_processor.feature_extractor.sampling_rate, waveform)
|
|
|
82 |
return (None, None)
|
83 |
|
84 |
# --- Gradio Interface ---
|
85 |
+
with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo:
|
86 |
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna")
|
87 |
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
|
88 |
|
89 |
with gr.Tab("Transcribe & Synthesize"):
|
90 |
mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
|
91 |
+
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
|
92 |
audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
|
93 |
transcription_state = gr.State(value="")
|
94 |
|
|
|
98 |
outputs=[transcription_output, transcription_state]
|
99 |
).then(
|
100 |
fn=synthesize_speech,
|
101 |
+
inputs=transcription_output,
|
102 |
outputs=audio_output
|
103 |
)
|
104 |
|
105 |
+
demo.launch(enable_queue=True, share=False) # share=False is usually better for local development
|