noahlenz commited on
Commit
b178707
Β·
verified Β·
1 Parent(s): 09388fe

Update binoculars/detector.py

Browse files
Files changed (1) hide show
  1. binoculars/detector.py +11 -1
binoculars/detector.py CHANGED
@@ -24,6 +24,7 @@ class Binoculars(object):
24
  use_bfloat16: bool = True,
25
  max_token_observed: int = 512,
26
  mode: str = "low-fpr",
 
27
  ) -> None:
28
  assert_tokenizer_consistency(observer_name_or_path, performer_name_or_path)
29
 
@@ -37,11 +38,20 @@ class Binoculars(object):
37
  # Load models with memory-efficient settings
38
  model_kwargs = {
39
  "device_map": "auto",
40
- "load_in_8bit": True,
41
  "trust_remote_code": True,
42
  "token": huggingface_config["TOKEN"]
43
  }
44
 
 
 
 
 
 
 
 
 
 
 
45
  self.observer_model = AutoModelForCausalLM.from_pretrained(observer_name_or_path, **model_kwargs)
46
  self.performer_model = AutoModelForCausalLM.from_pretrained(performer_name_or_path, **model_kwargs)
47
 
 
24
  use_bfloat16: bool = True,
25
  max_token_observed: int = 512,
26
  mode: str = "low-fpr",
27
+ quantize: bool = True
28
  ) -> None:
29
  assert_tokenizer_consistency(observer_name_or_path, performer_name_or_path)
30
 
 
38
  # Load models with memory-efficient settings
39
  model_kwargs = {
40
  "device_map": "auto",
 
41
  "trust_remote_code": True,
42
  "token": huggingface_config["TOKEN"]
43
  }
44
 
45
+ if quantize:
46
+ try:
47
+ import bitsandbytes as bnb
48
+ model_kwargs["load_in_8bit"] = True
49
+ except ImportError:
50
+ print("bitsandbytes not available. Falling back to full precision.")
51
+ model_kwargs["torch_dtype"] = torch.bfloat16 if use_bfloat16 else torch.float32
52
+ else:
53
+ model_kwargs["torch_dtype"] = torch.bfloat16 if use_bfloat16 else torch.float32
54
+
55
  self.observer_model = AutoModelForCausalLM.from_pretrained(observer_name_or_path, **model_kwargs)
56
  self.performer_model = AutoModelForCausalLM.from_pretrained(performer_name_or_path, **model_kwargs)
57