|
import logging |
|
import os |
|
import uuid |
|
from collections import defaultdict |
|
from copy import copy |
|
from enum import Enum |
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union |
|
|
|
from langchain_core.callbacks import ( |
|
AsyncCallbackManagerForRetrieverRun, |
|
CallbackManagerForRetrieverRun, |
|
) |
|
from langchain_core.documents import BaseDocumentCompressor |
|
from langchain_core.documents import Document as LCDocument |
|
from langchain_core.retrievers import BaseRetriever |
|
from pydantic import Field |
|
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span |
|
from pytorch_ie.core.document import BaseAnnotationList |
|
from pytorch_ie.documents import ( |
|
TextBasedDocument, |
|
TextDocumentWithLabeledMultiSpans, |
|
TextDocumentWithLabeledSpans, |
|
TextDocumentWithSpans, |
|
) |
|
|
|
from ..utils import parse_config |
|
from .pie_document_store import PieDocumentStore |
|
from .serializable_store import SerializableStore |
|
from .span_vectorstore import SpanVectorStore |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
METADATA_KEY_CHILD_ID2IDX = "child_id2idx" |
|
|
|
|
|
class SpanNotFoundError(ValueError): |
|
def __init__(self, span_id: str, doc_id: Optional[str] = None, msg: Optional[str] = None): |
|
if msg is None: |
|
if doc_id is not None: |
|
msg = f"Span with id [{span_id}] not found in document [{doc_id}]" |
|
else: |
|
msg = f"Span with id [{span_id}] not found in the vectorstore" |
|
super().__init__(msg) |
|
self.span_id = span_id |
|
self.doc_id = doc_id |
|
|
|
|
|
class DocumentNotFoundError(ValueError): |
|
def __init__(self, doc_id: str, msg: Optional[str] = None): |
|
msg = msg or f"Document with id [{doc_id}] not found in the docstore" |
|
super().__init__(msg) |
|
self.doc_id = doc_id |
|
|
|
|
|
class SearchType(str, Enum): |
|
"""Enumerator of the types of search to perform.""" |
|
|
|
similarity = "similarity" |
|
"""Similarity search.""" |
|
similarity_score_threshold = "similarity_score_threshold" |
|
"""Similarity search with a score threshold.""" |
|
mmr = "mmr" |
|
"""Maximal Marginal Relevance reranking of similarity search.""" |
|
|
|
|
|
class DocumentAwareSpanRetriever(BaseRetriever, SerializableStore): |
|
"""Retriever for contextualized text spans, i.e. spans within text documents. |
|
It accepts spans as queries and retrieves spans with their containing document. |
|
Note that the query span (and its document) must already be in the retriever's |
|
store.""" |
|
|
|
pie_document_type: Type[TextBasedDocument] |
|
"""The name of the span annotation layer in the pie document.""" |
|
use_predicted_annotations_key: str = "use_predicted_annotations" |
|
"""Whether to use the predicted annotations or the gold annotations.""" |
|
retrieve_from_same_document: bool = False |
|
"""Whether to retrieve spans exclusively from the same document as the query span.""" |
|
retrieve_from_different_documents: bool = False |
|
"""Whether to retrieve spans exclusively from different documents than the query span.""" |
|
|
|
|
|
vectorstore: SpanVectorStore |
|
"""The underlying vectorstore to use to store small chunks |
|
and their embedding vectors""" |
|
docstore: PieDocumentStore |
|
"""The storage interface for the parent documents""" |
|
id_key: str = "doc_id" |
|
"""The key to use to track the parent id. This will be stored in the |
|
metadata of child documents.""" |
|
search_kwargs: dict = Field(default_factory=dict) |
|
"""Keyword arguments to pass to the search function.""" |
|
search_type: SearchType = SearchType.similarity |
|
"""Type of search to perform (similarity / mmr)""" |
|
|
|
|
|
child_metadata_fields: Optional[Sequence[str]] = None |
|
"""Metadata fields to leave in child documents. If None, leave all parent document |
|
metadata. |
|
""" |
|
|
|
|
|
compressor: Optional[BaseDocumentCompressor] = None |
|
"""Compressor for compressing retrieved documents.""" |
|
compressor_context_size: int = 50 |
|
"""Size of the context to use around the query and retrieved spans when compressing.""" |
|
compressor_query_context_size: Optional[int] = 10 |
|
"""Size of the context to use around the query when compressing. If None, will use the |
|
same value as `compressor_context_size`.""" |
|
|
|
@classmethod |
|
def instantiate_from_config( |
|
cls, config: Dict[str, Any], overwrites: Optional[Dict[str, Any]] = None |
|
) -> "DocumentAwareSpanRetriever": |
|
"""Instantiate a retriever from a configuration dictionary.""" |
|
from hydra.utils import instantiate |
|
|
|
return instantiate(config, **(overwrites or {})) |
|
|
|
@classmethod |
|
def instantiate_from_config_string( |
|
cls, config_string: str, format: str, overwrites: Optional[Dict[str, Any]] = None |
|
) -> "DocumentAwareSpanRetriever": |
|
"""Instantiate a retriever from a configuration string.""" |
|
return cls.instantiate_from_config( |
|
parse_config(config_string, format=format), overwrites=overwrites |
|
) |
|
|
|
@classmethod |
|
def instantiate_from_config_file( |
|
cls, config_path: str, overwrites: Optional[Dict[str, Any]] = None |
|
) -> "DocumentAwareSpanRetriever": |
|
"""Instantiate a retriever from a configuration file.""" |
|
with open(config_path, "r") as file: |
|
config_string = file.read() |
|
if config_path.endswith(".json"): |
|
return cls.instantiate_from_config_string( |
|
config_string, format="json", overwrites=overwrites |
|
) |
|
elif config_path.endswith(".yaml"): |
|
return cls.instantiate_from_config_string( |
|
config_string, format="yaml", overwrites=overwrites |
|
) |
|
else: |
|
raise ValueError(f"Unsupported file extension: {config_path}") |
|
|
|
@property |
|
def pie_annotation_layer_name(self) -> str: |
|
if issubclass(self.pie_document_type, TextDocumentWithSpans): |
|
return "spans" |
|
elif issubclass(self.pie_document_type, TextDocumentWithLabeledSpans): |
|
return "labeled_spans" |
|
elif issubclass(self.pie_document_type, TextDocumentWithLabeledMultiSpans): |
|
return "labeled_multi_spans" |
|
else: |
|
raise ValueError( |
|
f"Unsupported pie document type: {self.pie_document_type}. " |
|
"Must be one of TextDocumentWithSpans, TextDocumentWithLabeledSpans, " |
|
"or TextDocumentWithLabeledMultiSpans." |
|
) |
|
|
|
def _span_to_dict(self, span: Union[Span, MultiSpan]) -> dict: |
|
span_dict = {} |
|
if isinstance(span, Span): |
|
span_dict[self.vectorstore.SPAN_START_KEY] = span.start |
|
span_dict[self.vectorstore.SPAN_END_KEY] = span.end |
|
span_dict["type"] = "Span" |
|
elif isinstance(span, MultiSpan): |
|
starts, ends = zip(*span.slices) |
|
span_dict[self.vectorstore.SPAN_START_KEY] = starts |
|
span_dict[self.vectorstore.SPAN_END_KEY] = ends |
|
span_dict["type"] = "MultiSpan" |
|
else: |
|
raise ValueError(f"Unsupported span type: {type(span)}") |
|
if isinstance(span, (LabeledSpan, LabeledMultiSpan)): |
|
span_dict["label"] = span.label |
|
span_dict["score"] = span.score |
|
return span_dict |
|
|
|
def _dict_to_span(self, span_dict: dict) -> Union[Span, MultiSpan]: |
|
|
|
if span_dict["type"] == "Span": |
|
kwargs = dict( |
|
start=span_dict[self.vectorstore.SPAN_START_KEY], |
|
end=span_dict[self.vectorstore.SPAN_END_KEY], |
|
) |
|
if "label" in span_dict: |
|
kwargs["label"] = span_dict["label"] |
|
kwargs["score"] = span_dict["score"] |
|
return LabeledSpan(**kwargs) |
|
else: |
|
return Span(**kwargs) |
|
elif span_dict["type"] == "MultiSpan": |
|
starts = span_dict[self.vectorstore.SPAN_START_KEY] |
|
ends = span_dict[self.vectorstore.SPAN_END_KEY] |
|
slices = tuple((start, end) for start, end in zip(starts, ends)) |
|
kwargs = dict(slices=slices) |
|
if "label" in span_dict: |
|
kwargs["label"] = span_dict["label"] |
|
kwargs["score"] = span_dict["score"] |
|
return LabeledMultiSpan(**kwargs) |
|
else: |
|
return MultiSpan(**kwargs) |
|
else: |
|
raise ValueError(f"Unsupported span type: {span_dict['type']}") |
|
|
|
def use_predicted_annotations(self, doc: LCDocument) -> bool: |
|
"""Check if the document uses predicted spans.""" |
|
return doc.metadata.get(self.use_predicted_annotations_key, True) |
|
|
|
def get_document(self, doc_id: str) -> LCDocument: |
|
"""Get a document by its id.""" |
|
documents = self.docstore.mget([doc_id]) |
|
if len(documents) == 0 or documents[0] is None: |
|
raise DocumentNotFoundError(doc_id=doc_id) |
|
if len(documents) > 1: |
|
raise ValueError(f"Multiple documents found with id: {doc_id}") |
|
return documents[0] |
|
|
|
def get_span_document(self, span_id: str, with_vector: bool = False) -> LCDocument: |
|
"""Get a span document by its id.""" |
|
if with_vector: |
|
span_docs = self.vectorstore.get_by_ids_with_vectors([span_id]) |
|
else: |
|
span_docs = self.vectorstore.get_by_ids([span_id]) |
|
if len(span_docs) == 0 or span_docs[0] is None: |
|
raise SpanNotFoundError(span_id=span_id) |
|
if len(span_docs) > 1: |
|
raise ValueError(f"Multiple span documents found with id: {span_id}") |
|
return span_docs[0] |
|
|
|
def get_base_layer( |
|
self, pie_document: TextBasedDocument, use_predicted_annotations: bool |
|
) -> BaseAnnotationList: |
|
"""Get the base layer of the pie document.""" |
|
|
|
if self.pie_annotation_layer_name not in pie_document: |
|
raise ValueError( |
|
f'The pie document must contain the annotation layer "{self.pie_annotation_layer_name}"' |
|
) |
|
layer = pie_document[self.pie_annotation_layer_name] |
|
return layer.predictions if use_predicted_annotations else layer |
|
|
|
def get_span_by_id(self, span_id: str) -> Union[Span, MultiSpan]: |
|
"""Get a span annotation by its id.""" |
|
span_doc = self.get_span_document(span_id) |
|
doc_id = span_doc.metadata[self.id_key] |
|
doc = self.get_document(doc_id) |
|
return self.get_span_from_doc_by_id(doc=doc, span_id=span_id) |
|
|
|
def get_span_from_doc_by_id(self, doc: LCDocument, span_id: str) -> Union[Span, MultiSpan]: |
|
"""Get the span of a query.""" |
|
base_layer = self.get_base_layer( |
|
self.docstore.unwrap(doc), |
|
use_predicted_annotations=self.use_predicted_annotations(doc), |
|
) |
|
span_idx = doc.metadata[METADATA_KEY_CHILD_ID2IDX].get(span_id) |
|
if span_idx is None: |
|
raise SpanNotFoundError(span_id=span_id, doc_id=doc.id) |
|
return base_layer[span_idx] |
|
|
|
def get_span_id2idx_from_doc(self, doc: Union[LCDocument, str]) -> Dict[str, int]: |
|
"""Get all span ids from a document. |
|
|
|
Args: |
|
doc: Document or document id |
|
|
|
Returns: |
|
Dictionary mapping span ids to their index in the base layer. |
|
""" |
|
|
|
if isinstance(doc, str): |
|
doc = self.get_document(doc) |
|
return doc.metadata[METADATA_KEY_CHILD_ID2IDX] |
|
|
|
def prepare_search_kwargs( |
|
self, |
|
span_id: str, |
|
doc_id_whitelist: Optional[List[str]] = None, |
|
doc_id_blacklist: Optional[List[str]] = None, |
|
kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[dict, LCDocument]: |
|
|
|
query_span_doc = self.get_span_document(span_id, with_vector=True) |
|
query_doc_id = query_span_doc.metadata[self.id_key] |
|
query_doc = self.get_document(query_doc_id) |
|
|
|
|
|
|
|
query_span_doc.metadata[self.docstore.METADATA_KEY_PIE_DOCUMENT] = self.docstore.unwrap( |
|
query_doc |
|
) |
|
|
|
search_kwargs = copy(self.search_kwargs) |
|
search_kwargs.update(kwargs or {}) |
|
|
|
query_span = self.get_span_from_doc_by_id(doc=query_doc, span_id=span_id) |
|
|
|
if self.retrieve_from_different_documents and self.retrieve_from_same_document: |
|
raise ValueError("Cannot retrieve from both same and different documents") |
|
|
|
if self.retrieve_from_same_document: |
|
if doc_id_whitelist is None: |
|
doc_id_whitelist = [query_doc_id] |
|
elif query_doc_id not in doc_id_whitelist: |
|
doc_id_whitelist.append(query_doc_id) |
|
|
|
if self.retrieve_from_different_documents: |
|
if doc_id_blacklist is None: |
|
doc_id_blacklist = [query_doc_id] |
|
elif query_doc_id not in doc_id_blacklist: |
|
doc_id_blacklist.append(query_doc_id) |
|
|
|
query_filter = self.vectorstore.construct_filter( |
|
query_span=query_span, |
|
metadata_doc_id_key=self.id_key, |
|
doc_id_whitelist=doc_id_whitelist, |
|
doc_id_blacklist=doc_id_blacklist, |
|
) |
|
if query_filter is not None: |
|
search_kwargs["filter"] = query_filter |
|
|
|
|
|
search_kwargs["embedding"] = query_span_doc.metadata[self.vectorstore.METADATA_VECTOR_KEY] |
|
return search_kwargs, query_span_doc |
|
|
|
def _prepare_query_for_compression(self, query_doc: LCDocument) -> str: |
|
return self._prepare_doc_for_compression( |
|
query_doc, context_size=self.compressor_query_context_size |
|
).page_content |
|
|
|
def _prepare_doc_for_compression( |
|
self, doc: LCDocument, context_size: Optional[int] = None |
|
) -> LCDocument: |
|
if context_size is None: |
|
context_size = self.compressor_context_size |
|
pie_doc: TextBasedDocument = self.docstore.unwrap(doc) |
|
text = pie_doc.text |
|
span_dict = doc.metadata[self.vectorstore.METADATA_SPAN_KEY] |
|
span_start = span_dict[self.vectorstore.SPAN_START_KEY] |
|
span_end = span_dict[self.vectorstore.SPAN_END_KEY] |
|
if isinstance(span_start, list): |
|
span_start = span_start[0] |
|
if isinstance(span_end, list): |
|
span_end = span_end[0] |
|
context_start = span_start - context_size |
|
context_end = span_end + context_size |
|
doc.page_content = text[max(0, context_start) : min(context_end, len(text))] |
|
|
|
|
|
if "relevance_score" in doc.metadata: |
|
doc.metadata["relevance_score_without_reranking"] = doc.metadata.pop("relevance_score") |
|
return doc |
|
|
|
def _get_relevant_documents( |
|
self, |
|
query: str, |
|
*, |
|
run_manager: CallbackManagerForRetrieverRun, |
|
doc_id_whitelist: Optional[List[str]] = None, |
|
doc_id_blacklist: Optional[List[str]] = None, |
|
**kwargs: Any, |
|
) -> List[LCDocument]: |
|
"""Get span documents relevant to a query span |
|
Args: |
|
query: The span id to find relevant spans for |
|
run_manager: The callbacks handler to use |
|
Returns: |
|
List of relevant span documents with metadata from the parent document |
|
""" |
|
|
|
search_kwargs, query_span_doc = self.prepare_search_kwargs( |
|
span_id=query, |
|
kwargs=kwargs, |
|
doc_id_whitelist=doc_id_whitelist, |
|
doc_id_blacklist=doc_id_blacklist, |
|
) |
|
if self.search_type == SearchType.mmr: |
|
span_docs = self.vectorstore.max_marginal_relevance_search_by_vector(**search_kwargs) |
|
elif self.search_type == SearchType.similarity_score_threshold: |
|
sub_docs_and_similarities = self.vectorstore.similarity_search_with_score_by_vector( |
|
**search_kwargs |
|
) |
|
span_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] |
|
else: |
|
span_docs = self.vectorstore.similarity_search_by_vector(**search_kwargs) |
|
|
|
|
|
doc_ids = [] |
|
for span_doc in span_docs: |
|
if self.id_key not in span_doc.metadata: |
|
raise ValueError(f"Metadata must contain the key {self.id_key}") |
|
if span_doc.metadata[self.id_key] not in doc_ids: |
|
doc_ids.append(span_doc.metadata[self.id_key]) |
|
docs = self.docstore.mget(doc_ids) |
|
doc_id2doc = dict(zip(doc_ids, docs)) |
|
for span_doc in span_docs: |
|
doc = doc_id2doc[span_doc.metadata[self.id_key]] |
|
span_doc.metadata.update(doc.metadata) |
|
span_doc.metadata["attached_span"] = self.get_span_from_doc_by_id( |
|
doc=doc, span_id=span_doc.id |
|
) |
|
span_doc.metadata["query_span_id"] = query |
|
|
|
span_docs_filtered = [ |
|
span_doc for span_doc in span_docs if span_doc.id != query_span_doc.id |
|
] |
|
if self.compressor is None: |
|
return span_docs_filtered |
|
if span_docs_filtered: |
|
prepared_docs = [ |
|
self._prepare_doc_for_compression(sub_doc) for sub_doc in span_docs_filtered |
|
] |
|
prepared_query = self._prepare_query_for_compression(query_span_doc) |
|
compressed_docs = self.compressor.compress_documents( |
|
documents=prepared_docs, query=prepared_query, callbacks=run_manager.get_child() |
|
) |
|
return list(compressed_docs) |
|
else: |
|
return [] |
|
|
|
async def _aget_relevant_documents( |
|
self, |
|
query: str, |
|
*, |
|
run_manager: AsyncCallbackManagerForRetrieverRun, |
|
doc_id_whitelist: Optional[List[str]] = None, |
|
doc_id_blacklist: Optional[List[str]] = None, |
|
**kwargs: Any, |
|
) -> List[LCDocument]: |
|
"""Asynchronously get span documents relevant to a query span |
|
Args: |
|
query: The span id to find relevant spans for |
|
run_manager: The callbacks handler to use |
|
Returns: |
|
List of relevant span documents with metadata from the parent document |
|
""" |
|
search_kwargs, query_span_doc = self.prepare_search_kwargs( |
|
span_id=query, |
|
kwargs=kwargs, |
|
doc_id_whitelist=doc_id_whitelist, |
|
doc_id_blacklist=doc_id_blacklist, |
|
) |
|
if self.search_type == SearchType.mmr: |
|
span_docs = await self.vectorstore.amax_marginal_relevance_search_by_vector( |
|
**search_kwargs |
|
) |
|
elif self.search_type == SearchType.similarity_score_threshold: |
|
sub_docs_and_similarities = ( |
|
await self.vectorstore.asimilarity_search_with_score_by_vector(**search_kwargs) |
|
) |
|
span_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] |
|
else: |
|
span_docs = await self.vectorstore.asimilarity_search_by_vector(**search_kwargs) |
|
|
|
|
|
ids = [] |
|
for span_doc in span_docs: |
|
if self.id_key not in span_doc.metadata: |
|
raise ValueError(f"Metadata must contain the key {self.id_key}") |
|
if span_doc.metadata[self.id_key] not in ids: |
|
ids.append(span_doc.metadata[self.id_key]) |
|
docs = await self.docstore.amget(ids) |
|
doc_id2doc = dict(zip(ids, docs)) |
|
for span_doc in span_docs: |
|
doc = doc_id2doc[span_doc.metadata[self.id_key]] |
|
span_doc.metadata.update(doc.metadata) |
|
span_doc.metadata["attached_span"] = self.get_span_from_doc_by_id( |
|
doc=doc, span_id=span_doc.id |
|
) |
|
span_doc.metadata["query_span_id"] = query |
|
|
|
span_docs_filtered = [ |
|
span_doc for span_doc in span_docs if span_doc.id != query_span_doc.id |
|
] |
|
|
|
if self.compressor is None: |
|
return span_docs_filtered |
|
if docs: |
|
prepared_docs = [ |
|
self._prepare_doc_for_compression(sub_doc) for sub_doc in span_docs_filtered |
|
] |
|
prepared_query = self._prepare_query_for_compression(query_span_doc) |
|
compressed_docs = await self.base_compressor.acompress_documents( |
|
prepared_docs, query=prepared_query, callbacks=run_manager.get_child() |
|
) |
|
return list(compressed_docs) |
|
else: |
|
return [] |
|
|
|
def create_span_documents( |
|
self, documents: List[LCDocument] |
|
) -> Tuple[List[LCDocument], Dict[str, int]]: |
|
span_docs = [] |
|
id2idx = {} |
|
for i, doc in enumerate(documents): |
|
pie_doc, metadata = self.docstore.unwrap_with_metadata(doc) |
|
base_layer = self.get_base_layer( |
|
pie_doc, use_predicted_annotations=self.use_predicted_annotations(doc) |
|
) |
|
if len(base_layer) == 0: |
|
logger.warning(f"No spans found in document {i} (id: {doc.id})") |
|
for idx, labeled_span in enumerate(base_layer): |
|
_metadata = {k: v for k, v in metadata.items() if k != METADATA_KEY_CHILD_ID2IDX} |
|
|
|
_metadata[self.vectorstore.METADATA_SPAN_KEY] = self._span_to_dict(labeled_span) |
|
new_doc = LCDocument( |
|
id=str(uuid.uuid4()), page_content=pie_doc.text, metadata=_metadata |
|
) |
|
span_docs.append(new_doc) |
|
id2idx[new_doc.id] = idx |
|
return span_docs, id2idx |
|
|
|
def _split_docs_for_adding( |
|
self, |
|
documents: List[LCDocument], |
|
ids: Optional[List[str]] = None, |
|
add_to_docstore: bool = True, |
|
) -> Tuple[List[LCDocument], List[Tuple[str, LCDocument]]]: |
|
if ids is None: |
|
doc_ids = [doc.id for doc in documents] |
|
if not add_to_docstore: |
|
raise ValueError("If ids are not passed in, `add_to_docstore` MUST be True") |
|
else: |
|
if len(documents) != len(ids): |
|
raise ValueError( |
|
"Got uneven list of documents and ids. " |
|
"If `ids` is provided, should be same length as `documents`." |
|
) |
|
doc_ids = ids |
|
|
|
if len(set(doc_ids)) != len(doc_ids): |
|
raise ValueError("IDs must be unique") |
|
|
|
docs = [] |
|
full_docs = [] |
|
for i, doc in enumerate(documents): |
|
_id = doc_ids[i] |
|
sub_docs, sub_doc_id2idx = self.create_span_documents([doc]) |
|
if self.child_metadata_fields is not None: |
|
for sub_doc in sub_docs: |
|
sub_doc.metadata = {k: sub_doc.metadata[k] for k in self.child_metadata_fields} |
|
for sub_doc in sub_docs: |
|
|
|
sub_doc.metadata[self.id_key] = _id |
|
docs.extend(sub_docs) |
|
doc.metadata[METADATA_KEY_CHILD_ID2IDX] = sub_doc_id2idx |
|
full_docs.append((_id, doc)) |
|
|
|
return docs, full_docs |
|
|
|
def remove_missing_span_ids_from_document( |
|
self, document: LCDocument, span_ids: Set[str] |
|
) -> LCDocument: |
|
"""Remove invalid span ids from the span to idx mapping |
|
of the document. |
|
|
|
Args: |
|
document: Document to remove invalid span ids from |
|
span_ids: Set of valid span ids |
|
|
|
Returns: |
|
Document with invalid span ids removed |
|
""" |
|
span_id2idx = document.metadata[METADATA_KEY_CHILD_ID2IDX] |
|
new_doc = copy(document) |
|
filtered_span_id2idx = { |
|
span_id: idx for span_id, idx in span_id2idx.items() if span_id in span_ids |
|
} |
|
new_doc.metadata[METADATA_KEY_CHILD_ID2IDX] = filtered_span_id2idx |
|
missed_span_ids = set(span_id2idx.keys()) - span_ids |
|
if len(missed_span_ids) > 0: |
|
layer = self.get_base_layer( |
|
self.docstore.unwrap(document), |
|
use_predicted_annotations=self.use_predicted_annotations(document), |
|
) |
|
resolved_missed_spans = [ |
|
layer[span_id2idx[span_id]].resolve() for span_id in missed_span_ids |
|
] |
|
logger.warning( |
|
f"Document {document.id} contains spans that can not be added to the " |
|
f"vectorstore because no vector could be calculated:\n{resolved_missed_spans}.\n" |
|
"These spans will be not queryable." |
|
) |
|
return document |
|
|
|
def add_documents( |
|
self, |
|
documents: List[LCDocument], |
|
ids: Optional[List[str]] = None, |
|
add_to_docstore: bool = True, |
|
**kwargs: Any, |
|
) -> None: |
|
"""Adds documents to the docstore and vectorstores. |
|
|
|
Args: |
|
documents: List of documents to add |
|
ids: Optional list of ids for documents. If provided should be the same |
|
length as the list of documents. Can be provided if parent documents |
|
are already in the document store and you don't want to re-add |
|
to the docstore. If not provided, random UUIDs will be used as |
|
ids. |
|
add_to_docstore: Boolean of whether to add documents to docstore. |
|
This can be false if and only if `ids` are provided. You may want |
|
to set this to False if the documents are already in the docstore |
|
and you don't want to re-add them. |
|
""" |
|
docs, full_docs = self._split_docs_for_adding(documents, ids, add_to_docstore) |
|
added_span_ids = self.vectorstore.add_documents(docs, **kwargs) |
|
full_docs = [ |
|
(doc_id, self.remove_missing_span_ids_from_document(doc, set(added_span_ids))) |
|
for doc_id, doc in full_docs |
|
] |
|
if add_to_docstore: |
|
self.docstore.mset(full_docs) |
|
|
|
async def aadd_documents( |
|
self, |
|
documents: List[LCDocument], |
|
ids: Optional[List[str]] = None, |
|
add_to_docstore: bool = True, |
|
**kwargs: Any, |
|
) -> None: |
|
docs, full_docs = self._split_docs_for_adding(documents, ids, add_to_docstore) |
|
added_span_ids = await self.vectorstore.aadd_documents(docs, **kwargs) |
|
full_docs = [ |
|
(doc_id, self.remove_missing_span_ids_from_document(doc, set(added_span_ids))) |
|
for doc_id, doc in full_docs |
|
] |
|
if add_to_docstore: |
|
await self.docstore.amset(full_docs) |
|
|
|
def delete_documents(self, ids: List[str]) -> None: |
|
"""Remove documents from the docstore and vectorstores. |
|
|
|
Args: |
|
ids: List of ids to remove |
|
""" |
|
|
|
child_ids = [] |
|
for doc in self.docstore.mget(ids): |
|
child_ids.extend(doc.metadata[METADATA_KEY_CHILD_ID2IDX]) |
|
|
|
self.vectorstore.delete(child_ids) |
|
self.docstore.mdelete(ids) |
|
|
|
async def adelete_documents(self, ids: List[str]) -> None: |
|
"""Asynchronously remove documents from the docstore and vectorstores. |
|
|
|
Args: |
|
ids: List of ids to remove |
|
""" |
|
|
|
child_ids = [] |
|
docs: List[LCDocument] = await self.docstore.amget(ids) |
|
for doc in docs: |
|
child_ids.extend(doc.metadata[METADATA_KEY_CHILD_ID2IDX]) |
|
|
|
await self.vectorstore.adelete(child_ids) |
|
await self.docstore.amdelete(ids) |
|
|
|
def add_pie_documents( |
|
self, |
|
documents: Iterable[TextBasedDocument], |
|
use_predicted_annotations: bool, |
|
metadata: Optional[Dict[str, Any]] = None, |
|
) -> None: |
|
"""Add pie documents to the retriever. |
|
|
|
Args: |
|
documents: Iterable of pie documents to add |
|
use_predicted_annotations: Whether to use the predicted annotations or the gold annotations |
|
metadata: Optional metadata to add to each document |
|
""" |
|
metadata = metadata or {} |
|
metadata = copy(metadata) |
|
metadata[self.use_predicted_annotations_key] = use_predicted_annotations |
|
docs = [self.docstore.wrap(doc, **metadata) for doc in documents] |
|
|
|
|
|
new_docs_ids = [doc.id for doc in docs] |
|
existing_docs = self.docstore.mget(new_docs_ids) |
|
existing_doc_ids = [doc.id for doc in existing_docs] |
|
self.delete_documents(existing_doc_ids) |
|
|
|
self.add_documents(docs) |
|
|
|
def _save_to_directory(self, path: str, **kwargs) -> None: |
|
logger.info(f'Saving docstore and vectorstore to "{path}" ...') |
|
self.docstore.save_to_directory(os.path.join(path, "docstore")) |
|
self.vectorstore.save_to_directory(os.path.join(path, "vectorstore")) |
|
|
|
def _load_from_directory(self, path: str, **kwargs) -> None: |
|
logger.info(f'Loading docstore and vectorstore from "{path}" ...') |
|
self.docstore.load_from_directory(os.path.join(path, "docstore")) |
|
self.vectorstore.load_from_directory(os.path.join(path, "vectorstore")) |
|
|
|
|
|
METADATA_KEY_RELATION_LABEL2TAILS_WITH_SCORES = "relation_label2tails_with_scores" |
|
|
|
|
|
class DocumentAwareSpanRetrieverWithRelations(DocumentAwareSpanRetriever): |
|
"""Retriever for related contextualized text spans, i.e. spans linked by relations |
|
to reference spans that are similar to the query span. It accepts spans as queries and |
|
retrieves spans with their containing document and the reference span.""" |
|
|
|
relation_layer_name: str = "binary_relations" |
|
"""The name of the relation annotation layer in the pie document.""" |
|
relation_labels: Optional[List[str]] = None |
|
"""The list of relation labels to consider.""" |
|
span_labels: Optional[List[str]] = None |
|
"""The list of span labels to consider.""" |
|
reversed_relations_suffix: Optional[str] = None |
|
"""Whether to consider reverse relations as well.""" |
|
symmetric_relations: Optional[list[str]] = None |
|
"""The list of relation labels that are symmetric.""" |
|
|
|
def get_relation_layer( |
|
self, pie_document: TextBasedDocument, use_predicted_annotations: bool |
|
) -> BaseAnnotationList: |
|
"""Get the relation layer of the pie document.""" |
|
if self.relation_layer_name not in pie_document: |
|
raise ValueError( |
|
f'The pie document must contain the annotation layer "{self.relation_layer_name}"' |
|
) |
|
layer = pie_document[self.relation_layer_name] |
|
return layer.predictions if use_predicted_annotations else layer |
|
|
|
def create_span_documents( |
|
self, documents: List[LCDocument] |
|
) -> Tuple[List[LCDocument], Dict[str, int]]: |
|
span_docs = [] |
|
id2idx = {} |
|
for i, doc in enumerate(documents): |
|
pie_doc, metadata = self.docstore.unwrap_with_metadata(doc) |
|
use_predicted_annotations = self.use_predicted_annotations(doc) |
|
base_layer = self.get_base_layer( |
|
pie_doc, use_predicted_annotations=use_predicted_annotations |
|
) |
|
if len(base_layer) == 0: |
|
logger.warning(f"No spans found in document {i} (id: {doc.id})") |
|
id2span = {str(uuid.uuid4()): span for span in base_layer} |
|
span2id = {span: span_id for span_id, span in id2span.items()} |
|
if len(id2span) != len(span2id): |
|
raise ValueError("Span ids and spans must be unique") |
|
relations = self.get_relation_layer( |
|
pie_doc, use_predicted_annotations=use_predicted_annotations |
|
) |
|
head2label2tails_with_scores: Dict[str, Dict[str, List[Tuple[str, float]]]] = ( |
|
defaultdict(lambda: defaultdict(list)) |
|
) |
|
|
|
for relation in relations: |
|
is_symmetric = ( |
|
self.symmetric_relations is not None |
|
and relation.label in self.symmetric_relations |
|
) |
|
if self.relation_labels is None or relation.label in self.relation_labels: |
|
head2label2tails_with_scores[span2id[relation.head]][relation.label].append( |
|
(span2id[relation.tail], relation.score) |
|
) |
|
if is_symmetric: |
|
head2label2tails_with_scores[span2id[relation.tail]][ |
|
relation.label |
|
].append((span2id[relation.head], relation.score)) |
|
if self.reversed_relations_suffix is not None and not is_symmetric: |
|
reversed_label = f"{relation.label}{self.reversed_relations_suffix}" |
|
if self.relation_labels is None or reversed_label in self.relation_labels: |
|
head2label2tails_with_scores[span2id[relation.tail]][ |
|
reversed_label |
|
].append((span2id[relation.head], relation.score)) |
|
|
|
for idx, span in enumerate(base_layer): |
|
span_id = span2id[span] |
|
_metadata = {k: v for k, v in metadata.items() if k != METADATA_KEY_CHILD_ID2IDX} |
|
|
|
_metadata[self.vectorstore.METADATA_SPAN_KEY] = self._span_to_dict(span) |
|
relation_label2tails_with_scores = head2label2tails_with_scores[span_id] |
|
_metadata[METADATA_KEY_RELATION_LABEL2TAILS_WITH_SCORES] = dict( |
|
relation_label2tails_with_scores |
|
) |
|
new_doc = LCDocument(id=span_id, page_content=pie_doc.text, metadata=_metadata) |
|
span_docs.append(new_doc) |
|
id2idx[span_id] = idx |
|
return span_docs, id2idx |
|
|
|
def _get_relevant_documents( |
|
self, |
|
query: str, |
|
return_related: bool = False, |
|
*, |
|
run_manager: CallbackManagerForRetrieverRun, |
|
**kwargs: Any, |
|
) -> List[LCDocument]: |
|
"""Get span documents relevant to a query span. We follow one hop of relations. |
|
|
|
Args: |
|
query: The span id to find relevant spans for |
|
return_related: Whether to return related spans |
|
run_manager: The callbacks handler to use |
|
Returns: |
|
List of relevant span documents with metadata from the parent document |
|
""" |
|
similar_span_docs = super()._get_relevant_documents( |
|
query=query, run_manager=run_manager, **kwargs |
|
) |
|
if not return_related: |
|
return similar_span_docs |
|
|
|
related_docs = [] |
|
for head_span_doc in similar_span_docs: |
|
doc_id = head_span_doc.metadata[self.id_key] |
|
doc = self.get_document(doc_id) |
|
query_span_id = head_span_doc.metadata["query_span_id"] |
|
|
|
for relation_label, tails_with_score in head_span_doc.metadata[ |
|
METADATA_KEY_RELATION_LABEL2TAILS_WITH_SCORES |
|
].items(): |
|
for tail_id, relation_score in tails_with_score: |
|
|
|
|
|
if tail_id == query_span_id: |
|
continue |
|
|
|
try: |
|
attached_tail_span = self.get_span_from_doc_by_id(doc=doc, span_id=tail_id) |
|
|
|
|
|
except SpanNotFoundError: |
|
logger.warning( |
|
f"Tail span with id [{tail_id}] not found in the vectorstore. Skipping." |
|
) |
|
continue |
|
|
|
|
|
if self.span_labels is not None: |
|
if not isinstance(attached_tail_span, (LabeledSpan, LabeledMultiSpan)): |
|
raise ValueError( |
|
"Span must must be a labeled span if span_labels is provided" |
|
) |
|
if attached_tail_span.label not in self.span_labels: |
|
continue |
|
|
|
related_docs.append( |
|
LCDocument( |
|
id=tail_id, |
|
page_content="", |
|
metadata={ |
|
"relation_score": relation_score, |
|
"head_id": head_span_doc.id, |
|
"relation_label": relation_label, |
|
"attached_tail_span": attached_tail_span, |
|
**head_span_doc.metadata, |
|
}, |
|
) |
|
) |
|
return related_docs |
|
|