ford442 commited on
Commit
d7978a0
·
verified ·
1 Parent(s): 6fe541d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
4
- # Note: No AutoModelForTextToSpeech needed
5
  import soundfile as sf
6
  import numpy as np
7
- from espnet2.bin.tts_inference import Text2Speech # Import Text2Speech from espnet2
8
-
 
 
9
 
10
  # --- Whisper (ASR) Setup ---
11
  ASR_MODEL_NAME = "openai/whisper-large-v2"
@@ -20,12 +21,24 @@ all_special_ids = asr_pipe.tokenizer.all_special_ids
20
  transcribe_token_id = all_special_ids[-5]
21
  translate_token_id = all_special_ids[-6]
22
 
23
- # --- VITS (TTS) Setup - Using espnet2 ---
24
  TTS_MODEL_NAME = "espnet/kan_bayashi_ljspeech_vits"
25
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
- # Load the Text2Speech model from espnet2
28
- tts_model = Text2Speech.from_pretrained(TTS_MODEL_NAME).to(tts_device)
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  # --- Vicuna (LLM) Setup ---
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
 
4
  import soundfile as sf
5
  import numpy as np
6
+ # No more fairseq imports
7
+ # No espnet2 imports here
8
+ import IPython.display as ipd # For notebook use (optional)
9
+ import os, pathlib
10
 
11
  # --- Whisper (ASR) Setup ---
12
  ASR_MODEL_NAME = "openai/whisper-large-v2"
 
21
  transcribe_token_id = all_special_ids[-5]
22
  translate_token_id = all_special_ids[-6]
23
 
24
+ # --- VITS (TTS) Setup ---
25
  TTS_MODEL_NAME = "espnet/kan_bayashi_ljspeech_vits"
26
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
27
 
28
+ # Download the ESPnet model (if it hasn't been downloaded yet)
29
+ # We use a try-except block here to handle potential download issues gracefully.
30
+ try:
31
+ from espnet_model_zoo.downloader import ModelDownloader
32
+ d = ModelDownloader()
33
+ tts_model_path = d.download_and_unpack(TTS_MODEL_NAME)
34
+ except Exception as e:
35
+ print(f"Error downloading ESPnet model: {e}")
36
+ print("Make sure you have espnet_model_zoo installed: `pip install espnet_model_zoo`")
37
+ raise # Re-raise the exception to stop execution
38
+
39
+ #Now import and set up the text to speech
40
+ from espnet2.bin.tts_inference import Text2Speech
41
+ tts_model = Text2Speech(tts_model_path, device=tts_device)
42
 
43
 
44
  # --- Vicuna (LLM) Setup ---