ford442 commited on
Commit
29b4682
·
verified ·
1 Parent(s): ebed5e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -77
app.py CHANGED
@@ -1,35 +1,18 @@
1
- import spaces
2
- import torch
3
  import gradio as gr
4
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
5
- import soundfile as sf
6
- import numpy as np
7
- from espnet2.bin.tts_inference import Text2Speech
8
- import IPython.display as ipd
 
 
9
  import os
10
- from huggingface_hub import snapshot_download
11
-
12
- # ... (Whisper and Vicuna setup remain the same)
13
- # --- VITS (TTS) Setup ---
14
- TTS_MODEL_NAME = "espnet/speechlm_tts_v1" # Updated Model Name
15
- tts_device = "cuda" if torch.cuda.is_available() else "cpu"
16
-
17
- model_dir = "speechlm_model" # Updated directory name
18
-
19
- if os.path.exists(model_dir):
20
- shutil.rmtree(model_dir)
21
-
22
- os.makedirs(model_dir)
23
- download_path = snapshot_download(repo_id=TTS_MODEL_NAME, local_dir=model_dir, local_dir_use_symlinks=False)
24
- print(f"Downloaded ESPnet model to: {download_path}")
25
 
26
- # --- KEY CHANGE: Adjust paths for speechlm_tts_v1 ---
27
- config_path = os.path.join(download_path, "exp/speechlm_tts_v1/config.yaml") # Correct path for speechlm_tts_v1
28
- model_path = os.path.join(download_path, "exp/speechlm_tts_v1/model.pth") # Correct path for speechlm_tts_v1
29
 
30
- tts_model = Text2Speech(train_config=config_path, model_file=model_path, device=tts_device)
31
-
32
- # --- Vicuna (LLM) Setup ---
33
  VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
34
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
35
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
@@ -40,62 +23,115 @@ vicuna_model = AutoModelForCausalLM.from_pretrained(
40
  device_map="auto",
41
  )
42
 
43
- # --- ASR Function ---
44
- def transcribe_audio(microphone, state, task="transcribe"):
45
- if microphone is None:
46
- return state, state
47
- asr_pipe.model.config.forced_decoder_ids = [
48
- [2, transcribe_token_id if task == "transcribe" else translate_token_id]
49
- ]
50
- text = asr_pipe(microphone)["text"]
51
-
52
- # --- VICUNA INTEGRATION ---
53
- system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
54
- You answer questions clearly and simply, using age-appropriate language.
55
- You are also a little bit silly and like to make jokes."""
 
 
 
56
 
57
- prompt = f"{system_prompt}\nUser: {text}"
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- with torch.no_grad():
60
- vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
61
- vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
62
- vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
63
- vicuna_response = vicuna_response.replace(prompt, "").strip()
 
 
 
 
64
 
65
- updated_state = state + "\n" + vicuna_response
66
- return updated_state, updated_state
 
 
 
 
 
67
 
68
- # --- TTS Function (Using espnet2) ---
69
- def synthesize_speech(text):
70
  try:
 
 
 
 
71
  with torch.no_grad():
72
- output = tts_model(text)
73
- waveform_np = output["wav"].cpu().numpy()
74
- return (tts_model.fs, waveform_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  except Exception as e:
77
- print(e)
78
- return (None, None)
 
79
 
80
  # --- Gradio Interface ---
81
- with gr.Blocks(title="Whisper, Vicuna, & VITS Demo") as demo:
82
- gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and VITS")
83
- gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
84
-
85
- with gr.Tab("Transcribe & Synthesize"):
86
- mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
87
- transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
88
- audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
89
- transcription_state = gr.State(value="")
90
-
91
- mic_input.change(
92
- fn=transcribe_audio,
93
- inputs=[mic_input, transcription_state],
94
- outputs=[transcription_output, transcription_state]
95
- ).then(
96
- fn=synthesize_speech,
97
- inputs=transcription_output,
98
- outputs=audio_output
99
- )
100
-
101
- demo.launch(enable_queue=True, share=False)
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torchaudio
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from speechbrain.inference.speaker import EncoderClassifier
6
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, AutoTokenizer, AutoModelForCausalLM
7
+ import noisereduce as nr
8
+ import librosa
9
  import os
10
+ import shutil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # --- Speaker Embedding Model ---
13
+ classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb", savedir="pretrained_models/spkrec-xvect-voxceleb")
 
14
 
15
+ # --- Vicuna Setup ---
 
 
16
  VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
17
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
18
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
 
23
  device_map="auto",
24
  )
25
 
26
+ # --- Audio Processing Functions ---
27
+ def f2embed(wav_file, classifier, size_embed):
28
+ signal, fs = stereo_to_mono(wav_file)
29
+ if signal is None:
30
+ return None
31
+ if fs!= 16000:
32
+ signal, fs = resample_to_16000(signal, fs)
33
+ if signal is None:
34
+ return None
35
+ assert fs == 16000, fs
36
+ with torch.no_grad():
37
+ embeddings = classifier.encode_batch(signal)
38
+ embeddings = F.normalize(embeddings, dim=2)
39
+ embeddings = embeddings.squeeze().cpu().numpy()
40
+ assert embeddings.shape == size_embed, embeddings.shape
41
+ return embeddings
42
 
43
+ def stereo_to_mono(wav_file):
44
+ try:
45
+ signal, fs = torchaudio.load(wav_file)
46
+ signal_np = signal.numpy()
47
+ if signal_np.shape == 2:
48
+ signal_mono = librosa.to_mono(signal_np)
49
+ signal_mono = torch.from_numpy(signal_mono).unsqueeze(0)
50
+ else:
51
+ signal_mono = signal
52
+ return signal_mono, fs
53
+ except Exception as e:
54
+ print(f"Error in stereo_to_mono: {e}")
55
+ return None, None
56
 
57
+ def resample_to_16000(signal, original_sr):
58
+ try:
59
+ signal_np = signal.numpy().flatten()
60
+ signal_resampled = librosa.resample(signal_np, orig_sr=original_sr, target_sr=16000)
61
+ signal_resampled = torch.from_numpy(signal_resampled).unsqueeze(0)
62
+ return signal_resampled, 16000
63
+ except Exception as e:
64
+ print(f"Error in resample_to_16000: {e}")
65
+ return None, None
66
 
67
+ def reduce_noise(speech, noise_reduction_amount=0.5):
68
+ try:
69
+ denoised_speech = nr.reduce_noise(y=speech, sr=16000, amount=noise_reduction_amount) # Added amount parameter
70
+ return denoised_speech
71
+ except Exception as e:
72
+ print(f"Error in reduce_noise: {e}")
73
+ return speech
74
 
75
+ def process_audio(wav_file, text):
 
76
  try:
77
+ # --- Vicuna Text Processing ---
78
+ system_prompt = """You are a helpful assistant. Refine or expand the user's text as needed before it is converted to speech. You can correct grammar, add details, or make the text sound more natural."""
79
+ prompt = f"{system_prompt}\nUser: {text}"
80
+
81
  with torch.no_grad():
82
+ vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
83
+ vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=256)
84
+ vicuna_response = vicuna_tokenizer.decode(vicuna_output, skip_special_tokens=True) # Decode the first element
85
+ vicuna_processed_text = vicuna_response.replace(prompt, "").strip()
86
+
87
+ print(f"Vicuna processed text: {vicuna_processed_text}")
88
+
89
+ # --- Speaker Embedding Extraction ---
90
+ speaker_embeddings = f2embed(wav_file, classifier, 512)
91
+ if speaker_embeddings is None:
92
+ return None, "Error in speaker embedding extraction"
93
+ embeddings = torch.tensor(speaker_embeddings).unsqueeze(0)
94
+
95
+ # --- SpeechT5 TTS with Vicuna's output ---
96
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
97
+ model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
98
+ inputs = processor(text=vicuna_processed_text, return_tensors="pt")
99
+ inputs.update({"speaker_embeddings": embeddings})
100
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
101
+ speech = model.generate_speech(inputs["input_ids"], speaker_embeddings=inputs["speaker_embeddings"], vocoder=vocoder)
102
+
103
+ # --- Noise Reduction ---
104
+ speech_denoised = reduce_noise(speech)
105
+ return speech_denoised, 16000
106
 
107
  except Exception as e:
108
+ print(f"Error in process_audio: {e}")
109
+ return None, f"Error in audio processing: {e}" # Include the error message
110
+
111
 
112
  # --- Gradio Interface ---
113
+ def gradio_interface(wav_file, text):
114
+ try:
115
+ if wav_file is None:
116
+ return "Error: Please upload an audio file."
117
+ if not text:
118
+ return "Error: Please enter text to synthesize."
119
+
120
+ processed_audio, rate = process_audio(wav_file, text)
121
+ if processed_audio is None:
122
+ return "Error occurred during processing. Check the console for details."
123
+
124
+ return (rate, processed_audio)
125
+ except Exception as e:
126
+ print(f"Error in gradio_interface: {e}")
127
+ return f"An unexpected error occurred: {e}"
128
+
129
+ gr_interface = gr.Interface(
130
+ fn=gradio_interface,
131
+ inputs=[gr.Audio(type="filepath"), gr.Textbox(lines=2, placeholder="Enter text here...")],
132
+ outputs=gr.Audio(type="numpy"),
133
+ title="Text-to-Speech with Speaker Embeddings and Vicuna",
134
+ description="Upload a speaker audio file and enter text to convert the text to speech using the speaker's voice, enhanced by Vicuna.",
135
+ )
136
+
137
+ gr_interface.launch()