File size: 5,042 Bytes
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Union

from langchain_core.documents import Document as LCDocument
from langchain_core.runnables import run_in_executor
from langchain_core.stores import BaseStore
from langchain_core.vectorstores import VectorStore
from pytorch_ie.annotations import MultiSpan, Span

from .serializable_store import SerializableStore
from .span_embeddings import SpanEmbeddings


class SpanVectorStore(VectorStore, BaseStore, SerializableStore, ABC):
    """Abstract base class for vector stores specialized in storing
    and retrieving embeddings for text spans within documents."""

    METADATA_SPAN_KEY: str = "pie_labeled_span"
    """Key for the span data in the (langchain) document metadata."""
    SPAN_START_KEY: str = "start"
    """Key for the start of the span in the span data."""
    SPAN_END_KEY: str = "end"
    """Key for the end of the span in the span data."""
    METADATA_VECTOR_KEY: str = "vector"
    """Key for the vector in the (langchain) document metadata."""
    RELEVANCE_SCORE_KEY: str = "relevance_score"
    """Key for the relevance score in the (langchain) document metadata."""

    def __init__(
        self,
        label_mapping: Optional[Dict[str, List[str]]] = None,
        **kwargs: Any,
    ):
        """Initialize the SpanVectorStore.

        Args:
            label_mapping: Mapping from query span labels to target span labels. If provided,
                only spans with a label in the mapping for the query span's label are retrieved.
            **kwargs: Additional arguments.
        """
        self.label_mapping = label_mapping
        super().__init__(**kwargs)

    @property
    def embeddings(self) -> SpanEmbeddings:
        """Get the dense embeddings instance that is being used.

        Raises:
            ValueError: If embeddings are `None`.

        Returns:
            Embeddings: An instance of `Embeddings`.
        """
        result = super().embeddings
        if not isinstance(result, SpanEmbeddings):
            raise ValueError(f"Embeddings must be of type SpanEmbeddings, but got: {result}")
        return result

    @abstractmethod
    def get_by_ids_with_vectors(self, ids: Sequence[Union[str, int]], /) -> List[LCDocument]:
        """Get documents by their ids.

        Args:
            ids: List of document ids.

        Returns:
            List of documents including their vectors in the metadata at key `metadata_vector_key`.
        """
        ...

    @abstractmethod
    def construct_filter(
        self,
        query_span: Union[Span, MultiSpan],
        metadata_doc_id_key: str,
        doc_id_whitelist: Optional[Sequence[str]] = None,
        doc_id_blacklist: Optional[Sequence[str]] = None,
    ) -> Any:
        """Construct a filter for the retrieval. It should enforce that:
        - if the span is labeled, the retrieved span has the same label, or
        - if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label
        - if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist
        - if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist

        Args:
            query_span: The query span.
            metadata_doc_id_key: The key in the metadata that holds the document id.
            doc_id_whitelist: A list of document ids to restrict the retrieval to.
            doc_id_blacklist: A list of document ids to exclude from the retrieval.

        Returns:
            A filter object.
        """
        ...

    @abstractmethod
    def similarity_search_with_score_by_vector(
        self, embedding: list[float], k: int = 4, **kwargs: Any
    ) -> list[LCDocument]:
        """Return docs most similar to embedding vector.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            **kwargs: Arguments to pass to the search method.

        Returns:
            List of Documents most similar to the query vector.
        """
        ...

    async def asimilarity_search_with_score_by_vector(
        self, embedding: list[float], k: int = 4, **kwargs: Any
    ) -> list[LCDocument]:
        """Async return docs most similar to embedding vector.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            **kwargs: Arguments to pass to the search method.

        Returns:
            List of Documents most similar to the query vector.
        """

        # This is a temporary workaround to make the similarity search
        # asynchronous. The proper solution is to make the similarity search
        # asynchronous in the vector store implementations.
        return await run_in_executor(
            None, self.similarity_search_with_score_by_vector, embedding, k=k, **kwargs
        )