Spaces:
Sleeping
Sleeping
File size: 4,644 Bytes
892a58d 5e5c4bb 8213d9e f5ebbd5 68c0a47 8213d9e 5e5c4bb fb908dd f5ebbd5 f44e9ba f5ebbd5 892a58d 8213d9e df3b410 8213d9e 447f99a 8213d9e 447f99a b56cef1 fb908dd 8213d9e fb908dd b56cef1 f44e9ba fb908dd f5ebbd5 68c0a47 f5ebbd5 fb908dd 447f99a fb908dd f5ebbd5 892a58d 8213d9e 892a58d 8213d9e 892a58d f5ebbd5 fb908dd 447f99a fb908dd f5ebbd5 fb908dd 447f99a fb908dd f5ebbd5 fb908dd f5ebbd5 fb908dd f5ebbd5 447f99a fb908dd f5ebbd5 fb908dd 2d155ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import soundfile as sf
import numpy as np
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
import IPython.display as ipd # We still need this if running in a notebook
# --- Whisper (ASR) Setup ---
ASR_MODEL_NAME = "openai/whisper-large-v2"
asr_device = "cuda" if torch.cuda.is_available() else "cpu"
asr_pipe = pipeline(
task="automatic-speech-recognition",
model=ASR_MODEL_NAME,
chunk_length_s=30,
device=asr_device,
)
all_special_ids = asr_pipe.tokenizer.all_special_ids
transcribe_token_id = all_special_ids[-5]
translate_token_id = all_special_ids[-6]
# --- FastSpeech2 (TTS) Setup - Using fairseq ---
TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech"
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the fairseq model, config, and task.
tts_models, tts_cfg, tts_task = load_model_ensemble_and_task_from_hf_hub(
TTS_MODEL_NAME,
arg_overrides={"vocoder": "hifigan", "fp16": False}
)
tts_model = tts_models[0]
TTSHubInterface.update_cfg_with_data_cfg(tts_cfg, tts_task.data_cfg)
tts_generator = tts_task.build_generator(tts_model, tts_cfg)
# Move the fairseq model to the correct device.
tts_model.to(tts_device)
tts_model.eval() # Put the model in evaluation mode
# --- Vicuna (LLM) Setup ---
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or your preferred Vicuna
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
vicuna_model = AutoModelForCausalLM.from_pretrained(
VICUNA_MODEL_NAME,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
# --- ASR Function ---
def transcribe_audio(microphone, state, task="transcribe"):
if microphone is None:
return state, state
asr_pipe.model.config.forced_decoder_ids = [
[2, transcribe_token_id if task == "transcribe" else translate_token_id]
]
text = asr_pipe(microphone)["text"]
# --- VICUNA INTEGRATION ---
system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
You answer questions clearly and simply, using age-appropriate language.
You are also a little bit silly and like to make jokes."""
prompt = f"{system_prompt}\nUser: {text}"
with torch.no_grad():
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
vicuna_response = vicuna_response.replace(prompt, "").strip()
updated_state = state + "\n" + vicuna_response
return updated_state, updated_state
# --- TTS Function (Modified for fairseq) ---
def synthesize_speech(text):
try:
sample = TTSHubInterface.get_model_input(tts_task, text)
# Move input tensors to the correct device
if torch.cuda.is_available():
sample['net_input'] = {k: v.cuda() for k, v in sample['net_input'].items()}
else:
sample['net_input'] = {k: v.cpu() for k, v in sample['net_input'].items()}
wav, rate = TTSHubInterface.get_prediction(tts_task, tts_model, tts_generator, sample)
wav_numpy = wav.cpu().numpy() # fairseq returns a tensor, not a numpy array
return (rate, wav_numpy) # Return rate and NumPy array
except Exception as e:
print(e)
return (None, None)
# --- Gradio Interface ---
with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo:
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna")
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
with gr.Tab("Transcribe & Synthesize"):
mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
transcription_state = gr.State(value="")
mic_input.change(
fn=transcribe_audio,
inputs=[mic_input, transcription_state],
outputs=[transcription_output, transcription_state]
).then(
fn=synthesize_speech,
inputs=transcription_output,
outputs=audio_output
)
demo.launch(enable_queue=True, share=False) |