import json import logging from typing import Iterable, Optional, Sequence, Union import gradio as gr import pandas as pd from pie_datasets import Dataset, IterableDataset, load_dataset from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger from pytorch_ie import Pipeline from pytorch_ie.annotations import LabeledSpan from pytorch_ie.auto import AutoPipeline from pytorch_ie.documents import ( TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, ) from typing_extensions import Protocol from src.langchain_modules import DocumentAwareSpanRetriever from src.langchain_modules.span_retriever import ( DocumentAwareSpanRetrieverWithRelations, _parse_config, ) logger = logging.getLogger(__name__) def annotate_document( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, argumentation_model: Pipeline, handle_parts_of_same: bool = False, ) -> Union[ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, ]: """Annotate a document with the provided pipeline. Args: document: The document to annotate. argumentation_model: The pipeline to use for annotation. handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span. """ # execute prediction pipeline argumentation_model(document) if handle_parts_of_same: merger = SpansViaRelationMerger( relation_layer="binary_relations", link_relation_label="parts_of_same", create_multi_spans=True, result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, result_field_mapping={ "labeled_spans": "labeled_multi_spans", "binary_relations": "binary_relations", "labeled_partitions": "labeled_partitions", }, ) document = merger(document) return document def create_document( text: str, doc_id: str, split_regex: Optional[str] = None ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided text. Parameters: text: The text to process. doc_id: The ID of the document. split_regex: A regular expression pattern to use for splitting the text into partitions. Returns: The processed document. """ document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( id=doc_id, text=text, metadata={} ) if split_regex is not None: partitioner = RegexPartitioner( pattern=split_regex, partition_layer_name="labeled_partitions" ) document = partitioner(document) else: # add single partition from the whole text (the model only considers text in partitions) document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) return document def add_annotated_pie_documents( retriever: DocumentAwareSpanRetriever, pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions], use_predicted_annotations: bool, verbose: bool = False, ) -> None: if verbose: gr.Info(f"Create span embeddings for {len(pie_documents)} documents...") num_docs_before = len(retriever.docstore) retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations) # number of documents that were overwritten num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore) # warn if documents were overwritten if num_overwritten_docs > 0: gr.Warning(f"{num_overwritten_docs} documents were overwritten.") def process_texts( texts: Iterable[str], doc_ids: Iterable[str], argumentation_model: Pipeline, retriever: DocumentAwareSpanRetriever, split_regex_escaped: Optional[str], handle_parts_of_same: bool = False, verbose: bool = False, ) -> None: # check that doc_ids are unique if len(set(doc_ids)) != len(list(doc_ids)): raise gr.Error("Document IDs must be unique.") pie_documents = [ create_document(text=text, doc_id=doc_id, split_regex=split_regex_escaped) for text, doc_id in zip(texts, doc_ids) ] if verbose: gr.Info(f"Annotate {len(pie_documents)} documents...") pie_documents = [ annotate_document( document=pie_document, argumentation_model=argumentation_model, handle_parts_of_same=handle_parts_of_same, ) for pie_document in pie_documents ] add_annotated_pie_documents( retriever=retriever, pie_documents=pie_documents, use_predicted_annotations=True, verbose=verbose, ) def add_annotated_pie_documents_from_dataset( retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs ) -> None: try: gr.Info( "Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2) ) dataset = load_dataset(**load_dataset_kwargs) if not isinstance(dataset, (Dataset, IterableDataset)): raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.") dataset_converted = dataset.to_document_type( TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions ) add_annotated_pie_documents( retriever=retriever, pie_documents=dataset_converted, use_predicted_annotations=False, verbose=verbose, ) except Exception as e: raise gr.Error(f"Failed to load dataset: {e}") def load_argumentation_model( model_name: str, revision: Optional[str] = None, device: str = "cpu", ) -> Pipeline: try: # the Pipeline class expects an integer for the device if device == "cuda": pipeline_device = 0 elif device.startswith("cuda:"): pipeline_device = int(device.split(":")[1]) elif device == "cpu": pipeline_device = -1 else: raise gr.Error(f"Invalid device: {device}") model = AutoPipeline.from_pretrained( model_name, device=pipeline_device, num_workers=0, taskmodule_kwargs=dict(revision=revision), model_kwargs=dict(revision=revision), ) gr.Info( f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}" ) except Exception as e: raise gr.Error(f"Failed to load argumentation model: {e}") return model def load_retriever( retriever_config: str, config_format: str, device: str = "cpu", previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None, ) -> DocumentAwareSpanRetrieverWithRelations: try: retriever_config = _parse_config(retriever_config, format=config_format) # set device for the embeddings pipeline retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config) # if a previous retriever is provided, load all documents and vectors from the previous retriever if previous_retriever is not None: # documents all_doc_ids = list(previous_retriever.docstore.yield_keys()) gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...") all_docs = previous_retriever.docstore.mget(all_doc_ids) result.docstore.mset([(doc.id, doc) for doc in all_docs]) # spans (with vectors) all_span_ids = list(previous_retriever.vectorstore.yield_keys()) all_spans = previous_retriever.vectorstore.mget(all_span_ids) result.vectorstore.mset([(span.id, span) for span in all_spans]) gr.Info("Retriever loaded successfully.") return result except Exception as e: raise gr.Error(f"Failed to load retriever: {e}") def retrieve_similar_spans( retriever: DocumentAwareSpanRetriever, query_span_id: str, **kwargs, ) -> pd.DataFrame: if not query_span_id.strip(): raise gr.Error("No query span selected.") try: retrieval_result = retriever.invoke(input=query_span_id, **kwargs) records = [] for similar_span_doc in retrieval_result: pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc) span_ann = metadata["attached_span"] records.append( { "doc_id": pie_doc.id, "span_id": similar_span_doc.id, "score": metadata["relevance_score"], "label": span_ann.label, "text": str(span_ann), } ) return ( pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"]) .sort_values(by="score", ascending=False) .round(3) ) except Exception as e: raise gr.Error(f"Failed to retrieve similar ADUs: {e}") def retrieve_relevant_spans( retriever: DocumentAwareSpanRetriever, query_span_id: str, relation_label_mapping: Optional[dict[str, str]] = None, **kwargs, ) -> pd.DataFrame: if not query_span_id.strip(): raise gr.Error("No query span selected.") try: relation_label_mapping = relation_label_mapping or {} retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs) records = [] for relevant_span_doc in retrieval_result: pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc) span_ann = metadata["attached_span"] tail_span_ann = metadata["attached_tail_span"] mapped_relation_label = relation_label_mapping.get( metadata["relation_label"], metadata["relation_label"] ) records.append( { "doc_id": pie_doc.id, "type": mapped_relation_label, "rel_score": metadata["relation_score"], "text": str(tail_span_ann), "span_id": relevant_span_doc.id, "label": tail_span_ann.label, "ref_score": metadata["relevance_score"], "ref_label": span_ann.label, "ref_text": str(span_ann), "ref_span_id": metadata["head_id"], } ) return ( pd.DataFrame( records, columns=[ "type", # omitted for now, we get no valid relation scores for the generative model # "rel_score", "ref_score", "label", "text", "ref_label", "ref_text", "doc_id", "span_id", "ref_span_id", ], ) .sort_values(by=["ref_score"], ascending=False) .round(3) ) except Exception as e: raise gr.Error(f"Failed to retrieve relevant ADUs: {e}") class RetrieverCallable(Protocol): def __call__( self, retriever: DocumentAwareSpanRetriever, query_span_id: str, **kwargs, ) -> Optional[pd.DataFrame]: pass def _retrieve_for_all_spans( retriever: DocumentAwareSpanRetriever, query_doc_id: str, retrieve_func: RetrieverCallable, query_span_id_column: str = "query_span_id", **kwargs, ) -> Optional[pd.DataFrame]: if not query_doc_id.strip(): raise gr.Error("No query document selected.") try: span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id) gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...") span_results = { query_span_id: retrieve_func( retriever=retriever, query_span_id=query_span_id, **kwargs, ) for query_span_id in span_id2idx.keys() } span_results_not_empty = { query_span_id: df for query_span_id, df in span_results.items() if df is not None and not df.empty } # add column with query_span_id for query_span_id, query_span_result in span_results_not_empty.items(): query_span_result[query_span_id_column] = query_span_id if len(span_results_not_empty) == 0: gr.Info(f"No results found for any ADU in document {query_doc_id}.") return None else: result = pd.concat(span_results_not_empty.values(), ignore_index=True) gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.") return result except Exception as e: raise gr.Error( f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}' ) def retrieve_all_similar_spans( retriever: DocumentAwareSpanRetriever, query_doc_id: str, **kwargs, ) -> Optional[pd.DataFrame]: return _retrieve_for_all_spans( retriever=retriever, query_doc_id=query_doc_id, retrieve_func=retrieve_similar_spans, **kwargs, ) def retrieve_all_relevant_spans( retriever: DocumentAwareSpanRetriever, query_doc_id: str, **kwargs, ) -> Optional[pd.DataFrame]: return _retrieve_for_all_spans( retriever=retriever, query_doc_id=query_doc_id, retrieve_func=retrieve_relevant_spans, **kwargs, ) class RetrieverForAllSpansCallable(Protocol): def __call__( self, retriever: DocumentAwareSpanRetriever, query_doc_id: str, **kwargs, ) -> Optional[pd.DataFrame]: pass def _retrieve_for_all_documents( retriever: DocumentAwareSpanRetriever, retrieve_func: RetrieverForAllSpansCallable, query_doc_id_column: str = "query_doc_id", **kwargs, ) -> Optional[pd.DataFrame]: try: all_doc_ids = list(retriever.docstore.yield_keys()) gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...") doc_results = { doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs) for doc_id in all_doc_ids } doc_results_not_empty = { doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty } # add column with query_doc_id for doc_id, doc_result in doc_results_not_empty.items(): doc_result[query_doc_id_column] = doc_id if len(doc_results_not_empty) == 0: gr.Info("No results found for any document.") return None else: result = pd.concat(doc_results_not_empty, ignore_index=True) gr.Info(f"Retrieved {len(result)} ADUs for all documents.") return result except Exception as e: raise gr.Error(f"Failed to retrieve results for all documents: {e}") def retrieve_all_similar_spans_for_all_documents( retriever: DocumentAwareSpanRetriever, **kwargs, ) -> Optional[pd.DataFrame]: return _retrieve_for_all_documents( retriever=retriever, retrieve_func=retrieve_all_similar_spans, **kwargs, ) def retrieve_all_relevant_spans_for_all_documents( retriever: DocumentAwareSpanRetriever, **kwargs, ) -> Optional[pd.DataFrame]: return _retrieve_for_all_documents( retriever=retriever, retrieve_func=retrieve_all_relevant_spans, **kwargs, )