tomaarsen HF Staff commited on
Commit
d7354d1
·
verified ·
1 Parent(s): dc8e28d

Add new CrossEncoder model

Browse files
Files changed (9) hide show
  1. README.md +396 -0
  2. config.json +60 -0
  3. model.py +420 -0
  4. model.safetensors +3 -0
  5. rotary.py +61 -0
  6. special_tokens_map.json +37 -0
  7. tokenizer.json +0 -0
  8. tokenizer_config.json +61 -0
  9. vocab.txt +0 -0
README.md ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ tags:
6
+ - sentence-transformers
7
+ - cross-encoder
8
+ - text-classification
9
+ - generated_from_trainer
10
+ - dataset_size:578402
11
+ - loss:BinaryCrossEntropyLoss
12
+ pipeline_tag: text-classification
13
+ library_name: sentence-transformers
14
+ metrics:
15
+ - map
16
+ - mrr@10
17
+ - ndcg@10
18
+ model-index:
19
+ - name: NeoBERT-medium trained on GooAQ
20
+ results: []
21
+ ---
22
+
23
+ # NeoBERT-medium trained on GooAQ
24
+
25
+ This is a [Cross Encoder](https://www.sbert.net/docs/cross_encoder/usage/usage.html) model trained using the [sentence-transformers](https://www.SBERT.net) library. It computes scores for pairs of texts, which can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
26
+
27
+ ## Model Details
28
+
29
+ ### Model Description
30
+ - **Model Type:** Cross Encoder
31
+ <!-- - **Base model:** [Unknown](https://huggingface.co/unknown) -->
32
+ - **Maximum Sequence Length:** 4096 tokens
33
+ - **Number of Output Labels:** 1 label
34
+ <!-- - **Training Dataset:** Unknown -->
35
+ - **Language:** en
36
+ - **License:** apache-2.0
37
+
38
+ ### Model Sources
39
+
40
+ - **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
41
+ - **Documentation:** [Cross Encoder Documentation](https://www.sbert.net/docs/cross_encoder/usage/usage.html)
42
+ - **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
43
+ - **Hugging Face:** [Cross Encoders on Hugging Face](https://huggingface.co/models?library=sentence-transformers&other=cross-encoder)
44
+
45
+ ## Usage
46
+
47
+ ### Direct Usage (Sentence Transformers)
48
+
49
+ First install the Sentence Transformers library:
50
+
51
+ ```bash
52
+ pip install -U sentence-transformers
53
+ ```
54
+
55
+ Then you can load this model and run inference.
56
+ ```python
57
+ from sentence_transformers import CrossEncoder
58
+
59
+ # Download from the 🤗 Hub
60
+ model = CrossEncoder("tomaarsen/reranker-NeoBERT-gooaq-bce")
61
+ # Get scores for pairs of texts
62
+ pairs = [
63
+ ['what are the signs of a bad yeast infection?', "['Itching and irritation in the vagina and vulva.', 'A burning sensation, especially during intercourse or while urinating.', 'Redness and swelling of the vulva.', 'Vaginal pain and soreness.', 'Vaginal rash.', 'Thick, white, odor-free vaginal discharge with a cottage cheese appearance.', 'Watery vaginal discharge.']"],
64
+ ['what are the signs of a bad yeast infection?', 'Vaginal yeast infections can cause: itching and irritation in the vagina. redness, swelling, or itching of the vulva (the folds of skin outside the vagina) a thick, white discharge that can look like cottage cheese and is usually odorless, although it might smell like bread or yeast.'],
65
+ ['what are the signs of a bad yeast infection?', 'It can feel like itching or maybe even burning. Or you may experience swelling so extreme, it leads to sores. Whether your symptoms are mild or severe, a yeast infection can be uncomfortable. Also known as vaginal candidiasis, yeast infections are caused by a fungus.'],
66
+ ['what are the signs of a bad yeast infection?', 'Complications of untreated yeast infections If left untreated, vaginal candidiasis will most likely get worse, causing itching, redness, and inflammation in the area surrounding your vagina. This may lead to a skin infection if the inflamed area becomes cracked, or if continual scratching creates open or raw areas.'],
67
+ ['what are the signs of a bad yeast infection?', "Drinking alcohol may also put you at greater risk for yeast infections. So if you're worried about yeast infection symptoms, consider curbing your cocktails. Eating only yeast-free foods is one way some women try to control yeast infections."],
68
+ ]
69
+ scores = model.predict(pairs)
70
+ print(scores.shape)
71
+ # (5,)
72
+
73
+ # Or rank different texts based on similarity to a single text
74
+ ranks = model.rank(
75
+ 'what are the signs of a bad yeast infection?',
76
+ [
77
+ "['Itching and irritation in the vagina and vulva.', 'A burning sensation, especially during intercourse or while urinating.', 'Redness and swelling of the vulva.', 'Vaginal pain and soreness.', 'Vaginal rash.', 'Thick, white, odor-free vaginal discharge with a cottage cheese appearance.', 'Watery vaginal discharge.']",
78
+ 'Vaginal yeast infections can cause: itching and irritation in the vagina. redness, swelling, or itching of the vulva (the folds of skin outside the vagina) a thick, white discharge that can look like cottage cheese and is usually odorless, although it might smell like bread or yeast.',
79
+ 'It can feel like itching or maybe even burning. Or you may experience swelling so extreme, it leads to sores. Whether your symptoms are mild or severe, a yeast infection can be uncomfortable. Also known as vaginal candidiasis, yeast infections are caused by a fungus.',
80
+ 'Complications of untreated yeast infections If left untreated, vaginal candidiasis will most likely get worse, causing itching, redness, and inflammation in the area surrounding your vagina. This may lead to a skin infection if the inflamed area becomes cracked, or if continual scratching creates open or raw areas.',
81
+ "Drinking alcohol may also put you at greater risk for yeast infections. So if you're worried about yeast infection symptoms, consider curbing your cocktails. Eating only yeast-free foods is one way some women try to control yeast infections.",
82
+ ]
83
+ )
84
+ # [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]
85
+ ```
86
+
87
+ <!--
88
+ ### Direct Usage (Transformers)
89
+
90
+ <details><summary>Click to see the direct usage in Transformers</summary>
91
+
92
+ </details>
93
+ -->
94
+
95
+ <!--
96
+ ### Downstream Usage (Sentence Transformers)
97
+
98
+ You can finetune this model on your own dataset.
99
+
100
+ <details><summary>Click to expand</summary>
101
+
102
+ </details>
103
+ -->
104
+
105
+ <!--
106
+ ### Out-of-Scope Use
107
+
108
+ *List how the model may foreseeably be misused and address what users ought not to do with the model.*
109
+ -->
110
+
111
+ ## Evaluation
112
+
113
+ ### Metrics
114
+
115
+ #### Cross Encoder Reranking
116
+
117
+ * Dataset: `gooaq-dev`
118
+ * Evaluated with [<code>CrossEncoderRerankingEvaluator</code>](https://sbert.net/docs/package_reference/cross_encoder/evaluation.html#sentence_transformers.cross_encoder.evaluation.CrossEncoderRerankingEvaluator)
119
+
120
+ | Metric | Value |
121
+ |:------------|:---------------------|
122
+ | map | 0.8039 (+0.2728) |
123
+ | mrr@10 | 0.8028 (+0.2789) |
124
+ | **ndcg@10** | **0.8475 (+0.2562)** |
125
+
126
+ <!--
127
+ ## Bias, Risks and Limitations
128
+
129
+ *What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
130
+ -->
131
+
132
+ <!--
133
+ ### Recommendations
134
+
135
+ *What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
136
+ -->
137
+
138
+ ## Training Details
139
+
140
+ ### Training Dataset
141
+
142
+ #### Unnamed Dataset
143
+
144
+ * Size: 578,402 training samples
145
+ * Columns: <code>question</code>, <code>answer</code>, and <code>label</code>
146
+ * Approximate statistics based on the first 1000 samples:
147
+ | | question | answer | label |
148
+ |:--------|:-----------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------|:------------------------------------------------|
149
+ | type | string | string | int |
150
+ | details | <ul><li>min: 21 characters</li><li>mean: 43.81 characters</li><li>max: 91 characters</li></ul> | <ul><li>min: 51 characters</li><li>mean: 251.2 characters</li><li>max: 365 characters</li></ul> | <ul><li>0: ~82.90%</li><li>1: ~17.10%</li></ul> |
151
+ * Samples:
152
+ | question | answer | label |
153
+ |:----------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:---------------|
154
+ | <code>what are the signs of a bad yeast infection?</code> | <code>['Itching and irritation in the vagina and vulva.', 'A burning sensation, especially during intercourse or while urinating.', 'Redness and swelling of the vulva.', 'Vaginal pain and soreness.', 'Vaginal rash.', 'Thick, white, odor-free vaginal discharge with a cottage cheese appearance.', 'Watery vaginal discharge.']</code> | <code>1</code> |
155
+ | <code>what are the signs of a bad yeast infection?</code> | <code>Vaginal yeast infections can cause: itching and irritation in the vagina. redness, swelling, or itching of the vulva (the folds of skin outside the vagina) a thick, white discharge that can look like cottage cheese and is usually odorless, although it might smell like bread or yeast.</code> | <code>0</code> |
156
+ | <code>what are the signs of a bad yeast infection?</code> | <code>It can feel like itching or maybe even burning. Or you may experience swelling so extreme, it leads to sores. Whether your symptoms are mild or severe, a yeast infection can be uncomfortable. Also known as vaginal candidiasis, yeast infections are caused by a fungus.</code> | <code>0</code> |
157
+ * Loss: [<code>BinaryCrossEntropyLoss</code>](https://sbert.net/docs/package_reference/cross_encoder/losses.html#binarycrossentropyloss) with these parameters:
158
+ ```json
159
+ {
160
+ "activation_fct": "torch.nn.modules.linear.Identity",
161
+ "pos_weight": 5
162
+ }
163
+ ```
164
+
165
+ ### Training Hyperparameters
166
+ #### Non-Default Hyperparameters
167
+
168
+ - `eval_strategy`: steps
169
+ - `per_device_train_batch_size`: 64
170
+ - `per_device_eval_batch_size`: 64
171
+ - `learning_rate`: 2e-05
172
+ - `num_train_epochs`: 1
173
+ - `warmup_ratio`: 0.1
174
+ - `seed`: 12
175
+ - `bf16`: True
176
+ - `dataloader_num_workers`: 4
177
+ - `load_best_model_at_end`: True
178
+
179
+ #### All Hyperparameters
180
+ <details><summary>Click to expand</summary>
181
+
182
+ - `overwrite_output_dir`: False
183
+ - `do_predict`: False
184
+ - `eval_strategy`: steps
185
+ - `prediction_loss_only`: True
186
+ - `per_device_train_batch_size`: 64
187
+ - `per_device_eval_batch_size`: 64
188
+ - `per_gpu_train_batch_size`: None
189
+ - `per_gpu_eval_batch_size`: None
190
+ - `gradient_accumulation_steps`: 1
191
+ - `eval_accumulation_steps`: None
192
+ - `torch_empty_cache_steps`: None
193
+ - `learning_rate`: 2e-05
194
+ - `weight_decay`: 0.0
195
+ - `adam_beta1`: 0.9
196
+ - `adam_beta2`: 0.999
197
+ - `adam_epsilon`: 1e-08
198
+ - `max_grad_norm`: 1.0
199
+ - `num_train_epochs`: 1
200
+ - `max_steps`: -1
201
+ - `lr_scheduler_type`: linear
202
+ - `lr_scheduler_kwargs`: {}
203
+ - `warmup_ratio`: 0.1
204
+ - `warmup_steps`: 0
205
+ - `log_level`: passive
206
+ - `log_level_replica`: warning
207
+ - `log_on_each_node`: True
208
+ - `logging_nan_inf_filter`: True
209
+ - `save_safetensors`: True
210
+ - `save_on_each_node`: False
211
+ - `save_only_model`: False
212
+ - `restore_callback_states_from_checkpoint`: False
213
+ - `no_cuda`: False
214
+ - `use_cpu`: False
215
+ - `use_mps_device`: False
216
+ - `seed`: 12
217
+ - `data_seed`: None
218
+ - `jit_mode_eval`: False
219
+ - `use_ipex`: False
220
+ - `bf16`: True
221
+ - `fp16`: False
222
+ - `fp16_opt_level`: O1
223
+ - `half_precision_backend`: auto
224
+ - `bf16_full_eval`: False
225
+ - `fp16_full_eval`: False
226
+ - `tf32`: None
227
+ - `local_rank`: 0
228
+ - `ddp_backend`: None
229
+ - `tpu_num_cores`: None
230
+ - `tpu_metrics_debug`: False
231
+ - `debug`: []
232
+ - `dataloader_drop_last`: False
233
+ - `dataloader_num_workers`: 4
234
+ - `dataloader_prefetch_factor`: None
235
+ - `past_index`: -1
236
+ - `disable_tqdm`: False
237
+ - `remove_unused_columns`: True
238
+ - `label_names`: None
239
+ - `load_best_model_at_end`: True
240
+ - `ignore_data_skip`: False
241
+ - `fsdp`: []
242
+ - `fsdp_min_num_params`: 0
243
+ - `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
244
+ - `fsdp_transformer_layer_cls_to_wrap`: None
245
+ - `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
246
+ - `deepspeed`: None
247
+ - `label_smoothing_factor`: 0.0
248
+ - `optim`: adamw_torch
249
+ - `optim_args`: None
250
+ - `adafactor`: False
251
+ - `group_by_length`: False
252
+ - `length_column_name`: length
253
+ - `ddp_find_unused_parameters`: None
254
+ - `ddp_bucket_cap_mb`: None
255
+ - `ddp_broadcast_buffers`: False
256
+ - `dataloader_pin_memory`: True
257
+ - `dataloader_persistent_workers`: False
258
+ - `skip_memory_metrics`: True
259
+ - `use_legacy_prediction_loop`: False
260
+ - `push_to_hub`: False
261
+ - `resume_from_checkpoint`: None
262
+ - `hub_model_id`: None
263
+ - `hub_strategy`: every_save
264
+ - `hub_private_repo`: None
265
+ - `hub_always_push`: False
266
+ - `gradient_checkpointing`: False
267
+ - `gradient_checkpointing_kwargs`: None
268
+ - `include_inputs_for_metrics`: False
269
+ - `include_for_metrics`: []
270
+ - `eval_do_concat_batches`: True
271
+ - `fp16_backend`: auto
272
+ - `push_to_hub_model_id`: None
273
+ - `push_to_hub_organization`: None
274
+ - `mp_parameters`:
275
+ - `auto_find_batch_size`: False
276
+ - `full_determinism`: False
277
+ - `torchdynamo`: None
278
+ - `ray_scope`: last
279
+ - `ddp_timeout`: 1800
280
+ - `torch_compile`: False
281
+ - `torch_compile_backend`: None
282
+ - `torch_compile_mode`: None
283
+ - `dispatch_batches`: None
284
+ - `split_batches`: None
285
+ - `include_tokens_per_second`: False
286
+ - `include_num_input_tokens_seen`: False
287
+ - `neftune_noise_alpha`: None
288
+ - `optim_target_modules`: None
289
+ - `batch_eval_metrics`: False
290
+ - `eval_on_start`: False
291
+ - `use_liger_kernel`: False
292
+ - `eval_use_gather_object`: False
293
+ - `average_tokens_across_devices`: False
294
+ - `prompts`: None
295
+ - `batch_sampler`: batch_sampler
296
+ - `multi_dataset_batch_sampler`: proportional
297
+
298
+ </details>
299
+
300
+ ### Training Logs
301
+ | Epoch | Step | Training Loss | gooaq-dev_ndcg@10 |
302
+ |:----------:|:--------:|:-------------:|:--------------------:|
303
+ | -1 | -1 | - | 0.1489 (-0.4423) |
304
+ | 0.0001 | 1 | 1.328 | - |
305
+ | 0.0221 | 200 | 1.1586 | - |
306
+ | 0.0443 | 400 | 0.7765 | - |
307
+ | 0.0664 | 600 | 0.651 | - |
308
+ | 0.0885 | 800 | 0.6165 | - |
309
+ | 0.1106 | 1000 | 0.6434 | 0.7674 (+0.1762) |
310
+ | 0.1328 | 1200 | 0.5952 | - |
311
+ | 0.1549 | 1400 | 0.573 | - |
312
+ | 0.1770 | 1600 | 0.5538 | - |
313
+ | 0.1992 | 1800 | 0.5492 | - |
314
+ | 0.2213 | 2000 | 0.5452 | 0.8095 (+0.2182) |
315
+ | 0.2434 | 2200 | 0.5325 | - |
316
+ | 0.2655 | 2400 | 0.5178 | - |
317
+ | 0.2877 | 2600 | 0.5233 | - |
318
+ | 0.3098 | 2800 | 0.5079 | - |
319
+ | 0.3319 | 3000 | 0.5084 | 0.8178 (+0.2266) |
320
+ | 0.3541 | 3200 | 0.5104 | - |
321
+ | 0.3762 | 3400 | 0.5053 | - |
322
+ | 0.3983 | 3600 | 0.4892 | - |
323
+ | 0.4204 | 3800 | 0.4879 | - |
324
+ | 0.4426 | 4000 | 0.4969 | 0.8260 (+0.2348) |
325
+ | 0.4647 | 4200 | 0.492 | - |
326
+ | 0.4868 | 4400 | 0.4798 | - |
327
+ | 0.5090 | 4600 | 0.4708 | - |
328
+ | 0.5311 | 4800 | 0.4638 | - |
329
+ | 0.5532 | 5000 | 0.4746 | 0.8286 (+0.2374) |
330
+ | 0.5753 | 5200 | 0.4467 | - |
331
+ | 0.5975 | 5400 | 0.4615 | - |
332
+ | 0.6196 | 5600 | 0.452 | - |
333
+ | 0.6417 | 5800 | 0.4632 | - |
334
+ | 0.6639 | 6000 | 0.4517 | 0.8290 (+0.2378) |
335
+ | 0.6860 | 6200 | 0.447 | - |
336
+ | 0.7081 | 6400 | 0.4581 | - |
337
+ | 0.7303 | 6600 | 0.4521 | - |
338
+ | 0.7524 | 6800 | 0.4461 | - |
339
+ | 0.7745 | 7000 | 0.4418 | 0.8372 (+0.2459) |
340
+ | 0.7966 | 7200 | 0.4279 | - |
341
+ | 0.8188 | 7400 | 0.4136 | - |
342
+ | 0.8409 | 7600 | 0.4163 | - |
343
+ | 0.8630 | 7800 | 0.4099 | - |
344
+ | 0.8852 | 8000 | 0.4156 | 0.8431 (+0.2518) |
345
+ | 0.9073 | 8200 | 0.4146 | - |
346
+ | 0.9294 | 8400 | 0.4264 | - |
347
+ | 0.9515 | 8600 | 0.4261 | - |
348
+ | 0.9737 | 8800 | 0.4145 | - |
349
+ | **0.9958** | **9000** | **0.4219** | **0.8475 (+0.2562)** |
350
+ | -1 | -1 | - | 0.8475 (+0.2562) |
351
+
352
+ * The bold row denotes the saved checkpoint.
353
+
354
+ ### Framework Versions
355
+ - Python: 3.11.10
356
+ - Sentence Transformers: 3.5.0.dev0
357
+ - Transformers: 4.49.0
358
+ - PyTorch: 2.5.1+cu124
359
+ - Accelerate: 1.2.0
360
+ - Datasets: 2.21.0
361
+ - Tokenizers: 0.21.0
362
+
363
+ ## Citation
364
+
365
+ ### BibTeX
366
+
367
+ #### Sentence Transformers
368
+ ```bibtex
369
+ @inproceedings{reimers-2019-sentence-bert,
370
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
371
+ author = "Reimers, Nils and Gurevych, Iryna",
372
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
373
+ month = "11",
374
+ year = "2019",
375
+ publisher = "Association for Computational Linguistics",
376
+ url = "https://arxiv.org/abs/1908.10084",
377
+ }
378
+ ```
379
+
380
+ <!--
381
+ ## Glossary
382
+
383
+ *Clearly define terms in order to be accessible across audiences.*
384
+ -->
385
+
386
+ <!--
387
+ ## Model Card Authors
388
+
389
+ *Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
390
+ -->
391
+
392
+ <!--
393
+ ## Model Card Contact
394
+
395
+ *Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
396
+ -->
config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "../models/NeoBERT",
3
+ "architectures": [
4
+ "NeoBERTForSequenceClassification"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "model.NeoBERTConfig",
8
+ "AutoModel": "model.NeoBERTLMHead",
9
+ "AutoModelForMaskedLM": "model.NeoBERTLMHead",
10
+ "AutoModelForSequenceClassification": "chandar-lab/NeoBERT--model.NeoBERTForSequenceClassification"
11
+ },
12
+ "classifier_init_range": 0.02,
13
+ "decoder_init_range": 0.02,
14
+ "dim_head": 64,
15
+ "embedding_init_range": 0.02,
16
+ "hidden_size": 768,
17
+ "id2label": {
18
+ "0": "LABEL_0"
19
+ },
20
+ "intermediate_size": 3072,
21
+ "kwargs": {
22
+ "_commit_hash": null,
23
+ "architectures": [
24
+ "NeoBERTLMHead"
25
+ ],
26
+ "attn_implementation": null,
27
+ "auto_map": {
28
+ "AutoConfig": "model.NeoBERTConfig",
29
+ "AutoModel": "model.NeoBERTLMHead",
30
+ "AutoModelForMaskedLM": "model.NeoBERTLMHead",
31
+ "AutoModelForSequenceClassification": "chandar-lab/NeoBERT--model.NeoBERTForSequenceClassification"
32
+ },
33
+ "classifier_init_range": 0.02,
34
+ "dim_head": 64,
35
+ "kwargs": {
36
+ "classifier_init_range": 0.02,
37
+ "pretrained_model_name_or_path": "google-bert/bert-base-uncased",
38
+ "trust_remote_code": true
39
+ },
40
+ "model_type": "neobert",
41
+ "pretrained_model_name_or_path": "google-bert/bert-base-uncased",
42
+ "torch_dtype": "float32",
43
+ "transformers_version": "4.48.2",
44
+ "trust_remote_code": true
45
+ },
46
+ "label2id": {
47
+ "LABEL_0": 0
48
+ },
49
+ "max_length": 4096,
50
+ "model_type": "neobert",
51
+ "norm_eps": 1e-05,
52
+ "num_attention_heads": 12,
53
+ "num_hidden_layers": 28,
54
+ "pad_token_id": 0,
55
+ "pretrained_model_name_or_path": "google-bert/bert-base-uncased",
56
+ "torch_dtype": "float32",
57
+ "transformers_version": "4.49.0",
58
+ "trust_remote_code": true,
59
+ "vocab_size": 30522
60
+ }
model.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/facebookresearch/llama/blob/main/llama/model.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from torch.nn.functional import scaled_dot_product_attention
8
+
9
+ from typing import Optional
10
+ import numpy as np
11
+
12
+ from xformers.ops import SwiGLU
13
+
14
+ try:
15
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
16
+
17
+ FLASH_ATTN_AVAILABLE = True
18
+ except ImportError:
19
+ FLASH_ATTN_AVAILABLE = False
20
+
21
+ from transformers import (
22
+ PreTrainedModel,
23
+ PretrainedConfig,
24
+ DataCollatorForLanguageModeling,
25
+ )
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutput,
28
+ MaskedLMOutput,
29
+ SequenceClassifierOutput,
30
+ )
31
+
32
+ from .rotary import precompute_freqs_cis, apply_rotary_emb
33
+
34
+
35
+ class DataCollatorWithPacking(DataCollatorForLanguageModeling):
36
+ def __init__(self, pack_sequences=False, **kwargs):
37
+ super().__init__(**kwargs)
38
+ self.pack_sequences = pack_sequences
39
+
40
+ def __call__(self, batch):
41
+ if self.pack_sequences:
42
+ # Add position_ids if not present
43
+ if "position_ids" not in batch[0]:
44
+ for item in batch:
45
+ item["position_ids"] = list(range(len(item["input_ids"])))
46
+
47
+ # Pack the sequences into a single list
48
+ input_ids_list = [item["input_ids"] for item in batch]
49
+ position_ids_list = [item["position_ids"] for item in batch]
50
+ seqlens = np.array([0] + [len(ids) for ids in input_ids_list])
51
+
52
+ packed_batch = {
53
+ "position_ids": np.concatenate(position_ids_list, axis=0),
54
+ "input_ids": np.concatenate(input_ids_list, axis=0),
55
+ "cu_seqlens": np.cumsum(seqlens),
56
+ "max_seqlen": max(seqlens),
57
+ }
58
+
59
+ batch = super().__call__([packed_batch])
60
+ batch["cu_seqlens"] = batch["cu_seqlens"].to(torch.int32).squeeze()
61
+ else:
62
+ batch = super().__call__(batch)
63
+ batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
64
+
65
+ return batch
66
+
67
+
68
+ class NeoBERTConfig(PretrainedConfig):
69
+ model_type = "neobert"
70
+
71
+ # All config parameters must have a default value.
72
+ def __init__(
73
+ self,
74
+ hidden_size: int = 768,
75
+ num_hidden_layers: int = 28,
76
+ num_attention_heads: int = 12,
77
+ intermediate_size: int = 3072,
78
+ embedding_init_range: float = 0.02,
79
+ decoder_init_range: float = 0.02,
80
+ norm_eps: float = 1e-06,
81
+ vocab_size: int = 30522,
82
+ pad_token_id: int = 0,
83
+ max_length: int = 1024,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(**kwargs)
87
+
88
+ self.hidden_size = hidden_size
89
+ self.num_hidden_layers = num_hidden_layers
90
+ self.num_attention_heads = num_attention_heads
91
+ if hidden_size % num_attention_heads != 0:
92
+ raise ValueError("Hidden size must be divisible by the number of heads.")
93
+ self.dim_head = hidden_size // num_attention_heads
94
+ self.intermediate_size = intermediate_size
95
+ self.embedding_init_range = embedding_init_range
96
+ self.decoder_init_range = decoder_init_range
97
+ self.norm_eps = norm_eps
98
+ self.vocab_size = vocab_size
99
+ self.pad_token_id = pad_token_id
100
+ self.max_length = max_length
101
+ self.kwargs = kwargs
102
+
103
+
104
+ class EncoderBlock(nn.Module):
105
+ """Transformer encoder block."""
106
+
107
+ def __init__(self, config: NeoBERTConfig):
108
+ super().__init__()
109
+
110
+ self.config = config
111
+
112
+ # Attention
113
+ self.qkv = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size * 3, bias=False)
114
+ self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=False)
115
+
116
+ # Feedforward network
117
+ multiple_of = 8
118
+ intermediate_size = int(2 * config.intermediate_size / 3)
119
+ intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
120
+ self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
121
+
122
+ # Layer norms
123
+ self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
124
+ self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
125
+
126
+ def forward(
127
+ self,
128
+ x: torch.Tensor,
129
+ attention_mask: torch.Tensor,
130
+ freqs_cis: torch.Tensor,
131
+ output_attentions: bool,
132
+ max_seqlen: int = None,
133
+ cu_seqlens: torch.Tensor = None,
134
+ ):
135
+ # Attention
136
+ attn_output, attn_weights = self._att_block(
137
+ self.attention_norm(x), attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens
138
+ )
139
+
140
+ # Residual
141
+ x = x + attn_output
142
+
143
+ # Feed-forward
144
+ x = x + self.ffn(self.ffn_norm(x))
145
+
146
+ return x, attn_weights
147
+
148
+ def _att_block(
149
+ self,
150
+ x: torch.Tensor,
151
+ attention_mask: torch.Tensor,
152
+ freqs_cis: torch.Tensor,
153
+ output_attentions: bool,
154
+ max_seqlen: int = None,
155
+ cu_seqlens: torch.Tensor = None,
156
+ ):
157
+ batch_size, seq_len, _ = x.shape
158
+
159
+ xq, xk, xv = self.qkv(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.dim_head * 3).chunk(3, axis=-1)
160
+
161
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
162
+
163
+ # Attn block
164
+ attn_weights = None
165
+
166
+ # Flash attention if the tensors are packed
167
+ if cu_seqlens is not None:
168
+ attn = flash_attn_varlen_func(
169
+ q=xq.squeeze(0),
170
+ k=xk.squeeze(0),
171
+ v=xv.squeeze(0),
172
+ cu_seqlens_q=cu_seqlens,
173
+ cu_seqlens_k=cu_seqlens,
174
+ max_seqlen_q=max_seqlen,
175
+ max_seqlen_k=max_seqlen,
176
+ dropout_p=0.0,
177
+ causal=False,
178
+ )
179
+ # Eager attention if attention weights are needed in the output
180
+ elif output_attentions:
181
+ attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
182
+ if attention_mask is not None:
183
+ attn_weights = attn_weights * attention_mask
184
+ attn_weights = attn_weights.softmax(-1)
185
+ attn = attn_weights @ xv.permute(0, 2, 1, 3)
186
+ attn = attn.transpose(1, 2)
187
+ # Fall back to SDPA otherwise
188
+ else:
189
+ attn = scaled_dot_product_attention(
190
+ query=xq.transpose(1, 2),
191
+ key=xk.transpose(1, 2),
192
+ value=xv.transpose(1, 2),
193
+ attn_mask=attention_mask.bool(),
194
+ dropout_p=0,
195
+ ).transpose(1, 2)
196
+
197
+ return self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.config.dim_head)), attn_weights
198
+
199
+
200
+ class NeoBERTPreTrainedModel(PreTrainedModel):
201
+ config_class = NeoBERTConfig
202
+ _supports_cache_class = True
203
+
204
+ def _init_weights(self, module):
205
+ if isinstance(module, nn.Linear):
206
+ module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range)
207
+ elif isinstance(module, nn.Embedding):
208
+ module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range)
209
+
210
+
211
+ class NeoBERT(NeoBERTPreTrainedModel):
212
+ config_class = NeoBERTConfig
213
+
214
+ def __init__(self, config: NeoBERTConfig):
215
+ super().__init__(config)
216
+
217
+ self.config = config
218
+
219
+ self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
220
+
221
+ # Ensures freqs_cis is moved to the same devices as the model. Non-persistent buffers are not saved in the state_dict.
222
+ freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
223
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
224
+
225
+ self.transformer_encoder = nn.ModuleList()
226
+ for _ in range(config.num_hidden_layers):
227
+ self.transformer_encoder.append(EncoderBlock(config))
228
+
229
+ self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
230
+
231
+ # Initialize weights and apply final processing
232
+ self.post_init()
233
+
234
+ def forward(
235
+ self,
236
+ input_ids: torch.Tensor,
237
+ position_ids: torch.Tensor = None,
238
+ max_seqlen: int = None,
239
+ cu_seqlens: torch.Tensor = None,
240
+ attention_mask: torch.Tensor = None,
241
+ output_hidden_states: bool = False,
242
+ output_attentions: bool = False,
243
+ **kwargs,
244
+ ):
245
+ # Initialize
246
+ hidden_states, attentions = [], []
247
+
248
+ # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
249
+ if attention_mask is not None:
250
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
251
+
252
+ # Checks to be done if inputs are packed sequences
253
+ if cu_seqlens is not None:
254
+ assert (
255
+ FLASH_ATTN_AVAILABLE
256
+ ), "Flash-attention is not available. Please ''pip install flash_attn'', or provide un-packed sequences."
257
+ assert not output_attentions, "Output attentions is not supported when sequences are packed."
258
+ assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
259
+ assert input_ids.shape[0] == 1, "Cumulative sequence lengths are provided but input_ids are not packed."
260
+ assert input_ids.is_cuda, "Packing uses an implementation of flash-attention and is only supported on GPU."
261
+
262
+ # RoPE
263
+ freqs_cis = self.freqs_cis[position_ids] if position_ids is not None else self.freqs_cis[: input_ids.shape[1]].unsqueeze(0)
264
+
265
+ # Embedding
266
+ x = self.encoder(input_ids)
267
+
268
+ # Transformer encoder
269
+ for layer in self.transformer_encoder:
270
+ x, attn = layer(x, attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)
271
+ if output_hidden_states:
272
+ hidden_states.append(x)
273
+ if output_attentions:
274
+ attentions.append(attn)
275
+
276
+ # Final normalization layer
277
+ x = self.layer_norm(x)
278
+
279
+ # Return the output of the last hidden layer
280
+ return BaseModelOutput(
281
+ last_hidden_state=x,
282
+ hidden_states=hidden_states if output_hidden_states else None,
283
+ attentions=attentions if output_attentions else None,
284
+ )
285
+
286
+
287
+ class NeoBERTLMHead(NeoBERTPreTrainedModel):
288
+ config_class = NeoBERTConfig
289
+
290
+ def __init__(self, config: NeoBERTConfig):
291
+ super().__init__(config)
292
+
293
+ self.config = config
294
+
295
+ self.model = NeoBERT(config)
296
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
297
+
298
+ self.post_init()
299
+
300
+ def forward(
301
+ self,
302
+ input_ids: torch.Tensor,
303
+ position_ids: torch.Tensor = None,
304
+ max_seqlen: int = None,
305
+ cu_seqlens: torch.Tensor = None,
306
+ attention_mask: torch.Tensor = None,
307
+ output_hidden_states: bool = False,
308
+ output_attentions: bool = False,
309
+ **kwargs,
310
+ ):
311
+
312
+ output = self.model.forward(
313
+ input_ids,
314
+ position_ids,
315
+ max_seqlen,
316
+ cu_seqlens,
317
+ attention_mask,
318
+ output_hidden_states,
319
+ output_attentions,
320
+ )
321
+ logits = self.decoder(output.last_hidden_state)
322
+
323
+ return MaskedLMOutput(
324
+ hidden_states=output.hidden_states if output_hidden_states else None,
325
+ attentions=output.attentions if output_attentions else None,
326
+ logits=logits,
327
+ )
328
+
329
+
330
+ class NeoBERTForSequenceClassification(NeoBERTPreTrainedModel):
331
+ config_class = NeoBERTConfig
332
+
333
+ def __init__(self, config: NeoBERTConfig):
334
+ super().__init__(config)
335
+
336
+ self.config = config
337
+
338
+ self.num_labels = getattr(config, "num_labels", 2)
339
+ self.classifier_dropout = getattr(config, "classifier_dropout", 0.1)
340
+ self.classifier_init_range = getattr(config, "classifier_init_range", 0.02)
341
+
342
+ self.model = NeoBERT(config)
343
+
344
+ self.dense = nn.Linear(self.config.hidden_size, self.config.hidden_size)
345
+ self.dropout = nn.Dropout(self.classifier_dropout)
346
+ self.classifier = nn.Linear(self.config.hidden_size, self.num_labels)
347
+
348
+ self.post_init()
349
+
350
+ def _init_weights(self, module):
351
+ if isinstance(module, nn.Linear):
352
+ module.weight.data.normal_(mean=0.0, std=self.classifier_init_range)
353
+ if module.bias is not None:
354
+ module.bias.data.zero_()
355
+
356
+ def forward(
357
+ self,
358
+ input_ids: torch.Tensor,
359
+ position_ids: torch.Tensor = None,
360
+ max_seqlen: int = None,
361
+ cu_seqlens: torch.Tensor = None,
362
+ attention_mask: torch.Tensor = None,
363
+ output_hidden_states: bool = False,
364
+ output_attentions: bool = False,
365
+ labels: Optional[torch.Tensor] = None,
366
+ return_dict: Optional[bool] = None,
367
+ ):
368
+
369
+ output = self.model.forward(
370
+ input_ids,
371
+ position_ids,
372
+ max_seqlen,
373
+ cu_seqlens,
374
+ attention_mask,
375
+ output_hidden_states,
376
+ output_attentions,
377
+ )
378
+ hidden_states = output.last_hidden_state
379
+
380
+ x = hidden_states[:, 0, :]
381
+ x = self.dropout(x)
382
+ x = self.dense(x)
383
+ x = torch.tanh(x)
384
+ x = self.dropout(x)
385
+
386
+ logits = self.classifier(x)
387
+
388
+ loss = None
389
+ if labels is not None:
390
+ if self.config.problem_type is None:
391
+ if self.num_labels == 1:
392
+ self.config.problem_type = "regression"
393
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
394
+ self.config.problem_type = "single_label_classification"
395
+ else:
396
+ self.config.problem_type = "multi_label_classification"
397
+
398
+ if self.config.problem_type == "regression":
399
+ loss_fct = MSELoss()
400
+ if self.num_labels == 1:
401
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
402
+ else:
403
+ loss = loss_fct(logits, labels)
404
+ elif self.config.problem_type == "single_label_classification":
405
+ loss_fct = CrossEntropyLoss()
406
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
407
+ elif self.config.problem_type == "multi_label_classification":
408
+ loss_fct = BCEWithLogitsLoss()
409
+ loss = loss_fct(logits, labels)
410
+
411
+ if not return_dict:
412
+ result = (logits,)
413
+ return ((loss,) + result) if loss is not None else result
414
+
415
+ return SequenceClassifierOutput(
416
+ loss=loss,
417
+ logits=logits,
418
+ hidden_states=output.hidden_states if output_hidden_states else None,
419
+ attentions=output.attentions if output_attentions else None,
420
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e79b9fccea0b7f46362b245105068cc9b4dcef5880ef562d483ddf382032298
3
+ size 889047508
rotary.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/facebookresearch/llama/blob/main/llama/model.py
2
+
3
+ import torch
4
+ from typing import Tuple
5
+
6
+
7
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
8
+ """
9
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
10
+
11
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
12
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
13
+ The returned tensor contains complex values in complex64 data type.
14
+
15
+ Args:
16
+ dim (int): Dimension of the frequency tensor.
17
+ end (int): End index for precomputing frequencies.
18
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
19
+
20
+ Returns:
21
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
22
+ """
23
+
24
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
25
+ t = torch.arange(end, device=freqs.device)
26
+ freqs = torch.outer(t, freqs).float()
27
+ return torch.polar(torch.ones_like(freqs), freqs)
28
+
29
+
30
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
31
+ assert freqs_cis.shape[1:] == (x.shape[1], x.shape[-1])
32
+ return freqs_cis.contiguous().unsqueeze(2)
33
+
34
+
35
+ def apply_rotary_emb(
36
+ xq: torch.Tensor,
37
+ xk: torch.Tensor,
38
+ freqs_cis: torch.Tensor,
39
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
40
+ """
41
+ Apply rotary embeddings to input tensors using the given frequency tensor.
42
+
43
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
44
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
45
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
46
+ returned as real tensors.
47
+
48
+ Args:
49
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
50
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
51
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
52
+
53
+ Returns:
54
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
55
+ """
56
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
57
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
58
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
59
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
60
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
61
+ return xq_out.type_as(xq), xk_out.type_as(xk)
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "extra_special_tokens": {},
48
+ "mask_token": "[MASK]",
49
+ "model_input_names": [
50
+ "input_ids",
51
+ "attention_mask"
52
+ ],
53
+ "model_max_length": 4096,
54
+ "pad_token": "[PAD]",
55
+ "sep_token": "[SEP]",
56
+ "strip_accents": null,
57
+ "tokenize_chinese_chars": true,
58
+ "tokenizer_class": "BertTokenizer",
59
+ "unk_token": "[UNK]",
60
+ "vocab_size": 30522
61
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff