Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
-
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer,
|
4 |
import soundfile as sf
|
5 |
import numpy as np
|
6 |
-
import importlib
|
7 |
|
8 |
# --- Whisper (ASR) Setup ---
|
9 |
ASR_MODEL_NAME = "openai/whisper-large-v2"
|
@@ -19,35 +19,49 @@ transcribe_token_id = all_special_ids[-5]
|
|
19 |
translate_token_id = all_special_ids[-6]
|
20 |
|
21 |
# --- FastSpeech2 (TTS) Setup ---
|
22 |
-
TTS_MODEL_NAME = "ford442/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech"
|
23 |
-
|
24 |
-
# 1. Load the
|
25 |
-
|
26 |
-
|
27 |
-
# 2.
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
|
38 |
tts_model = tts_model.to(tts_device)
|
39 |
|
|
|
40 |
# --- Vicuna (LLM) Setup ---
|
41 |
-
VICUNA_MODEL_NAME = "lmsys/vicuna-
|
42 |
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
44 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
45 |
VICUNA_MODEL_NAME,
|
46 |
-
load_in_8bit=
|
47 |
-
torch_dtype=torch.
|
48 |
device_map="auto",
|
49 |
)
|
50 |
|
|
|
|
|
|
|
51 |
# --- ASR Function ---
|
52 |
def transcribe_audio(microphone, state, task="transcribe"):
|
53 |
if microphone is None:
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
|
4 |
import soundfile as sf
|
5 |
import numpy as np
|
6 |
+
import importlib
|
7 |
|
8 |
# --- Whisper (ASR) Setup ---
|
9 |
ASR_MODEL_NAME = "openai/whisper-large-v2"
|
|
|
19 |
translate_token_id = all_special_ids[-6]
|
20 |
|
21 |
# --- FastSpeech2 (TTS) Setup ---
|
22 |
+
TTS_MODEL_NAME = "ford442/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech"
|
23 |
+
|
24 |
+
# 1. Load the processor (we still need trust_remote_code for this)
|
25 |
+
tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
|
26 |
+
|
27 |
+
# 2. Load the model using the *custom* modeling file. This is the key.
|
28 |
+
# We CANNOT use AutoConfig or AutoModel here.
|
29 |
+
model_file_path = f"models--{TTS_MODEL_NAME.replace('/', '--')}/snapshots"
|
30 |
+
|
31 |
+
import os
|
32 |
+
# Find the commit hash - this is needed because of the way Hugging Face caches models.
|
33 |
+
for d in os.listdir(os.path.expanduser(f"~/.cache/huggingface/hub/{model_file_path}")):
|
34 |
+
if os.path.isdir(os.path.expanduser(f"~/.cache/huggingface/hub/{model_file_path}/{d}")) and not d.startswith("."):
|
35 |
+
commit_hash = d
|
36 |
+
break
|
37 |
+
else:
|
38 |
+
raise ValueError ("Cannot find the model")
|
39 |
+
model_file_path += f"/{commit_hash}/modeling_fastspeech2.py"
|
40 |
+
|
41 |
+
# Use importlib to import the custom modeling file.
|
42 |
+
spec = importlib.util.spec_from_file_location("modeling_fastspeech2", os.path.expanduser(f"~/.cache/huggingface/hub/{model_file_path}"))
|
43 |
+
fastspeech2_module = importlib.util.module_from_spec(spec)
|
44 |
+
spec.loader.exec_module(fastspeech2_module)
|
45 |
+
tts_model = fastspeech2_module.FastSpeech2.from_pretrained(TTS_MODEL_NAME) #Use the actual class name!
|
46 |
|
47 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
|
48 |
tts_model = tts_model.to(tts_device)
|
49 |
|
50 |
+
|
51 |
# --- Vicuna (LLM) Setup ---
|
52 |
+
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Use a smaller model if needed
|
53 |
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
|
54 |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
55 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
56 |
VICUNA_MODEL_NAME,
|
57 |
+
load_in_8bit=True,
|
58 |
+
torch_dtype=torch.float16,
|
59 |
device_map="auto",
|
60 |
)
|
61 |
|
62 |
+
# --- ASR and TTS Functions (and Gradio Interface) ---
|
63 |
+
# (Same as before, but using tts_model and tts_processor)
|
64 |
+
|
65 |
# --- ASR Function ---
|
66 |
def transcribe_audio(microphone, state, task="transcribe"):
|
67 |
if microphone is None:
|