Spaces:
Runtime error
Runtime error
Update binoculars/detector.py
Browse files- binoculars/detector.py +16 -25
binoculars/detector.py
CHANGED
@@ -15,9 +15,7 @@ torch.set_grad_enabled(False)
|
|
15 |
BINOCULARS_ACCURACY_THRESHOLD = 0.9015310749276843 # optimized for f1-score
|
16 |
BINOCULARS_FPR_THRESHOLD = 0.8536432310785527 # optimized for low-fpr
|
17 |
|
18 |
-
|
19 |
-
DEVICE_2 = "cuda:1" if torch.cuda.device_count() > 1 else DEVICE_1
|
20 |
-
|
21 |
|
22 |
class Binoculars(object):
|
23 |
def __init__(self,
|
@@ -36,20 +34,16 @@ class Binoculars(object):
|
|
36 |
else:
|
37 |
raise ValueError(f"Invalid mode: {mode}")
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
torch_dtype=torch.bfloat16 if use_bfloat16
|
50 |
-
else torch.float32,
|
51 |
-
token=huggingface_config["TOKEN"]
|
52 |
-
)
|
53 |
|
54 |
self.observer_model.eval()
|
55 |
self.performer_model.eval()
|
@@ -76,15 +70,13 @@ class Binoculars(object):
|
|
76 |
padding="longest" if batch_size > 1 else False,
|
77 |
truncation=True,
|
78 |
max_length=self.max_token_observed,
|
79 |
-
return_token_type_ids=False)
|
80 |
return encodings
|
81 |
|
82 |
@torch.inference_mode()
|
83 |
def _get_logits(self, encodings: transformers.BatchEncoding) -> torch.Tensor:
|
84 |
-
observer_logits = self.observer_model(**encodings
|
85 |
-
performer_logits = self.performer_model(**encodings
|
86 |
-
if DEVICE_1 != "cpu":
|
87 |
-
torch.cuda.synchronize()
|
88 |
return observer_logits, performer_logits
|
89 |
|
90 |
def compute_score(self, input_text: Union[list[str], str]) -> Union[float, list[float]]:
|
@@ -92,8 +84,7 @@ class Binoculars(object):
|
|
92 |
encodings = self._tokenize(batch)
|
93 |
observer_logits, performer_logits = self._get_logits(encodings)
|
94 |
ppl = perplexity(encodings, performer_logits)
|
95 |
-
x_ppl = entropy(observer_logits
|
96 |
-
encodings.to(DEVICE_1), self.tokenizer.pad_token_id)
|
97 |
binoculars_scores = ppl / x_ppl
|
98 |
binoculars_scores = binoculars_scores.tolist()
|
99 |
return binoculars_scores[0] if isinstance(input_text, str) else binoculars_scores
|
@@ -104,4 +95,4 @@ class Binoculars(object):
|
|
104 |
"Most likely AI-generated",
|
105 |
"Most likely human-generated"
|
106 |
).tolist()
|
107 |
-
return pred
|
|
|
15 |
BINOCULARS_ACCURACY_THRESHOLD = 0.9015310749276843 # optimized for f1-score
|
16 |
BINOCULARS_FPR_THRESHOLD = 0.8536432310785527 # optimized for low-fpr
|
17 |
|
18 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
19 |
|
20 |
class Binoculars(object):
|
21 |
def __init__(self,
|
|
|
34 |
else:
|
35 |
raise ValueError(f"Invalid mode: {mode}")
|
36 |
|
37 |
+
# Load models with memory-efficient settings
|
38 |
+
model_kwargs = {
|
39 |
+
"device_map": "auto",
|
40 |
+
"load_in_8bit": True,
|
41 |
+
"trust_remote_code": True,
|
42 |
+
"token": huggingface_config["TOKEN"]
|
43 |
+
}
|
44 |
+
|
45 |
+
self.observer_model = AutoModelForCausalLM.from_pretrained(observer_name_or_path, **model_kwargs)
|
46 |
+
self.performer_model = AutoModelForCausalLM.from_pretrained(performer_name_or_path, **model_kwargs)
|
|
|
|
|
|
|
|
|
47 |
|
48 |
self.observer_model.eval()
|
49 |
self.performer_model.eval()
|
|
|
70 |
padding="longest" if batch_size > 1 else False,
|
71 |
truncation=True,
|
72 |
max_length=self.max_token_observed,
|
73 |
+
return_token_type_ids=False)
|
74 |
return encodings
|
75 |
|
76 |
@torch.inference_mode()
|
77 |
def _get_logits(self, encodings: transformers.BatchEncoding) -> torch.Tensor:
|
78 |
+
observer_logits = self.observer_model(**encodings).logits
|
79 |
+
performer_logits = self.performer_model(**encodings).logits
|
|
|
|
|
80 |
return observer_logits, performer_logits
|
81 |
|
82 |
def compute_score(self, input_text: Union[list[str], str]) -> Union[float, list[float]]:
|
|
|
84 |
encodings = self._tokenize(batch)
|
85 |
observer_logits, performer_logits = self._get_logits(encodings)
|
86 |
ppl = perplexity(encodings, performer_logits)
|
87 |
+
x_ppl = entropy(observer_logits, performer_logits, encodings, self.tokenizer.pad_token_id)
|
|
|
88 |
binoculars_scores = ppl / x_ppl
|
89 |
binoculars_scores = binoculars_scores.tolist()
|
90 |
return binoculars_scores[0] if isinstance(input_text, str) else binoculars_scores
|
|
|
95 |
"Most likely AI-generated",
|
96 |
"Most likely human-generated"
|
97 |
).tolist()
|
98 |
+
return pred[0] if isinstance(input_text, str) else pred
|