ScientificArgumentRecommender / src /data /calc_iaa_for_brat.py
ArneBinder's picture
upload https://github.com/ArneBinder/pie-document-level/pull/452
e7eaeed verified
from collections.abc import Iterable
import pyrootutils
from pytorch_ie import Document
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
dotenv=True,
)
import argparse
from functools import partial
from typing import Callable, List, Optional, Union
import pandas as pd
from pie_datasets import Dataset, load_dataset
from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
from pie_modules.document.processing import RelationArgumentSorter, SpansViaRelationMerger
from pytorch_ie.metrics import F1Metric
from src.document.processing import align_predicted_span_annotations
def add_annotations_as_predictions(document: BratDocument, other: BratDocument) -> BratDocument:
document = document.copy()
other = other.copy()
document.spans.predictions.extend(other.spans.clear())
gold2gold_span_mapping = {span: span for span in document.spans}
predicted2maybe_gold_span = {}
for span in document.spans.predictions:
predicted2maybe_gold_span[span] = gold2gold_span_mapping.get(span, span)
predicted_relations = [
rel.copy(
head=predicted2maybe_gold_span[rel.head], tail=predicted2maybe_gold_span[rel.tail]
)
for rel in other.relations.clear()
]
document.relations.predictions.extend(predicted_relations)
return document
def remove_annotations_existing_in_other(
document: BratDocumentWithMergedSpans, other: BratDocumentWithMergedSpans
) -> BratDocumentWithMergedSpans:
result = document.copy(with_annotations=False)
document = document.copy()
other = other.copy()
spans = set(document.spans.clear()) - set(other.spans.clear())
relations = set(document.relations.clear()) - set(other.relations.clear())
result.spans.extend(spans)
result.relations.extend(relations)
return result
def unnest_dict(d):
result = {}
for key, value in d.items():
if isinstance(value, dict):
unnested = unnest_dict(value)
for k, v in unnested.items():
result[(key,) + k] = v
else:
result[(key,)] = value
return result
def calc_brat_iaas(
annotator_dirs: List[str],
ignore_annotation_dir: Optional[str] = None,
combine_fragmented_spans_via_relation: Optional[str] = None,
sort_arguments_of_relations: Optional[List[str]] = None,
align_spans: bool = False,
show_results: bool = False,
per_file: bool = False,
) -> Union[pd.Series, List[pd.Series]]:
if len(annotator_dirs) < 2:
raise ValueError("At least two annotation dirs must be provided")
span_aligner = None
if align_spans:
span_aligner = partial(align_predicted_span_annotations, span_layer="spans")
if combine_fragmented_spans_via_relation is not None:
print(f"Combine fragmented spans via {combine_fragmented_spans_via_relation} relations")
merger = SpansViaRelationMerger(
relation_layer="relations",
link_relation_label=combine_fragmented_spans_via_relation,
create_multi_spans=True,
result_document_type=BratDocument,
result_field_mapping={"spans": "spans", "relations": "relations"},
combine_scores_method="product",
)
else:
merger = None
if sort_arguments_of_relations is not None and len(sort_arguments_of_relations) > 0:
print(f"Sort arguments of relations with labels {sort_arguments_of_relations}")
relation_argument_sorter = RelationArgumentSorter(
relation_layer="relations",
label_whitelist=sort_arguments_of_relations, # ["parts_of_same", "semantically_same", "contradicts"],
)
else:
relation_argument_sorter = None
all_docs = [
load_dataset(
"pie/brat",
name="merge_fragmented_spans",
base_dataset_kwargs=dict(data_dir=annotation_dir),
split="train",
).map(lambda doc: doc.deduplicate_annotations())
for annotation_dir in annotator_dirs
]
if ignore_annotation_dir is not None:
print(f"Ignoring annotations loaded from {ignore_annotation_dir}")
ignore_annotation_docs = load_dataset(
"pie/brat",
name="merge_fragmented_spans",
base_dataset_kwargs=dict(data_dir=ignore_annotation_dir),
split="train",
)
ignore_annotation_docs_dict = {doc.id: doc for doc in ignore_annotation_docs}
all_docs = [
docs.map(
lambda doc: remove_annotations_existing_in_other(
doc, other=ignore_annotation_docs_dict[doc.id]
)
)
for docs in all_docs
]
if relation_argument_sorter is not None:
all_docs = [docs.map(relation_argument_sorter) for docs in all_docs]
if per_file:
results_per_doc = []
for docs_tuple in zip(*all_docs):
if show_results:
print(f"\ncalculate scores for document id={docs_tuple[0].id} ...")
docs = [Dataset.from_documents([doc]) for doc in docs_tuple]
result_per_doc = calc_brat_iaas_for_docs(
docs, span_aligner=span_aligner, merger=merger, show_results=show_results
)
results_per_doc.append(result_per_doc)
return results_per_doc
else:
return calc_brat_iaas_for_docs(
all_docs, span_aligner=span_aligner, merger=merger, show_results=show_results
)
def calc_brat_iaas_for_docs(
all_docs: List[Dataset],
span_aligner: Optional[Callable] = None,
merger: Optional[Callable] = None,
show_results: bool = False,
) -> pd.Series:
num_annotators = len(all_docs)
all_docs_dict = [{doc.id: doc for doc in docs} for docs in all_docs]
gold_predicted = {}
for gold_annotator_idx in range(num_annotators):
gold = all_docs[gold_annotator_idx]
for predicted_annotator_idx in range(num_annotators):
if gold_annotator_idx == predicted_annotator_idx:
continue
predicted_dict = all_docs_dict[predicted_annotator_idx]
gold_predicted[(gold_annotator_idx, predicted_annotator_idx)] = gold.map(
lambda doc: add_annotations_as_predictions(doc, other=predicted_dict[doc.id])
)
spans_metric = F1Metric(layer="spans", labels="INFERRED", show_as_markdown=True)
relations_metric = F1Metric(layer="relations", labels="INFERRED", show_as_markdown=True)
metric_values = {}
for gold_annotator_idx, predicted_annotator_idx in gold_predicted:
print(
f"calculate scores for annotations {gold_annotator_idx} -> {predicted_annotator_idx}"
)
for doc in gold_predicted[(gold_annotator_idx, predicted_annotator_idx)]:
if span_aligner is not None:
doc = span_aligner(doc)
if merger is not None:
doc = merger(doc)
spans_metric(doc)
relations_metric(doc)
metric_id = f"gold:{gold_annotator_idx},predicted:{predicted_annotator_idx}"
metric_values[metric_id] = {
"spans": spans_metric.compute(reset=True),
"relations": relations_metric.compute(reset=True),
}
result = pd.Series(unnest_dict(metric_values))
if show_results:
metric_values_series_mean = result.unstack(0).mean(axis=1)
metric_values_relations = metric_values_series_mean.xs("relations").unstack()
metric_values_spans = metric_values_series_mean.xs("spans").unstack()
print("\nspans:")
print(metric_values_spans.round(decimals=3).to_markdown())
print("\nrelations:")
print(metric_values_relations.round(decimals=3).to_markdown())
return result
if __name__ == "__main__":
"""
example call:
python calc_iaa_for_brat.py \
--annotation-dirs annotations/sciarg/v0.9/with_abstracts_rin annotations/sciarg/v0.9/with_abstracts_alisa \
--ignore-annotation-dir annotations/sciarg/v0.9/original
"""
parser = argparse.ArgumentParser(
description="Calculate inter-annotator agreement for spans and relations in means of F1 "
"(exact match, i.e. offsets / arguments and labels must match) for two or more BRAT "
"annotation directories."
)
parser.add_argument(
"--annotation-dirs",
type=str,
required=True,
nargs="+",
help="List of annotation directories. At least two directories must be provided.",
)
parser.add_argument(
"--ignore-annotation-dir",
type=str,
default=None,
help="If set, ignore annotations loaded from this directory.",
)
parser.add_argument(
"--combine-fragmented-spans-via-relation",
type=str,
default=None,
help="If set, combine fragmented spans via this relation type.",
)
parser.add_argument(
"--sort-arguments-of-relations",
type=str,
default=None,
nargs="+",
help="If set, sort the arguments of the relations with the given labels.",
)
parser.add_argument(
"--align-spans",
action="store_true",
help="If set, align the spans of the predicted annotations to the gold annotations.",
)
parser.add_argument(
"--per-file",
action="store_true",
help="If set, calculate IAA per file.",
)
args = parser.parse_args()
metric_values_series = calc_brat_iaas(
annotator_dirs=args.annotation_dirs,
ignore_annotation_dir=args.ignore_annotation_dir,
combine_fragmented_spans_via_relation=args.combine_fragmented_spans_via_relation,
sort_arguments_of_relations=args.sort_arguments_of_relations,
align_spans=args.align_spans,
per_file=args.per_file,
show_results=True,
)