Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,11 +2,13 @@ import spaces
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel
|
5 |
-
import soundfile as sf
|
6 |
import numpy as np
|
7 |
from espnet2.bin.tts_inference import Text2Speech
|
8 |
-
import yaml # Import yaml for config loading
|
9 |
-
import os
|
|
|
|
|
10 |
|
11 |
# Load Whisper model
|
12 |
ASR_MODEL_NAME = "openai/whisper-medium.en"
|
@@ -52,15 +54,19 @@ def process_audio(microphone, state, task="transcribe"):
|
|
52 |
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=192)
|
53 |
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) # Access the first sequence [0]
|
54 |
vicuna_response = vicuna_response.replace(prompt, "").strip()
|
55 |
-
updated_state = state + "\n" + vicuna_response
|
56 |
try:
|
57 |
with torch.no_grad():
|
58 |
-
|
|
|
|
|
|
|
59 |
audio_arr = wav.cpu().numpy()
|
|
|
60 |
SAMPLE_RATE = sr
|
61 |
audio_arr = audio_arr / np.abs(audio_arr).max() # Normalize to -1 to 1
|
62 |
audio_output = (SAMPLE_RATE, audio_arr)
|
63 |
-
#sf.write('generated_audio.wav', audio_arr, SAMPLE_RATE)
|
64 |
except requests.exceptions.RequestException as e:
|
65 |
print(f"Error in Hugging Face API request: {e}")
|
66 |
audio_output = None
|
@@ -79,7 +85,7 @@ with gr.Blocks(title="Whisper, Vicuna, & TTS Demo") as demo: # Updated title
|
|
79 |
transcription_state = gr.State(value="")
|
80 |
mic_input.change(
|
81 |
fn=process_audio, # Call the combined function
|
82 |
-
inputs=[mic_input, transcription_state],
|
83 |
outputs=[transcription_output, transcription_state, audio_output]
|
84 |
)
|
85 |
|
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel
|
5 |
+
#import soundfile as sf # Removed: Not directly used for outputting audio to Gradio
|
6 |
import numpy as np
|
7 |
from espnet2.bin.tts_inference import Text2Speech
|
8 |
+
import yaml # Import yaml for config loading (though not used in the current code, kept for potential future use)
|
9 |
+
import os # Kept for potential future use (e.g., if loading config from files)
|
10 |
+
import requests # Corrected: Import the 'requests' library
|
11 |
+
|
12 |
|
13 |
# Load Whisper model
|
14 |
ASR_MODEL_NAME = "openai/whisper-medium.en"
|
|
|
54 |
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=192)
|
55 |
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) # Access the first sequence [0]
|
56 |
vicuna_response = vicuna_response.replace(prompt, "").strip()
|
57 |
+
updated_state = state + "\nUser: " + text + "\n" + "Tutor: " + vicuna_response # Include user input in state
|
58 |
try:
|
59 |
with torch.no_grad():
|
60 |
+
# The espnet TTS model outputs a dictionary
|
61 |
+
output = tts(vicuna_response)
|
62 |
+
wav = output["wav"]
|
63 |
+
sr = tts.fs # Get the sampling rate from the tts object
|
64 |
audio_arr = wav.cpu().numpy()
|
65 |
+
|
66 |
SAMPLE_RATE = sr
|
67 |
audio_arr = audio_arr / np.abs(audio_arr).max() # Normalize to -1 to 1
|
68 |
audio_output = (SAMPLE_RATE, audio_arr)
|
69 |
+
#sf.write('generated_audio.wav', audio_arr, SAMPLE_RATE) # Removed writing to file
|
70 |
except requests.exceptions.RequestException as e:
|
71 |
print(f"Error in Hugging Face API request: {e}")
|
72 |
audio_output = None
|
|
|
85 |
transcription_state = gr.State(value="")
|
86 |
mic_input.change(
|
87 |
fn=process_audio, # Call the combined function
|
88 |
+
inputs=[mic_input, transcription_state, gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")],
|
89 |
outputs=[transcription_output, transcription_state, audio_output]
|
90 |
)
|
91 |
|