File size: 4,047 Bytes
913ceff
f045a71
 
d7978a0
f5ebbd5
68c0a47
3ae1238
913ceff
3ae1238
9b1d5ab
5e5c4bb
913ceff
892a58d
d7978a0
71460b6
8213d9e
f045a71
913ceff
 
3ae1238
 
913ceff
 
 
f045a71
9b1d5ab
 
 
 
 
 
 
b56cef1
fb908dd
f045a71
fb908dd
 
 
 
b56cef1
 
f44e9ba
fb908dd
 
f045a71
 
68c0a47
f045a71
f5ebbd5
 
 
 
fb908dd
f045a71
 
 
 
fb908dd
f045a71
fb908dd
447f99a
 
f045a71
447f99a
 
fb908dd
f045a71
 
03d2efe
f045a71
892a58d
 
03d2efe
3ae1238
71460b6
3ae1238
f045a71
892a58d
f5ebbd5
 
 
f045a71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import spaces
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import soundfile as sf
import numpy as np
from espnet2.bin.tts_inference import Text2Speech
import IPython.display as ipd
import os
from huggingface_hub import snapshot_download

# ... (Whisper and Vicuna setup remain the same)

# --- VITS (TTS) Setup ---
TTS_MODEL_NAME = "espnet/kan_bayashi_ljspeech_vits"
tts_device = "cuda" if torch.cuda.is_available() else "cpu"

# Download the ESPnet model files and get the download path
model_dir = "vits_model"
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

download_path = snapshot_download(repo_id=TTS_MODEL_NAME, local_dir=model_dir, local_dir_use_symlinks=False)
print(f"Downloaded ESPnet model to: {download_path}") # Print the path!

# Construct *absolute* paths to the config and model files.
config_path = os.path.join(download_path, "exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/config.yaml")
model_path = os.path.join(download_path, "exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/train.total_count.ave_10best.pth")

# Load the Text2Speech model using the downloaded files and absolute paths
tts_model = Text2Speech(train_config=config_path, model_file=model_path, device=tts_device)


# --- Vicuna (LLM) Setup ---
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
vicuna_model = AutoModelForCausalLM.from_pretrained(
    VICUNA_MODEL_NAME,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)

# --- ASR Function ---
def transcribe_audio(microphone, state, task="transcribe"):
    if microphone is None:
        return state, state
    asr_pipe.model.config.forced_decoder_ids = [
        [2, transcribe_token_id if task == "transcribe" else translate_token_id]
    ]
    text = asr_pipe(microphone)["text"]

    # --- VICUNA INTEGRATION ---
    system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
      You answer questions clearly and simply, using age-appropriate language.
      You are also a little bit silly and like to make jokes."""

    prompt = f"{system_prompt}\nUser: {text}"

    with torch.no_grad():
        vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
        vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
        vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
        vicuna_response = vicuna_response.replace(prompt, "").strip()

    updated_state = state + "\n" + vicuna_response
    return updated_state, updated_state

# --- TTS Function (Using espnet2) ---
def synthesize_speech(text):
    try:
        with torch.no_grad():
            output = tts_model(text)
        waveform_np = output["wav"].cpu().numpy()
        return (tts_model.fs, waveform_np)

    except Exception as e:
        print(e)
        return (None, None)

# --- Gradio Interface ---
with gr.Blocks(title="Whisper, Vicuna, & VITS Demo") as demo:
    gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and VITS")
    gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")

    with gr.Tab("Transcribe & Synthesize"):
        mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
        transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
        audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
        transcription_state = gr.State(value="")

        mic_input.change(
            fn=transcribe_audio,
            inputs=[mic_input, transcription_state],
            outputs=[transcription_output, transcription_state]
        ).then(
            fn=synthesize_speech,
            inputs=transcription_output,
            outputs=audio_output
        )

demo.launch(enable_queue=True, share=False)