|
import json |
|
import logging |
|
from typing import Iterable, Optional, Sequence, Union |
|
|
|
import gradio as gr |
|
from hydra.utils import instantiate |
|
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger |
|
|
|
|
|
from pie_modules.models import * |
|
from pie_modules.taskmodules import * |
|
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE |
|
from pytorch_ie import Pipeline |
|
from pytorch_ie.annotations import LabeledSpan |
|
from pytorch_ie.documents import ( |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, |
|
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
) |
|
|
|
|
|
from pytorch_ie.models import * |
|
from pytorch_ie.taskmodules import * |
|
|
|
from src.utils import parse_config |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_merger() -> SpansViaRelationMerger: |
|
return SpansViaRelationMerger( |
|
relation_layer="binary_relations", |
|
link_relation_label="parts_of_same", |
|
create_multi_spans=True, |
|
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, |
|
result_field_mapping={ |
|
"labeled_spans": "labeled_multi_spans", |
|
"binary_relations": "binary_relations", |
|
"labeled_partitions": "labeled_partitions", |
|
}, |
|
) |
|
|
|
|
|
def annotate_document( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
argumentation_model: Pipeline, |
|
handle_parts_of_same: bool = False, |
|
) -> Union[ |
|
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, |
|
]: |
|
"""Annotate a document with the provided pipeline. |
|
|
|
Args: |
|
document: The document to annotate. |
|
argumentation_model: The pipeline to use for annotation. |
|
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span. |
|
""" |
|
|
|
|
|
result: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions = argumentation_model( |
|
document, inplace=True |
|
) |
|
|
|
if handle_parts_of_same: |
|
merger = get_merger() |
|
result = merger(result) |
|
|
|
return result |
|
|
|
|
|
def annotate_documents( |
|
documents: Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions], |
|
argumentation_model: Pipeline, |
|
handle_parts_of_same: bool = False, |
|
) -> Union[ |
|
Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions], |
|
Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions], |
|
]: |
|
"""Annotate a sequence of documents with the provided pipeline. |
|
|
|
Args: |
|
documents: The documents to annotate. |
|
argumentation_model: The pipeline to use for annotation. |
|
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span. |
|
""" |
|
|
|
result = argumentation_model(documents, inplace=True) |
|
|
|
if handle_parts_of_same: |
|
merger = get_merger() |
|
result = [merger(document) for document in result] |
|
|
|
return result |
|
|
|
|
|
def create_document( |
|
text: str, doc_id: str, split_regex: Optional[str] = None |
|
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: |
|
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided |
|
text. |
|
|
|
Parameters: |
|
text: The text to process. |
|
doc_id: The ID of the document. |
|
split_regex: A regular expression pattern to use for splitting the text into partitions. |
|
|
|
Returns: |
|
The processed document. |
|
""" |
|
|
|
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( |
|
id=doc_id, text=text, metadata={} |
|
) |
|
if split_regex is not None: |
|
partitioner = RegexPartitioner( |
|
pattern=split_regex, partition_layer_name="labeled_partitions" |
|
) |
|
document = partitioner(document) |
|
else: |
|
|
|
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) |
|
return document |
|
|
|
|
|
def create_documents( |
|
texts: Iterable[str], doc_ids: Iterable[str], split_regex: Optional[str] = None |
|
) -> Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]: |
|
"""Create a sequence of TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided |
|
texts. |
|
|
|
Parameters: |
|
texts: The texts to process. |
|
doc_ids: The IDs of the documents. |
|
split_regex: A regular expression pattern to use for splitting the text into partitions. |
|
|
|
Returns: |
|
The processed documents. |
|
""" |
|
return [ |
|
create_document(text=text, doc_id=doc_id, split_regex=split_regex) |
|
for text, doc_id in zip(texts, doc_ids) |
|
] |
|
|
|
|
|
def load_argumentation_model(config_str: str, **kwargs) -> Pipeline: |
|
try: |
|
config = parse_config(config_str, format="yaml") |
|
|
|
|
|
|
|
if ( |
|
config.get("_target_") == "pytorch_ie.auto.AutoPipeline.from_pretrained" |
|
and "revision" in config |
|
): |
|
revision = config.pop("revision") |
|
if "taskmodule_kwargs" not in config: |
|
config["taskmodule_kwargs"] = {} |
|
config["taskmodule_kwargs"]["revision"] = revision |
|
if "model_kwargs" not in config: |
|
config["model_kwargs"] = {} |
|
config["model_kwargs"]["revision"] = revision |
|
model = instantiate(config, **kwargs) |
|
gr.Info(f"Loaded argumentation model: {json.dumps({**config, **kwargs})}") |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load argumentation model: {e}") |
|
|
|
return model |
|
|
|
|
|
def set_relation_types( |
|
argumentation_model: Pipeline, |
|
default: Optional[Sequence[str]] = None, |
|
) -> gr.Dropdown: |
|
if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE): |
|
relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"] |
|
else: |
|
raise gr.Error("Unsupported taskmodule for relation types") |
|
|
|
return gr.Dropdown( |
|
choices=relation_types, |
|
label="Argumentative Relation Types", |
|
value=default, |
|
multiselect=True, |
|
) |
|
|