ford442 commited on
Commit
eb936fd
·
verified ·
1 Parent(s): a736521

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -27,7 +27,7 @@ VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5"
27
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
28
  vicuna_model = AutoModelForCausalLM.from_pretrained(
29
  VICUNA_MODEL_NAME,
30
- torch_dtype=torch.bfloat16,
31
  device_map="auto",
32
  )
33
 
@@ -45,14 +45,14 @@ def process_audio(microphone, state, task="transcribe"):
45
  prompt = f"{system_prompt}\nUser: {text}"
46
  with torch.no_grad():
47
  vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
48
- vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=256)
49
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
50
  vicuna_response = vicuna_response.replace(prompt, "").strip()
51
  updated_state = state + "\n" + vicuna_response
52
  try:
53
  with torch.no_grad():
54
  inputs = tts_processor(vicuna_response, return_tensors="pt").to('cuda')
55
- output = tts_model.generate(**inputs, do_sample=True)
56
  waveform_np = output[0].cpu().numpy()
57
  audio_output = (tts_model.generation_config.sample_rate, waveform_np)
58
  except Exception as e:
 
27
  vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
28
  vicuna_model = AutoModelForCausalLM.from_pretrained(
29
  VICUNA_MODEL_NAME,
30
+ torch_dtype=torch.float16,
31
  device_map="auto",
32
  )
33
 
 
45
  prompt = f"{system_prompt}\nUser: {text}"
46
  with torch.no_grad():
47
  vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to('cuda')
48
+ vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=192)
49
  vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
50
  vicuna_response = vicuna_response.replace(prompt, "").strip()
51
  updated_state = state + "\n" + vicuna_response
52
  try:
53
  with torch.no_grad():
54
  inputs = tts_processor(vicuna_response, return_tensors="pt").to('cuda')
55
+ output = tts_model.generate(**inputs, do_sample=False)
56
  waveform_np = output[0].cpu().numpy()
57
  audio_output = (tts_model.generation_config.sample_rate, waveform_np)
58
  except Exception as e: