Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoProcessor | |
import soundfile as sf | |
import numpy as np | |
import IPython.display as ipd | |
import os | |
ASR_MODEL_NAME = "openai/whisper-large-v2" | |
asr_pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=ASR_MODEL_NAME, | |
chunk_length_s=30, | |
device='cuda', | |
) | |
all_special_ids = asr_pipe.tokenizer.all_special_ids | |
transcribe_token_id = all_special_ids[-5] | |
translate_token_id = all_special_ids[-6] | |
TTS_MODEL_NAME = "suno/bark-small" | |
tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME) | |
tts_model = AutoModel.from_pretrained(TTS_MODEL_NAME).to('cuda') | |
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" | |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME) | |
vicuna_model = AutoModelForCausalLM.from_pretrained( | |
VICUNA_MODEL_NAME, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
def process_audio(microphone, state, task="transcribe"): | |
if microphone is None: | |
return state, state, None | |
asr_pipe.model.config.forced_decoder_ids = [ | |
[2, transcribe_token_id if task == "transcribe" else translate_token_id] | |
] | |
text = asr_pipe(microphone)["text"] | |
system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9). | |
You answer questions clearly and simply, using age-appropriate language. | |
You are also a little bit silly and like to make jokes.""" | |
prompt = f"{system_prompt}\nUser: {text}" | |
with torch.no_grad(): | |
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda') | |
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128) | |
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) | |
vicuna_response = vicuna_response.replace(prompt, "").strip() | |
updated_state = state + "\n" + vicuna_response | |
try: | |
with torch.no_grad(): | |
inputs = tts_processor(vicuna_response, return_tensors="pt").to('cuda') | |
output = tts_model.generate(**inputs, do_sample=True) | |
waveform_np = output[0].cpu().numpy() | |
audio_output = (tts_model.generation_config.sample_rate, waveform_np) | |
except Exception as e: | |
print(f"Error in speech synthesis: {e}") | |
audio_output = None | |
return updated_state, updated_state, audio_output | |
with gr.Blocks(title="Whisper, Vicuna, & Bark Demo") as demo: | |
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Bark") | |
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!") | |
with gr.Tab("Transcribe & Synthesize"): | |
mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here") | |
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response") | |
audio_output = gr.Audio(label="Synthesized Speech", type="numpy") | |
transcription_state = gr.State(value="") | |
mic_input.change( | |
fn=process_audio, # Call the combined function | |
inputs=[mic_input, transcription_state], | |
outputs=[transcription_output, transcription_state, audio_output] | |
) | |
demo.launch(share=False) |