ford442 commited on
Commit
892a58d
·
verified ·
1 Parent(s): 256d229

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -1
app.py CHANGED
@@ -1,3 +1,47 @@
 
 
 
1
  import gradio as gr
2
 
3
- gr.Interface.load("models/facebook/fastspeech2-en-ljspeech").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForTextToSpeech, AutoProcessor
3
+ import soundfile as sf # For saving the audio
4
  import gradio as gr
5
 
6
+ # 1. Choose the model and processor
7
+ model_name = "facebook/fastspeech2-en-ljspeech"
8
+
9
+ # 2. Load the processor and model
10
+ processor = AutoProcessor.from_pretrained(model_name)
11
+ model = AutoModelForTextToSpeech.from_pretrained(model_name)
12
+
13
+ # 3. Move the model to the GPU (if available)
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model = model.to(device)
16
+
17
+ # 4. Define a function for text-to-speech
18
+ def synthesize_speech(text):
19
+ try:
20
+ inputs = processor(text=text, return_tensors="pt")
21
+ # Move input tensors to the same device as the model
22
+ inputs = {key: value.to(device) for key, value in inputs.items()}
23
+ with torch.no_grad(): # Disable gradient calculation during inference
24
+ output = model(**inputs).waveform
25
+ # Move to cpu before converting
26
+ output = output.cpu()
27
+
28
+ # Convert the output to a NumPy array (required by soundfile)
29
+ waveform = output.squeeze().numpy()
30
+
31
+ # Return the waveform and the sample rate (needed for Gradio)
32
+ return (processor.feature_extractor.sampling_rate, waveform)
33
+ except Exception as e:
34
+ print (e)
35
+ return (None, None) # in case of error
36
+
37
+ # 5. create interface
38
+ iface = gr.Interface(
39
+ fn=synthesize_speech,
40
+ inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
41
+ outputs=gr.Audio(label="Generated Speech", type="numpy"),
42
+ title="FastSpeech2 Text-to-Speech",
43
+ description="Enter text to synthesize speech using FastSpeech2.",
44
+ )
45
+
46
+ # 6. launch
47
+ iface.launch()