ford442 commited on
Commit
03d2efe
·
verified ·
1 Parent(s): 31c2346

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -46
app.py CHANGED
@@ -1,13 +1,8 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
  import soundfile as sf
5
  import numpy as np
6
- import fairseq
7
- import IPython.display as ipd
8
- import os # Import the 'os' module
9
-
10
- commit_hash = "8798153927c22132778bef7b507d389474fa3589" # Example - find a suitable one!
11
 
12
  # --- Whisper (ASR) Setup ---
13
  ASR_MODEL_NAME = "openai/whisper-large-v2"
@@ -22,27 +17,16 @@ all_special_ids = asr_pipe.tokenizer.all_special_ids
22
  transcribe_token_id = all_special_ids[-5]
23
  translate_token_id = all_special_ids[-6]
24
 
25
- # --- FastSpeech2 (TTS) Setup - Using fairseq 0.10.2 ---
26
- TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech"
 
 
27
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
 
28
 
29
- # Download the model files if they don't exist
30
- if not os.path.exists("fastspeech2_model"):
31
- os.makedirs("fastspeech2_model")
32
- print("Downloading FastSpeech2 model...")
33
- os.system(f"wget https://huggingface.co/{TTS_MODEL_NAME}/resolve/{commit_hash}/pytorch_model.pt -O fastspeech2_model/pytorch_model.pt")
34
- os.system(f"wget https://huggingface.co/{TTS_MODEL_NAME}/resolve/{commit_hash}/vocab.txt -O fastspeech2_model/vocab.txt")
35
- print("Download complete.")
36
-
37
- # Load the model using fairseq 0.10.2 compatible methods.
38
- tts_model_path = "fastspeech2_model/pytorch_model.pt" # Path to the downloaded model
39
- tts_model, tts_cfg, tts_task = fairseq.checkpoint_utils.load_model_ensemble_and_task([tts_model_path])
40
- tts_model = tts_model[0]
41
- tts_model.to(tts_device)
42
- tts_model.eval()
43
 
44
  # --- Vicuna (LLM) Setup ---
45
- VICUNA_MODEL_NAME = "lmsys/vicuna-33b-v1.3" # Use a smaller model if needed
46
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
47
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
48
  vicuna_model = AutoModelForCausalLM.from_pretrained(
@@ -76,39 +60,28 @@ def transcribe_audio(microphone, state, task="transcribe"):
76
 
77
  updated_state = state + "\n" + vicuna_response
78
  return updated_state, updated_state
79
- # --- TTS Function ---
 
80
  def synthesize_speech(text):
81
  try:
82
- # Preprocess using fairseq's task.
83
- sample = tts_task.build_dataset_for_inference([text], [len(text)])
84
-
85
- # Move to device
86
- if tts_device == 'cuda':
87
- sample = fairseq.utils.move_to_cuda(sample)
88
- else:
89
- sample = sample
90
-
91
- # Generate
92
- generator = tts_task.build_generator([tts_model], tts_cfg.task) # Pass the task
93
- output = generator.generate([tts_model], sample) # Generate using the generator
94
 
95
- # Extract waveform and sample rate.
96
- waveform = output[0][0]['waveform']
97
- sample_rate = tts_cfg.task.sample_rate # Get the rate
98
-
99
- # Convert to NumPy (and ensure CPU)
100
  waveform_np = waveform.cpu().numpy()
101
-
102
- return (sample_rate, waveform_np)
103
-
104
 
105
  except Exception as e:
106
  print(e)
107
  return (None, None)
108
 
109
  # --- Gradio Interface ---
110
- with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo:
111
- gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna")
112
  gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
113
 
114
  with gr.Tab("Transcribe & Synthesize"):
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModelForTextToSpeech, AutoProcessor
4
  import soundfile as sf
5
  import numpy as np
 
 
 
 
 
6
 
7
  # --- Whisper (ASR) Setup ---
8
  ASR_MODEL_NAME = "openai/whisper-large-v2"
 
17
  transcribe_token_id = all_special_ids[-5]
18
  translate_token_id = all_special_ids[-6]
19
 
20
+ # --- VITS (TTS) Setup - Using transformers ---
21
+ TTS_MODEL_NAME = "espnet/kan_bayashi_ljspeech_vits" # Changed to VITS model
22
+ tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
23
+ tts_model = AutoModelForTextToSpeech.from_pretrained(TTS_MODEL_NAME)
24
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ tts_model = tts_model.to(tts_device)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # --- Vicuna (LLM) Setup ---
29
+ VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or your preferred Vicuna
30
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
31
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
32
  vicuna_model = AutoModelForCausalLM.from_pretrained(
 
60
 
61
  updated_state = state + "\n" + vicuna_response
62
  return updated_state, updated_state
63
+
64
+ # --- TTS Function (Simplified for VITS) ---
65
  def synthesize_speech(text):
66
  try:
67
+ inputs = tts_processor(text=text, return_tensors="pt")
68
+ inputs = {key: value.to(tts_device) for key, value in inputs.items()}
69
+ with torch.no_grad():
70
+ output = tts_model(**inputs).spectrogram # VITS models often output a spectrogram
71
+ # Convert spectrogram to waveform using the vocoder
72
+ waveform = tts_model.vocoder(output).squeeze()
 
 
 
 
 
 
73
 
 
 
 
 
 
74
  waveform_np = waveform.cpu().numpy()
75
+ #VITS models use a sample rate of 22050
76
+ return (22050, waveform_np)
 
77
 
78
  except Exception as e:
79
  print(e)
80
  return (None, None)
81
 
82
  # --- Gradio Interface ---
83
+ with gr.Blocks(title="Whisper, Vicuna, & VITS Demo") as demo: # Updated title
84
+ gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and VITS")
85
  gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
86
 
87
  with gr.Tab("Transcribe & Synthesize"):