Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,8 @@
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
-
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
4 |
import soundfile as sf
|
5 |
import numpy as np
|
6 |
-
import fairseq
|
7 |
-
import IPython.display as ipd
|
8 |
-
import os # Import the 'os' module
|
9 |
-
|
10 |
-
commit_hash = "8798153927c22132778bef7b507d389474fa3589" # Example - find a suitable one!
|
11 |
|
12 |
# --- Whisper (ASR) Setup ---
|
13 |
ASR_MODEL_NAME = "openai/whisper-large-v2"
|
@@ -22,27 +17,16 @@ all_special_ids = asr_pipe.tokenizer.all_special_ids
|
|
22 |
transcribe_token_id = all_special_ids[-5]
|
23 |
translate_token_id = all_special_ids[-6]
|
24 |
|
25 |
-
# ---
|
26 |
-
TTS_MODEL_NAME = "
|
|
|
|
|
27 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
28 |
|
29 |
-
# Download the model files if they don't exist
|
30 |
-
if not os.path.exists("fastspeech2_model"):
|
31 |
-
os.makedirs("fastspeech2_model")
|
32 |
-
print("Downloading FastSpeech2 model...")
|
33 |
-
os.system(f"wget https://huggingface.co/{TTS_MODEL_NAME}/resolve/{commit_hash}/pytorch_model.pt -O fastspeech2_model/pytorch_model.pt")
|
34 |
-
os.system(f"wget https://huggingface.co/{TTS_MODEL_NAME}/resolve/{commit_hash}/vocab.txt -O fastspeech2_model/vocab.txt")
|
35 |
-
print("Download complete.")
|
36 |
-
|
37 |
-
# Load the model using fairseq 0.10.2 compatible methods.
|
38 |
-
tts_model_path = "fastspeech2_model/pytorch_model.pt" # Path to the downloaded model
|
39 |
-
tts_model, tts_cfg, tts_task = fairseq.checkpoint_utils.load_model_ensemble_and_task([tts_model_path])
|
40 |
-
tts_model = tts_model[0]
|
41 |
-
tts_model.to(tts_device)
|
42 |
-
tts_model.eval()
|
43 |
|
44 |
# --- Vicuna (LLM) Setup ---
|
45 |
-
VICUNA_MODEL_NAME = "lmsys/vicuna-
|
46 |
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
48 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
@@ -76,39 +60,28 @@ def transcribe_audio(microphone, state, task="transcribe"):
|
|
76 |
|
77 |
updated_state = state + "\n" + vicuna_response
|
78 |
return updated_state, updated_state
|
79 |
-
|
|
|
80 |
def synthesize_speech(text):
|
81 |
try:
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
else:
|
89 |
-
sample = sample
|
90 |
-
|
91 |
-
# Generate
|
92 |
-
generator = tts_task.build_generator([tts_model], tts_cfg.task) # Pass the task
|
93 |
-
output = generator.generate([tts_model], sample) # Generate using the generator
|
94 |
|
95 |
-
# Extract waveform and sample rate.
|
96 |
-
waveform = output[0][0]['waveform']
|
97 |
-
sample_rate = tts_cfg.task.sample_rate # Get the rate
|
98 |
-
|
99 |
-
# Convert to NumPy (and ensure CPU)
|
100 |
waveform_np = waveform.cpu().numpy()
|
101 |
-
|
102 |
-
return (
|
103 |
-
|
104 |
|
105 |
except Exception as e:
|
106 |
print(e)
|
107 |
return (None, None)
|
108 |
|
109 |
# --- Gradio Interface ---
|
110 |
-
with gr.Blocks(title="Whisper, Vicuna, &
|
111 |
-
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna")
|
112 |
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
|
113 |
|
114 |
with gr.Tab("Transcribe & Synthesize"):
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModelForTextToSpeech, AutoProcessor
|
4 |
import soundfile as sf
|
5 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# --- Whisper (ASR) Setup ---
|
8 |
ASR_MODEL_NAME = "openai/whisper-large-v2"
|
|
|
17 |
transcribe_token_id = all_special_ids[-5]
|
18 |
translate_token_id = all_special_ids[-6]
|
19 |
|
20 |
+
# --- VITS (TTS) Setup - Using transformers ---
|
21 |
+
TTS_MODEL_NAME = "espnet/kan_bayashi_ljspeech_vits" # Changed to VITS model
|
22 |
+
tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
|
23 |
+
tts_model = AutoModelForTextToSpeech.from_pretrained(TTS_MODEL_NAME)
|
24 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
+
tts_model = tts_model.to(tts_device)
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# --- Vicuna (LLM) Setup ---
|
29 |
+
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or your preferred Vicuna
|
30 |
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
32 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
|
|
60 |
|
61 |
updated_state = state + "\n" + vicuna_response
|
62 |
return updated_state, updated_state
|
63 |
+
|
64 |
+
# --- TTS Function (Simplified for VITS) ---
|
65 |
def synthesize_speech(text):
|
66 |
try:
|
67 |
+
inputs = tts_processor(text=text, return_tensors="pt")
|
68 |
+
inputs = {key: value.to(tts_device) for key, value in inputs.items()}
|
69 |
+
with torch.no_grad():
|
70 |
+
output = tts_model(**inputs).spectrogram # VITS models often output a spectrogram
|
71 |
+
# Convert spectrogram to waveform using the vocoder
|
72 |
+
waveform = tts_model.vocoder(output).squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
|
|
|
|
|
|
|
|
|
|
74 |
waveform_np = waveform.cpu().numpy()
|
75 |
+
#VITS models use a sample rate of 22050
|
76 |
+
return (22050, waveform_np)
|
|
|
77 |
|
78 |
except Exception as e:
|
79 |
print(e)
|
80 |
return (None, None)
|
81 |
|
82 |
# --- Gradio Interface ---
|
83 |
+
with gr.Blocks(title="Whisper, Vicuna, & VITS Demo") as demo: # Updated title
|
84 |
+
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and VITS")
|
85 |
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
|
86 |
|
87 |
with gr.Tab("Transcribe & Synthesize"):
|