import logging from typing import Dict, Optional, Sequence, Tuple, Union import gradio as gr import pandas as pd from pytorch_ie import Annotation from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan from typing_extensions import Protocol from src.langchain_modules import DocumentAwareSpanRetriever from src.langchain_modules.span_retriever import DocumentAwareSpanRetrieverWithRelations from src.utils import parse_config logger = logging.getLogger(__name__) def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict: document = retriever.get_document(doc_id=doc_id) return retriever.docstore.as_dict(document) def load_retriever( config_str: str, config_format: str, device: str = "cpu", previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None, ) -> DocumentAwareSpanRetrieverWithRelations: try: retriever_config = parse_config(config_str, format=config_format) # set device for the embeddings pipeline retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config) # if a previous retriever is provided, load all documents and vectors from the previous retriever if previous_retriever is not None: # documents all_doc_ids = list(previous_retriever.docstore.yield_keys()) gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...") all_docs = previous_retriever.docstore.mget(all_doc_ids) result.docstore.mset([(doc.id, doc) for doc in all_docs]) # spans (with vectors) all_span_ids = list(previous_retriever.vectorstore.yield_keys()) all_spans = previous_retriever.vectorstore.mget(all_span_ids) result.vectorstore.mset([(span.id, span) for span in all_spans]) gr.Info("Retriever loaded successfully.") return result except Exception as e: raise gr.Error(f"Failed to load retriever: {e}") def retrieve_similar_spans( retriever: DocumentAwareSpanRetriever, query_span_id: str, **kwargs, ) -> pd.DataFrame: if not query_span_id.strip(): raise gr.Error("No query span selected.") try: retrieval_result = retriever.invoke(input=query_span_id, **kwargs) records = [] for similar_span_doc in retrieval_result: pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc) span_ann = metadata["attached_span"] records.append( { "doc_id": pie_doc.id, "span_id": similar_span_doc.id, "score": metadata["relevance_score"], "label": span_ann.label, "text": str(span_ann), } ) return ( pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"]) .sort_values(by="score", ascending=False) .round(3) ) except Exception as e: raise gr.Error(f"Failed to retrieve similar ADUs: {e}") def retrieve_relevant_spans( retriever: DocumentAwareSpanRetriever, query_span_id: str, relation_label_mapping: Optional[dict[str, str]] = None, **kwargs, ) -> pd.DataFrame: if not query_span_id.strip(): raise gr.Error("No query span selected.") try: relation_label_mapping = relation_label_mapping or {} retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs) records = [] for relevant_span_doc in retrieval_result: pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc) span_ann = metadata["attached_span"] tail_span_ann = metadata["attached_tail_span"] mapped_relation_label = relation_label_mapping.get( metadata["relation_label"], metadata["relation_label"] ) records.append( { "doc_id": pie_doc.id, "type": mapped_relation_label, "rel_score": metadata["relation_score"], "text": str(tail_span_ann), "span_id": relevant_span_doc.id, "label": tail_span_ann.label, "ref_score": metadata["relevance_score"], "ref_label": span_ann.label, "ref_text": str(span_ann), "ref_span_id": metadata["head_id"], } ) return ( pd.DataFrame( records, columns=[ "type", # omitted for now, we get no valid relation scores for the generative model # "rel_score", "ref_score", "label", "text", "ref_label", "ref_text", "doc_id", "span_id", "ref_span_id", ], ) .sort_values(by=["ref_score"], ascending=False) .round(3) ) except Exception as e: raise gr.Error(f"Failed to retrieve relevant ADUs: {e}") class RetrieverCallable(Protocol): def __call__( self, retriever: DocumentAwareSpanRetriever, query_span_id: str, **kwargs, ) -> Optional[pd.DataFrame]: pass def _retrieve_for_all_spans( retriever: DocumentAwareSpanRetriever, query_doc_id: str, retrieve_func: RetrieverCallable, query_span_id_column: str = "query_span_id", query_span_text_column: Optional[str] = None, **kwargs, ) -> Optional[pd.DataFrame]: if not query_doc_id.strip(): raise gr.Error("No query document selected.") try: span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id) gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...") span_results = { query_span_id: retrieve_func( retriever=retriever, query_span_id=query_span_id, **kwargs, ) for query_span_id in span_id2idx.keys() } span_results_not_empty = { query_span_id: df for query_span_id, df in span_results.items() if df is not None and not df.empty } # add column with query_span_id for query_span_id, query_span_result in span_results_not_empty.items(): query_span_result[query_span_id_column] = query_span_id if query_span_text_column is not None: query_span = retriever.get_span_by_id(span_id=query_span_id) query_span_result[query_span_text_column] = str(query_span) if len(span_results_not_empty) == 0: gr.Info(f"No results found for any ADU in document {query_doc_id}.") return None else: result = pd.concat(span_results_not_empty.values(), ignore_index=True) gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.") return result except Exception as e: raise gr.Error( f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}' ) def retrieve_all_similar_spans( retriever: DocumentAwareSpanRetriever, query_doc_id: str, **kwargs, ) -> Optional[pd.DataFrame]: return _retrieve_for_all_spans( retriever=retriever, query_doc_id=query_doc_id, retrieve_func=retrieve_similar_spans, **kwargs, ) def retrieve_all_relevant_spans( retriever: DocumentAwareSpanRetriever, query_doc_id: str, **kwargs, ) -> Optional[pd.DataFrame]: return _retrieve_for_all_spans( retriever=retriever, query_doc_id=query_doc_id, retrieve_func=retrieve_relevant_spans, **kwargs, ) class RetrieverForAllSpansCallable(Protocol): def __call__( self, retriever: DocumentAwareSpanRetriever, query_doc_id: str, **kwargs, ) -> Optional[pd.DataFrame]: pass def _retrieve_for_all_documents( retriever: DocumentAwareSpanRetriever, retrieve_func: RetrieverForAllSpansCallable, query_doc_id_column: str = "query_doc_id", **kwargs, ) -> Optional[pd.DataFrame]: try: all_doc_ids = list(retriever.docstore.yield_keys()) gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...") doc_results = { doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs) for doc_id in all_doc_ids } doc_results_not_empty = { doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty } # add column with query_doc_id for doc_id, doc_result in doc_results_not_empty.items(): doc_result[query_doc_id_column] = doc_id if len(doc_results_not_empty) == 0: gr.Info("No results found for any document.") return None else: result = pd.concat(doc_results_not_empty, ignore_index=True) gr.Info(f"Retrieved {len(result)} ADUs for all documents.") return result except Exception as e: raise gr.Error(f"Failed to retrieve results for all documents: {e}") def retrieve_all_similar_spans_for_all_documents( retriever: DocumentAwareSpanRetriever, **kwargs, ) -> Optional[pd.DataFrame]: return _retrieve_for_all_documents( retriever=retriever, retrieve_func=retrieve_all_similar_spans, **kwargs, ) def retrieve_all_relevant_spans_for_all_documents( retriever: DocumentAwareSpanRetriever, **kwargs, ) -> Optional[pd.DataFrame]: return _retrieve_for_all_documents( retriever=retriever, retrieve_func=retrieve_all_relevant_spans, **kwargs, ) def get_text_spans_and_relations_from_document( retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str ) -> Tuple[ str, Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], Dict[str, int], Sequence[BinaryRelation], ]: document = retriever.get_document(doc_id=document_id) pie_document = retriever.docstore.unwrap(document) use_predicted_annotations = retriever.use_predicted_annotations(document) spans = retriever.get_base_layer( pie_document=pie_document, use_predicted_annotations=use_predicted_annotations ) relations = retriever.get_relation_layer( pie_document=pie_document, use_predicted_annotations=use_predicted_annotations ) span_id2idx = retriever.get_span_id2idx_from_doc(document) return pie_document.text, spans, span_id2idx, relations def get_span_annotation( retriever: DocumentAwareSpanRetriever, span_id: str, ) -> Annotation: if span_id.strip() == "": raise gr.Error("No span selected.") try: return retriever.get_span_by_id(span_id=span_id) except Exception as e: raise gr.Error(f"Failed to retrieve span annotation: {e}")