ford442 commited on
Commit
f5ebbd5
·
verified ·
1 Parent(s): 892a58d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -35
app.py CHANGED
@@ -1,47 +1,95 @@
1
  import torch
2
- from transformers import AutoModelForTextToSpeech, AutoProcessor
3
- import soundfile as sf # For saving the audio
4
  import gradio as gr
 
 
 
5
 
6
- # 1. Choose the model and processor
7
- model_name = "facebook/fastspeech2-en-ljspeech"
 
 
 
 
 
 
 
 
 
 
8
 
9
- # 2. Load the processor and model
10
- processor = AutoProcessor.from_pretrained(model_name)
11
- model = AutoModelForTextToSpeech.from_pretrained(model_name)
 
 
 
12
 
13
- # 3. Move the model to the GPU (if available)
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- model = model.to(device)
 
 
 
 
 
 
 
16
 
17
- # 4. Define a function for text-to-speech
18
  def synthesize_speech(text):
19
  try:
20
- inputs = processor(text=text, return_tensors="pt")
21
- # Move input tensors to the same device as the model
22
- inputs = {key: value.to(device) for key, value in inputs.items()}
23
- with torch.no_grad(): # Disable gradient calculation during inference
24
- output = model(**inputs).waveform
25
- # Move to cpu before converting
26
  output = output.cpu()
27
-
28
- # Convert the output to a NumPy array (required by soundfile)
29
  waveform = output.squeeze().numpy()
30
-
31
- # Return the waveform and the sample rate (needed for Gradio)
32
- return (processor.feature_extractor.sampling_rate, waveform)
33
  except Exception as e:
34
- print (e)
35
- return (None, None) # in case of error
36
-
37
- # 5. create interface
38
- iface = gr.Interface(
39
- fn=synthesize_speech,
40
- inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
41
- outputs=gr.Audio(label="Generated Speech", type="numpy"),
42
- title="FastSpeech2 Text-to-Speech",
43
- description="Enter text to synthesize speech using FastSpeech2.",
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # 6. launch
47
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
 
2
  import gradio as gr
3
+ from transformers import pipeline, AutoModelForTextToSpeech, AutoProcessor
4
+ import soundfile as sf
5
+ import numpy as np # Import numpy
6
 
7
+ # --- Whisper (ASR) Setup ---
8
+ ASR_MODEL_NAME = "openai/whisper-large-v2"
9
+ asr_device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ asr_pipe = pipeline(
11
+ task="automatic-speech-recognition",
12
+ model=ASR_MODEL_NAME,
13
+ chunk_length_s=30,
14
+ device=asr_device,
15
+ )
16
+ all_special_ids = asr_pipe.tokenizer.all_special_ids
17
+ transcribe_token_id = all_special_ids[-5]
18
+ translate_token_id = all_special_ids[-6]
19
 
20
+ # --- FastSpeech2 (TTS) Setup ---
21
+ TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech"
22
+ tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
23
+ tts_model = AutoModelForTextToSpeech.from_pretrained(TTS_MODEL_NAME)
24
+ tts_device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ tts_model = tts_model.to(tts_device)
26
 
27
+ # --- ASR Function ---
28
+ def transcribe_audio(microphone, state, task="transcribe"):
29
+ if microphone is None: # Handle case where no audio is provided
30
+ return state, state
31
+ asr_pipe.model.config.forced_decoder_ids = [
32
+ [2, transcribe_token_id if task == "transcribe" else translate_token_id]
33
+ ]
34
+ text = asr_pipe(microphone)["text"]
35
+ updated_state = state + "\n" + text
36
+ return updated_state, updated_state
37
 
38
+ # --- TTS Function ---
39
  def synthesize_speech(text):
40
  try:
41
+ inputs = tts_processor(text=text, return_tensors="pt")
42
+ inputs = {key: value.to(tts_device) for key, value in inputs.items()}
43
+ with torch.no_grad():
44
+ output = tts_model(**inputs).waveform
 
 
45
  output = output.cpu()
 
 
46
  waveform = output.squeeze().numpy()
47
+ return (tts_processor.feature_extractor.sampling_rate, waveform)
 
 
48
  except Exception as e:
49
+ print(e)
50
+ return (None, None)
51
+
52
+ # --- Gradio Interface ---
53
+ with gr.Blocks(title="Whisper & FastSpeech2 Demo") as demo:
54
+ gr.Markdown("# Speech-to-Text-to-Speech Demo")
55
+ gr.Markdown("Speak into your microphone, get a transcription, and then hear it spoken back!")
56
+
57
+ with gr.Tab("Transcribe"):
58
+ mic_input = gr.Audio(source="microphone", type="filepath", optional=True)
59
+ transcription_output = gr.Textbox(lines=5, label="Transcription")
60
+ transcription_state = gr.State(value="") # State to accumulate transcription
61
+ transcribe_btn = gr.Button("Transcribe")
62
+
63
+ transcribe_btn.click(
64
+ fn=transcribe_audio,
65
+ inputs=[mic_input, transcription_state],
66
+ outputs=[transcription_output, transcription_state],
67
+ )
68
+
69
+ with gr.Tab("Synthesize"):
70
+ text_input = gr.Textbox(lines=5, label="Text to Speak", placeholder="Enter text here...")
71
+ audio_output = gr.Audio(label="Generated Speech", type="numpy")
72
+ synthesize_btn = gr.Button("Synthesize")
73
 
74
+ synthesize_btn.click(
75
+ fn=synthesize_speech,
76
+ inputs=text_input,
77
+ outputs=audio_output,
78
+ )
79
+ with gr.Tab("Combined"):
80
+ # combined interface. Speak to transcribe, auto synthesize
81
+ mic_input_c = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
82
+ transcription_output_c = gr.Textbox(lines=5, label="Transcription")
83
+ audio_output_c = gr.Audio(label="Synthesized Speech", type="numpy")
84
+ transcription_state_c = gr.State(value="") # State to accumulate transcription
85
+ #transcribe and output audio
86
+ mic_input_c.change(
87
+ fn=transcribe_audio,
88
+ inputs=[mic_input_c, transcription_state_c],
89
+ outputs=[transcription_output_c, transcription_state_c]
90
+ ).then(
91
+ fn=synthesize_speech,
92
+ inputs=transcription_output_c,
93
+ outputs=audio_output_c
94
+ )
95
+ demo.launch(enable_queue=True)