vicuna-clip / app.py
ford442's picture
Update app.py
d70f358 verified
raw
history blame
4.64 kB
import spaces
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel
import numpy as np
from espnet2.bin.tts_inference import Text2Speech
import yaml # Import yaml for config loading (though not used in the current code, kept for potential future use)
import os # Kept for potential future use (e.g., if loading config from files)
import requests # Corrected: Import the 'requests' library
import nltk # Import nltk
# Download required NLTK resources
try:
nltk.data.find('taggers/averaged_perceptron_tagger_eng')
except LookupError:
nltk.download('averaged_perceptron_tagger_eng')
try:
nltk.data.find('corpora/cmudict') # Check for cmudict
except LookupError:
nltk.download('cmudict')
# Load Whisper model
ASR_MODEL_NAME = "openai/whisper-medium.en"
asr_pipe = pipeline(
task="automatic-speech-recognition",
model=ASR_MODEL_NAME,
chunk_length_s=30,
device='cuda' if torch.cuda.is_available() else 'cpu', # Use GPU if available
)
all_special_ids = asr_pipe.tokenizer.all_special_ids
transcribe_token_id = all_special_ids[-5]
translate_token_id = all_special_ids[-6]
def _preload_and_load_models():
global vicuna_tokenizer, vicuna_model
VICUNA_MODEL_NAME = "EleutherAI/gpt-neo-2.7B" # Or another model
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
vicuna_model = AutoModelForCausalLM.from_pretrained(
VICUNA_MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto", # or.to('cuda')
) #.to('cuda') # Explicitly move to CUDA after loading
_preload_and_load_models()
tts = Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits")
@spaces.GPU(required=True)
def process_audio(microphone, state, task="transcribe"):
if microphone is None:
return state, state, None
asr_pipe.model.config.forced_decoder_ids = [
[2, transcribe_token_id if task == "transcribe" else translate_token_id]
]
text = asr_pipe(microphone)["text"]
system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
You answer questions clearly and simply, using age-appropriate language.
You are also a little bit silly and like to make jokes."""
prompt = f"{system_prompt}\nUser: {text}"
with torch.no_grad():
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=192)
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) # Access the first sequence [0]
vicuna_response = vicuna_response.replace(prompt, "").strip()
updated_state = state + "\nUser: " + text + "\n" + "Tutor: " + vicuna_response # Include user input in state
try:
with torch.no_grad():
# The espnet TTS model outputs a dictionary
output = tts(vicuna_response)
wav = output["wav"]
sr = tts.fs # Get the sampling rate from the tts object
audio_arr = wav.cpu().numpy()
SAMPLE_RATE = sr
audio_arr = audio_arr / np.abs(audio_arr).max() # Normalize to -1 to 1
audio_output = (SAMPLE_RATE, audio_arr)
#sf.write('generated_audio.wav', audio_arr, SAMPLE_RATE) # Removed writing to file
except requests.exceptions.RequestException as e:
print(f"Error in Hugging Face API request: {e}")
audio_output = None
except Exception as e:
print(f"Error in speech synthesis: {e}")
audio_output = None
return updated_state, updated_state, audio_output
with gr.Blocks(title="Whisper, Vicuna, & TTS Demo") as demo: # Updated title
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Hugging Face TTS") # Updated Markdown
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
with gr.Tab("Transcribe & Synthesize"):
mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here")
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
audio_output = gr.Audio(label="Synthesized Speech", type="numpy") # Important: type="numpy"
transcription_state = gr.State(value="")
mic_input.change(
fn=process_audio, # Call the combined function
inputs=[mic_input, transcription_state, gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")],
outputs=[transcription_output, transcription_state, audio_output]
)
demo.launch(share=False)