|
from abc import ABC, abstractmethod |
|
from typing import Any, Dict, List, Optional, Sequence, Union |
|
|
|
from langchain_core.documents import Document as LCDocument |
|
from langchain_core.runnables import run_in_executor |
|
from langchain_core.stores import BaseStore |
|
from langchain_core.vectorstores import VectorStore |
|
from pytorch_ie.annotations import MultiSpan, Span |
|
|
|
from .serializable_store import SerializableStore |
|
from .span_embeddings import SpanEmbeddings |
|
|
|
|
|
class SpanVectorStore(VectorStore, BaseStore, SerializableStore, ABC): |
|
"""Abstract base class for vector stores specialized in storing |
|
and retrieving embeddings for text spans within documents.""" |
|
|
|
METADATA_SPAN_KEY: str = "pie_labeled_span" |
|
"""Key for the span data in the (langchain) document metadata.""" |
|
SPAN_START_KEY: str = "start" |
|
"""Key for the start of the span in the span data.""" |
|
SPAN_END_KEY: str = "end" |
|
"""Key for the end of the span in the span data.""" |
|
METADATA_VECTOR_KEY: str = "vector" |
|
"""Key for the vector in the (langchain) document metadata.""" |
|
RELEVANCE_SCORE_KEY: str = "relevance_score" |
|
"""Key for the relevance score in the (langchain) document metadata.""" |
|
|
|
def __init__( |
|
self, |
|
label_mapping: Optional[Dict[str, List[str]]] = None, |
|
**kwargs: Any, |
|
): |
|
"""Initialize the SpanVectorStore. |
|
|
|
Args: |
|
label_mapping: Mapping from query span labels to target span labels. If provided, |
|
only spans with a label in the mapping for the query span's label are retrieved. |
|
**kwargs: Additional arguments. |
|
""" |
|
self.label_mapping = label_mapping |
|
super().__init__(**kwargs) |
|
|
|
@property |
|
def embeddings(self) -> SpanEmbeddings: |
|
"""Get the dense embeddings instance that is being used. |
|
|
|
Raises: |
|
ValueError: If embeddings are `None`. |
|
|
|
Returns: |
|
Embeddings: An instance of `Embeddings`. |
|
""" |
|
result = super().embeddings |
|
if not isinstance(result, SpanEmbeddings): |
|
raise ValueError(f"Embeddings must be of type SpanEmbeddings, but got: {result}") |
|
return result |
|
|
|
@abstractmethod |
|
def get_by_ids_with_vectors(self, ids: Sequence[Union[str, int]], /) -> List[LCDocument]: |
|
"""Get documents by their ids. |
|
|
|
Args: |
|
ids: List of document ids. |
|
|
|
Returns: |
|
List of documents including their vectors in the metadata at key `metadata_vector_key`. |
|
""" |
|
... |
|
|
|
@abstractmethod |
|
def construct_filter( |
|
self, |
|
query_span: Union[Span, MultiSpan], |
|
metadata_doc_id_key: str, |
|
doc_id_whitelist: Optional[Sequence[str]] = None, |
|
doc_id_blacklist: Optional[Sequence[str]] = None, |
|
) -> Any: |
|
"""Construct a filter for the retrieval. It should enforce that: |
|
- if the span is labeled, the retrieved span has the same label, or |
|
- if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label |
|
- if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist |
|
- if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist |
|
|
|
Args: |
|
query_span: The query span. |
|
metadata_doc_id_key: The key in the metadata that holds the document id. |
|
doc_id_whitelist: A list of document ids to restrict the retrieval to. |
|
doc_id_blacklist: A list of document ids to exclude from the retrieval. |
|
|
|
Returns: |
|
A filter object. |
|
""" |
|
... |
|
|
|
@abstractmethod |
|
def similarity_search_with_score_by_vector( |
|
self, embedding: list[float], k: int = 4, **kwargs: Any |
|
) -> list[LCDocument]: |
|
"""Return docs most similar to embedding vector. |
|
|
|
Args: |
|
embedding: Embedding to look up documents similar to. |
|
k: Number of Documents to return. Defaults to 4. |
|
**kwargs: Arguments to pass to the search method. |
|
|
|
Returns: |
|
List of Documents most similar to the query vector. |
|
""" |
|
... |
|
|
|
async def asimilarity_search_with_score_by_vector( |
|
self, embedding: list[float], k: int = 4, **kwargs: Any |
|
) -> list[LCDocument]: |
|
"""Async return docs most similar to embedding vector. |
|
|
|
Args: |
|
embedding: Embedding to look up documents similar to. |
|
k: Number of Documents to return. Defaults to 4. |
|
**kwargs: Arguments to pass to the search method. |
|
|
|
Returns: |
|
List of Documents most similar to the query vector. |
|
""" |
|
|
|
|
|
|
|
|
|
return await run_in_executor( |
|
None, self.similarity_search_with_score_by_vector, embedding, k=k, **kwargs |
|
) |
|
|