Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
|
|
27 |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
28 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
29 |
VICUNA_MODEL_NAME,
|
30 |
-
torch_dtype=torch.
|
31 |
device_map="auto",
|
32 |
)
|
33 |
|
@@ -45,14 +45,14 @@ def process_audio(microphone, state, task="transcribe"):
|
|
45 |
prompt = f"{system_prompt}\nUser: {text}"
|
46 |
with torch.no_grad():
|
47 |
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
|
48 |
-
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=
|
49 |
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
|
50 |
vicuna_response = vicuna_response.replace(prompt, "").strip()
|
51 |
updated_state = state + "\n" + vicuna_response
|
52 |
try:
|
53 |
with torch.no_grad():
|
54 |
inputs = tts_processor(vicuna_response, return_tensors="pt").to('cuda')
|
55 |
-
output = tts_model.generate(**inputs, do_sample=
|
56 |
waveform_np = output[0].cpu().numpy()
|
57 |
audio_output = (tts_model.generation_config.sample_rate, waveform_np)
|
58 |
except Exception as e:
|
|
|
27 |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
|
28 |
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
29 |
VICUNA_MODEL_NAME,
|
30 |
+
torch_dtype=torch.float16,
|
31 |
device_map="auto",
|
32 |
)
|
33 |
|
|
|
45 |
prompt = f"{system_prompt}\nUser: {text}"
|
46 |
with torch.no_grad():
|
47 |
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
|
48 |
+
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=192)
|
49 |
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
|
50 |
vicuna_response = vicuna_response.replace(prompt, "").strip()
|
51 |
updated_state = state + "\n" + vicuna_response
|
52 |
try:
|
53 |
with torch.no_grad():
|
54 |
inputs = tts_processor(vicuna_response, return_tensors="pt").to('cuda')
|
55 |
+
output = tts_model.generate(**inputs, do_sample=False)
|
56 |
waveform_np = output[0].cpu().numpy()
|
57 |
audio_output = (tts_model.generation_config.sample_rate, waveform_np)
|
58 |
except Exception as e:
|