|
import logging |
|
from typing import Dict, Optional, Sequence, Tuple, Union |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from pytorch_ie import Annotation |
|
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan |
|
from typing_extensions import Protocol |
|
|
|
from src.langchain_modules import DocumentAwareSpanRetriever |
|
from src.langchain_modules.span_retriever import DocumentAwareSpanRetrieverWithRelations |
|
from src.utils import parse_config |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict: |
|
document = retriever.get_document(doc_id=doc_id) |
|
return retriever.docstore.as_dict(document) |
|
|
|
|
|
def load_retriever( |
|
config_str: str, |
|
config_format: str, |
|
device: str = "cpu", |
|
previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None, |
|
) -> DocumentAwareSpanRetrieverWithRelations: |
|
try: |
|
retriever_config = parse_config(config_str, format=config_format) |
|
|
|
retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device |
|
result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config) |
|
|
|
if previous_retriever is not None: |
|
|
|
all_doc_ids = list(previous_retriever.docstore.yield_keys()) |
|
gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...") |
|
all_docs = previous_retriever.docstore.mget(all_doc_ids) |
|
result.docstore.mset([(doc.id, doc) for doc in all_docs]) |
|
|
|
all_span_ids = list(previous_retriever.vectorstore.yield_keys()) |
|
all_spans = previous_retriever.vectorstore.mget(all_span_ids) |
|
result.vectorstore.mset([(span.id, span) for span in all_spans]) |
|
|
|
gr.Info("Retriever loaded successfully.") |
|
return result |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load retriever: {e}") |
|
|
|
|
|
def retrieve_similar_spans( |
|
retriever: DocumentAwareSpanRetriever, |
|
query_span_id: str, |
|
**kwargs, |
|
) -> pd.DataFrame: |
|
if not query_span_id.strip(): |
|
raise gr.Error("No query span selected.") |
|
try: |
|
retrieval_result = retriever.invoke(input=query_span_id, **kwargs) |
|
records = [] |
|
for similar_span_doc in retrieval_result: |
|
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc) |
|
span_ann = metadata["attached_span"] |
|
records.append( |
|
{ |
|
"doc_id": pie_doc.id, |
|
"span_id": similar_span_doc.id, |
|
"score": metadata["relevance_score"], |
|
"label": span_ann.label, |
|
"text": str(span_ann), |
|
} |
|
) |
|
return ( |
|
pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"]) |
|
.sort_values(by="score", ascending=False) |
|
.round(3) |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to retrieve similar ADUs: {e}") |
|
|
|
|
|
def retrieve_relevant_spans( |
|
retriever: DocumentAwareSpanRetriever, |
|
query_span_id: str, |
|
relation_label_mapping: Optional[dict[str, str]] = None, |
|
**kwargs, |
|
) -> pd.DataFrame: |
|
if not query_span_id.strip(): |
|
raise gr.Error("No query span selected.") |
|
try: |
|
relation_label_mapping = relation_label_mapping or {} |
|
retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs) |
|
records = [] |
|
for relevant_span_doc in retrieval_result: |
|
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc) |
|
span_ann = metadata["attached_span"] |
|
tail_span_ann = metadata["attached_tail_span"] |
|
mapped_relation_label = relation_label_mapping.get( |
|
metadata["relation_label"], metadata["relation_label"] |
|
) |
|
records.append( |
|
{ |
|
"doc_id": pie_doc.id, |
|
"type": mapped_relation_label, |
|
"rel_score": metadata["relation_score"], |
|
"text": str(tail_span_ann), |
|
"span_id": relevant_span_doc.id, |
|
"label": tail_span_ann.label, |
|
"ref_score": metadata["relevance_score"], |
|
"ref_label": span_ann.label, |
|
"ref_text": str(span_ann), |
|
"ref_span_id": metadata["head_id"], |
|
} |
|
) |
|
return ( |
|
pd.DataFrame( |
|
records, |
|
columns=[ |
|
"type", |
|
|
|
|
|
"ref_score", |
|
"label", |
|
"text", |
|
"ref_label", |
|
"ref_text", |
|
"doc_id", |
|
"span_id", |
|
"ref_span_id", |
|
], |
|
) |
|
.sort_values(by=["ref_score"], ascending=False) |
|
.round(3) |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to retrieve relevant ADUs: {e}") |
|
|
|
|
|
class RetrieverCallable(Protocol): |
|
def __call__( |
|
self, |
|
retriever: DocumentAwareSpanRetriever, |
|
query_span_id: str, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
pass |
|
|
|
|
|
def _retrieve_for_all_spans( |
|
retriever: DocumentAwareSpanRetriever, |
|
query_doc_id: str, |
|
retrieve_func: RetrieverCallable, |
|
query_span_id_column: str = "query_span_id", |
|
query_span_text_column: Optional[str] = None, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
if not query_doc_id.strip(): |
|
raise gr.Error("No query document selected.") |
|
try: |
|
span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id) |
|
gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...") |
|
span_results = { |
|
query_span_id: retrieve_func( |
|
retriever=retriever, |
|
query_span_id=query_span_id, |
|
**kwargs, |
|
) |
|
for query_span_id in span_id2idx.keys() |
|
} |
|
span_results_not_empty = { |
|
query_span_id: df |
|
for query_span_id, df in span_results.items() |
|
if df is not None and not df.empty |
|
} |
|
|
|
|
|
for query_span_id, query_span_result in span_results_not_empty.items(): |
|
query_span_result[query_span_id_column] = query_span_id |
|
if query_span_text_column is not None: |
|
query_span = retriever.get_span_by_id(span_id=query_span_id) |
|
query_span_result[query_span_text_column] = str(query_span) |
|
|
|
if len(span_results_not_empty) == 0: |
|
gr.Info(f"No results found for any ADU in document {query_doc_id}.") |
|
return None |
|
else: |
|
result = pd.concat(span_results_not_empty.values(), ignore_index=True) |
|
gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.") |
|
return result |
|
except Exception as e: |
|
raise gr.Error( |
|
f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}' |
|
) |
|
|
|
|
|
def retrieve_all_similar_spans( |
|
retriever: DocumentAwareSpanRetriever, |
|
query_doc_id: str, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
return _retrieve_for_all_spans( |
|
retriever=retriever, |
|
query_doc_id=query_doc_id, |
|
retrieve_func=retrieve_similar_spans, |
|
**kwargs, |
|
) |
|
|
|
|
|
def retrieve_all_relevant_spans( |
|
retriever: DocumentAwareSpanRetriever, |
|
query_doc_id: str, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
return _retrieve_for_all_spans( |
|
retriever=retriever, |
|
query_doc_id=query_doc_id, |
|
retrieve_func=retrieve_relevant_spans, |
|
**kwargs, |
|
) |
|
|
|
|
|
class RetrieverForAllSpansCallable(Protocol): |
|
def __call__( |
|
self, |
|
retriever: DocumentAwareSpanRetriever, |
|
query_doc_id: str, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
pass |
|
|
|
|
|
def _retrieve_for_all_documents( |
|
retriever: DocumentAwareSpanRetriever, |
|
retrieve_func: RetrieverForAllSpansCallable, |
|
query_doc_id_column: str = "query_doc_id", |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
try: |
|
all_doc_ids = list(retriever.docstore.yield_keys()) |
|
gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...") |
|
doc_results = { |
|
doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs) |
|
for doc_id in all_doc_ids |
|
} |
|
doc_results_not_empty = { |
|
doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty |
|
} |
|
|
|
for doc_id, doc_result in doc_results_not_empty.items(): |
|
doc_result[query_doc_id_column] = doc_id |
|
|
|
if len(doc_results_not_empty) == 0: |
|
gr.Info("No results found for any document.") |
|
return None |
|
else: |
|
result = pd.concat(doc_results_not_empty, ignore_index=True) |
|
gr.Info(f"Retrieved {len(result)} ADUs for all documents.") |
|
return result |
|
except Exception as e: |
|
raise gr.Error(f"Failed to retrieve results for all documents: {e}") |
|
|
|
|
|
def retrieve_all_similar_spans_for_all_documents( |
|
retriever: DocumentAwareSpanRetriever, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
return _retrieve_for_all_documents( |
|
retriever=retriever, |
|
retrieve_func=retrieve_all_similar_spans, |
|
**kwargs, |
|
) |
|
|
|
|
|
def retrieve_all_relevant_spans_for_all_documents( |
|
retriever: DocumentAwareSpanRetriever, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
return _retrieve_for_all_documents( |
|
retriever=retriever, |
|
retrieve_func=retrieve_all_relevant_spans, |
|
**kwargs, |
|
) |
|
|
|
|
|
def get_text_spans_and_relations_from_document( |
|
retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str |
|
) -> Tuple[ |
|
str, |
|
Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], |
|
Dict[str, int], |
|
Sequence[BinaryRelation], |
|
]: |
|
document = retriever.get_document(doc_id=document_id) |
|
pie_document = retriever.docstore.unwrap(document) |
|
use_predicted_annotations = retriever.use_predicted_annotations(document) |
|
spans = retriever.get_base_layer( |
|
pie_document=pie_document, use_predicted_annotations=use_predicted_annotations |
|
) |
|
relations = retriever.get_relation_layer( |
|
pie_document=pie_document, use_predicted_annotations=use_predicted_annotations |
|
) |
|
span_id2idx = retriever.get_span_id2idx_from_doc(document) |
|
return pie_document.text, spans, span_id2idx, relations |
|
|
|
|
|
def get_span_annotation( |
|
retriever: DocumentAwareSpanRetriever, |
|
span_id: str, |
|
) -> Annotation: |
|
if span_id.strip() == "": |
|
raise gr.Error("No span selected.") |
|
try: |
|
return retriever.get_span_by_id(span_id=span_id) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to retrieve span annotation: {e}") |
|
|