File size: 13,589 Bytes
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
from __future__ import annotations

import json
import logging
import os
import uuid
from collections import defaultdict
from itertools import islice
from typing import (  # type: ignore[import-not-found]
    Any,
    Dict,
    Generator,
    Iterable,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import numpy as np
from langchain_core.documents import Document as LCDocument
from langchain_qdrant import QdrantVectorStore, RetrievalMode
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Record

from .span_embeddings import SpanEmbeddings
from .span_vectorstore import SpanVectorStore

logger = logging.getLogger(__name__)


class QdrantSpanVectorStore(SpanVectorStore, QdrantVectorStore):
    """An implementation of the SpanVectorStore interface that uses Qdrant
    as backend for storing and retrieving span embeddings."""

    EMBEDDINGS_FILE = "embeddings.npy"
    PAYLOADS_FILE = "payloads.json"
    INDEX_FILE = "index.json"

    def __init__(
        self,
        client: QdrantClient,
        collection_name: str,
        embedding: SpanEmbeddings,
        vector_params: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        if not client.collection_exists(collection_name):
            logger.info(f'Collection "{collection_name}" does not exist. Creating it now.')
            client.create_collection(
                collection_name=collection_name,
                vectors_config=models.VectorParams(size=embedding.embedding_dim, **vector_params),
            )
        else:
            logger.info(f'Collection "{collection_name}" already exists.')
        super().__init__(
            client=client, collection_name=collection_name, embedding=embedding, **kwargs
        )

    def __len__(self):
        return self.client.get_collection(collection_name=self.collection_name).points_count

    def get_by_ids_with_vectors(self, ids: Sequence[str | int], /) -> List[LCDocument]:
        results = self.client.retrieve(
            self.collection_name, ids, with_payload=True, with_vectors=True
        )

        return [
            self._document_from_point(
                result,
                self.collection_name,
                self.content_payload_key,
                self.metadata_payload_key,
            )
            for result in results
        ]

    def construct_filter(
        self,
        query_span: Union[Span, MultiSpan],
        metadata_doc_id_key: str,
        doc_id_whitelist: Optional[Sequence[str]] = None,
        doc_id_blacklist: Optional[Sequence[str]] = None,
    ) -> Optional[models.Filter]:
        """Construct a filter for the retrieval. It should enforce that:
        - if the span is labeled, the retrieved span has the same label, or
        - if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label
        - if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist
        - if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist

        Args:
            query_span: The query span.
            metadata_doc_id_key: The key in the metadata that holds the document id.
            doc_id_whitelist: A list of document ids to restrict the retrieval to.
            doc_id_blacklist: A list of document ids to exclude from the retrieval.

        Returns:
            A filter object.
        """
        filter_kwargs = defaultdict(list)
        # if the span is labeled, enforce that the retrieved span has the same label
        if isinstance(query_span, (LabeledSpan, LabeledMultiSpan)):
            if self.label_mapping is not None:
                target_labels = self.label_mapping.get(query_span.label, [])
            else:
                target_labels = [query_span.label]
            filter_kwargs["must"].append(
                models.FieldCondition(
                    key=f"metadata.{self.METADATA_SPAN_KEY}.label",
                    match=models.MatchAny(any=target_labels),
                )
            )
        elif self.label_mapping is not None:
            raise TypeError("Label mapping is only supported for labeled spans")

        if doc_id_blacklist is not None and doc_id_whitelist is not None:
            overlap = set(doc_id_whitelist) & set(doc_id_blacklist)
            if len(overlap) > 0:
                raise ValueError(
                    f"Overlap between doc_id_whitelist and doc_id_blacklist: {overlap}"
                )

        if doc_id_whitelist is not None:
            filter_kwargs["must"].append(
                models.FieldCondition(
                    key=f"metadata.{metadata_doc_id_key}",
                    match=(
                        models.MatchValue(value=doc_id_whitelist[0])
                        if len(doc_id_whitelist) == 1
                        else models.MatchAny(any=doc_id_whitelist)
                    ),
                )
            )
        if doc_id_blacklist is not None:
            filter_kwargs["must_not"].append(
                models.FieldCondition(
                    key=f"metadata.{metadata_doc_id_key}",
                    match=(
                        models.MatchValue(value=doc_id_blacklist[0])
                        if len(doc_id_blacklist) == 1
                        else models.MatchAny(any=doc_id_blacklist)
                    ),
                )
            )
        if len(filter_kwargs) > 0:
            return models.Filter(**filter_kwargs)
        return None

    @classmethod
    def _document_from_point(
        cls,
        scored_point: Any,
        collection_name: str,
        content_payload_key: str,
        metadata_payload_key: str,
    ) -> LCDocument:
        metadata = scored_point.payload.get(metadata_payload_key) or {}
        metadata["_collection_name"] = collection_name
        if hasattr(scored_point, "score"):
            metadata[cls.RELEVANCE_SCORE_KEY] = scored_point.score
        if hasattr(scored_point, "vector"):
            metadata[cls.METADATA_VECTOR_KEY] = scored_point.vector
        return LCDocument(
            id=scored_point.id,
            page_content=scored_point.payload.get(content_payload_key, ""),
            metadata=metadata,
        )

    def _build_vectors_with_metadata(
        self,
        texts: Iterable[str],
        metadatas: List[dict],
    ) -> List[models.VectorStruct]:
        starts = [metadata[self.METADATA_SPAN_KEY][self.SPAN_START_KEY] for metadata in metadatas]
        ends = [metadata[self.METADATA_SPAN_KEY][self.SPAN_END_KEY] for metadata in metadatas]
        if self.retrieval_mode == RetrievalMode.DENSE:
            batch_embeddings = self.embeddings.embed_document_spans(list(texts), starts, ends)
            return [
                {
                    self.vector_name: vector,
                }
                for vector in batch_embeddings
            ]

        elif self.retrieval_mode == RetrievalMode.SPARSE:
            raise ValueError("Sparse retrieval mode is not yet implemented.")

        elif self.retrieval_mode == RetrievalMode.HYBRID:
            raise NotImplementedError("Hybrid retrieval mode is not yet implemented.")
        else:
            raise ValueError(f"Unknown retrieval mode. {self.retrieval_mode} to build vectors.")

    def _build_payloads_from_metadata(
        self,
        metadatas: Iterable[dict],
        metadata_payload_key: str,
    ) -> List[dict]:
        payloads = [{metadata_payload_key: metadata} for metadata in metadatas]

        return payloads

    def _generate_batches(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[dict]] = None,
        ids: Optional[Sequence[str | int]] = None,
        batch_size: int = 64,
    ) -> Generator[tuple[list[str | int], list[models.PointStruct]], Any, None]:
        """Generate batches of points to index. Same as in `QdrantVectorStore` but metadata is used
        to build vectors and payloads."""

        texts_iterator = iter(texts)
        if metadatas is None:
            raise ValueError("Metadata must be provided to generate batches.")
        metadatas_iterator = iter(metadatas)
        ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])

        while batch_texts := list(islice(texts_iterator, batch_size)):
            batch_metadatas = list(islice(metadatas_iterator, batch_size))
            batch_ids = list(islice(ids_iterator, batch_size))
            points = [
                models.PointStruct(
                    id=point_id,
                    vector=vector,
                    payload=payload,
                )
                for point_id, vector, payload in zip(
                    batch_ids,
                    self._build_vectors_with_metadata(batch_texts, metadatas=batch_metadatas),
                    # we do not save the text in the payload because the text is the full
                    # document which is usually already saved in the docstore
                    self._build_payloads_from_metadata(
                        metadatas=batch_metadatas,
                        metadata_payload_key=self.metadata_payload_key,
                    ),
                )
                if vector[self.vector_name] is not None
            ]

            yield [point.id for point in points], points

    def similarity_search_with_score_by_vector(
        self,
        embedding: List[float],
        k: int = 4,
        filter: Optional[models.Filter] = None,
        search_params: Optional[models.SearchParams] = None,
        offset: int = 0,
        score_threshold: Optional[float] = None,
        consistency: Optional[models.ReadConsistency] = None,
        **kwargs: Any,
    ) -> List[Tuple[LCDocument, float]]:
        """Return docs most similar to query vector.

        Returns:
            List of documents most similar to the query text and distance for each.
        """
        query_options = {
            "collection_name": self.collection_name,
            "query_filter": filter,
            "search_params": search_params,
            "limit": k,
            "offset": offset,
            "with_payload": True,
            "with_vectors": False,
            "score_threshold": score_threshold,
            "consistency": consistency,
            **kwargs,
        }

        results = self.client.query_points(
            query=embedding,
            using=self.vector_name,
            **query_options,
        ).points

        return [
            (
                self._document_from_point(
                    result,
                    self.collection_name,
                    self.content_payload_key,
                    self.metadata_payload_key,
                ),
                result.score,
            )
            for result in results
        ]

    def _as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[Any]]:
        data, _ = self.client.scroll(
            collection_name=self.collection_name, with_vectors=True, limit=len(self)
        )
        vectors_np = np.array([point.vector for point in data])
        payloads = [point.payload for point in data]
        emb_ids = [point.id for point in data]
        return emb_ids, vectors_np, payloads

    # TODO: or use create_snapshot and restore_snapshot?
    def _save_to_directory(self, path: str, **kwargs) -> None:
        indices, vectors, payloads = self._as_indices_vectors_payloads()
        np.save(os.path.join(path, self.EMBEDDINGS_FILE), vectors)
        with open(os.path.join(path, self.PAYLOADS_FILE), "w") as f:
            json.dump(payloads, f, indent=2)
        with open(os.path.join(path, self.INDEX_FILE), "w") as f:
            json.dump(indices, f)

    def _load_from_directory(self, path: str, **kwargs) -> None:
        with open(os.path.join(path, self.INDEX_FILE), "r") as f:
            index = json.load(f)
        embeddings_np: np.ndarray = np.load(os.path.join(path, self.EMBEDDINGS_FILE))
        with open(os.path.join(path, self.PAYLOADS_FILE), "r") as f:
            payloads = json.load(f)
        points = models.Batch(ids=index, vectors=embeddings_np.tolist(), payloads=payloads)
        self.client.upsert(
            collection_name=self.collection_name,
            points=points,
        )
        logger.info(f"Loaded {len(index)} points into collection {self.collection_name}.")

    def mget(self, keys: Sequence[str]) -> list[Optional[Record]]:
        return self.client.retrieve(
            self.collection_name, ids=keys, with_payload=True, with_vectors=True
        )

    def mset(self, key_value_pairs: Sequence[tuple[str, Record]]) -> None:
        self.client.upsert(
            collection_name=self.collection_name,
            points=models.Batch(
                ids=[key for key, _ in key_value_pairs],
                vectors=[value.vector for _, value in key_value_pairs],
                payloads=[value.payload for _, value in key_value_pairs],
            ),
        )

    def mdelete(self, keys: Sequence[str]) -> None:
        self.client.delete(collection_name=self.collection_name, points_selector=keys)

    def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
        for point in self.client.scroll(
            collection_name=self.collection_name,
            with_vectors=False,
            with_payload=False,
            limit=len(self),
        )[0]:
            yield point.id