noahlenz commited on
Commit
c939101
Β·
verified Β·
1 Parent(s): b178707

Update binoculars/detector.py

Browse files
Files changed (1) hide show
  1. binoculars/detector.py +6 -6
binoculars/detector.py CHANGED
@@ -99,10 +99,10 @@ class Binoculars(object):
99
  binoculars_scores = binoculars_scores.tolist()
100
  return binoculars_scores[0] if isinstance(input_text, str) else binoculars_scores
101
 
102
- def predict(self, input_text: Union[list[str], str]) -> Union[list[str], str]:
103
  binoculars_scores = np.array(self.compute_score(input_text))
104
- pred = np.where(binoculars_scores < self.threshold,
105
- "Most likely AI-generated",
106
- "Most likely human-generated"
107
- ).tolist()
108
- return pred[0] if isinstance(input_text, str) else pred
 
99
  binoculars_scores = binoculars_scores.tolist()
100
  return binoculars_scores[0] if isinstance(input_text, str) else binoculars_scores
101
 
102
+ def predict(self, input_text: Union[list[str], str]) -> Union[str, list[str]]:
103
  binoculars_scores = np.array(self.compute_score(input_text))
104
+
105
+ if isinstance(input_text, str):
106
+ return "Most likely AI-generated" if binoculars_scores < self.threshold else "Most likely human-generated"
107
+ else:
108
+ return ["Most likely AI-generated" if score < self.threshold else "Most likely human-generated" for score in binoculars_scores]