File size: 4,644 Bytes
892a58d
5e5c4bb
8213d9e
f5ebbd5
68c0a47
8213d9e
 
 
5e5c4bb
fb908dd
f5ebbd5
f44e9ba
f5ebbd5
 
 
 
 
 
 
 
 
892a58d
8213d9e
 
 
df3b410
8213d9e
 
 
 
 
 
 
 
447f99a
8213d9e
 
 
447f99a
b56cef1
fb908dd
8213d9e
fb908dd
 
 
 
b56cef1
 
f44e9ba
fb908dd
 
 
f5ebbd5
68c0a47
f5ebbd5
 
 
 
 
fb908dd
 
 
 
 
 
 
 
447f99a
 
 
 
 
fb908dd
 
f5ebbd5
892a58d
8213d9e
892a58d
 
8213d9e
 
 
 
 
 
 
 
 
 
 
 
 
892a58d
f5ebbd5
 
 
fb908dd
447f99a
fb908dd
 
f5ebbd5
fb908dd
 
447f99a
fb908dd
 
f5ebbd5
fb908dd
f5ebbd5
 
fb908dd
f5ebbd5
 
447f99a
fb908dd
f5ebbd5
fb908dd
2d155ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import soundfile as sf
import numpy as np
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
import IPython.display as ipd  # We still need this if running in a notebook

# --- Whisper (ASR) Setup ---
ASR_MODEL_NAME = "openai/whisper-large-v2"
asr_device = "cuda" if torch.cuda.is_available() else "cpu"
asr_pipe = pipeline(
    task="automatic-speech-recognition",
    model=ASR_MODEL_NAME,
    chunk_length_s=30,
    device=asr_device,
)
all_special_ids = asr_pipe.tokenizer.all_special_ids
transcribe_token_id = all_special_ids[-5]
translate_token_id = all_special_ids[-6]

# --- FastSpeech2 (TTS) Setup - Using fairseq ---
TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech"
tts_device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the fairseq model, config, and task.
tts_models, tts_cfg, tts_task = load_model_ensemble_and_task_from_hf_hub(
    TTS_MODEL_NAME,
    arg_overrides={"vocoder": "hifigan", "fp16": False}
)
tts_model = tts_models[0]
TTSHubInterface.update_cfg_with_data_cfg(tts_cfg, tts_task.data_cfg)
tts_generator = tts_task.build_generator(tts_model, tts_cfg)

# Move the fairseq model to the correct device.
tts_model.to(tts_device)
tts_model.eval() # Put the model in evaluation mode


# --- Vicuna (LLM) Setup ---
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"  # Or your preferred Vicuna
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
vicuna_model = AutoModelForCausalLM.from_pretrained(
    VICUNA_MODEL_NAME,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)

# --- ASR Function ---
def transcribe_audio(microphone, state, task="transcribe"):
    if microphone is None:
        return state, state
    asr_pipe.model.config.forced_decoder_ids = [
        [2, transcribe_token_id if task == "transcribe" else translate_token_id]
    ]
    text = asr_pipe(microphone)["text"]

    # --- VICUNA INTEGRATION ---
    system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
      You answer questions clearly and simply, using age-appropriate language.
      You are also a little bit silly and like to make jokes."""

    prompt = f"{system_prompt}\nUser: {text}"

    with torch.no_grad():
        vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
        vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
        vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
        vicuna_response = vicuna_response.replace(prompt, "").strip()

    updated_state = state + "\n" + vicuna_response
    return updated_state, updated_state

# --- TTS Function (Modified for fairseq) ---
def synthesize_speech(text):
    try:
        sample = TTSHubInterface.get_model_input(tts_task, text)

        # Move input tensors to the correct device
        if torch.cuda.is_available():
          sample['net_input'] = {k: v.cuda() for k, v in sample['net_input'].items()}
        else:
          sample['net_input'] = {k: v.cpu() for k, v in sample['net_input'].items()}

        wav, rate = TTSHubInterface.get_prediction(tts_task, tts_model, tts_generator, sample)
        wav_numpy = wav.cpu().numpy() # fairseq returns a tensor, not a numpy array

        return (rate, wav_numpy)  # Return rate and NumPy array

    except Exception as e:
        print(e)
        return (None, None)

# --- Gradio Interface ---
with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo:
    gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna")
    gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")

    with gr.Tab("Transcribe & Synthesize"):
        mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
        transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
        audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
        transcription_state = gr.State(value="")

        mic_input.change(
            fn=transcribe_audio,
            inputs=[mic_input, transcription_state],
            outputs=[transcription_output, transcription_state]
        ).then(
            fn=synthesize_speech,
            inputs=transcription_output,
            outputs=audio_output
        )

demo.launch(enable_queue=True, share=False)