ArneBinder's picture
update from https://github.com/ArneBinder/pie-document-level/pull/397
ced4316 verified
raw
history blame
11.3 kB
import logging
from typing import Dict, Optional, Sequence, Tuple, Union
import gradio as gr
import pandas as pd
from pytorch_ie import Annotation
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from typing_extensions import Protocol
from src.langchain_modules import DocumentAwareSpanRetriever
from src.langchain_modules.span_retriever import DocumentAwareSpanRetrieverWithRelations
from src.utils import parse_config
logger = logging.getLogger(__name__)
def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict:
document = retriever.get_document(doc_id=doc_id)
return retriever.docstore.as_dict(document)
def load_retriever(
config_str: str,
config_format: str,
device: str = "cpu",
previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None,
) -> DocumentAwareSpanRetrieverWithRelations:
try:
retriever_config = parse_config(config_str, format=config_format)
# set device for the embeddings pipeline
retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device
result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config)
# if a previous retriever is provided, load all documents and vectors from the previous retriever
if previous_retriever is not None:
# documents
all_doc_ids = list(previous_retriever.docstore.yield_keys())
gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...")
all_docs = previous_retriever.docstore.mget(all_doc_ids)
result.docstore.mset([(doc.id, doc) for doc in all_docs])
# spans (with vectors)
all_span_ids = list(previous_retriever.vectorstore.yield_keys())
all_spans = previous_retriever.vectorstore.mget(all_span_ids)
result.vectorstore.mset([(span.id, span) for span in all_spans])
gr.Info("Retriever loaded successfully.")
return result
except Exception as e:
raise gr.Error(f"Failed to load retriever: {e}")
def retrieve_similar_spans(
retriever: DocumentAwareSpanRetriever,
query_span_id: str,
**kwargs,
) -> pd.DataFrame:
if not query_span_id.strip():
raise gr.Error("No query span selected.")
try:
retrieval_result = retriever.invoke(input=query_span_id, **kwargs)
records = []
for similar_span_doc in retrieval_result:
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
span_ann = metadata["attached_span"]
records.append(
{
"doc_id": pie_doc.id,
"span_id": similar_span_doc.id,
"score": metadata["relevance_score"],
"label": span_ann.label,
"text": str(span_ann),
}
)
return (
pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"])
.sort_values(by="score", ascending=False)
.round(3)
)
except Exception as e:
raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
def retrieve_relevant_spans(
retriever: DocumentAwareSpanRetriever,
query_span_id: str,
relation_label_mapping: Optional[dict[str, str]] = None,
**kwargs,
) -> pd.DataFrame:
if not query_span_id.strip():
raise gr.Error("No query span selected.")
try:
relation_label_mapping = relation_label_mapping or {}
retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs)
records = []
for relevant_span_doc in retrieval_result:
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc)
span_ann = metadata["attached_span"]
tail_span_ann = metadata["attached_tail_span"]
mapped_relation_label = relation_label_mapping.get(
metadata["relation_label"], metadata["relation_label"]
)
records.append(
{
"doc_id": pie_doc.id,
"type": mapped_relation_label,
"rel_score": metadata["relation_score"],
"text": str(tail_span_ann),
"span_id": relevant_span_doc.id,
"label": tail_span_ann.label,
"ref_score": metadata["relevance_score"],
"ref_label": span_ann.label,
"ref_text": str(span_ann),
"ref_span_id": metadata["head_id"],
}
)
return (
pd.DataFrame(
records,
columns=[
"type",
# omitted for now, we get no valid relation scores for the generative model
# "rel_score",
"ref_score",
"label",
"text",
"ref_label",
"ref_text",
"doc_id",
"span_id",
"ref_span_id",
],
)
.sort_values(by=["ref_score"], ascending=False)
.round(3)
)
except Exception as e:
raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
class RetrieverCallable(Protocol):
def __call__(
self,
retriever: DocumentAwareSpanRetriever,
query_span_id: str,
**kwargs,
) -> Optional[pd.DataFrame]:
pass
def _retrieve_for_all_spans(
retriever: DocumentAwareSpanRetriever,
query_doc_id: str,
retrieve_func: RetrieverCallable,
query_span_id_column: str = "query_span_id",
query_span_text_column: Optional[str] = None,
**kwargs,
) -> Optional[pd.DataFrame]:
if not query_doc_id.strip():
raise gr.Error("No query document selected.")
try:
span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id)
gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...")
span_results = {
query_span_id: retrieve_func(
retriever=retriever,
query_span_id=query_span_id,
**kwargs,
)
for query_span_id in span_id2idx.keys()
}
span_results_not_empty = {
query_span_id: df
for query_span_id, df in span_results.items()
if df is not None and not df.empty
}
# add column with query_span_id
for query_span_id, query_span_result in span_results_not_empty.items():
query_span_result[query_span_id_column] = query_span_id
if query_span_text_column is not None:
query_span = retriever.get_span_by_id(span_id=query_span_id)
query_span_result[query_span_text_column] = str(query_span)
if len(span_results_not_empty) == 0:
gr.Info(f"No results found for any ADU in document {query_doc_id}.")
return None
else:
result = pd.concat(span_results_not_empty.values(), ignore_index=True)
gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.")
return result
except Exception as e:
raise gr.Error(
f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}'
)
def retrieve_all_similar_spans(
retriever: DocumentAwareSpanRetriever,
query_doc_id: str,
**kwargs,
) -> Optional[pd.DataFrame]:
return _retrieve_for_all_spans(
retriever=retriever,
query_doc_id=query_doc_id,
retrieve_func=retrieve_similar_spans,
**kwargs,
)
def retrieve_all_relevant_spans(
retriever: DocumentAwareSpanRetriever,
query_doc_id: str,
**kwargs,
) -> Optional[pd.DataFrame]:
return _retrieve_for_all_spans(
retriever=retriever,
query_doc_id=query_doc_id,
retrieve_func=retrieve_relevant_spans,
**kwargs,
)
class RetrieverForAllSpansCallable(Protocol):
def __call__(
self,
retriever: DocumentAwareSpanRetriever,
query_doc_id: str,
**kwargs,
) -> Optional[pd.DataFrame]:
pass
def _retrieve_for_all_documents(
retriever: DocumentAwareSpanRetriever,
retrieve_func: RetrieverForAllSpansCallable,
query_doc_id_column: str = "query_doc_id",
**kwargs,
) -> Optional[pd.DataFrame]:
try:
all_doc_ids = list(retriever.docstore.yield_keys())
gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...")
doc_results = {
doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs)
for doc_id in all_doc_ids
}
doc_results_not_empty = {
doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty
}
# add column with query_doc_id
for doc_id, doc_result in doc_results_not_empty.items():
doc_result[query_doc_id_column] = doc_id
if len(doc_results_not_empty) == 0:
gr.Info("No results found for any document.")
return None
else:
result = pd.concat(doc_results_not_empty, ignore_index=True)
gr.Info(f"Retrieved {len(result)} ADUs for all documents.")
return result
except Exception as e:
raise gr.Error(f"Failed to retrieve results for all documents: {e}")
def retrieve_all_similar_spans_for_all_documents(
retriever: DocumentAwareSpanRetriever,
**kwargs,
) -> Optional[pd.DataFrame]:
return _retrieve_for_all_documents(
retriever=retriever,
retrieve_func=retrieve_all_similar_spans,
**kwargs,
)
def retrieve_all_relevant_spans_for_all_documents(
retriever: DocumentAwareSpanRetriever,
**kwargs,
) -> Optional[pd.DataFrame]:
return _retrieve_for_all_documents(
retriever=retriever,
retrieve_func=retrieve_all_relevant_spans,
**kwargs,
)
def get_text_spans_and_relations_from_document(
retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str
) -> Tuple[
str,
Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
Dict[str, int],
Sequence[BinaryRelation],
]:
document = retriever.get_document(doc_id=document_id)
pie_document = retriever.docstore.unwrap(document)
use_predicted_annotations = retriever.use_predicted_annotations(document)
spans = retriever.get_base_layer(
pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
)
relations = retriever.get_relation_layer(
pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
)
span_id2idx = retriever.get_span_id2idx_from_doc(document)
return pie_document.text, spans, span_id2idx, relations
def get_span_annotation(
retriever: DocumentAwareSpanRetriever,
span_id: str,
) -> Annotation:
if span_id.strip() == "":
raise gr.Error("No span selected.")
try:
return retriever.get_span_by_id(span_id=span_id)
except Exception as e:
raise gr.Error(f"Failed to retrieve span annotation: {e}")