electroglyph commited on
Commit
b480395
·
verified ·
1 Parent(s): 732bd9b

Delete benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +0 -292
benchmark.py DELETED
@@ -1,292 +0,0 @@
1
- import json
2
- import onnxruntime as rt
3
- import transformers
4
- from qdrant_client import QdrantClient, models
5
- import queue
6
- from threading import Thread, Lock
7
- import time
8
- from pyatomix import AtomicInt
9
-
10
- # adjust these settings as needed
11
- TOKENIZER_PATH = "."
12
- ORIG_MODEL_PATH = "model_uint8.onnx"
13
- ORIG_DATATYPE = models.Datatype.FLOAT32
14
- ORIG_COLLECTION_NAME = "baseline"
15
- COMPARE_MODEL_PATH = "snowflake2_m_uint8.onnx"
16
- COMPARE_DATATYPE = models.Datatype.UINT8
17
- COMPARE_COLLECTION_NAME = "compare"
18
- EMBEDDING_DIM = 768 # size of the model output
19
- STAT_RANGES = [
20
- 10,
21
- 20,
22
- 50,
23
- ] # stats will be calculated for each range: top 10, top 20, etc.
24
- STATS = {}
25
- STAT_LOCK = Lock()
26
- BATCH_SIZE = 1000 # this many token/id pairs will be processed at a time
27
- THREADS = 8 # number of threads to use
28
- # Qdrant client settings here
29
- CLIENT_URL = "http://127.0.0.1"
30
- CLIENT_PORT = 6333
31
- CLIENT_GRPC_PORT = 6334
32
- CLIENT_USE_GRPC = True
33
- FINISHED = AtomicInt(0)
34
-
35
-
36
- def collect_tokens() -> list[str] | None:
37
- print("Attempting to grab tokens from tokenizer...")
38
- with open(f"{TOKENIZER_PATH}/tokenizer.json", "r") as f:
39
- t = f.read()
40
- j = json.loads(t)
41
- v = j["model"]["vocab"]
42
- toks = [x[0] for x in v]
43
- print(f"Found {len(toks)} tokens.")
44
- return toks
45
-
46
-
47
- def init_worker(q: queue.Queue, model_path: str, collection_name: str):
48
- try:
49
- session = rt.InferenceSession(model_path, providers=["CPUExecutionProvider"])
50
- except Exception as e:
51
- print(f"Error loading ONNX model: {e}")
52
- return
53
- tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)
54
- client = QdrantClient(
55
- url=CLIENT_URL,
56
- port=CLIENT_PORT,
57
- grpc_port=CLIENT_GRPC_PORT,
58
- prefer_grpc=CLIENT_USE_GRPC,
59
- )
60
- global FINISHED
61
- while True:
62
- try:
63
- chunk = q.get(False)
64
- except queue.Empty:
65
- return
66
- batch = []
67
- for c in chunk:
68
- FINISHED += 1
69
- # c[0] == id, c[1] == token, we want id to always be associated with the same token across models
70
- enc = tokenizer(c[1]) # this could've been batched...
71
- embeddings = session.run(
72
- None,
73
- {
74
- "input_ids": [enc.input_ids],
75
- "attention_mask": [enc.attention_mask],
76
- },
77
- )
78
- batch.append( # [1][0] == sentence_embedding
79
- models.PointStruct(id=c[0], vector={"dense": embeddings[1][0]})
80
- )
81
- client.batch_update_points(
82
- collection_name=collection_name,
83
- update_operations=[models.UpsertOperation(upsert=models.PointsList(points=batch))],
84
- wait=False,
85
- )
86
-
87
-
88
- def init_collection(collection_name: str, model_path: str, datatype: models.Datatype) -> bool:
89
- client = QdrantClient(
90
- url=CLIENT_URL,
91
- port=CLIENT_PORT,
92
- grpc_port=CLIENT_GRPC_PORT,
93
- prefer_grpc=CLIENT_USE_GRPC,
94
- )
95
- if client.collection_exists(collection_name):
96
- info = client.get_collection(collection_name)
97
- print(f"Collection '{collection_name}' already exists, skipping init.")
98
- print(f"{info.points_count} points in collection.")
99
- return True
100
- res = client.create_collection(
101
- collection_name=collection_name,
102
- vectors_config={
103
- "dense": models.VectorParams(
104
- size=EMBEDDING_DIM,
105
- distance=models.Distance.COSINE,
106
- on_disk=False,
107
- datatype=datatype,
108
- ),
109
- },
110
- hnsw_config=models.HnswConfigDiff(m=0), # no index
111
- on_disk_payload=False,
112
- )
113
- if not res:
114
- print(f"Error creating collection.")
115
- return False
116
- else:
117
- print("Collection created.")
118
- toks = collect_tokens()
119
- FINISHED.store(0)
120
- if toks:
121
- ids = [x for x in range(len(toks))]
122
- # align Qdrant IDs with the token for later analysis
123
- pairs = list(zip(ids, toks))
124
- # lists of (Qdrant ID, token)
125
- chunks = [pairs[i : i + BATCH_SIZE] for i in range(0, len(pairs), BATCH_SIZE)]
126
- q = queue.Queue()
127
- for c in chunks:
128
- q.put(c)
129
- for _ in range(THREADS):
130
- t = Thread(target=init_worker, args=[q, model_path, collection_name])
131
- t.start()
132
- count = 0
133
- while FINISHED.load() < len(toks):
134
- time.sleep(0.5)
135
- count += 1
136
- if count == 20: # update every 10 seconds or so
137
- print(f"approximately {q.qsize() * BATCH_SIZE} items left in queue...")
138
- count = 0
139
- print(f"Done with collection init, {len(toks)} tokens upserted.")
140
- # enable indexing
141
- client.update_collection(collection_name=collection_name, hnsw_config=models.HnswConfigDiff(m=16))
142
- return True
143
- else:
144
- print("Failed to grab tokens from tokenizer.")
145
- return False
146
-
147
-
148
- def count_mismatches(list1, list2) -> int:
149
- count = 0
150
- assert len(list1) == len(list2)
151
- for i in range(len(list1)):
152
- if list1[i] != list2[i]:
153
- count += 1
154
- return count
155
-
156
-
157
- def score_results(
158
- list1: list,
159
- list2: list,
160
- ):
161
- assert len(list1) == len(list2)
162
- global STATS
163
- for x in STAT_RANGES:
164
- with STAT_LOCK:
165
- # STATS = { range, {"exact": AtomicInt, ... }}
166
- d = STATS.get(x)
167
- if d is None:
168
- d = {
169
- "exact": AtomicInt(0),
170
- "off_by_1": AtomicInt(0),
171
- "off_by_2": AtomicInt(0),
172
- "off_by_3": AtomicInt(0),
173
- "off_by_4": AtomicInt(0),
174
- "off_by_5": AtomicInt(0),
175
- "missing": AtomicInt(0),
176
- }
177
- STATS[x] = d
178
- for i in range(x):
179
- if list1[i] == list2[i]:
180
- d["exact"] += 1
181
- else:
182
- if list1[i] in list2:
183
- i2 = list2.index(list1[i])
184
- val = abs(i2 - i)
185
- if val == 1:
186
- d["off_by_1"] += 1
187
- elif val == 2:
188
- d["off_by_2"] += 1
189
- elif val == 3:
190
- d["off_by_3"] += 1
191
- elif val == 4:
192
- d["off_by_4"] += 1
193
- else:
194
- d["off_by_5"] += 1
195
- else:
196
- d["missing"] += 1
197
-
198
-
199
- def main_worker(q: queue.Queue, limit: int):
200
- global FINISHED
201
- tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)
202
- orig_session = rt.InferenceSession(ORIG_MODEL_PATH, providers=["CPUExecutionProvider"])
203
- compare_session = rt.InferenceSession(COMPARE_MODEL_PATH, providers=["CPUExecutionProvider"])
204
- client = QdrantClient(
205
- url=CLIENT_URL,
206
- port=CLIENT_PORT,
207
- grpc_port=CLIENT_GRPC_PORT,
208
- prefer_grpc=CLIENT_USE_GRPC,
209
- )
210
- while True:
211
- try:
212
- chunk = q.get(False)
213
- except queue.Empty:
214
- return
215
- # c[0] == id, c[1] == token, we want id to always be associated with the same token across models
216
- for c in chunk:
217
- enc = tokenizer(c)
218
- oe = orig_session.run(
219
- None,
220
- {"input_ids": [enc.input_ids], "attention_mask": [enc.attention_mask]},
221
- )
222
- ce = compare_session.run(
223
- None,
224
- {"input_ids": [enc.input_ids], "attention_mask": [enc.attention_mask]},
225
- )
226
- oresult = client.query_points(
227
- collection_name=ORIG_COLLECTION_NAME,
228
- using="dense",
229
- query=oe[1][0],
230
- limit=limit + 5, # for our scoring metric we want to look slightly past the end
231
- )
232
- cresult = client.query_points(
233
- collection_name=COMPARE_COLLECTION_NAME,
234
- using="dense",
235
- query=ce[1][0],
236
- limit=limit + 5,
237
- )
238
- oids = [p.id for p in oresult.points]
239
- cids = [p.id for p in cresult.points]
240
- score_results(
241
- oids,
242
- cids,
243
- )
244
- FINISHED += 1
245
-
246
-
247
- def main():
248
- if not init_collection(ORIG_COLLECTION_NAME, ORIG_MODEL_PATH, ORIG_DATATYPE):
249
- print("Failed to initialize original model values, exiting.")
250
- return
251
- if not init_collection(COMPARE_COLLECTION_NAME, COMPARE_MODEL_PATH, COMPARE_DATATYPE):
252
- print("Failed to initialize secondary model values, exiting.")
253
- return
254
- toks = collect_tokens()
255
- limit = 0
256
- for x in STAT_RANGES:
257
- if x > limit:
258
- limit = x
259
- FINISHED.store(0)
260
- if toks:
261
- chunks = [toks[i : i + BATCH_SIZE] for i in range(0, len(toks), BATCH_SIZE)]
262
- q = queue.Queue()
263
- for c in chunks:
264
- q.put(c)
265
- print("Starting analysis.")
266
- for _ in range(THREADS):
267
- t = Thread(
268
- target=main_worker,
269
- args=[q, limit],
270
- )
271
- t.start()
272
- count = 0
273
- while FINISHED.load() < len(toks):
274
- time.sleep(0.5)
275
- count += 1
276
- if count == 20: # update every 10 seconds or so
277
- print(f"approximately {q.qsize() * BATCH_SIZE} items left in queue...")
278
- count = 0
279
- print(f"Done with analysis.")
280
- with STAT_LOCK:
281
- for k, v in STATS.items():
282
- print(f"Stats for top {k} query results across entire token range:")
283
- print(f"exact : {(float(v["exact"].load()) / (len(toks) * k)) * 100:.2f}%")
284
- print(f"off by 1 : {(float(v["off_by_1"].load()) / (len(toks) * k)) * 100:.2f}%")
285
- print(f"off by 2 : {(float(v["off_by_2"].load()) / (len(toks) * k)) * 100:.2f}%")
286
- print(f"off by 3 : {(float(v["off_by_3"].load()) / (len(toks) * k)) * 100:.2f}%")
287
- print(f"off by 4 : {(float(v["off_by_4"].load()) / (len(toks) * k)) * 100:.2f}%")
288
- print(f"off by 5+: {(float(v["off_by_5"].load()) / (len(toks) * k)) * 100:.2f}%")
289
- print(f"missing : {(float(v["missing"].load()) / (len(toks) * k)) * 100:.2f}%\n")
290
-
291
-
292
- main()