vicuna-clip / app.py
ford442's picture
Update app.py
5774403 verified
raw
history blame
5.99 kB
import spaces
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel, LlamaTokenizer, LlamaForCausalLM
import numpy as np
from espnet2.bin.tts_inference import Text2Speech
#import yaml
#import os
import requests
import nltk
import scipy.io.wavfile
from flask import Flask, request, jsonify
app = Flask(__name__) # Create the Flask app instance FIRST
try:
nltk.data.find('taggers/averaged_perceptron_tagger_eng')
except LookupError:
nltk.download('averaged_perceptron_tagger_eng')
try:
nltk.data.find('corpora/cmudict') # Check for cmudict
except LookupError:
nltk.download('cmudict')
ASR_MODEL_NAME = "openai/whisper-medium.en"
asr_pipe = pipeline(
task="automatic-speech-recognition",
model=ASR_MODEL_NAME,
chunk_length_s=30,
device='cuda' if torch.cuda.is_available() else 'cpu', # Use GPU if available
)
all_special_ids = asr_pipe.tokenizer.all_special_ids
transcribe_token_id = all_special_ids[-5]
translate_token_id = all_special_ids[-6]
def _preload_and_load_models():
global vicuna_tokenizer, vicuna_model
#VICUNA_MODEL_NAME = "EleutherAI/gpt-neo-2.7B" # Or another model
#VICUNA_MODEL_NAME = "lmsys/vicuna-13b-v1.5" # Or another model
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or another model
vicuna_tokenizer = LlamaTokenizer.from_pretrained(VICUNA_MODEL_NAME)
vicuna_model = LlamaForCausalLM.from_pretrained(
VICUNA_MODEL_NAME,
torch_dtype=torch.float16,
# device_map="auto", # or.to('cuda')
).to('cuda') # Explicitly move to CUDA after loading
_preload_and_load_models()
tts = Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits",device='cuda')
@app.route('/api/predict', methods=['POST']) # The API endpoint
@spaces.GPU(required=True)
def process_audio(microphone, audio_upload, state, answer_mode): # Added audio_upload
audio_source = None
if microphone:
audio_source = microphone
asr_pipe.model.config.forced_decoder_ids = [[2, transcribe_token_id ]]
text = asr_pipe(audio_source)["text"]
elif audio_upload:
audio_source = audio_upload
rate, data = scipy.io.wavfile.read(audio_source)
asr_pipe.model.config.forced_decoder_ids = [[2, transcribe_token_id ]]
text = asr_pipe(data)["text"]
else:
return state, state, None # No audio input
system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
You answer questions clearly and simply, using age-appropriate language.
You are also a little bit silly and like to make jokes."""
prompt = f"{system_prompt}\nUser: {text}"
with torch.no_grad():
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
if answer_mode == 'slow':
vicuna_output = vicuna_model.generate(
**vicuna_input,
max_length = 512,
min_new_tokens = 256,
do_sample = True
)
if answer_mode == 'medium':
vicuna_output = vicuna_model.generate(
**vicuna_input,
max_length = 128,
min_new_tokens = 64,
do_sample = True
)
if answer_mode == 'fast':
vicuna_output = vicuna_model.generate(
**vicuna_input,
max_length = 42,
min_new_tokens = 16,
do_sample = True
)
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
vicuna_response = vicuna_response.replace(prompt, "").strip()
updated_state = state + "\nUser: " + text + "\n" + "Tutor: " + vicuna_response
try:
#with torch.no_grad():
output = tts(vicuna_response)
wav = output["wav"]
sr = tts.fs
audio_arr = wav.cpu().numpy()
SAMPLE_RATE = sr
audio_arr = audio_arr / np.abs(audio_arr).max()
audio_output = (SAMPLE_RATE, audio_arr)
#sf.write('generated_audio.wav', audio_arr, SAMPLE_RATE) # Removed writing to file
except requests.exceptions.RequestException as e:
print(f"Error in Hugging Face API request: {e}")
audio_output = None
except Exception as e:
print(f"Error in speech synthesis: {e}")
audio_output = None
return updated_state, updated_state, audio_output
with gr.Blocks(title="Whisper, Vicuna, & TTS Demo") as demo: # Updated title
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Hugging Face TTS")
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
with gr.Tab("Transcribe & Synthesize"):
with gr.Row(): # Added a row for better layout
mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here")
audio_upload = gr.Audio(sources="upload", type="filepath", label="Or Upload Audio File") # Added upload component
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
audio_output = gr.Audio(label="Synthesized Speech", type="numpy", autoplay=True)
answer_mode = gr.Radio(["fast", "medium", "slow"], value='medium')
transcription_state = gr.State(value="")
mic_input.change(
fn=process_audio,
inputs=[mic_input, audio_upload, transcription_state, answer_mode], # Include audio_upload
outputs=[transcription_output, transcription_state, audio_output]
)
audio_upload.change( # Added change event for upload
fn=process_audio,
inputs=[mic_input, audio_upload, transcription_state, answer_mode], # Include audio_upload
outputs=[transcription_output, transcription_state, audio_output]
)
if __name__ == '__main__':
# app.run(debug=True, port=5000) # Run Flask app
demo.launch(share=False)