Spaces:
Sleeping
Sleeping
import spaces | |
import torch | |
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import soundfile as sf | |
import numpy as np | |
from espnet2.bin.tts_inference import Text2Speech | |
import IPython.display as ipd | |
import os | |
from huggingface_hub import snapshot_download | |
# ... (Whisper and Vicuna setup remain the same) | |
# --- VITS (TTS) Setup --- | |
TTS_MODEL_NAME = "espnet/kan_bayashi_ljspeech_vits" | |
tts_device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Download the ESPnet model files and get the download path | |
model_dir = "vits_model" | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
download_path = snapshot_download(repo_id=TTS_MODEL_NAME, local_dir=model_dir, local_dir_use_symlinks=False) | |
print(f"Downloaded ESPnet model to: {download_path}") # Print the path! | |
# Construct *absolute* paths to the config and model files. | |
config_path = os.path.join(download_path, "exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/config.yaml") | |
model_path = os.path.join(download_path, "exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/train.total_count.ave_10best.pth") | |
# Load the Text2Speech model using the downloaded files and absolute paths | |
tts_model = Text2Speech(train_config=config_path, model_file=model_path, device=tts_device) | |
# --- Vicuna (LLM) Setup --- | |
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" | |
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu" | |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME) | |
vicuna_model = AutoModelForCausalLM.from_pretrained( | |
VICUNA_MODEL_NAME, | |
load_in_8bit=True, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
# --- ASR Function --- | |
def transcribe_audio(microphone, state, task="transcribe"): | |
if microphone is None: | |
return state, state | |
asr_pipe.model.config.forced_decoder_ids = [ | |
[2, transcribe_token_id if task == "transcribe" else translate_token_id] | |
] | |
text = asr_pipe(microphone)["text"] | |
# --- VICUNA INTEGRATION --- | |
system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9). | |
You answer questions clearly and simply, using age-appropriate language. | |
You are also a little bit silly and like to make jokes.""" | |
prompt = f"{system_prompt}\nUser: {text}" | |
with torch.no_grad(): | |
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device) | |
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128) | |
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) | |
vicuna_response = vicuna_response.replace(prompt, "").strip() | |
updated_state = state + "\n" + vicuna_response | |
return updated_state, updated_state | |
# --- TTS Function (Using espnet2) --- | |
def synthesize_speech(text): | |
try: | |
with torch.no_grad(): | |
output = tts_model(text) | |
waveform_np = output["wav"].cpu().numpy() | |
return (tts_model.fs, waveform_np) | |
except Exception as e: | |
print(e) | |
return (None, None) | |
# --- Gradio Interface --- | |
with gr.Blocks(title="Whisper, Vicuna, & VITS Demo") as demo: | |
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and VITS") | |
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!") | |
with gr.Tab("Transcribe & Synthesize"): | |
mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here") | |
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response") | |
audio_output = gr.Audio(label="Synthesized Speech", type="numpy") | |
transcription_state = gr.State(value="") | |
mic_input.change( | |
fn=transcribe_audio, | |
inputs=[mic_input, transcription_state], | |
outputs=[transcription_output, transcription_state] | |
).then( | |
fn=synthesize_speech, | |
inputs=transcription_output, | |
outputs=audio_output | |
) | |
demo.launch(enable_queue=True, share=False) |