File size: 8,208 Bytes
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import logging
from typing import Any, Dict, List, Optional, Union  # type: ignore[import-not-found]

import torch
from pydantic import BaseModel, ConfigDict, Field
from transformers import pipeline

from ..hf_pipeline import FeatureExtractionPipelineWithStriding
from .span_embeddings import SpanEmbeddings

logger = logging.getLogger(__name__)

DEFAULT_MODEL_NAME = "allenai/specter2_base"


class HuggingFaceSpanEmbeddings(BaseModel, SpanEmbeddings):
    """An implementation of SpanEmbeddings using a modified HuggingFace Transformers
    feature-extraction pipeline, adapted for long text inputs by chunking with optional stride
    (see src.hf_pipeline.FeatureExtractionPipelineWithStriding).

    Note that calculating embeddings for multiple spans is efficient when all spans for a
    text are passed in a single call to embed_document_spans, as the text embedding is computed
    only once per unique text, and the span embeddings are simply pooled from these text embeddings.

    It accepts any model that can be used with the HuggingFace feature-extraction pipeline, also
    models with adapters such as SPECTER2 (see https://huggingface.co/allenai/specter2). In this case,
    the model should be loaded beforehand and passed as parameter 'model' instead of the model identifier.
    See https://huggingface.co/docs/transformers/main_classes/pipelines for further information.

    To use, you should have the ``transformers`` python package installed.

    Example:
        .. code-block:: python
            from transformers import AutoModel

            model = "allenai/specter2_base"
            pipeline_kwargs = {'device': 'cpu', 'stride': 64, 'batch_size': 8}
            encode_kwargs = {'normalize_embeddings': False}
            hf = HuggingFaceSpanEmbeddings(
                model=model,
                pipeline_kwargs=pipeline_kwargs,
            )

            text = "This is a test sentence."

            # calculate embeddings for text[0:4]="This" and text[15:23]="sentence"
            embeddings = hf.embed_document_spans(texts=[text, text], starts=[0, 11], ends=[4, 19])
    """

    client: Any = None  #: :meta private:
    model: Optional[Any] = DEFAULT_MODEL_NAME
    pooling_strategy: str = "mean"
    """Model name to use."""
    pipeline_kwargs: Dict[str, Any] = Field(default_factory=dict)
    """Keyword arguments to pass to the Huggingface pipeline constructor."""
    encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
    """Keyword arguments to pass when calling the pipeline."""
    # show_progress: bool = False
    """Whether to show a progress bar."""
    model_max_length: Optional[int] = None
    """The maximum input length of the model. Required for some model checkpoints with outdated configs."""

    def __init__(self, **kwargs: Any):
        """Initialize the sentence_transformer."""
        super().__init__(**kwargs)

        self.client = pipeline(
            "feature-extraction",
            model=self.model,
            pipeline_class=FeatureExtractionPipelineWithStriding,
            trust_remote_code=True,
            **self.pipeline_kwargs,
        )

        # The Transformers library is buggy since 4.40.0,
        # see https://github.com/huggingface/transformers/issues/30643,
        # so we need to set the max_length to e.g. 512 manually
        if self.model_max_length is not None:
            self.client.tokenizer.model_max_length = self.model_max_length

        # Check if the model has a valid max length
        max_input_size = self.client.tokenizer.model_max_length
        if max_input_size > 1e5:  # A high threshold to catch "unlimited" values
            raise ValueError(
                "The tokenizer does not specify a valid `model_max_length` attribute. "
                "Consider setting it manually by passing `model_max_length` to the "
                "HuggingFaceSpanEmbeddings constructor."
            )

    model_config = ConfigDict(
        extra="forbid",
        protected_namespaces=(),
    )

    def get_span_embedding(
        self,
        last_hidden_state: torch.Tensor,
        offset_mapping: torch.Tensor,
        start: Union[int, List[int]],
        end: Union[int, List[int]],
        **unused_kwargs,
    ) -> Optional[List[float]]:
        """Pool the span embeddings."""
        if isinstance(start, int):
            start = [start]
        if isinstance(end, int):
            end = [end]
        if len(start) != len(end):
            raise ValueError("start and end should have the same length.")
        if len(start) == 0:
            raise ValueError("start and end should not be empty.")
        if last_hidden_state.shape[0] != 1:
            raise ValueError("last_hidden_state should have a batch size of 1.")
        if last_hidden_state.shape[0] != offset_mapping.shape[0]:
            raise ValueError(
                "last_hidden_state and offset_mapping should have the same batch size."
            )
        offset_mapping = offset_mapping[0]
        last_hidden_state = last_hidden_state[0]

        mask = (start[0] <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= end[0])
        for s, e in zip(start[1:], end[1:]):
            mask = mask | ((s <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= e))
        span_embeddings = last_hidden_state[mask]
        if span_embeddings.shape[0] == 0:
            return None
        if self.pooling_strategy == "mean":
            return span_embeddings.mean(dim=0).tolist()
        elif self.pooling_strategy == "max":
            return span_embeddings.max(dim=0).values.tolist()
        else:
            raise ValueError(f"Unknown pool strategy: {self.pooling_strategy}")

    def embed_document_spans(
        self,
        texts: List[str],
        starts: Union[List[int], List[List[int]]],
        ends: Union[List[int], List[List[int]]],
    ) -> List[Optional[List[float]]]:
        """Compute doc embeddings using a HuggingFace transformer model.

        Args:
            texts: The list of texts to embed.
            starts: The list of start indices or list of lists of start indices (multi-span).
            ends: The list of end indices or list of lists of end indices (multi-span).

        Returns:
            List of embeddings, one for each text.
        """
        pipeline_kwargs = self.encode_kwargs.copy()
        pipeline_kwargs["return_offset_mapping"] = True
        # we enable long text handling by providing the stride parameter
        if pipeline_kwargs.get("stride", None) is None:
            pipeline_kwargs["stride"] = 0
        # when stride is positive, we need to create unique embeddings per token
        if pipeline_kwargs["stride"] > 0:
            pipeline_kwargs["create_unique_embeddings_per_token"] = True
        # we ask for tensors to efficiently compute the span embeddings
        pipeline_kwargs["return_tensors"] = True

        unique_texts = sorted(set(texts))
        idx2unique_idx = {i: unique_texts.index(text) for i, text in enumerate(texts)}
        pipeline_results = self.client(unique_texts, **pipeline_kwargs)
        embeddings = [
            self.get_span_embedding(
                start=starts[idx], end=ends[idx], **pipeline_results[idx2unique_idx[idx]]
            )
            for idx in range(len(texts))
        ]
        return embeddings

    def embed_query_span(
        self, text: str, start: Union[int, List[int]], end: Union[int, List[int]]
    ) -> Optional[List[float]]:
        """Compute query embeddings using a HuggingFace transformer model.

        Args:
            text: The text to embed.
            start: The start index or list of start indices (multi-span).
            end: The end index or list of end indices (multi-span).

        Returns:
            Embeddings for the text.
        """
        starts: Union[List[int], List[List[int]]] = [start]  # type: ignore[assignment]
        ends: Union[List[int], List[List[int]]] = [end]  # type: ignore[assignment]
        return self.embed_document_spans([text], starts=starts, ends=ends)[0]

    @property
    def embedding_dim(self) -> int:
        """Get the embedding dimension."""
        return self.client.model.config.hidden_size