ford442 commited on
Commit
df3b410
·
verified ·
1 Parent(s): 6199585

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModelForTextToSpeech, AutoProcessor
4
  import soundfile as sf
5
  import numpy as np
 
6
 
7
  # --- Whisper (ASR) Setup ---
8
  ASR_MODEL_NAME = "openai/whisper-large-v2"
@@ -18,17 +19,26 @@ transcribe_token_id = all_special_ids[-5]
18
  translate_token_id = all_special_ids[-6]
19
 
20
  # --- FastSpeech2 (TTS) Setup ---
21
- # Use your fork, or the original if/when the change is merged.
22
- TTS_MODEL_NAME = "ford442/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech"
23
- # Now we can use AutoModelForTextToSpeech!
 
 
 
 
 
 
 
 
 
24
  tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
25
- tts_model = AutoModelForTextToSpeech.from_pretrained(TTS_MODEL_NAME)
26
 
27
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
28
  tts_model = tts_model.to(tts_device)
29
 
30
  # --- Vicuna (LLM) Setup ---
31
- VICUNA_MODEL_NAME = "lmsys/vicuna-33b-v1.3" # Or your preferred Vicuna model
32
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
33
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
34
  vicuna_model = AutoModelForCausalLM.from_pretrained(
@@ -38,10 +48,6 @@ vicuna_model = AutoModelForCausalLM.from_pretrained(
38
  device_map="auto",
39
  )
40
 
41
- # --- ASR and TTS Functions (and Gradio Interface) ---
42
- # (Rest of your code - transcribe_audio, synthesize_speech, Gradio setup)
43
- # ... (same as before, but using tts_model, tts_processor) ...
44
-
45
  # --- ASR Function ---
46
  def transcribe_audio(microphone, state, task="transcribe"):
47
  if microphone is None:
@@ -73,7 +79,7 @@ def synthesize_speech(text):
73
  inputs = tts_processor(text=text, return_tensors="pt")
74
  inputs = {key: value.to(tts_device) for key, value in inputs.items()}
75
  with torch.no_grad():
76
- output = tts_model(**inputs).waveform
77
  output = output.cpu()
78
  waveform = output.squeeze().numpy()
79
  return (tts_processor.feature_extractor.sampling_rate, waveform)
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoProcessor
4
  import soundfile as sf
5
  import numpy as np
6
+ import importlib # Import the importlib module
7
 
8
  # --- Whisper (ASR) Setup ---
9
  ASR_MODEL_NAME = "openai/whisper-large-v2"
 
19
  translate_token_id = all_special_ids[-6]
20
 
21
  # --- FastSpeech2 (TTS) Setup ---
22
+ TTS_MODEL_NAME = "your_username/fastspeech2-en-ljspeech" # OR "facebook/fastspeech2-en-ljspeech" after PR
23
+
24
+ # 1. Load the config (now it should exist!)
25
+ tts_config = AutoConfig.from_pretrained(TTS_MODEL_NAME)
26
+
27
+ # 2. Dynamically import the model class. This is the correct way.
28
+ module_name = tts_config.architectures[0] # Get model class name from config
29
+ module = importlib.import_module(f".{tts_config._name_or_path}", package="transformers.models")
30
+ model_class = getattr(module, tts_config.architectures[0])
31
+
32
+
33
+ # 3. Load the processor and model.
34
  tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
35
+ tts_model = model_class.from_pretrained(TTS_MODEL_NAME, config=tts_config)
36
 
37
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
38
  tts_model = tts_model.to(tts_device)
39
 
40
  # --- Vicuna (LLM) Setup ---
41
+ VICUNA_MODEL_NAME = "lmsys/vicuna-33b-v1.3" # Use a smaller model if needed
42
  vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
43
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
44
  vicuna_model = AutoModelForCausalLM.from_pretrained(
 
48
  device_map="auto",
49
  )
50
 
 
 
 
 
51
  # --- ASR Function ---
52
  def transcribe_audio(microphone, state, task="transcribe"):
53
  if microphone is None:
 
79
  inputs = tts_processor(text=text, return_tensors="pt")
80
  inputs = {key: value.to(tts_device) for key, value in inputs.items()}
81
  with torch.no_grad():
82
+ output = tts_model(**inputs).waveform # Use .waveform
83
  output = output.cpu()
84
  waveform = output.squeeze().numpy()
85
  return (tts_processor.feature_extractor.sampling_rate, waveform)