ford442 commited on
Commit
69cfc54
·
verified ·
1 Parent(s): 590e946

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -35,8 +35,8 @@ def _preload_and_load_models():
35
  #VICUNA_MODEL_NAME = "EleutherAI/gpt-neo-2.7B" # Or another model
36
  #VICUNA_MODEL_NAME = "lmsys/vicuna-13b-v1.5" # Or another model
37
  VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or another model
38
- vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
39
- vicuna_model = AutoModelForCausalLM.from_pretrained(
40
  VICUNA_MODEL_NAME,
41
  torch_dtype=torch.float16,
42
  # device_map="auto", # or.to('cuda')
@@ -60,16 +60,21 @@ def process_audio(microphone, state, task="transcribe"):
60
  prompt = f"{system_prompt}\nUser: {text}"
61
  with torch.no_grad():
62
  vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
63
- vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
 
 
 
 
 
64
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
65
  vicuna_response = vicuna_response.replace(prompt, "").strip()
66
  updated_state = state + "\nUser: " + text + "\n" + "Tutor: " + vicuna_response
67
  try:
68
- with torch.no_grad():
69
- output = tts(vicuna_response)
70
- wav = output["wav"]
71
- sr = tts.fs
72
- audio_arr = wav.cpu().numpy()
73
  SAMPLE_RATE = sr
74
  audio_arr = audio_arr / np.abs(audio_arr).max()
75
  audio_output = (SAMPLE_RATE, audio_arr)
@@ -89,10 +94,11 @@ with gr.Blocks(title="Whisper, Vicuna, & TTS Demo") as demo: # Updated title
89
  mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here")
90
  transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
91
  audio_output = gr.Audio(label="Synthesized Speech", type="numpy", autoplay=True)
 
92
  transcription_state = gr.State(value="")
93
  mic_input.change(
94
  fn=process_audio,
95
- inputs=[mic_input, transcription_state, gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")],
96
  outputs=[transcription_output, transcription_state, audio_output]
97
  )
98
 
 
35
  #VICUNA_MODEL_NAME = "EleutherAI/gpt-neo-2.7B" # Or another model
36
  #VICUNA_MODEL_NAME = "lmsys/vicuna-13b-v1.5" # Or another model
37
  VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or another model
38
+ vicuna_tokenizer = LlamaTokenizer.from_pretrained(VICUNA_MODEL_NAME)
39
+ vicuna_model = LlamaForCausalLM.from_pretrained(
40
  VICUNA_MODEL_NAME,
41
  torch_dtype=torch.float16,
42
  # device_map="auto", # or.to('cuda')
 
60
  prompt = f"{system_prompt}\nUser: {text}"
61
  with torch.no_grad():
62
  vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
63
+ vicuna_output = vicuna_model.generate(
64
+ **vicuna_input,
65
+ max_length = 96,
66
+ min_new_tokens = 64,
67
+ do_sample = True
68
+ )
69
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
70
  vicuna_response = vicuna_response.replace(prompt, "").strip()
71
  updated_state = state + "\nUser: " + text + "\n" + "Tutor: " + vicuna_response
72
  try:
73
+ #with torch.no_grad():
74
+ output = tts(vicuna_response)
75
+ wav = output["wav"]
76
+ sr = tts.fs
77
+ audio_arr = wav.cpu().numpy()
78
  SAMPLE_RATE = sr
79
  audio_arr = audio_arr / np.abs(audio_arr).max()
80
  audio_output = (SAMPLE_RATE, audio_arr)
 
94
  mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here")
95
  transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
96
  audio_output = gr.Audio(label="Synthesized Speech", type="numpy", autoplay=True)
97
+ audio_output = gr.Radio(["transcribe", "translate"]
98
  transcription_state = gr.State(value="")
99
  mic_input.change(
100
  fn=process_audio,
101
+ inputs=[mic_input, transcription_state, , label="Task", value="transcribe")],
102
  outputs=[transcription_output, transcription_state, audio_output]
103
  )
104