ArneBinder's picture
update from https://github.com/ArneBinder/pie-document-level/pull/397
ced4316 verified
raw
history blame
6.56 kB
import json
import logging
from typing import Iterable, Optional, Sequence, Union
import gradio as gr
from hydra.utils import instantiate
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
# this is required to dynamically load the PIE models
from pie_modules.models import * # noqa: F403
from pie_modules.taskmodules import * # noqa: F403
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
from pytorch_ie import Pipeline
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import (
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)
# this is required to dynamically load the PIE models
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.taskmodules import * # noqa: F403
from src.utils import parse_config
logger = logging.getLogger(__name__)
def get_merger() -> SpansViaRelationMerger:
return SpansViaRelationMerger(
relation_layer="binary_relations",
link_relation_label="parts_of_same",
create_multi_spans=True,
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
result_field_mapping={
"labeled_spans": "labeled_multi_spans",
"binary_relations": "binary_relations",
"labeled_partitions": "labeled_partitions",
},
)
def annotate_document(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
argumentation_model: Pipeline,
handle_parts_of_same: bool = False,
) -> Union[
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
]:
"""Annotate a document with the provided pipeline.
Args:
document: The document to annotate.
argumentation_model: The pipeline to use for annotation.
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
"""
# execute prediction pipeline
result: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions = argumentation_model(
document, inplace=True
)
if handle_parts_of_same:
merger = get_merger()
result = merger(result)
return result
def annotate_documents(
documents: Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
argumentation_model: Pipeline,
handle_parts_of_same: bool = False,
) -> Union[
Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
]:
"""Annotate a sequence of documents with the provided pipeline.
Args:
documents: The documents to annotate.
argumentation_model: The pipeline to use for annotation.
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
"""
# execute prediction pipeline
result = argumentation_model(documents, inplace=True)
if handle_parts_of_same:
merger = get_merger()
result = [merger(document) for document in result]
return result
def create_document(
text: str, doc_id: str, split_regex: Optional[str] = None
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
text.
Parameters:
text: The text to process.
doc_id: The ID of the document.
split_regex: A regular expression pattern to use for splitting the text into partitions.
Returns:
The processed document.
"""
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
id=doc_id, text=text, metadata={}
)
if split_regex is not None:
partitioner = RegexPartitioner(
pattern=split_regex, partition_layer_name="labeled_partitions"
)
document = partitioner(document)
else:
# add single partition from the whole text (the model only considers text in partitions)
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
return document
def create_documents(
texts: Iterable[str], doc_ids: Iterable[str], split_regex: Optional[str] = None
) -> Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
"""Create a sequence of TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
texts.
Parameters:
texts: The texts to process.
doc_ids: The IDs of the documents.
split_regex: A regular expression pattern to use for splitting the text into partitions.
Returns:
The processed documents.
"""
return [
create_document(text=text, doc_id=doc_id, split_regex=split_regex)
for text, doc_id in zip(texts, doc_ids)
]
def load_argumentation_model(config_str: str, **kwargs) -> Pipeline:
try:
config = parse_config(config_str, format="yaml")
# for PIE AutoPipeline, we need to handle the revision separately for
# the taskmodule and the model
if (
config.get("_target_") == "pytorch_ie.auto.AutoPipeline.from_pretrained"
and "revision" in config
):
revision = config.pop("revision")
if "taskmodule_kwargs" not in config:
config["taskmodule_kwargs"] = {}
config["taskmodule_kwargs"]["revision"] = revision
if "model_kwargs" not in config:
config["model_kwargs"] = {}
config["model_kwargs"]["revision"] = revision
model = instantiate(config, **kwargs)
gr.Info(f"Loaded argumentation model: {json.dumps({**config, **kwargs})}")
except Exception as e:
raise gr.Error(f"Failed to load argumentation model: {e}")
return model
def set_relation_types(
argumentation_model: Pipeline,
default: Optional[Sequence[str]] = None,
) -> gr.Dropdown:
if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"]
else:
raise gr.Error("Unsupported taskmodule for relation types")
return gr.Dropdown(
choices=relation_types,
label="Argumentative Relation Types",
value=default,
multiselect=True,
)