|
from __future__ import annotations |
|
|
|
import json |
|
import logging |
|
import os |
|
import uuid |
|
from collections import defaultdict |
|
from itertools import islice |
|
from typing import ( |
|
Any, |
|
Dict, |
|
Generator, |
|
Iterable, |
|
Iterator, |
|
List, |
|
Optional, |
|
Sequence, |
|
Tuple, |
|
Union, |
|
) |
|
|
|
import numpy as np |
|
from langchain_core.documents import Document as LCDocument |
|
from langchain_qdrant import QdrantVectorStore, RetrievalMode |
|
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span |
|
from qdrant_client import QdrantClient, models |
|
from qdrant_client.http.models import Record |
|
|
|
from .span_embeddings import SpanEmbeddings |
|
from .span_vectorstore import SpanVectorStore |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class QdrantSpanVectorStore(SpanVectorStore, QdrantVectorStore): |
|
"""An implementation of the SpanVectorStore interface that uses Qdrant |
|
as backend for storing and retrieving span embeddings.""" |
|
|
|
EMBEDDINGS_FILE = "embeddings.npy" |
|
PAYLOADS_FILE = "payloads.json" |
|
INDEX_FILE = "index.json" |
|
|
|
def __init__( |
|
self, |
|
client: QdrantClient, |
|
collection_name: str, |
|
embedding: SpanEmbeddings, |
|
vector_params: Optional[Dict[str, Any]] = None, |
|
**kwargs, |
|
): |
|
if not client.collection_exists(collection_name): |
|
logger.info(f'Collection "{collection_name}" does not exist. Creating it now.') |
|
client.create_collection( |
|
collection_name=collection_name, |
|
vectors_config=models.VectorParams(size=embedding.embedding_dim, **vector_params), |
|
) |
|
else: |
|
logger.info(f'Collection "{collection_name}" already exists.') |
|
super().__init__( |
|
client=client, collection_name=collection_name, embedding=embedding, **kwargs |
|
) |
|
|
|
def __len__(self): |
|
return self.client.get_collection(collection_name=self.collection_name).points_count |
|
|
|
def get_by_ids_with_vectors(self, ids: Sequence[str | int], /) -> List[LCDocument]: |
|
results = self.client.retrieve( |
|
self.collection_name, ids, with_payload=True, with_vectors=True |
|
) |
|
|
|
return [ |
|
self._document_from_point( |
|
result, |
|
self.collection_name, |
|
self.content_payload_key, |
|
self.metadata_payload_key, |
|
) |
|
for result in results |
|
] |
|
|
|
def construct_filter( |
|
self, |
|
query_span: Union[Span, MultiSpan], |
|
metadata_doc_id_key: str, |
|
doc_id_whitelist: Optional[Sequence[str]] = None, |
|
doc_id_blacklist: Optional[Sequence[str]] = None, |
|
) -> Optional[models.Filter]: |
|
"""Construct a filter for the retrieval. It should enforce that: |
|
- if the span is labeled, the retrieved span has the same label, or |
|
- if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label |
|
- if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist |
|
- if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist |
|
|
|
Args: |
|
query_span: The query span. |
|
metadata_doc_id_key: The key in the metadata that holds the document id. |
|
doc_id_whitelist: A list of document ids to restrict the retrieval to. |
|
doc_id_blacklist: A list of document ids to exclude from the retrieval. |
|
|
|
Returns: |
|
A filter object. |
|
""" |
|
filter_kwargs = defaultdict(list) |
|
|
|
if isinstance(query_span, (LabeledSpan, LabeledMultiSpan)): |
|
if self.label_mapping is not None: |
|
target_labels = self.label_mapping.get(query_span.label, []) |
|
else: |
|
target_labels = [query_span.label] |
|
filter_kwargs["must"].append( |
|
models.FieldCondition( |
|
key=f"metadata.{self.METADATA_SPAN_KEY}.label", |
|
match=models.MatchAny(any=target_labels), |
|
) |
|
) |
|
elif self.label_mapping is not None: |
|
raise TypeError("Label mapping is only supported for labeled spans") |
|
|
|
if doc_id_blacklist is not None and doc_id_whitelist is not None: |
|
overlap = set(doc_id_whitelist) & set(doc_id_blacklist) |
|
if len(overlap) > 0: |
|
raise ValueError( |
|
f"Overlap between doc_id_whitelist and doc_id_blacklist: {overlap}" |
|
) |
|
|
|
if doc_id_whitelist is not None: |
|
filter_kwargs["must"].append( |
|
models.FieldCondition( |
|
key=f"metadata.{metadata_doc_id_key}", |
|
match=( |
|
models.MatchValue(value=doc_id_whitelist[0]) |
|
if len(doc_id_whitelist) == 1 |
|
else models.MatchAny(any=doc_id_whitelist) |
|
), |
|
) |
|
) |
|
if doc_id_blacklist is not None: |
|
filter_kwargs["must_not"].append( |
|
models.FieldCondition( |
|
key=f"metadata.{metadata_doc_id_key}", |
|
match=( |
|
models.MatchValue(value=doc_id_blacklist[0]) |
|
if len(doc_id_blacklist) == 1 |
|
else models.MatchAny(any=doc_id_blacklist) |
|
), |
|
) |
|
) |
|
if len(filter_kwargs) > 0: |
|
return models.Filter(**filter_kwargs) |
|
return None |
|
|
|
@classmethod |
|
def _document_from_point( |
|
cls, |
|
scored_point: Any, |
|
collection_name: str, |
|
content_payload_key: str, |
|
metadata_payload_key: str, |
|
) -> LCDocument: |
|
metadata = scored_point.payload.get(metadata_payload_key) or {} |
|
metadata["_collection_name"] = collection_name |
|
if hasattr(scored_point, "score"): |
|
metadata[cls.RELEVANCE_SCORE_KEY] = scored_point.score |
|
if hasattr(scored_point, "vector"): |
|
metadata[cls.METADATA_VECTOR_KEY] = scored_point.vector |
|
return LCDocument( |
|
id=scored_point.id, |
|
page_content=scored_point.payload.get(content_payload_key, ""), |
|
metadata=metadata, |
|
) |
|
|
|
def _build_vectors_with_metadata( |
|
self, |
|
texts: Iterable[str], |
|
metadatas: List[dict], |
|
) -> List[models.VectorStruct]: |
|
starts = [metadata[self.METADATA_SPAN_KEY][self.SPAN_START_KEY] for metadata in metadatas] |
|
ends = [metadata[self.METADATA_SPAN_KEY][self.SPAN_END_KEY] for metadata in metadatas] |
|
if self.retrieval_mode == RetrievalMode.DENSE: |
|
batch_embeddings = self.embeddings.embed_document_spans(list(texts), starts, ends) |
|
return [ |
|
{ |
|
self.vector_name: vector, |
|
} |
|
for vector in batch_embeddings |
|
] |
|
|
|
elif self.retrieval_mode == RetrievalMode.SPARSE: |
|
raise ValueError("Sparse retrieval mode is not yet implemented.") |
|
|
|
elif self.retrieval_mode == RetrievalMode.HYBRID: |
|
raise NotImplementedError("Hybrid retrieval mode is not yet implemented.") |
|
else: |
|
raise ValueError(f"Unknown retrieval mode. {self.retrieval_mode} to build vectors.") |
|
|
|
def _build_payloads_from_metadata( |
|
self, |
|
metadatas: Iterable[dict], |
|
metadata_payload_key: str, |
|
) -> List[dict]: |
|
payloads = [{metadata_payload_key: metadata} for metadata in metadatas] |
|
|
|
return payloads |
|
|
|
def _generate_batches( |
|
self, |
|
texts: Iterable[str], |
|
metadatas: Optional[List[dict]] = None, |
|
ids: Optional[Sequence[str | int]] = None, |
|
batch_size: int = 64, |
|
) -> Generator[tuple[list[str | int], list[models.PointStruct]], Any, None]: |
|
"""Generate batches of points to index. Same as in `QdrantVectorStore` but metadata is used |
|
to build vectors and payloads.""" |
|
|
|
texts_iterator = iter(texts) |
|
if metadatas is None: |
|
raise ValueError("Metadata must be provided to generate batches.") |
|
metadatas_iterator = iter(metadatas) |
|
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) |
|
|
|
while batch_texts := list(islice(texts_iterator, batch_size)): |
|
batch_metadatas = list(islice(metadatas_iterator, batch_size)) |
|
batch_ids = list(islice(ids_iterator, batch_size)) |
|
points = [ |
|
models.PointStruct( |
|
id=point_id, |
|
vector=vector, |
|
payload=payload, |
|
) |
|
for point_id, vector, payload in zip( |
|
batch_ids, |
|
self._build_vectors_with_metadata(batch_texts, metadatas=batch_metadatas), |
|
|
|
|
|
self._build_payloads_from_metadata( |
|
metadatas=batch_metadatas, |
|
metadata_payload_key=self.metadata_payload_key, |
|
), |
|
) |
|
if vector[self.vector_name] is not None |
|
] |
|
|
|
yield [point.id for point in points], points |
|
|
|
def similarity_search_with_score_by_vector( |
|
self, |
|
embedding: List[float], |
|
k: int = 4, |
|
filter: Optional[models.Filter] = None, |
|
search_params: Optional[models.SearchParams] = None, |
|
offset: int = 0, |
|
score_threshold: Optional[float] = None, |
|
consistency: Optional[models.ReadConsistency] = None, |
|
**kwargs: Any, |
|
) -> List[Tuple[LCDocument, float]]: |
|
"""Return docs most similar to query vector. |
|
|
|
Returns: |
|
List of documents most similar to the query text and distance for each. |
|
""" |
|
query_options = { |
|
"collection_name": self.collection_name, |
|
"query_filter": filter, |
|
"search_params": search_params, |
|
"limit": k, |
|
"offset": offset, |
|
"with_payload": True, |
|
"with_vectors": False, |
|
"score_threshold": score_threshold, |
|
"consistency": consistency, |
|
**kwargs, |
|
} |
|
|
|
results = self.client.query_points( |
|
query=embedding, |
|
using=self.vector_name, |
|
**query_options, |
|
).points |
|
|
|
return [ |
|
( |
|
self._document_from_point( |
|
result, |
|
self.collection_name, |
|
self.content_payload_key, |
|
self.metadata_payload_key, |
|
), |
|
result.score, |
|
) |
|
for result in results |
|
] |
|
|
|
def _as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[Any]]: |
|
data, _ = self.client.scroll( |
|
collection_name=self.collection_name, with_vectors=True, limit=len(self) |
|
) |
|
vectors_np = np.array([point.vector for point in data]) |
|
payloads = [point.payload for point in data] |
|
emb_ids = [point.id for point in data] |
|
return emb_ids, vectors_np, payloads |
|
|
|
|
|
def _save_to_directory(self, path: str, **kwargs) -> None: |
|
indices, vectors, payloads = self._as_indices_vectors_payloads() |
|
np.save(os.path.join(path, self.EMBEDDINGS_FILE), vectors) |
|
with open(os.path.join(path, self.PAYLOADS_FILE), "w") as f: |
|
json.dump(payloads, f, indent=2) |
|
with open(os.path.join(path, self.INDEX_FILE), "w") as f: |
|
json.dump(indices, f) |
|
|
|
def _load_from_directory(self, path: str, **kwargs) -> None: |
|
with open(os.path.join(path, self.INDEX_FILE), "r") as f: |
|
index = json.load(f) |
|
embeddings_np: np.ndarray = np.load(os.path.join(path, self.EMBEDDINGS_FILE)) |
|
with open(os.path.join(path, self.PAYLOADS_FILE), "r") as f: |
|
payloads = json.load(f) |
|
points = models.Batch(ids=index, vectors=embeddings_np.tolist(), payloads=payloads) |
|
self.client.upsert( |
|
collection_name=self.collection_name, |
|
points=points, |
|
) |
|
logger.info(f"Loaded {len(index)} points into collection {self.collection_name}.") |
|
|
|
def mget(self, keys: Sequence[str]) -> list[Optional[Record]]: |
|
return self.client.retrieve( |
|
self.collection_name, ids=keys, with_payload=True, with_vectors=True |
|
) |
|
|
|
def mset(self, key_value_pairs: Sequence[tuple[str, Record]]) -> None: |
|
self.client.upsert( |
|
collection_name=self.collection_name, |
|
points=models.Batch( |
|
ids=[key for key, _ in key_value_pairs], |
|
vectors=[value.vector for _, value in key_value_pairs], |
|
payloads=[value.payload for _, value in key_value_pairs], |
|
), |
|
) |
|
|
|
def mdelete(self, keys: Sequence[str]) -> None: |
|
self.client.delete(collection_name=self.collection_name, points_selector=keys) |
|
|
|
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: |
|
for point in self.client.scroll( |
|
collection_name=self.collection_name, |
|
with_vectors=False, |
|
with_payload=False, |
|
limit=len(self), |
|
)[0]: |
|
yield point.id |
|
|