ford442 commited on
Commit
b56cef1
·
verified ·
1 Parent(s): 4d170e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -19
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoProcessor
4
  import soundfile as sf
5
  import numpy as np
6
- import importlib # Import the importlib module
7
 
8
  # --- Whisper (ASR) Setup ---
9
  ASR_MODEL_NAME = "openai/whisper-large-v2"
@@ -19,35 +19,49 @@ transcribe_token_id = all_special_ids[-5]
19
  translate_token_id = all_special_ids[-6]
20
 
21
  # --- FastSpeech2 (TTS) Setup ---
22
- TTS_MODEL_NAME = "ford442/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech" after PR
23
-
24
- # 1. Load the config. We DO need trust_remote_code here, and we explain why below.
25
- tts_config = AutoConfig.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
26
-
27
- # 2. Dynamically import the model class. This is *still* the correct way.
28
- module_name = tts_config.architectures[0]
29
- module = importlib.import_module(f"transformers.models.{tts_config.model_type}") # Corrected module path
30
- model_class = getattr(module, tts_config.architectures[0])
31
-
32
-
33
- # 3. Load the processor and model.
34
- tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True) # Keep this for now
35
- tts_model = model_class.from_pretrained(TTS_MODEL_NAME, config=tts_config)
 
 
 
 
 
 
 
 
 
 
36
 
37
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
38
  tts_model = tts_model.to(tts_device)
39
 
 
40
  # --- Vicuna (LLM) Setup ---
41
- VICUNA_MODEL_NAME = "lmsys/vicuna-33b-v1.3" # Use a smaller model if needed
42
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
43
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
44
  vicuna_model = AutoModelForCausalLM.from_pretrained(
45
  VICUNA_MODEL_NAME,
46
- load_in_8bit=False,
47
- torch_dtype=torch.float32,
48
  device_map="auto",
49
  )
50
 
 
 
 
51
  # --- ASR Function ---
52
  def transcribe_audio(microphone, state, task="transcribe"):
53
  if microphone is None:
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
4
  import soundfile as sf
5
  import numpy as np
6
+ import importlib
7
 
8
  # --- Whisper (ASR) Setup ---
9
  ASR_MODEL_NAME = "openai/whisper-large-v2"
 
19
  translate_token_id = all_special_ids[-6]
20
 
21
  # --- FastSpeech2 (TTS) Setup ---
22
+ TTS_MODEL_NAME = "ford442/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech"
23
+
24
+ # 1. Load the processor (we still need trust_remote_code for this)
25
+ tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
26
+
27
+ # 2. Load the model using the *custom* modeling file. This is the key.
28
+ # We CANNOT use AutoConfig or AutoModel here.
29
+ model_file_path = f"models--{TTS_MODEL_NAME.replace('/', '--')}/snapshots"
30
+
31
+ import os
32
+ # Find the commit hash - this is needed because of the way Hugging Face caches models.
33
+ for d in os.listdir(os.path.expanduser(f"~/.cache/huggingface/hub/{model_file_path}")):
34
+ if os.path.isdir(os.path.expanduser(f"~/.cache/huggingface/hub/{model_file_path}/{d}")) and not d.startswith("."):
35
+ commit_hash = d
36
+ break
37
+ else:
38
+ raise ValueError ("Cannot find the model")
39
+ model_file_path += f"/{commit_hash}/modeling_fastspeech2.py"
40
+
41
+ # Use importlib to import the custom modeling file.
42
+ spec = importlib.util.spec_from_file_location("modeling_fastspeech2", os.path.expanduser(f"~/.cache/huggingface/hub/{model_file_path}"))
43
+ fastspeech2_module = importlib.util.module_from_spec(spec)
44
+ spec.loader.exec_module(fastspeech2_module)
45
+ tts_model = fastspeech2_module.FastSpeech2.from_pretrained(TTS_MODEL_NAME) #Use the actual class name!
46
 
47
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
48
  tts_model = tts_model.to(tts_device)
49
 
50
+
51
  # --- Vicuna (LLM) Setup ---
52
+ VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Use a smaller model if needed
53
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
54
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
55
  vicuna_model = AutoModelForCausalLM.from_pretrained(
56
  VICUNA_MODEL_NAME,
57
+ load_in_8bit=True,
58
+ torch_dtype=torch.float16,
59
  device_map="auto",
60
  )
61
 
62
+ # --- ASR and TTS Functions (and Gradio Interface) ---
63
+ # (Same as before, but using tts_model and tts_processor)
64
+
65
  # --- ASR Function ---
66
  def transcribe_audio(microphone, state, task="transcribe"):
67
  if microphone is None: