File size: 5,986 Bytes
9f8fb3c
29b4682
1323ad0
f835a2f
1323ad0
821d0bc
b7d5671
 
 
 
f835a2f
a4e384d
 
 
d70f358
 
 
 
 
 
 
 
 
a5a3ff6
a736521
f4d388e
 
 
 
34172eb
f4d388e
c249a04
f4d388e
 
 
1323ad0
06fb866
 
687a46a
 
 
69cfc54
 
b1622fb
 
687a46a
 
b1622fb
34172eb
 
2a8752d
821d0bc
bef2e13
1c75248
34f0437
 
 
 
 
 
 
 
 
 
 
 
 
1c75248
 
 
 
6dcf7b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ddc7f6
b7d5671
1c75248
69cfc54
 
 
 
 
5ce404d
b7d5671
1c75248
a5a3ff6
1c75248
 
 
 
 
 
 
b1622fb
1c75248
b7d5671
1c75248
 
34f0437
 
 
1c75248
88ddd74
13f2800
1c75248
34f0437
1c75248
b7d5671
34f0437
 
 
 
 
 
1c75248
 
a4e384d
a2f4706
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import spaces
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel, LlamaTokenizer, LlamaForCausalLM
import numpy as np
from espnet2.bin.tts_inference import Text2Speech
#import yaml
#import os
import requests
import nltk
import scipy.io.wavfile
from flask import Flask, request, jsonify

app = Flask(__name__)  # Create the Flask app instance FIRST

try:
    nltk.data.find('taggers/averaged_perceptron_tagger_eng')
except LookupError:
    nltk.download('averaged_perceptron_tagger_eng')
try:
    nltk.data.find('corpora/cmudict')  # Check for cmudict
except LookupError:
    nltk.download('cmudict')

ASR_MODEL_NAME = "openai/whisper-medium.en"
asr_pipe = pipeline(
    task="automatic-speech-recognition",
    model=ASR_MODEL_NAME,
    chunk_length_s=30,
    device='cuda' if torch.cuda.is_available() else 'cpu', # Use GPU if available
)

all_special_ids = asr_pipe.tokenizer.all_special_ids
transcribe_token_id = all_special_ids[-5]
translate_token_id = all_special_ids[-6]

def _preload_and_load_models():
    global vicuna_tokenizer, vicuna_model
    #VICUNA_MODEL_NAME = "EleutherAI/gpt-neo-2.7B"  # Or another model
    #VICUNA_MODEL_NAME = "lmsys/vicuna-13b-v1.5"  # Or another model
    VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"  # Or another model
    vicuna_tokenizer = LlamaTokenizer.from_pretrained(VICUNA_MODEL_NAME)
    vicuna_model = LlamaForCausalLM.from_pretrained(
        VICUNA_MODEL_NAME,
        torch_dtype=torch.float16,
     #   device_map="auto", # or.to('cuda')
    ).to('cuda') # Explicitly move to CUDA after loading

_preload_and_load_models()

tts = Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits",device='cuda')

@app.route('/api/predict', methods=['POST'])  # The API endpoint
@spaces.GPU(required=True)
def process_audio(microphone, audio_upload, state, answer_mode):  # Added audio_upload
    audio_source = None
    if microphone:
        audio_source = microphone
        asr_pipe.model.config.forced_decoder_ids = [[2, transcribe_token_id ]]
        text = asr_pipe(audio_source)["text"]
    elif audio_upload:
        audio_source = audio_upload
        rate, data = scipy.io.wavfile.read(audio_source)
        asr_pipe.model.config.forced_decoder_ids = [[2, transcribe_token_id ]]
        text = asr_pipe(data)["text"]
    else:
        return state, state, None  # No audio input
    system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
        You answer questions clearly and simply, using age-appropriate language.
        You are also a little bit silly and like to make jokes."""
    prompt = f"{system_prompt}\nUser: {text}"
    with torch.no_grad():
        vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
        if answer_mode == 'slow':
            vicuna_output = vicuna_model.generate(
                **vicuna_input,
                max_length = 512,
                min_new_tokens = 256,
                do_sample = True
            )
        if answer_mode == 'medium':
            vicuna_output = vicuna_model.generate(
                **vicuna_input,
                max_length = 128,
                min_new_tokens = 64,
                do_sample = True
            )
        if answer_mode == 'fast':
            vicuna_output = vicuna_model.generate(
                **vicuna_input,
                max_length = 42,
                min_new_tokens = 16,
                do_sample = True
            )
        vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
    vicuna_response = vicuna_response.replace(prompt, "").strip()
    updated_state = state + "\nUser: " + text + "\n" + "Tutor: " + vicuna_response
    try:
        #with torch.no_grad():
        output = tts(vicuna_response)
        wav = output["wav"]
        sr = tts.fs
        audio_arr = wav.cpu().numpy()
        SAMPLE_RATE = sr
        audio_arr = audio_arr / np.abs(audio_arr).max()
        audio_output = (SAMPLE_RATE, audio_arr)
        #sf.write('generated_audio.wav', audio_arr, SAMPLE_RATE) # Removed writing to file
    except requests.exceptions.RequestException as e:
        print(f"Error in Hugging Face API request: {e}")
        audio_output = None
    except Exception as e:
        print(f"Error in speech synthesis: {e}")
        audio_output = None
    return updated_state, updated_state, audio_output

with gr.Blocks(title="Whisper, Vicuna, & TTS Demo") as demo:  # Updated title
    gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna and Hugging Face TTS")
    gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
    with gr.Tab("Transcribe & Synthesize"):
        with gr.Row(): # Added a row for better layout
            mic_input = gr.Audio(sources="microphone", type="filepath", label="Speak Here")
            audio_upload = gr.Audio(sources="upload", type="filepath", label="Or Upload Audio File") # Added upload component
        transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
        audio_output = gr.Audio(label="Synthesized Speech", type="numpy", autoplay=True)
        answer_mode = gr.Radio(["fast", "medium", "slow"], value='medium')
        transcription_state = gr.State(value="")

        mic_input.change(
            fn=process_audio,
            inputs=[mic_input, audio_upload, transcription_state, answer_mode], # Include audio_upload
            outputs=[transcription_output, transcription_state, audio_output]
        )
        audio_upload.change( # Added change event for upload
            fn=process_audio,
            inputs=[mic_input, audio_upload, transcription_state, answer_mode], # Include audio_upload
            outputs=[transcription_output, transcription_state, audio_output]
        )
        
if __name__ == '__main__':
    app.run(debug=True, port=5000) # Run Flask app
    demo.launch(share=False)