File size: 5,042 Bytes
2cc87ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_core.documents import Document as LCDocument
from langchain_core.runnables import run_in_executor
from langchain_core.stores import BaseStore
from langchain_core.vectorstores import VectorStore
from pytorch_ie.annotations import MultiSpan, Span
from .serializable_store import SerializableStore
from .span_embeddings import SpanEmbeddings
class SpanVectorStore(VectorStore, BaseStore, SerializableStore, ABC):
"""Abstract base class for vector stores specialized in storing
and retrieving embeddings for text spans within documents."""
METADATA_SPAN_KEY: str = "pie_labeled_span"
"""Key for the span data in the (langchain) document metadata."""
SPAN_START_KEY: str = "start"
"""Key for the start of the span in the span data."""
SPAN_END_KEY: str = "end"
"""Key for the end of the span in the span data."""
METADATA_VECTOR_KEY: str = "vector"
"""Key for the vector in the (langchain) document metadata."""
RELEVANCE_SCORE_KEY: str = "relevance_score"
"""Key for the relevance score in the (langchain) document metadata."""
def __init__(
self,
label_mapping: Optional[Dict[str, List[str]]] = None,
**kwargs: Any,
):
"""Initialize the SpanVectorStore.
Args:
label_mapping: Mapping from query span labels to target span labels. If provided,
only spans with a label in the mapping for the query span's label are retrieved.
**kwargs: Additional arguments.
"""
self.label_mapping = label_mapping
super().__init__(**kwargs)
@property
def embeddings(self) -> SpanEmbeddings:
"""Get the dense embeddings instance that is being used.
Raises:
ValueError: If embeddings are `None`.
Returns:
Embeddings: An instance of `Embeddings`.
"""
result = super().embeddings
if not isinstance(result, SpanEmbeddings):
raise ValueError(f"Embeddings must be of type SpanEmbeddings, but got: {result}")
return result
@abstractmethod
def get_by_ids_with_vectors(self, ids: Sequence[Union[str, int]], /) -> List[LCDocument]:
"""Get documents by their ids.
Args:
ids: List of document ids.
Returns:
List of documents including their vectors in the metadata at key `metadata_vector_key`.
"""
...
@abstractmethod
def construct_filter(
self,
query_span: Union[Span, MultiSpan],
metadata_doc_id_key: str,
doc_id_whitelist: Optional[Sequence[str]] = None,
doc_id_blacklist: Optional[Sequence[str]] = None,
) -> Any:
"""Construct a filter for the retrieval. It should enforce that:
- if the span is labeled, the retrieved span has the same label, or
- if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label
- if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist
- if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist
Args:
query_span: The query span.
metadata_doc_id_key: The key in the metadata that holds the document id.
doc_id_whitelist: A list of document ids to restrict the retrieval to.
doc_id_blacklist: A list of document ids to exclude from the retrieval.
Returns:
A filter object.
"""
...
@abstractmethod
def similarity_search_with_score_by_vector(
self, embedding: list[float], k: int = 4, **kwargs: Any
) -> list[LCDocument]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
**kwargs: Arguments to pass to the search method.
Returns:
List of Documents most similar to the query vector.
"""
...
async def asimilarity_search_with_score_by_vector(
self, embedding: list[float], k: int = 4, **kwargs: Any
) -> list[LCDocument]:
"""Async return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
**kwargs: Arguments to pass to the search method.
Returns:
List of Documents most similar to the query vector.
"""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
return await run_in_executor(
None, self.similarity_search_with_score_by_vector, embedding, k=k, **kwargs
)
|