from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Sequence, Union from langchain_core.documents import Document as LCDocument from langchain_core.runnables import run_in_executor from langchain_core.stores import BaseStore from langchain_core.vectorstores import VectorStore from pytorch_ie.annotations import MultiSpan, Span from .serializable_store import SerializableStore from .span_embeddings import SpanEmbeddings class SpanVectorStore(VectorStore, BaseStore, SerializableStore, ABC): """Abstract base class for vector stores specialized in storing and retrieving embeddings for text spans within documents.""" METADATA_SPAN_KEY: str = "pie_labeled_span" """Key for the span data in the (langchain) document metadata.""" SPAN_START_KEY: str = "start" """Key for the start of the span in the span data.""" SPAN_END_KEY: str = "end" """Key for the end of the span in the span data.""" METADATA_VECTOR_KEY: str = "vector" """Key for the vector in the (langchain) document metadata.""" RELEVANCE_SCORE_KEY: str = "relevance_score" """Key for the relevance score in the (langchain) document metadata.""" def __init__( self, label_mapping: Optional[Dict[str, List[str]]] = None, **kwargs: Any, ): """Initialize the SpanVectorStore. Args: label_mapping: Mapping from query span labels to target span labels. If provided, only spans with a label in the mapping for the query span's label are retrieved. **kwargs: Additional arguments. """ self.label_mapping = label_mapping super().__init__(**kwargs) @property def embeddings(self) -> SpanEmbeddings: """Get the dense embeddings instance that is being used. Raises: ValueError: If embeddings are `None`. Returns: Embeddings: An instance of `Embeddings`. """ result = super().embeddings if not isinstance(result, SpanEmbeddings): raise ValueError(f"Embeddings must be of type SpanEmbeddings, but got: {result}") return result @abstractmethod def get_by_ids_with_vectors(self, ids: Sequence[Union[str, int]], /) -> List[LCDocument]: """Get documents by their ids. Args: ids: List of document ids. Returns: List of documents including their vectors in the metadata at key `metadata_vector_key`. """ ... @abstractmethod def construct_filter( self, query_span: Union[Span, MultiSpan], metadata_doc_id_key: str, doc_id_whitelist: Optional[Sequence[str]] = None, doc_id_blacklist: Optional[Sequence[str]] = None, ) -> Any: """Construct a filter for the retrieval. It should enforce that: - if the span is labeled, the retrieved span has the same label, or - if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label - if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist - if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist Args: query_span: The query span. metadata_doc_id_key: The key in the metadata that holds the document id. doc_id_whitelist: A list of document ids to restrict the retrieval to. doc_id_blacklist: A list of document ids to exclude from the retrieval. Returns: A filter object. """ ... @abstractmethod def similarity_search_with_score_by_vector( self, embedding: list[float], k: int = 4, **kwargs: Any ) -> list[LCDocument]: """Return docs most similar to embedding vector. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. **kwargs: Arguments to pass to the search method. Returns: List of Documents most similar to the query vector. """ ... async def asimilarity_search_with_score_by_vector( self, embedding: list[float], k: int = 4, **kwargs: Any ) -> list[LCDocument]: """Async return docs most similar to embedding vector. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. **kwargs: Arguments to pass to the search method. Returns: List of Documents most similar to the query vector. """ # This is a temporary workaround to make the similarity search # asynchronous. The proper solution is to make the similarity search # asynchronous in the vector store implementations. return await run_in_executor( None, self.similarity_search_with_score_by_vector, embedding, k=k, **kwargs )