ArneBinder's picture
new demo setup with langchain retriever
2cc87ec verified
from __future__ import annotations
import json
import logging
import os
import uuid
from collections import defaultdict
from itertools import islice
from typing import ( # type: ignore[import-not-found]
Any,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)
import numpy as np
from langchain_core.documents import Document as LCDocument
from langchain_qdrant import QdrantVectorStore, RetrievalMode
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Record
from .span_embeddings import SpanEmbeddings
from .span_vectorstore import SpanVectorStore
logger = logging.getLogger(__name__)
class QdrantSpanVectorStore(SpanVectorStore, QdrantVectorStore):
"""An implementation of the SpanVectorStore interface that uses Qdrant
as backend for storing and retrieving span embeddings."""
EMBEDDINGS_FILE = "embeddings.npy"
PAYLOADS_FILE = "payloads.json"
INDEX_FILE = "index.json"
def __init__(
self,
client: QdrantClient,
collection_name: str,
embedding: SpanEmbeddings,
vector_params: Optional[Dict[str, Any]] = None,
**kwargs,
):
if not client.collection_exists(collection_name):
logger.info(f'Collection "{collection_name}" does not exist. Creating it now.')
client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(size=embedding.embedding_dim, **vector_params),
)
else:
logger.info(f'Collection "{collection_name}" already exists.')
super().__init__(
client=client, collection_name=collection_name, embedding=embedding, **kwargs
)
def __len__(self):
return self.client.get_collection(collection_name=self.collection_name).points_count
def get_by_ids_with_vectors(self, ids: Sequence[str | int], /) -> List[LCDocument]:
results = self.client.retrieve(
self.collection_name, ids, with_payload=True, with_vectors=True
)
return [
self._document_from_point(
result,
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
)
for result in results
]
def construct_filter(
self,
query_span: Union[Span, MultiSpan],
metadata_doc_id_key: str,
doc_id_whitelist: Optional[Sequence[str]] = None,
doc_id_blacklist: Optional[Sequence[str]] = None,
) -> Optional[models.Filter]:
"""Construct a filter for the retrieval. It should enforce that:
- if the span is labeled, the retrieved span has the same label, or
- if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label
- if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist
- if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist
Args:
query_span: The query span.
metadata_doc_id_key: The key in the metadata that holds the document id.
doc_id_whitelist: A list of document ids to restrict the retrieval to.
doc_id_blacklist: A list of document ids to exclude from the retrieval.
Returns:
A filter object.
"""
filter_kwargs = defaultdict(list)
# if the span is labeled, enforce that the retrieved span has the same label
if isinstance(query_span, (LabeledSpan, LabeledMultiSpan)):
if self.label_mapping is not None:
target_labels = self.label_mapping.get(query_span.label, [])
else:
target_labels = [query_span.label]
filter_kwargs["must"].append(
models.FieldCondition(
key=f"metadata.{self.METADATA_SPAN_KEY}.label",
match=models.MatchAny(any=target_labels),
)
)
elif self.label_mapping is not None:
raise TypeError("Label mapping is only supported for labeled spans")
if doc_id_blacklist is not None and doc_id_whitelist is not None:
overlap = set(doc_id_whitelist) & set(doc_id_blacklist)
if len(overlap) > 0:
raise ValueError(
f"Overlap between doc_id_whitelist and doc_id_blacklist: {overlap}"
)
if doc_id_whitelist is not None:
filter_kwargs["must"].append(
models.FieldCondition(
key=f"metadata.{metadata_doc_id_key}",
match=(
models.MatchValue(value=doc_id_whitelist[0])
if len(doc_id_whitelist) == 1
else models.MatchAny(any=doc_id_whitelist)
),
)
)
if doc_id_blacklist is not None:
filter_kwargs["must_not"].append(
models.FieldCondition(
key=f"metadata.{metadata_doc_id_key}",
match=(
models.MatchValue(value=doc_id_blacklist[0])
if len(doc_id_blacklist) == 1
else models.MatchAny(any=doc_id_blacklist)
),
)
)
if len(filter_kwargs) > 0:
return models.Filter(**filter_kwargs)
return None
@classmethod
def _document_from_point(
cls,
scored_point: Any,
collection_name: str,
content_payload_key: str,
metadata_payload_key: str,
) -> LCDocument:
metadata = scored_point.payload.get(metadata_payload_key) or {}
metadata["_collection_name"] = collection_name
if hasattr(scored_point, "score"):
metadata[cls.RELEVANCE_SCORE_KEY] = scored_point.score
if hasattr(scored_point, "vector"):
metadata[cls.METADATA_VECTOR_KEY] = scored_point.vector
return LCDocument(
id=scored_point.id,
page_content=scored_point.payload.get(content_payload_key, ""),
metadata=metadata,
)
def _build_vectors_with_metadata(
self,
texts: Iterable[str],
metadatas: List[dict],
) -> List[models.VectorStruct]:
starts = [metadata[self.METADATA_SPAN_KEY][self.SPAN_START_KEY] for metadata in metadatas]
ends = [metadata[self.METADATA_SPAN_KEY][self.SPAN_END_KEY] for metadata in metadatas]
if self.retrieval_mode == RetrievalMode.DENSE:
batch_embeddings = self.embeddings.embed_document_spans(list(texts), starts, ends)
return [
{
self.vector_name: vector,
}
for vector in batch_embeddings
]
elif self.retrieval_mode == RetrievalMode.SPARSE:
raise ValueError("Sparse retrieval mode is not yet implemented.")
elif self.retrieval_mode == RetrievalMode.HYBRID:
raise NotImplementedError("Hybrid retrieval mode is not yet implemented.")
else:
raise ValueError(f"Unknown retrieval mode. {self.retrieval_mode} to build vectors.")
def _build_payloads_from_metadata(
self,
metadatas: Iterable[dict],
metadata_payload_key: str,
) -> List[dict]:
payloads = [{metadata_payload_key: metadata} for metadata in metadatas]
return payloads
def _generate_batches(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str | int]] = None,
batch_size: int = 64,
) -> Generator[tuple[list[str | int], list[models.PointStruct]], Any, None]:
"""Generate batches of points to index. Same as in `QdrantVectorStore` but metadata is used
to build vectors and payloads."""
texts_iterator = iter(texts)
if metadatas is None:
raise ValueError("Metadata must be provided to generate batches.")
metadatas_iterator = iter(metadatas)
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
batch_metadatas = list(islice(metadatas_iterator, batch_size))
batch_ids = list(islice(ids_iterator, batch_size))
points = [
models.PointStruct(
id=point_id,
vector=vector,
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
self._build_vectors_with_metadata(batch_texts, metadatas=batch_metadatas),
# we do not save the text in the payload because the text is the full
# document which is usually already saved in the docstore
self._build_payloads_from_metadata(
metadatas=batch_metadatas,
metadata_payload_key=self.metadata_payload_key,
),
)
if vector[self.vector_name] is not None
]
yield [point.id for point in points], points
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[models.Filter] = None,
search_params: Optional[models.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None,
**kwargs: Any,
) -> List[Tuple[LCDocument, float]]:
"""Return docs most similar to query vector.
Returns:
List of documents most similar to the query text and distance for each.
"""
query_options = {
"collection_name": self.collection_name,
"query_filter": filter,
"search_params": search_params,
"limit": k,
"offset": offset,
"with_payload": True,
"with_vectors": False,
"score_threshold": score_threshold,
"consistency": consistency,
**kwargs,
}
results = self.client.query_points(
query=embedding,
using=self.vector_name,
**query_options,
).points
return [
(
self._document_from_point(
result,
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
),
result.score,
)
for result in results
]
def _as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[Any]]:
data, _ = self.client.scroll(
collection_name=self.collection_name, with_vectors=True, limit=len(self)
)
vectors_np = np.array([point.vector for point in data])
payloads = [point.payload for point in data]
emb_ids = [point.id for point in data]
return emb_ids, vectors_np, payloads
# TODO: or use create_snapshot and restore_snapshot?
def _save_to_directory(self, path: str, **kwargs) -> None:
indices, vectors, payloads = self._as_indices_vectors_payloads()
np.save(os.path.join(path, self.EMBEDDINGS_FILE), vectors)
with open(os.path.join(path, self.PAYLOADS_FILE), "w") as f:
json.dump(payloads, f, indent=2)
with open(os.path.join(path, self.INDEX_FILE), "w") as f:
json.dump(indices, f)
def _load_from_directory(self, path: str, **kwargs) -> None:
with open(os.path.join(path, self.INDEX_FILE), "r") as f:
index = json.load(f)
embeddings_np: np.ndarray = np.load(os.path.join(path, self.EMBEDDINGS_FILE))
with open(os.path.join(path, self.PAYLOADS_FILE), "r") as f:
payloads = json.load(f)
points = models.Batch(ids=index, vectors=embeddings_np.tolist(), payloads=payloads)
self.client.upsert(
collection_name=self.collection_name,
points=points,
)
logger.info(f"Loaded {len(index)} points into collection {self.collection_name}.")
def mget(self, keys: Sequence[str]) -> list[Optional[Record]]:
return self.client.retrieve(
self.collection_name, ids=keys, with_payload=True, with_vectors=True
)
def mset(self, key_value_pairs: Sequence[tuple[str, Record]]) -> None:
self.client.upsert(
collection_name=self.collection_name,
points=models.Batch(
ids=[key for key, _ in key_value_pairs],
vectors=[value.vector for _, value in key_value_pairs],
payloads=[value.payload for _, value in key_value_pairs],
),
)
def mdelete(self, keys: Sequence[str]) -> None:
self.client.delete(collection_name=self.collection_name, points_selector=keys)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
for point in self.client.scroll(
collection_name=self.collection_name,
with_vectors=False,
with_payload=False,
limit=len(self),
)[0]:
yield point.id