3v324v23 commited on
Commit
ed3e9c8
·
1 Parent(s): 6d21994
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -234,7 +234,10 @@ def main():
234
 
235
  # Prepare the conditioning
236
  cond_tokens, cond_str_tokens = description2tokens(description, metadata.word2id , cfg)
237
- cond_tokens = torch.tensor(cond_tokens).long().cuda()
 
 
 
238
  if pointer_words is not None:
239
  numberical_conditioning = [float(description["cost_to_pointer"][key]) for key in pointer_words if key in description["cost_to_pointer"]]
240
  else:
 
234
 
235
  # Prepare the conditioning
236
  cond_tokens, cond_str_tokens = description2tokens(description, metadata.word2id , cfg)
237
+ if is_cuda:
238
+ cond_tokens = torch.tensor(cond_tokens).long().cuda()
239
+ else:
240
+ cond_tokens = torch.tensor(cond_tokens).long()
241
  if pointer_words is not None:
242
  numberical_conditioning = [float(description["cost_to_pointer"][key]) for key in pointer_words if key in description["cost_to_pointer"]]
243
  else: