vicuna-clip / app.py
ford442's picture
Update app.py
913ceff verified
raw
history blame
4.05 kB
import spaces
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import soundfile as sf
import numpy as np
from espnet2.bin.tts_inference import Text2Speech
import IPython.display as ipd
import os
from huggingface_hub import snapshot_download
# ... (Whisper and Vicuna setup remain the same)
# --- VITS (TTS) Setup ---
TTS_MODEL_NAME = "espnet/kan_bayashi_ljspeech_vits"
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
# Download the ESPnet model files and get the download path
model_dir = "vits_model"
if not os.path.exists(model_dir):
os.makedirs(model_dir)
download_path = snapshot_download(repo_id=TTS_MODEL_NAME, local_dir=model_dir, local_dir_use_symlinks=False)
print(f"Downloaded ESPnet model to: {download_path}") # Print the path!
# Construct *absolute* paths to the config and model files.
config_path = os.path.join(download_path, "exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/config.yaml")
model_path = os.path.join(download_path, "exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/train.total_count.ave_10best.pth")
# Load the Text2Speech model using the downloaded files and absolute paths
tts_model = Text2Speech(train_config=config_path, model_file=model_path, device=tts_device)
# --- Vicuna (LLM) Setup ---
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
vicuna_model = AutoModelForCausalLM.from_pretrained(
VICUNA_MODEL_NAME,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
# --- ASR Function ---
def transcribe_audio(microphone, state, task="transcribe"):
if microphone is None:
return state, state
asr_pipe.model.config.forced_decoder_ids = [
[2, transcribe_token_id if task == "transcribe" else translate_token_id]
]
text = asr_pipe(microphone)["text"]
# --- VICUNA INTEGRATION ---
system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
You answer questions clearly and simply, using age-appropriate language.
You are also a little bit silly and like to make jokes."""
prompt = f"{system_prompt}\nUser: {text}"
with torch.no_grad():
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
vicuna_response = vicuna_response.replace(prompt, "").strip()
updated_state = state + "\n" + vicuna_response
return updated_state, updated_state
# --- TTS Function (Using espnet2) ---
def synthesize_speech(text):
try:
with torch.no_grad():
output = tts_model(text)
waveform_np = output["wav"].cpu().numpy()
return (tts_model.fs, waveform_np)
except Exception as e:
print(e)
return (None, None)
# --- Gradio Interface ---
with gr.Blocks(title="Whisper, Vicuna, & VITS Demo") as demo:
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and VITS")
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
with gr.Tab("Transcribe & Synthesize"):
mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
transcription_state = gr.State(value="")
mic_input.change(
fn=transcribe_audio,
inputs=[mic_input, transcription_state],
outputs=[transcription_output, transcription_state]
).then(
fn=synthesize_speech,
inputs=transcription_output,
outputs=audio_output
)
demo.launch(enable_queue=True, share=False)