Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel, LlamaTokenizer, LlamaForCausalLM | |
import numpy as np | |
from espnet2.bin.tts_inference import Text2Speech | |
#import yaml | |
#import os | |
import requests | |
import nltk | |
import scipy.io.wavfile | |
from flask import Flask, request, jsonify | |
app = Flask(__name__) # Create the Flask app instance FIRST | |
try: | |
nltk.data.find('taggers/averaged_perceptron_tagger_eng') | |
except LookupError: | |
nltk.download('averaged_perceptron_tagger_eng') | |
try: | |
nltk.data.find('corpora/cmudict') # Check for cmudict | |
except LookupError: | |
nltk.download('cmudict') | |
ASR_MODEL_NAME = "openai/whisper-medium.en" | |
asr_pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=ASR_MODEL_NAME, | |
chunk_length_s=30, | |
device='cuda' if torch.cuda.is_available() else 'cpu', # Use GPU if available | |
) | |
all_special_ids = asr_pipe.tokenizer.all_special_ids | |
transcribe_token_id = all_special_ids[-5] | |
translate_token_id = all_special_ids[-6] | |
def _preload_and_load_models(): | |
global vicuna_tokenizer, vicuna_model | |
#VICUNA_MODEL_NAME = "EleutherAI/gpt-neo-2.7B" # Or another model | |
#VICUNA_MODEL_NAME = "lmsys/vicuna-13b-v1.5" # Or another model | |
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or another model | |
vicuna_tokenizer = LlamaTokenizer.from_pretrained(VICUNA_MODEL_NAME) | |
vicuna_model = LlamaForCausalLM.from_pretrained( | |
VICUNA_MODEL_NAME, | |
torch_dtype=torch.float16, | |
# device_map="auto", # or.to('cuda') | |
).to('cuda') # Explicitly move to CUDA after loading | |
_preload_and_load_models() | |
tts = Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits",device='cuda') | |
# The API endpoint | |
def process_audio(microphone, audio_upload, state, answer_mode): # Added audio_upload | |
audio_source = None | |
if microphone: | |
audio_source = microphone | |
asr_pipe.model.config.forced_decoder_ids = [[2, transcribe_token_id ]] | |
text = asr_pipe(audio_source)["text"] | |
elif audio_upload: | |
audio_source = audio_upload | |
rate, data = scipy.io.wavfile.read(audio_source) | |
asr_pipe.model.config.forced_decoder_ids = [[2, transcribe_token_id ]] | |
text = asr_pipe(data)["text"] | |
else: | |
return state, state, None # No audio input | |
system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9). | |
You answer questions clearly and simply, using age-appropriate language. | |
You are also a little bit silly and like to make jokes.""" | |
prompt = f"{system_prompt}\nUser: {text}" | |
with torch.no_grad(): | |
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda') | |
if answer_mode == 'slow': | |
vicuna_output = vicuna_model.generate( | |
**vicuna_input, | |
max_length = 512, | |
min_new_tokens = 256, | |
do_sample = True | |
) | |
if answer_mode == 'medium': | |
vicuna_output = vicuna_model.generate( | |
**vicuna_input, | |
max_length = 128, | |
min_new_tokens = 64, | |
do_sample = True | |
) | |
if answer_mode == 'fast': | |
vicuna_output = vicuna_model.generate( | |
**vicuna_input, | |
max_length = 42, | |
min_new_tokens = 16, | |
do_sample = True | |
) | |
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) | |
vicuna_response = vicuna_response.replace(prompt, "").strip() | |
updated_state = state + "\nUser: " + text + "\n" + "Tutor: " + vicuna_response | |
try: | |
#with torch.no_grad(): | |
output = tts(vicuna_response) | |
wav = output["wav"] | |
sr = tts.fs | |
audio_arr = wav.cpu().numpy() | |
SAMPLE_RATE = sr | |
audio_arr = audio_arr / np.abs(audio_arr).max() | |
audio_output = (SAMPLE_RATE, audio_arr) | |
#sf.write('generated_audio.wav', audio_arr, SAMPLE_RATE) # Removed writing to file | |
except requests.exceptions.RequestException as e: | |
print(f"Error in Hugging Face API request: {e}") | |
audio_output = None | |
except Exception as e: | |
print(f"Error in speech synthesis: {e}") | |
audio_output = None | |
return updated_state, updated_state, audio_output | |
with gr.Blocks(title="Whisper, Vicuna, & TTS Demo") as demo: # Updated title | |
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Hugging Face TTS") | |
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!") | |
with gr.Tab("Transcribe & Synthesize"): | |
with gr.Row(): # Added a row for better layout | |
mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here") | |
audio_upload = gr.Audio(sources="upload", type="filepath", label="Or Upload Audio File") # Added upload component | |
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response") | |
audio_output = gr.Audio(label="Synthesized Speech", type="numpy", autoplay=True) | |
answer_mode = gr.Radio(["fast", "medium", "slow"], value='medium') | |
transcription_state = gr.State(value="") | |
mic_input.change( | |
fn=process_audio, | |
inputs=[mic_input, audio_upload, transcription_state, answer_mode], # Include audio_upload | |
outputs=[transcription_output, transcription_state, audio_output] | |
) | |
audio_upload.change( # Added change event for upload | |
fn=process_audio, | |
inputs=[mic_input, audio_upload, transcription_state, answer_mode], # Include audio_upload | |
outputs=[transcription_output, transcription_state, audio_output] | |
) | |
if __name__ == '__main__': | |
# app.run(debug=True, port=5000) # Run Flask app | |
demo.launch(share=False) |