Yadukrishnan commited on
Commit
f198d04
·
verified ·
1 Parent(s): 60b1e93

Create model_loader.py

Browse files
Files changed (1) hide show
  1. src/model_loader.py +20 -0
src/model_loader.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+
4
+ MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
5
+
6
+ cached_model = None
7
+ cached_tokenizer = None
8
+
9
+ def load_model():
10
+ global cached_model, cached_tokenizer
11
+ if cached_model is None or cached_tokenizer is None:
12
+ bnb_config = BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ bnb_4bit_use_double_quant=True,
15
+ bnb_4bit_quant_type="nf4",
16
+ bnb_4bit_compute_dtype=torch.bfloat16
17
+ )
18
+ cached_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb_config)
19
+ cached_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
+ return cached_model, cached_tokenizer