Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -19,19 +19,19 @@ transcribe_token_id = all_special_ids[-5]
|
|
19 |
translate_token_id = all_special_ids[-6]
|
20 |
|
21 |
# --- FastSpeech2 (TTS) Setup ---
|
22 |
-
TTS_MODEL_NAME = "
|
23 |
|
24 |
-
# 1. Load the config
|
25 |
-
tts_config = AutoConfig.from_pretrained(TTS_MODEL_NAME)
|
26 |
|
27 |
-
# 2. Dynamically import the model class.
|
28 |
-
module_name = tts_config.architectures[0]
|
29 |
-
module = importlib.import_module(f".{tts_config.
|
30 |
model_class = getattr(module, tts_config.architectures[0])
|
31 |
|
32 |
|
33 |
# 3. Load the processor and model.
|
34 |
-
tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
|
35 |
tts_model = model_class.from_pretrained(TTS_MODEL_NAME, config=tts_config)
|
36 |
|
37 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -44,7 +44,7 @@ vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
|
44 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
45 |
VICUNA_MODEL_NAME,
|
46 |
load_in_8bit=False,
|
47 |
-
torch_dtype=torch.
|
48 |
device_map="auto",
|
49 |
)
|
50 |
|
|
|
19 |
translate_token_id = all_special_ids[-6]
|
20 |
|
21 |
# --- FastSpeech2 (TTS) Setup ---
|
22 |
+
TTS_MODEL_NAME = "your_username/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech" after PR
|
23 |
|
24 |
+
# 1. Load the config. We DO need trust_remote_code here, and we explain why below.
|
25 |
+
tts_config = AutoConfig.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
|
26 |
|
27 |
+
# 2. Dynamically import the model class. This is *still* the correct way.
|
28 |
+
module_name = tts_config.architectures[0]
|
29 |
+
module = importlib.import_module(f"transformers.models.{tts_config.model_type}") # Corrected module path
|
30 |
model_class = getattr(module, tts_config.architectures[0])
|
31 |
|
32 |
|
33 |
# 3. Load the processor and model.
|
34 |
+
tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True) # Keep this for now
|
35 |
tts_model = model_class.from_pretrained(TTS_MODEL_NAME, config=tts_config)
|
36 |
|
37 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
44 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
45 |
VICUNA_MODEL_NAME,
|
46 |
load_in_8bit=False,
|
47 |
+
torch_dtype=torch.float32,
|
48 |
device_map="auto",
|
49 |
)
|
50 |
|