Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
-
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer,
|
4 |
import soundfile as sf
|
5 |
import numpy as np
|
|
|
6 |
|
7 |
# --- Whisper (ASR) Setup ---
|
8 |
ASR_MODEL_NAME = "openai/whisper-large-v2"
|
@@ -18,17 +19,26 @@ transcribe_token_id = all_special_ids[-5]
|
|
18 |
translate_token_id = all_special_ids[-6]
|
19 |
|
20 |
# --- FastSpeech2 (TTS) Setup ---
|
21 |
-
|
22 |
-
|
23 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
|
25 |
-
tts_model =
|
26 |
|
27 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
tts_model = tts_model.to(tts_device)
|
29 |
|
30 |
# --- Vicuna (LLM) Setup ---
|
31 |
-
VICUNA_MODEL_NAME = "lmsys/vicuna-33b-v1.3" #
|
32 |
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
34 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
@@ -38,10 +48,6 @@ vicuna_model = AutoModelForCausalLM.from_pretrained(
|
|
38 |
device_map="auto",
|
39 |
)
|
40 |
|
41 |
-
# --- ASR and TTS Functions (and Gradio Interface) ---
|
42 |
-
# (Rest of your code - transcribe_audio, synthesize_speech, Gradio setup)
|
43 |
-
# ... (same as before, but using tts_model, tts_processor) ...
|
44 |
-
|
45 |
# --- ASR Function ---
|
46 |
def transcribe_audio(microphone, state, task="transcribe"):
|
47 |
if microphone is None:
|
@@ -73,7 +79,7 @@ def synthesize_speech(text):
|
|
73 |
inputs = tts_processor(text=text, return_tensors="pt")
|
74 |
inputs = {key: value.to(tts_device) for key, value in inputs.items()}
|
75 |
with torch.no_grad():
|
76 |
-
|
77 |
output = output.cpu()
|
78 |
waveform = output.squeeze().numpy()
|
79 |
return (tts_processor.feature_extractor.sampling_rate, waveform)
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoProcessor
|
4 |
import soundfile as sf
|
5 |
import numpy as np
|
6 |
+
import importlib # Import the importlib module
|
7 |
|
8 |
# --- Whisper (ASR) Setup ---
|
9 |
ASR_MODEL_NAME = "openai/whisper-large-v2"
|
|
|
19 |
translate_token_id = all_special_ids[-6]
|
20 |
|
21 |
# --- FastSpeech2 (TTS) Setup ---
|
22 |
+
TTS_MODEL_NAME = "your_username/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech" after PR
|
23 |
+
|
24 |
+
# 1. Load the config (now it should exist!)
|
25 |
+
tts_config = AutoConfig.from_pretrained(TTS_MODEL_NAME)
|
26 |
+
|
27 |
+
# 2. Dynamically import the model class. This is the correct way.
|
28 |
+
module_name = tts_config.architectures[0] # Get model class name from config
|
29 |
+
module = importlib.import_module(f".{tts_config._name_or_path}", package="transformers.models")
|
30 |
+
model_class = getattr(module, tts_config.architectures[0])
|
31 |
+
|
32 |
+
|
33 |
+
# 3. Load the processor and model.
|
34 |
tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
|
35 |
+
tts_model = model_class.from_pretrained(TTS_MODEL_NAME, config=tts_config)
|
36 |
|
37 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
|
38 |
tts_model = tts_model.to(tts_device)
|
39 |
|
40 |
# --- Vicuna (LLM) Setup ---
|
41 |
+
VICUNA_MODEL_NAME = "lmsys/vicuna-33b-v1.3" # Use a smaller model if needed
|
42 |
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
44 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
|
|
48 |
device_map="auto",
|
49 |
)
|
50 |
|
|
|
|
|
|
|
|
|
51 |
# --- ASR Function ---
|
52 |
def transcribe_audio(microphone, state, task="transcribe"):
|
53 |
if microphone is None:
|
|
|
79 |
inputs = tts_processor(text=text, return_tensors="pt")
|
80 |
inputs = {key: value.to(tts_device) for key, value in inputs.items()}
|
81 |
with torch.no_grad():
|
82 |
+
output = tts_model(**inputs).waveform # Use .waveform
|
83 |
output = output.cpu()
|
84 |
waveform = output.squeeze().numpy()
|
85 |
return (tts_processor.feature_extractor.sampling_rate, waveform)
|