3v324v23 commited on
Commit
ddc036e
·
1 Parent(s): ed3e9c8
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -242,7 +242,11 @@ def main():
242
  numberical_conditioning = [float(description["cost_to_pointer"][key]) for key in pointer_words if key in description["cost_to_pointer"]]
243
  else:
244
  numberical_conditioning = []
245
- conditioning = {"symbolic_conditioning": cond_tokens, "numerical_conditioning": torch.tensor(numberical_conditioning,device="cuda").float()}
 
 
 
 
246
  #conditioning = {"symbolic_conditioning": torch.tensor([1,2],device="cuda").long(), "numerical_conditioning": torch.tensor([],device="cuda").float()}
247
 
248
  st.markdown("#### NSR")
 
242
  numberical_conditioning = [float(description["cost_to_pointer"][key]) for key in pointer_words if key in description["cost_to_pointer"]]
243
  else:
244
  numberical_conditioning = []
245
+
246
+ if is_cuda:
247
+ conditioning = {"symbolic_conditioning": cond_tokens, "numerical_conditioning": torch.tensor(numberical_conditioning,device="cuda").float()}
248
+ else:
249
+ conditioning = {"symbolic_conditioning": cond_tokens, "numerical_conditioning": torch.tensor(numberical_conditioning).float()}
250
  #conditioning = {"symbolic_conditioning": torch.tensor([1,2],device="cuda").long(), "numerical_conditioning": torch.tensor([],device="cuda").float()}
251
 
252
  st.markdown("#### NSR")