ford442 commited on
Commit
a5a3ff6
·
verified ·
1 Parent(s): 2dbdb2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -2,11 +2,13 @@ import spaces
2
  import torch
3
  import gradio as gr
4
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel
5
- import soundfile as sf
6
  import numpy as np
7
  from espnet2.bin.tts_inference import Text2Speech
8
- import yaml # Import yaml for config loading
9
- import os
 
 
10
 
11
  # Load Whisper model
12
  ASR_MODEL_NAME = "openai/whisper-medium.en"
@@ -52,15 +54,19 @@ def process_audio(microphone, state, task="transcribe"):
52
  vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=192)
53
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) # Access the first sequence [0]
54
  vicuna_response = vicuna_response.replace(prompt, "").strip()
55
- updated_state = state + "\n" + vicuna_response
56
  try:
57
  with torch.no_grad():
58
- wav, sr = tts([vicuna_response])[0]
 
 
 
59
  audio_arr = wav.cpu().numpy()
 
60
  SAMPLE_RATE = sr
61
  audio_arr = audio_arr / np.abs(audio_arr).max() # Normalize to -1 to 1
62
  audio_output = (SAMPLE_RATE, audio_arr)
63
- #sf.write('generated_audio.wav', audio_arr, SAMPLE_RATE)
64
  except requests.exceptions.RequestException as e:
65
  print(f"Error in Hugging Face API request: {e}")
66
  audio_output = None
@@ -79,7 +85,7 @@ with gr.Blocks(title="Whisper, Vicuna, & TTS Demo") as demo: # Updated title
79
  transcription_state = gr.State(value="")
80
  mic_input.change(
81
  fn=process_audio, # Call the combined function
82
- inputs=[mic_input, transcription_state],
83
  outputs=[transcription_output, transcription_state, audio_output]
84
  )
85
 
 
2
  import torch
3
  import gradio as gr
4
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel
5
+ #import soundfile as sf # Removed: Not directly used for outputting audio to Gradio
6
  import numpy as np
7
  from espnet2.bin.tts_inference import Text2Speech
8
+ import yaml # Import yaml for config loading (though not used in the current code, kept for potential future use)
9
+ import os # Kept for potential future use (e.g., if loading config from files)
10
+ import requests # Corrected: Import the 'requests' library
11
+
12
 
13
  # Load Whisper model
14
  ASR_MODEL_NAME = "openai/whisper-medium.en"
 
54
  vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=192)
55
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) # Access the first sequence [0]
56
  vicuna_response = vicuna_response.replace(prompt, "").strip()
57
+ updated_state = state + "\nUser: " + text + "\n" + "Tutor: " + vicuna_response # Include user input in state
58
  try:
59
  with torch.no_grad():
60
+ # The espnet TTS model outputs a dictionary
61
+ output = tts(vicuna_response)
62
+ wav = output["wav"]
63
+ sr = tts.fs # Get the sampling rate from the tts object
64
  audio_arr = wav.cpu().numpy()
65
+
66
  SAMPLE_RATE = sr
67
  audio_arr = audio_arr / np.abs(audio_arr).max() # Normalize to -1 to 1
68
  audio_output = (SAMPLE_RATE, audio_arr)
69
+ #sf.write('generated_audio.wav', audio_arr, SAMPLE_RATE) # Removed writing to file
70
  except requests.exceptions.RequestException as e:
71
  print(f"Error in Hugging Face API request: {e}")
72
  audio_output = None
 
85
  transcription_state = gr.State(value="")
86
  mic_input.change(
87
  fn=process_audio, # Call the combined function
88
+ inputs=[mic_input, transcription_state, gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")],
89
  outputs=[transcription_output, transcription_state, audio_output]
90
  )
91