ford442 commited on
Commit
9ddc7f6
·
verified ·
1 Parent(s): 69cfc54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -47,27 +47,40 @@ _preload_and_load_models()
47
  tts = Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits",device='cuda')
48
 
49
  @spaces.GPU(required=True)
50
- def process_audio(microphone, state, task="transcribe"):
51
  if microphone is None:
52
  return state, state, None
53
- asr_pipe.model.config.forced_decoder_ids = [
54
- [2, transcribe_token_id if task == "transcribe" else translate_token_id]
55
- ]
56
  text = asr_pipe(microphone)["text"]
57
  system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
58
  You answer questions clearly and simply, using age-appropriate language.
59
  You are also a little bit silly and like to make jokes."""
60
  prompt = f"{system_prompt}\nUser: {text}"
61
- with torch.no_grad():
62
- vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
 
63
  vicuna_output = vicuna_model.generate(
64
  **vicuna_input,
65
- max_length = 96,
 
 
 
 
 
 
 
66
  min_new_tokens = 64,
67
  do_sample = True
68
  )
69
- vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
70
- vicuna_response = vicuna_response.replace(prompt, "").strip()
 
 
 
 
 
 
 
71
  updated_state = state + "\nUser: " + text + "\n" + "Tutor: " + vicuna_response
72
  try:
73
  #with torch.no_grad():
@@ -94,11 +107,11 @@ with gr.Blocks(title="Whisper, Vicuna, & TTS Demo") as demo: # Updated title
94
  mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here")
95
  transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
96
  audio_output = gr.Audio(label="Synthesized Speech", type="numpy", autoplay=True)
97
- audio_output = gr.Radio(["transcribe", "translate"]
98
  transcription_state = gr.State(value="")
99
  mic_input.change(
100
  fn=process_audio,
101
- inputs=[mic_input, transcription_state, , label="Task", value="transcribe")],
102
  outputs=[transcription_output, transcription_state, audio_output]
103
  )
104
 
 
47
  tts = Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits",device='cuda')
48
 
49
  @spaces.GPU(required=True)
50
+ def process_audio(microphone, state, answer_mode):
51
  if microphone is None:
52
  return state, state, None
53
+ asr_pipe.model.config.forced_decoder_ids = [[2, transcribe_token_id ]]
 
 
54
  text = asr_pipe(microphone)["text"]
55
  system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
56
  You answer questions clearly and simply, using age-appropriate language.
57
  You are also a little bit silly and like to make jokes."""
58
  prompt = f"{system_prompt}\nUser: {text}"
59
+ #with torch.no_grad():
60
+ vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
61
+ if answer_mode == 'slow':
62
  vicuna_output = vicuna_model.generate(
63
  **vicuna_input,
64
+ max_length = 512,
65
+ min_new_tokens = 256,
66
+ do_sample = True
67
+ )
68
+ if answer_mode == 'medium':
69
+ vicuna_output = vicuna_model.generate(
70
+ **vicuna_input,
71
+ max_length = 128,
72
  min_new_tokens = 64,
73
  do_sample = True
74
  )
75
+ if answer_mode == 'fast':
76
+ vicuna_output = vicuna_model.generate(
77
+ **vicuna_input,
78
+ max_length = 42,
79
+ min_new_tokens = 16,
80
+ do_sample = True
81
+ )
82
+ vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
83
+ vicuna_response = vicuna_response.replace(prompt, "").strip()
84
  updated_state = state + "\nUser: " + text + "\n" + "Tutor: " + vicuna_response
85
  try:
86
  #with torch.no_grad():
 
107
  mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here")
108
  transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
109
  audio_output = gr.Audio(label="Synthesized Speech", type="numpy", autoplay=True)
110
+ answer_mode = gr.Radio(["fast", "medium", "slow"]
111
  transcription_state = gr.State(value="")
112
  mic_input.change(
113
  fn=process_audio,
114
+ inputs=[mic_input, transcription_state, answer_mode)],
115
  outputs=[transcription_output, transcription_state, audio_output]
116
  )
117