ford442 commited on
Commit
5304a42
·
verified ·
1 Parent(s): 61c01b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -19,19 +19,19 @@ transcribe_token_id = all_special_ids[-5]
19
  translate_token_id = all_special_ids[-6]
20
 
21
  # --- FastSpeech2 (TTS) Setup ---
22
- TTS_MODEL_NAME = "ford442/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech" after PR
23
 
24
- # 1. Load the config (now it should exist!)
25
- tts_config = AutoConfig.from_pretrained(TTS_MODEL_NAME)
26
 
27
- # 2. Dynamically import the model class. This is the correct way.
28
- module_name = tts_config.architectures[0] # Get model class name from config
29
- module = importlib.import_module(f".{tts_config._name_or_path}", package="transformers.models")
30
  model_class = getattr(module, tts_config.architectures[0])
31
 
32
 
33
  # 3. Load the processor and model.
34
- tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
35
  tts_model = model_class.from_pretrained(TTS_MODEL_NAME, config=tts_config)
36
 
37
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -44,7 +44,7 @@ vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
44
  vicuna_model = AutoModelForCausalLM.from_pretrained(
45
  VICUNA_MODEL_NAME,
46
  load_in_8bit=False,
47
- torch_dtype=torch.bfloat16,
48
  device_map="auto",
49
  )
50
 
 
19
  translate_token_id = all_special_ids[-6]
20
 
21
  # --- FastSpeech2 (TTS) Setup ---
22
+ TTS_MODEL_NAME = "your_username/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech" after PR
23
 
24
+ # 1. Load the config. We DO need trust_remote_code here, and we explain why below.
25
+ tts_config = AutoConfig.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
26
 
27
+ # 2. Dynamically import the model class. This is *still* the correct way.
28
+ module_name = tts_config.architectures[0]
29
+ module = importlib.import_module(f"transformers.models.{tts_config.model_type}") # Corrected module path
30
  model_class = getattr(module, tts_config.architectures[0])
31
 
32
 
33
  # 3. Load the processor and model.
34
+ tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True) # Keep this for now
35
  tts_model = model_class.from_pretrained(TTS_MODEL_NAME, config=tts_config)
36
 
37
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
 
44
  vicuna_model = AutoModelForCausalLM.from_pretrained(
45
  VICUNA_MODEL_NAME,
46
  load_in_8bit=False,
47
+ torch_dtype=torch.float32,
48
  device_map="auto",
49
  )
50