ford442 commited on
Commit
c249a04
·
verified ·
1 Parent(s): 6bf9fab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -26
app.py CHANGED
@@ -7,41 +7,31 @@ import numpy as np
7
  import IPython.display as ipd
8
  import os
9
 
10
- # --- Whisper (ASR) Setup ---
11
  ASR_MODEL_NAME = "openai/whisper-large-v2"
12
- asr_device = "cuda" if torch.cuda.is_available() else "cpu"
13
  asr_pipe = pipeline(
14
  task="automatic-speech-recognition",
15
  model=ASR_MODEL_NAME,
16
  chunk_length_s=30,
17
- device=asr_device,
18
  )
 
19
  all_special_ids = asr_pipe.tokenizer.all_special_ids
20
  transcribe_token_id = all_special_ids[-5]
21
  translate_token_id = all_special_ids[-6]
22
 
23
- # --- Bark (TTS) Setup ---
24
  TTS_MODEL_NAME = "suno/bark-small"
25
- tts_device = "cuda" if torch.cuda.is_available() else "cpu"
26
-
27
- # Load the Bark model and processor
28
  tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
29
- tts_model = AutoModel.from_pretrained(TTS_MODEL_NAME).to(tts_device)
30
-
31
 
32
- # --- Vicuna (LLM) Setup ---
33
  VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
34
- vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
35
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
36
  vicuna_model = AutoModelForCausalLM.from_pretrained(
37
  VICUNA_MODEL_NAME,
38
- # load_in_8bit=True, # Remove 8-bit quantization (no bitsandbytes)
39
  torch_dtype=torch.float16, # Use float16 for efficiency (if GPU supports it)
40
  device_map="auto", # Let transformers handle device placement
41
- )
42
-
43
 
44
- # --- ASR Function ---
45
  def transcribe_audio(microphone, state, task="transcribe"):
46
  if microphone is None:
47
  return state, state
@@ -49,48 +39,37 @@ def transcribe_audio(microphone, state, task="transcribe"):
49
  [2, transcribe_token_id if task == "transcribe" else translate_token_id]
50
  ]
51
  text = asr_pipe(microphone)["text"]
52
-
53
- # --- VICUNA INTEGRATION ---
54
  system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
55
  You answer questions clearly and simply, using age-appropriate language.
56
  You are also a little bit silly and like to make jokes."""
57
-
58
  prompt = f"{system_prompt}\nUser: {text}"
59
-
60
  with torch.no_grad():
61
  vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
62
  vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
63
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
64
  vicuna_response = vicuna_response.replace(prompt, "").strip()
65
-
66
  updated_state = state + "\n" + vicuna_response
67
  return updated_state, updated_state
68
 
69
- # --- TTS Function (Using Bark) ---
70
  def synthesize_speech(text):
71
  try:
72
  with torch.no_grad():
73
  inputs = tts_processor(text, return_tensors="pt").to(tts_device)
74
  output = tts_model.generate(**inputs, do_sample=True) #Bark generate
75
-
76
  waveform_np = output[0].cpu().numpy()
77
  return (tts_model.generation_config.sample_rate, waveform_np) #Bark sample rate
78
-
79
  except Exception as e:
80
  print(e)
81
  return (None, None)
82
 
83
- # --- Gradio Interface ---
84
  with gr.Blocks(title="Whisper, Vicuna, & Bark Demo") as demo:
85
  gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Bark")
86
  gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
87
-
88
  with gr.Tab("Transcribe & Synthesize"):
89
  mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here")
90
  transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
91
  audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
92
  transcription_state = gr.State(value="")
93
-
94
  mic_input.change(
95
  fn=transcribe_audio,
96
  inputs=[mic_input, transcription_state],
 
7
  import IPython.display as ipd
8
  import os
9
 
 
10
  ASR_MODEL_NAME = "openai/whisper-large-v2"
11
+
12
  asr_pipe = pipeline(
13
  task="automatic-speech-recognition",
14
  model=ASR_MODEL_NAME,
15
  chunk_length_s=30,
16
+ device='cuda',
17
  )
18
+
19
  all_special_ids = asr_pipe.tokenizer.all_special_ids
20
  transcribe_token_id = all_special_ids[-5]
21
  translate_token_id = all_special_ids[-6]
22
 
 
23
  TTS_MODEL_NAME = "suno/bark-small"
 
 
 
24
  tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
25
+ tts_model = AutoModel.from_pretrained(TTS_MODEL_NAME).to('cuda')
 
26
 
 
27
  VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
 
28
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
29
  vicuna_model = AutoModelForCausalLM.from_pretrained(
30
  VICUNA_MODEL_NAME,
 
31
  torch_dtype=torch.float16, # Use float16 for efficiency (if GPU supports it)
32
  device_map="auto", # Let transformers handle device placement
33
+ ).to('cuda')
 
34
 
 
35
  def transcribe_audio(microphone, state, task="transcribe"):
36
  if microphone is None:
37
  return state, state
 
39
  [2, transcribe_token_id if task == "transcribe" else translate_token_id]
40
  ]
41
  text = asr_pipe(microphone)["text"]
 
 
42
  system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
43
  You answer questions clearly and simply, using age-appropriate language.
44
  You are also a little bit silly and like to make jokes."""
 
45
  prompt = f"{system_prompt}\nUser: {text}"
 
46
  with torch.no_grad():
47
  vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
48
  vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
49
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
50
  vicuna_response = vicuna_response.replace(prompt, "").strip()
 
51
  updated_state = state + "\n" + vicuna_response
52
  return updated_state, updated_state
53
 
 
54
  def synthesize_speech(text):
55
  try:
56
  with torch.no_grad():
57
  inputs = tts_processor(text, return_tensors="pt").to(tts_device)
58
  output = tts_model.generate(**inputs, do_sample=True) #Bark generate
 
59
  waveform_np = output[0].cpu().numpy()
60
  return (tts_model.generation_config.sample_rate, waveform_np) #Bark sample rate
 
61
  except Exception as e:
62
  print(e)
63
  return (None, None)
64
 
 
65
  with gr.Blocks(title="Whisper, Vicuna, & Bark Demo") as demo:
66
  gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Bark")
67
  gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
 
68
  with gr.Tab("Transcribe & Synthesize"):
69
  mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here")
70
  transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
71
  audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
72
  transcription_state = gr.State(value="")
 
73
  mic_input.change(
74
  fn=transcribe_audio,
75
  inputs=[mic_input, transcription_state],