ford442 commited on
Commit
821d0bc
·
verified ·
1 Parent(s): 2048877

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -17
app.py CHANGED
@@ -4,8 +4,9 @@ import gradio as gr
4
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel
5
  import soundfile as sf
6
  import numpy as np
7
- import requests
8
- import os
 
9
 
10
  # Load Whisper model
11
  ASR_MODEL_NAME = "openai/whisper-medium.en"
@@ -32,11 +33,28 @@ def _preload_and_load_models():
32
 
33
  _preload_and_load_models()
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @spaces.GPU(required=True)
36
  def process_audio(microphone, state, task="transcribe"):
37
  if microphone is None:
38
  return state, state, None
39
-
40
  asr_pipe.model.config.forced_decoder_ids = [
41
  [2, transcribe_token_id if task == "transcribe" else translate_token_id]
42
  ]
@@ -45,27 +63,15 @@ def process_audio(microphone, state, task="transcribe"):
45
  You answer questions clearly and simply, using age-appropriate language.
46
  You are also a little bit silly and like to make jokes."""
47
  prompt = f"{system_prompt}\nUser: {text}"
48
-
49
  with torch.no_grad():
50
  vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
51
  vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=192)
52
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) # Access the first sequence [0]
53
  vicuna_response = vicuna_response.replace(prompt, "").strip()
54
  updated_state = state + "\n" + vicuna_response
55
-
56
  try:
57
- API_URL = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_vits"
58
- headers = {"Authorization": f"Bearer {os.environ['HUGGINGFACEHUB_API_TOKEN']}"}
59
- payloads = {'inputs': vicuna_response} # Use Vicuna's response for TTS
60
- response = requests.post(API_URL, headers=headers, json=payloads)
61
- response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
62
-
63
- audio_data = response.content
64
- # Convert bytes to numpy array (adjust sampling rate if needed)
65
- audio_arr = np.frombuffer(audio_data, dtype=np.int16) # Assumes 16-bit PCM
66
- SAMPLE_RATE = 22050 # Common for this model; you might need to check the actual value
67
- audio_arr = audio_arr.reshape(-1, 1).astype(np.float32) / np.iinfo(np.int16).max # Normalize
68
- audio_arr = audio_arr.flatten() # Make it 1D
69
  audio_output = (SAMPLE_RATE, audio_arr)
70
  #sf.write('generated_audio.wav', audio_arr, SAMPLE_RATE)
71
  except requests.exceptions.RequestException as e:
 
4
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel
5
  import soundfile as sf
6
  import numpy as np
7
+ # Import the TTS pipeline
8
+ from espnet2.bin.tts_inference import Text2Speech
9
+ from espnet2.utils.types import get_fastspeech_config
10
 
11
  # Load Whisper model
12
  ASR_MODEL_NAME = "openai/whisper-medium.en"
 
33
 
34
  _preload_and_load_models()
35
 
36
+
37
+ # Load the TTS model locally
38
+ TTS_MODEL_PATH = "path/to/your/espnet/kan-bayashi_ljspeech_vits" # Replace with the actual path
39
+ TTS_CONFIG_PATH = os.path.join(TTS_MODEL_PATH, "config.yaml") # Replace with your config.yaml
40
+ TTS_VOCAB_PATH = os.path.join(TTS_MODEL_PATH, "train.json") # Replace with your train.json
41
+
42
+ tts = Text2Speech(
43
+ TTS_MODEL_PATH,
44
+ TTS_CONFIG_PATH,
45
+ TTS_VOCAB_PATH,
46
+ device="cuda" if torch.cuda.is_available() else "cpu",
47
+ # You can customize the speed and other parameters here if needed
48
+ )
49
+ fastspeech_config = get_fastspeech_config(TTS_CONFIG_PATH)
50
+
51
+
52
+
53
+
54
  @spaces.GPU(required=True)
55
  def process_audio(microphone, state, task="transcribe"):
56
  if microphone is None:
57
  return state, state, None
 
58
  asr_pipe.model.config.forced_decoder_ids = [
59
  [2, transcribe_token_id if task == "transcribe" else translate_token_id]
60
  ]
 
63
  You answer questions clearly and simply, using age-appropriate language.
64
  You are also a little bit silly and like to make jokes."""
65
  prompt = f"{system_prompt}\nUser: {text}"
 
66
  with torch.no_grad():
67
  vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
68
  vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=192)
69
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) # Access the first sequence [0]
70
  vicuna_response = vicuna_response.replace(prompt, "").strip()
71
  updated_state = state + "\n" + vicuna_response
 
72
  try:
73
+ wav, sr = tts([vicuna_response])[0]
74
+ audio_arr = wav.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
75
  audio_output = (SAMPLE_RATE, audio_arr)
76
  #sf.write('generated_audio.wav', audio_arr, SAMPLE_RATE)
77
  except requests.exceptions.RequestException as e: