Spaces:
Runtime error
Runtime error
Update binoculars/detector.py
Browse files- binoculars/detector.py +11 -1
binoculars/detector.py
CHANGED
@@ -24,6 +24,7 @@ class Binoculars(object):
|
|
24 |
use_bfloat16: bool = True,
|
25 |
max_token_observed: int = 512,
|
26 |
mode: str = "low-fpr",
|
|
|
27 |
) -> None:
|
28 |
assert_tokenizer_consistency(observer_name_or_path, performer_name_or_path)
|
29 |
|
@@ -37,11 +38,20 @@ class Binoculars(object):
|
|
37 |
# Load models with memory-efficient settings
|
38 |
model_kwargs = {
|
39 |
"device_map": "auto",
|
40 |
-
"load_in_8bit": True,
|
41 |
"trust_remote_code": True,
|
42 |
"token": huggingface_config["TOKEN"]
|
43 |
}
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
self.observer_model = AutoModelForCausalLM.from_pretrained(observer_name_or_path, **model_kwargs)
|
46 |
self.performer_model = AutoModelForCausalLM.from_pretrained(performer_name_or_path, **model_kwargs)
|
47 |
|
|
|
24 |
use_bfloat16: bool = True,
|
25 |
max_token_observed: int = 512,
|
26 |
mode: str = "low-fpr",
|
27 |
+
quantize: bool = True
|
28 |
) -> None:
|
29 |
assert_tokenizer_consistency(observer_name_or_path, performer_name_or_path)
|
30 |
|
|
|
38 |
# Load models with memory-efficient settings
|
39 |
model_kwargs = {
|
40 |
"device_map": "auto",
|
|
|
41 |
"trust_remote_code": True,
|
42 |
"token": huggingface_config["TOKEN"]
|
43 |
}
|
44 |
|
45 |
+
if quantize:
|
46 |
+
try:
|
47 |
+
import bitsandbytes as bnb
|
48 |
+
model_kwargs["load_in_8bit"] = True
|
49 |
+
except ImportError:
|
50 |
+
print("bitsandbytes not available. Falling back to full precision.")
|
51 |
+
model_kwargs["torch_dtype"] = torch.bfloat16 if use_bfloat16 else torch.float32
|
52 |
+
else:
|
53 |
+
model_kwargs["torch_dtype"] = torch.bfloat16 if use_bfloat16 else torch.float32
|
54 |
+
|
55 |
self.observer_model = AutoModelForCausalLM.from_pretrained(observer_name_or_path, **model_kwargs)
|
56 |
self.performer_model = AutoModelForCausalLM.from_pretrained(performer_name_or_path, **model_kwargs)
|
57 |
|