ford442 commited on
Commit
9f8fb3c
·
verified ·
1 Parent(s): cc30e2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -39
app.py CHANGED
@@ -1,4 +1,4 @@
1
- #import spaces
2
  import torch
3
  import gradio as gr
4
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoProcessor
@@ -7,60 +7,65 @@ import numpy as np
7
  import IPython.display as ipd
8
  import os
9
 
10
- ASR_MODEL_NAME = "openai/whisper-large-v2"
 
 
 
11
 
12
- asr_pipe = pipeline(
13
- task="automatic-speech-recognition",
14
- model=ASR_MODEL_NAME,
15
- chunk_length_s=30,
16
- device='cuda',
17
- )
18
 
19
- all_special_ids = asr_pipe.tokenizer.all_special_ids
20
- transcribe_token_id = all_special_ids[-5]
21
- translate_token_id = all_special_ids[-6]
22
 
23
- TTS_MODEL_NAME = "suno/bark-small"
24
- tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
25
- tts_model = AutoModel.from_pretrained(TTS_MODEL_NAME).to('cuda')
 
 
 
 
 
 
 
 
26
 
27
- VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
28
- vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
29
- vicuna_model = AutoModelForCausalLM.from_pretrained(
30
- VICUNA_MODEL_NAME,
31
- torch_dtype=torch.float16, # Use float16 for efficiency (if GPU supports it)
32
- device_map="auto", # Let transformers handle device placement
33
- ) #.to('cuda')
34
 
35
- def transcribe_audio(microphone, state, task="transcribe"):
36
  if microphone is None:
37
- return state, state
 
38
  asr_pipe.model.config.forced_decoder_ids = [
39
  [2, transcribe_token_id if task == "transcribe" else translate_token_id]
40
  ]
41
  text = asr_pipe(microphone)["text"]
42
  system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
43
- You answer questions clearly and simply, using age-appropriate language.
44
- You are also a little bit silly and like to make jokes."""
45
  prompt = f"{system_prompt}\nUser: {text}"
46
  with torch.no_grad():
47
- vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
48
  vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
49
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
50
  vicuna_response = vicuna_response.replace(prompt, "").strip()
51
  updated_state = state + "\n" + vicuna_response
52
- return updated_state, updated_state
53
 
54
- def synthesize_speech(text):
55
  try:
56
  with torch.no_grad():
57
- inputs = tts_processor(text, return_tensors="pt").to('cuda')
58
- output = tts_model.generate(**inputs, do_sample=True) #Bark generate
59
  waveform_np = output[0].cpu().numpy()
60
- return (tts_model.generation_config.sample_rate, waveform_np) #Bark sample rate
61
  except Exception as e:
62
- print(e)
63
- return (None, None)
 
 
 
64
 
65
  with gr.Blocks(title="Whisper, Vicuna, & Bark Demo") as demo:
66
  gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Bark")
@@ -71,13 +76,9 @@ with gr.Blocks(title="Whisper, Vicuna, & Bark Demo") as demo:
71
  audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
72
  transcription_state = gr.State(value="")
73
  mic_input.change(
74
- fn=transcribe_audio,
75
  inputs=[mic_input, transcription_state],
76
- outputs=[transcription_output, transcription_state]
77
- ).then(
78
- fn=synthesize_speech,
79
- inputs=transcription_output,
80
- outputs=audio_output
81
  )
82
 
83
  demo.launch(share=False)
 
1
+ import spaces
2
  import torch
3
  import gradio as gr
4
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoProcessor
 
7
  import IPython.display as ipd
8
  import os
9
 
10
+ # Define a decorator for GPU usage in Spaces
11
+ @spaces.GPU(required=True) # This decorator ensures GPU availability
12
+ def process_audio(microphone, state, task="transcribe"):
13
+ ASR_MODEL_NAME = "openai/whisper-large-v2"
14
 
15
+ asr_pipe = pipeline(
16
+ task="automatic-speech-recognition",
17
+ model=ASR_MODEL_NAME,
18
+ chunk_length_s=30,
19
+ device='cuda', # Explicitly set device to 'cuda' within the function
20
+ )
21
 
22
+ all_special_ids = asr_pipe.tokenizer.all_special_ids
23
+ transcribe_token_id = all_special_ids[-5]
24
+ translate_token_id = all_special_ids[-6]
25
 
26
+ TTS_MODEL_NAME = "suno/bark-small"
27
+ tts_processor = AutoProcessor.from_pretrained(TTS_MODEL_NAME)
28
+ tts_model = AutoModel.from_pretrained(TTS_MODEL_NAME).to('cuda') # Explicitly set device to 'cuda' within the function
29
+
30
+ VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
31
+ vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
32
+ vicuna_model = AutoModelForCausalLM.from_pretrained(
33
+ VICUNA_MODEL_NAME,
34
+ torch_dtype=torch.float16, # Use float16 for efficiency (if GPU supports it)
35
+ device_map="auto", # Let transformers handle device placement
36
+ ) #.to('cuda')
37
 
 
 
 
 
 
 
 
38
 
 
39
  if microphone is None:
40
+ return state, state, None # Return None for audio if no microphone input
41
+
42
  asr_pipe.model.config.forced_decoder_ids = [
43
  [2, transcribe_token_id if task == "transcribe" else translate_token_id]
44
  ]
45
  text = asr_pipe(microphone)["text"]
46
  system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
47
+ You answer questions clearly and simply, using age-appropriate language.
48
+ You are also a little bit silly and like to make jokes."""
49
  prompt = f"{system_prompt}\nUser: {text}"
50
  with torch.no_grad():
51
+ vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda') # Explicitly set device to 'cuda' within the function
52
  vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
53
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
54
  vicuna_response = vicuna_response.replace(prompt, "").strip()
55
  updated_state = state + "\n" + vicuna_response
 
56
 
 
57
  try:
58
  with torch.no_grad():
59
+ inputs = tts_processor(vicuna_response, return_tensors="pt").to('cuda') # Explicitly set device to 'cuda' within the function
60
+ output = tts_model.generate(**inputs, do_sample=True) # Bark generate
61
  waveform_np = output[0].cpu().numpy()
62
+ audio_output = (tts_model.generation_config.sample_rate, waveform_np) # Bark sample rate
63
  except Exception as e:
64
+ print(f"Error in speech synthesis: {e}")
65
+ audio_output = None
66
+
67
+ return updated_state, updated_state, audio_output
68
+
69
 
70
  with gr.Blocks(title="Whisper, Vicuna, & Bark Demo") as demo:
71
  gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Bark")
 
76
  audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
77
  transcription_state = gr.State(value="")
78
  mic_input.change(
79
+ fn=process_audio, # Call the combined function
80
  inputs=[mic_input, transcription_state],
81
+ outputs=[transcription_output, transcription_state, audio_output]
 
 
 
 
82
  )
83
 
84
  demo.launch(share=False)