ArneBinder's picture
new demo setup with langchain retriever
2cc87ec verified
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_core.documents import Document as LCDocument
from langchain_core.runnables import run_in_executor
from langchain_core.stores import BaseStore
from langchain_core.vectorstores import VectorStore
from pytorch_ie.annotations import MultiSpan, Span
from .serializable_store import SerializableStore
from .span_embeddings import SpanEmbeddings
class SpanVectorStore(VectorStore, BaseStore, SerializableStore, ABC):
"""Abstract base class for vector stores specialized in storing
and retrieving embeddings for text spans within documents."""
METADATA_SPAN_KEY: str = "pie_labeled_span"
"""Key for the span data in the (langchain) document metadata."""
SPAN_START_KEY: str = "start"
"""Key for the start of the span in the span data."""
SPAN_END_KEY: str = "end"
"""Key for the end of the span in the span data."""
METADATA_VECTOR_KEY: str = "vector"
"""Key for the vector in the (langchain) document metadata."""
RELEVANCE_SCORE_KEY: str = "relevance_score"
"""Key for the relevance score in the (langchain) document metadata."""
def __init__(
self,
label_mapping: Optional[Dict[str, List[str]]] = None,
**kwargs: Any,
):
"""Initialize the SpanVectorStore.
Args:
label_mapping: Mapping from query span labels to target span labels. If provided,
only spans with a label in the mapping for the query span's label are retrieved.
**kwargs: Additional arguments.
"""
self.label_mapping = label_mapping
super().__init__(**kwargs)
@property
def embeddings(self) -> SpanEmbeddings:
"""Get the dense embeddings instance that is being used.
Raises:
ValueError: If embeddings are `None`.
Returns:
Embeddings: An instance of `Embeddings`.
"""
result = super().embeddings
if not isinstance(result, SpanEmbeddings):
raise ValueError(f"Embeddings must be of type SpanEmbeddings, but got: {result}")
return result
@abstractmethod
def get_by_ids_with_vectors(self, ids: Sequence[Union[str, int]], /) -> List[LCDocument]:
"""Get documents by their ids.
Args:
ids: List of document ids.
Returns:
List of documents including their vectors in the metadata at key `metadata_vector_key`.
"""
...
@abstractmethod
def construct_filter(
self,
query_span: Union[Span, MultiSpan],
metadata_doc_id_key: str,
doc_id_whitelist: Optional[Sequence[str]] = None,
doc_id_blacklist: Optional[Sequence[str]] = None,
) -> Any:
"""Construct a filter for the retrieval. It should enforce that:
- if the span is labeled, the retrieved span has the same label, or
- if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label
- if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist
- if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist
Args:
query_span: The query span.
metadata_doc_id_key: The key in the metadata that holds the document id.
doc_id_whitelist: A list of document ids to restrict the retrieval to.
doc_id_blacklist: A list of document ids to exclude from the retrieval.
Returns:
A filter object.
"""
...
@abstractmethod
def similarity_search_with_score_by_vector(
self, embedding: list[float], k: int = 4, **kwargs: Any
) -> list[LCDocument]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
**kwargs: Arguments to pass to the search method.
Returns:
List of Documents most similar to the query vector.
"""
...
async def asimilarity_search_with_score_by_vector(
self, embedding: list[float], k: int = 4, **kwargs: Any
) -> list[LCDocument]:
"""Async return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
**kwargs: Arguments to pass to the search method.
Returns:
List of Documents most similar to the query vector.
"""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
return await run_in_executor(
None, self.similarity_search_with_score_by_vector, embedding, k=k, **kwargs
)