Tom Aarsen commited on
Commit
23a9cf3
·
1 Parent(s): be3c10d

Integrate with (Sentence) Transformers

Browse files
Files changed (4) hide show
  1. README.md +58 -3
  2. config.json +16 -3
  3. model.safetensors +2 -2
  4. modeling.py +88 -0
README.md CHANGED
@@ -1,11 +1,66 @@
1
  ---
2
  license: apache-2.0
3
  pipeline_tag: text-ranking
4
- library_name: lightning-ir
 
 
5
  base_model:
6
  - google/electra-base-discriminator
 
 
7
  ---
8
 
9
- This model was introduced in the paper [A Systematic Investigation of Distilling Large Language Models into Cross-Encoders for Passage Re-ranking](https://arxiv.org/abs/2405.07920).
10
 
11
- For code, examples and more, please visit https://github.com/webis-de/msmarco-llm-distillation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  pipeline_tag: text-ranking
4
+ language:
5
+ - en
6
+ library_name: sentence-transformers
7
  base_model:
8
  - google/electra-base-discriminator
9
+ tags:
10
+ - transformers
11
  ---
12
 
13
+ ## Cross-Encoder for Text Ranking
14
 
15
+ This model is a port of the [webis/monoelectra-base](https://huggingface.co/webis/monoelectra-base) model from [lightning-ir](https://github.com/webis-de/lightning-ir) to [Sentence Transformers](https://sbert.net/) and [Transformers](https://huggingface.co/docs/transformers).
16
+
17
+ The original model was introduced in the paper [A Systematic Investigation of Distilling Large Language Models into Cross-Encoders for Passage Re-ranking](https://arxiv.org/abs/2405.07920). See https://github.com/webis-de/rank-distillm for code used to train the original model.
18
+
19
+ The model can be used as a reranker in a 2-stage "retrieve-rerank" pipeline, where it reorders passages returned by a retriever model (e.g. an embedding model or BM25) given some query. See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html) for more details.
20
+
21
+ ## Usage with Sentence Transformers
22
+
23
+ The usage is easy when you have [SentenceTransformers](https://www.sbert.net/) installed.
24
+
25
+ ```bash
26
+ pip install sentence-transformers
27
+ ```
28
+
29
+ Then you can use the pre-trained model like this:
30
+
31
+ ```python
32
+ from sentence_transformers import CrossEncoder
33
+
34
+ model = CrossEncoder("cross-encoder/monoelectra-base", trust_remote_code=True)
35
+ scores = model.predict([
36
+ ("How many people live in Berlin?", "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."),
37
+ ("How many people live in Berlin?", "Berlin is well known for its museums."),
38
+ ])
39
+ print(scores)
40
+ # [ 8.607138 -4.320078]
41
+ ```
42
+
43
+ ## Usage with Transformers
44
+
45
+ ```python
46
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
47
+ import torch
48
+
49
+ model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/monoelectra-base", trust_remote_code=True)
50
+ tokenizer = AutoTokenizer.from_pretrained("cross-encoder/monoelectra-base")
51
+
52
+ features = tokenizer(
53
+ [
54
+ ["How many people live in Berlin?", "How many people live in Berlin?"],
55
+ ["Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.", "New York City is famous for the Metropolitan Museum of Art."],
56
+ ],
57
+ padding=True,
58
+ truncation=True,
59
+ return_tensors="pt",
60
+ )
61
+
62
+ model.eval()
63
+ with torch.no_grad():
64
+ scores = model(**features).logits
65
+ print(scores)
66
+ ```
config.json CHANGED
@@ -1,8 +1,11 @@
1
  {
2
  "architectures": [
3
- "CrossEncoderElectraModel"
4
  ],
5
  "attention_probs_dropout_prob": 0.1,
 
 
 
6
  "backbone_model_type": "electra",
7
  "classifier_dropout": null,
8
  "doc_length": 256,
@@ -10,23 +13,33 @@
10
  "hidden_act": "gelu",
11
  "hidden_dropout_prob": 0.1,
12
  "hidden_size": 768,
 
 
 
13
  "initializer_range": 0.02,
14
  "intermediate_size": 3072,
 
 
 
15
  "layer_norm_eps": 1e-12,
16
  "max_position_embeddings": 512,
17
- "model_type": "cross-encoder",
18
  "num_attention_heads": 12,
19
  "num_hidden_layers": 12,
20
  "pad_token_id": 0,
21
  "pooling_strategy": "first",
22
  "position_embedding_type": "absolute",
23
  "query_length": 32,
 
 
 
 
24
  "summary_activation": "gelu",
25
  "summary_last_dropout": 0.1,
26
  "summary_type": "first",
27
  "summary_use_proj": true,
28
  "torch_dtype": "float32",
29
- "transformers_version": "4.41.2",
30
  "type_vocab_size": 2,
31
  "use_cache": true,
32
  "vocab_size": 30522
 
1
  {
2
  "architectures": [
3
+ "WebisCrossEncoderForSequenceClassification"
4
  ],
5
  "attention_probs_dropout_prob": 0.1,
6
+ "auto_map": {
7
+ "AutoModelForSequenceClassification": "modeling.WebisCrossEncoderForSequenceClassification"
8
+ },
9
  "backbone_model_type": "electra",
10
  "classifier_dropout": null,
11
  "doc_length": 256,
 
13
  "hidden_act": "gelu",
14
  "hidden_dropout_prob": 0.1,
15
  "hidden_size": 768,
16
+ "id2label": {
17
+ "0": "LABEL_0"
18
+ },
19
  "initializer_range": 0.02,
20
  "intermediate_size": 3072,
21
+ "label2id": {
22
+ "LABEL_0": 0
23
+ },
24
  "layer_norm_eps": 1e-12,
25
  "max_position_embeddings": 512,
26
+ "model_type": "electra",
27
  "num_attention_heads": 12,
28
  "num_hidden_layers": 12,
29
  "pad_token_id": 0,
30
  "pooling_strategy": "first",
31
  "position_embedding_type": "absolute",
32
  "query_length": 32,
33
+ "sentence_transformers": {
34
+ "activation_fn": "torch.nn.modules.linear.Identity",
35
+ "version": "4.0.1"
36
+ },
37
  "summary_activation": "gelu",
38
  "summary_last_dropout": 0.1,
39
  "summary_type": "first",
40
  "summary_use_proj": true,
41
  "torch_dtype": "float32",
42
+ "transformers_version": "4.49.0",
43
  "type_vocab_size": 2,
44
  "use_cache": true,
45
  "vocab_size": 30522
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:763a09add59ddfc02d87d3211d40374769b2c7060f92507a885f3917a33f7caa
3
- size 435592020
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5516a5f0510d6b44fc7415d3b283118f935c6438391e44e0850d079c0e644796
3
+ size 435593564
modeling.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Optional, Tuple, Union
3
+ import torch
4
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
5
+ from transformers import ElectraPreTrainedModel, ElectraModel, ElectraConfig
6
+ from transformers.modeling_outputs import SequenceClassifierOutput
7
+
8
+
9
+ class WebisCrossEncoderForSequenceClassification(ElectraPreTrainedModel):
10
+ def __init__(self, config: ElectraConfig):
11
+ super().__init__(config)
12
+ self.num_labels = config.num_labels
13
+ self.config = config
14
+ self.electra = ElectraModel(config)
15
+ self.linear = torch.nn.Linear(config.hidden_size, config.num_labels, bias=False)
16
+
17
+ # Initialize weights and apply final processing
18
+ self.post_init()
19
+
20
+ def forward(
21
+ self,
22
+ input_ids: Optional[torch.Tensor] = None,
23
+ attention_mask: Optional[torch.Tensor] = None,
24
+ token_type_ids: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.Tensor] = None,
26
+ head_mask: Optional[torch.Tensor] = None,
27
+ inputs_embeds: Optional[torch.Tensor] = None,
28
+ labels: Optional[torch.Tensor] = None,
29
+ output_attentions: Optional[bool] = None,
30
+ output_hidden_states: Optional[bool] = None,
31
+ return_dict: Optional[bool] = None,
32
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
33
+ r"""
34
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
35
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
36
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
37
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
38
+ """
39
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
40
+
41
+ discriminator_hidden_states = self.electra(
42
+ input_ids,
43
+ attention_mask=attention_mask,
44
+ token_type_ids=token_type_ids,
45
+ position_ids=position_ids,
46
+ head_mask=head_mask,
47
+ inputs_embeds=inputs_embeds,
48
+ output_attentions=output_attentions,
49
+ output_hidden_states=output_hidden_states,
50
+ return_dict=return_dict,
51
+ )
52
+
53
+ sequence_output = discriminator_hidden_states[0]
54
+ logits = self.linear(sequence_output[:, 0, :]) # Take [CLS] token representation for classification
55
+
56
+ loss = None
57
+ if labels is not None:
58
+ if self.config.problem_type is None:
59
+ if self.num_labels == 1:
60
+ self.config.problem_type = "regression"
61
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
62
+ self.config.problem_type = "single_label_classification"
63
+ else:
64
+ self.config.problem_type = "multi_label_classification"
65
+
66
+ if self.config.problem_type == "regression":
67
+ loss_fct = MSELoss()
68
+ if self.num_labels == 1:
69
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
70
+ else:
71
+ loss = loss_fct(logits, labels)
72
+ elif self.config.problem_type == "single_label_classification":
73
+ loss_fct = CrossEntropyLoss()
74
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
75
+ elif self.config.problem_type == "multi_label_classification":
76
+ loss_fct = BCEWithLogitsLoss()
77
+ loss = loss_fct(logits, labels)
78
+
79
+ if not return_dict:
80
+ output = (logits,) + discriminator_hidden_states[1:]
81
+ return ((loss,) + output) if loss is not None else output
82
+
83
+ return SequenceClassifierOutput(
84
+ loss=loss,
85
+ logits=logits,
86
+ hidden_states=discriminator_hidden_states.hidden_states,
87
+ attentions=discriminator_hidden_states.attentions,
88
+ )