ford442 commited on
Commit
a736521
·
verified ·
1 Parent(s): f055d9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -10
app.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  import IPython.display as ipd
8
  import os
9
 
10
- ASR_MODEL_NAME = "openai/whisper-large-v2"
11
  asr_pipe = pipeline(
12
  task="automatic-speech-recognition",
13
  model=ASR_MODEL_NAME,
@@ -19,7 +19,7 @@ all_special_ids = asr_pipe.tokenizer.all_special_ids
19
  transcribe_token_id = all_special_ids[-5]
20
  translate_token_id = all_special_ids[-6]
21
 
22
- TTS_MODEL_NAME = "suno/bark-small"
23
  tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
24
  tts_model = AutoModel.from_pretrained(TTS_MODEL_NAME).to('cuda')
25
 
@@ -27,7 +27,7 @@ VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
27
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
28
  vicuna_model = AutoModelForCausalLM.from_pretrained(
29
  VICUNA_MODEL_NAME,
30
- torch_dtype=torch.float16,
31
  device_map="auto",
32
  )
33
 
@@ -35,7 +35,6 @@ vicuna_model = AutoModelForCausalLM.from_pretrained(
35
  def process_audio(microphone, state, task="transcribe"):
36
  if microphone is None:
37
  return state, state, None
38
-
39
  asr_pipe.model.config.forced_decoder_ids = [
40
  [2, transcribe_token_id if task == "transcribe" else translate_token_id]
41
  ]
@@ -44,14 +43,12 @@ def process_audio(microphone, state, task="transcribe"):
44
  You answer questions clearly and simply, using age-appropriate language.
45
  You are also a little bit silly and like to make jokes."""
46
  prompt = f"{system_prompt}\nUser: {text}"
47
-
48
  with torch.no_grad():
49
  vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
50
- vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
51
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
52
  vicuna_response = vicuna_response.replace(prompt, "").strip()
53
  updated_state = state + "\n" + vicuna_response
54
-
55
  try:
56
  with torch.no_grad():
57
  inputs = tts_processor(vicuna_response, return_tensors="pt").to('cuda')
@@ -61,11 +58,8 @@ def process_audio(microphone, state, task="transcribe"):
61
  except Exception as e:
62
  print(f"Error in speech synthesis: {e}")
63
  audio_output = None
64
-
65
  return updated_state, updated_state, audio_output
66
 
67
-
68
-
69
  with gr.Blocks(title="Whisper, Vicuna, & Bark Demo") as demo:
70
  gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Bark")
71
  gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
 
7
  import IPython.display as ipd
8
  import os
9
 
10
+ ASR_MODEL_NAME = "openai/whisper-medium.en"
11
  asr_pipe = pipeline(
12
  task="automatic-speech-recognition",
13
  model=ASR_MODEL_NAME,
 
19
  transcribe_token_id = all_special_ids[-5]
20
  translate_token_id = all_special_ids[-6]
21
 
22
+ TTS_MODEL_NAME = "suno/bark"
23
  tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
24
  tts_model = AutoModel.from_pretrained(TTS_MODEL_NAME).to('cuda')
25
 
 
27
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
28
  vicuna_model = AutoModelForCausalLM.from_pretrained(
29
  VICUNA_MODEL_NAME,
30
+ torch_dtype=torch.bfloat16,
31
  device_map="auto",
32
  )
33
 
 
35
  def process_audio(microphone, state, task="transcribe"):
36
  if microphone is None:
37
  return state, state, None
 
38
  asr_pipe.model.config.forced_decoder_ids = [
39
  [2, transcribe_token_id if task == "transcribe" else translate_token_id]
40
  ]
 
43
  You answer questions clearly and simply, using age-appropriate language.
44
  You are also a little bit silly and like to make jokes."""
45
  prompt = f"{system_prompt}\nUser: {text}"
 
46
  with torch.no_grad():
47
  vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
48
+ vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=256)
49
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
50
  vicuna_response = vicuna_response.replace(prompt, "").strip()
51
  updated_state = state + "\n" + vicuna_response
 
52
  try:
53
  with torch.no_grad():
54
  inputs = tts_processor(vicuna_response, return_tensors="pt").to('cuda')
 
58
  except Exception as e:
59
  print(f"Error in speech synthesis: {e}")
60
  audio_output = None
 
61
  return updated_state, updated_state, audio_output
62
 
 
 
63
  with gr.Blocks(title="Whisper, Vicuna, & Bark Demo") as demo:
64
  gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Bark")
65
  gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")