|
from collections.abc import Iterable |
|
|
|
import pyrootutils |
|
from pytorch_ie import Document |
|
|
|
root = pyrootutils.setup_root( |
|
search_from=__file__, |
|
indicator=[".project-root"], |
|
pythonpath=True, |
|
dotenv=True, |
|
) |
|
|
|
import argparse |
|
from functools import partial |
|
from typing import Callable, List, Optional, Union |
|
|
|
import pandas as pd |
|
from pie_datasets import Dataset, load_dataset |
|
from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans |
|
from pie_modules.document.processing import RelationArgumentSorter, SpansViaRelationMerger |
|
from pytorch_ie.metrics import F1Metric |
|
|
|
from src.document.processing import align_predicted_span_annotations |
|
|
|
|
|
def add_annotations_as_predictions(document: BratDocument, other: BratDocument) -> BratDocument: |
|
document = document.copy() |
|
other = other.copy() |
|
document.spans.predictions.extend(other.spans.clear()) |
|
gold2gold_span_mapping = {span: span for span in document.spans} |
|
predicted2maybe_gold_span = {} |
|
for span in document.spans.predictions: |
|
predicted2maybe_gold_span[span] = gold2gold_span_mapping.get(span, span) |
|
predicted_relations = [ |
|
rel.copy( |
|
head=predicted2maybe_gold_span[rel.head], tail=predicted2maybe_gold_span[rel.tail] |
|
) |
|
for rel in other.relations.clear() |
|
] |
|
document.relations.predictions.extend(predicted_relations) |
|
return document |
|
|
|
|
|
def remove_annotations_existing_in_other( |
|
document: BratDocumentWithMergedSpans, other: BratDocumentWithMergedSpans |
|
) -> BratDocumentWithMergedSpans: |
|
result = document.copy(with_annotations=False) |
|
document = document.copy() |
|
other = other.copy() |
|
|
|
spans = set(document.spans.clear()) - set(other.spans.clear()) |
|
relations = set(document.relations.clear()) - set(other.relations.clear()) |
|
result.spans.extend(spans) |
|
result.relations.extend(relations) |
|
|
|
return result |
|
|
|
|
|
def unnest_dict(d): |
|
result = {} |
|
for key, value in d.items(): |
|
if isinstance(value, dict): |
|
unnested = unnest_dict(value) |
|
for k, v in unnested.items(): |
|
result[(key,) + k] = v |
|
else: |
|
result[(key,)] = value |
|
return result |
|
|
|
|
|
def calc_brat_iaas( |
|
annotator_dirs: List[str], |
|
ignore_annotation_dir: Optional[str] = None, |
|
combine_fragmented_spans_via_relation: Optional[str] = None, |
|
sort_arguments_of_relations: Optional[List[str]] = None, |
|
align_spans: bool = False, |
|
show_results: bool = False, |
|
per_file: bool = False, |
|
) -> Union[pd.Series, List[pd.Series]]: |
|
if len(annotator_dirs) < 2: |
|
raise ValueError("At least two annotation dirs must be provided") |
|
|
|
span_aligner = None |
|
if align_spans: |
|
span_aligner = partial(align_predicted_span_annotations, span_layer="spans") |
|
|
|
if combine_fragmented_spans_via_relation is not None: |
|
print(f"Combine fragmented spans via {combine_fragmented_spans_via_relation} relations") |
|
merger = SpansViaRelationMerger( |
|
relation_layer="relations", |
|
link_relation_label=combine_fragmented_spans_via_relation, |
|
create_multi_spans=True, |
|
result_document_type=BratDocument, |
|
result_field_mapping={"spans": "spans", "relations": "relations"}, |
|
combine_scores_method="product", |
|
) |
|
else: |
|
merger = None |
|
|
|
if sort_arguments_of_relations is not None and len(sort_arguments_of_relations) > 0: |
|
print(f"Sort arguments of relations with labels {sort_arguments_of_relations}") |
|
relation_argument_sorter = RelationArgumentSorter( |
|
relation_layer="relations", |
|
label_whitelist=sort_arguments_of_relations, |
|
) |
|
else: |
|
relation_argument_sorter = None |
|
|
|
all_docs = [ |
|
load_dataset( |
|
"pie/brat", |
|
name="merge_fragmented_spans", |
|
base_dataset_kwargs=dict(data_dir=annotation_dir), |
|
split="train", |
|
).map(lambda doc: doc.deduplicate_annotations()) |
|
for annotation_dir in annotator_dirs |
|
] |
|
|
|
if ignore_annotation_dir is not None: |
|
print(f"Ignoring annotations loaded from {ignore_annotation_dir}") |
|
ignore_annotation_docs = load_dataset( |
|
"pie/brat", |
|
name="merge_fragmented_spans", |
|
base_dataset_kwargs=dict(data_dir=ignore_annotation_dir), |
|
split="train", |
|
) |
|
ignore_annotation_docs_dict = {doc.id: doc for doc in ignore_annotation_docs} |
|
all_docs = [ |
|
docs.map( |
|
lambda doc: remove_annotations_existing_in_other( |
|
doc, other=ignore_annotation_docs_dict[doc.id] |
|
) |
|
) |
|
for docs in all_docs |
|
] |
|
|
|
if relation_argument_sorter is not None: |
|
all_docs = [docs.map(relation_argument_sorter) for docs in all_docs] |
|
|
|
if per_file: |
|
results_per_doc = [] |
|
for docs_tuple in zip(*all_docs): |
|
if show_results: |
|
print(f"\ncalculate scores for document id={docs_tuple[0].id} ...") |
|
docs = [Dataset.from_documents([doc]) for doc in docs_tuple] |
|
result_per_doc = calc_brat_iaas_for_docs( |
|
docs, span_aligner=span_aligner, merger=merger, show_results=show_results |
|
) |
|
results_per_doc.append(result_per_doc) |
|
return results_per_doc |
|
|
|
else: |
|
return calc_brat_iaas_for_docs( |
|
all_docs, span_aligner=span_aligner, merger=merger, show_results=show_results |
|
) |
|
|
|
|
|
def calc_brat_iaas_for_docs( |
|
all_docs: List[Dataset], |
|
span_aligner: Optional[Callable] = None, |
|
merger: Optional[Callable] = None, |
|
show_results: bool = False, |
|
) -> pd.Series: |
|
num_annotators = len(all_docs) |
|
all_docs_dict = [{doc.id: doc for doc in docs} for docs in all_docs] |
|
gold_predicted = {} |
|
for gold_annotator_idx in range(num_annotators): |
|
gold = all_docs[gold_annotator_idx] |
|
for predicted_annotator_idx in range(num_annotators): |
|
if gold_annotator_idx == predicted_annotator_idx: |
|
continue |
|
predicted_dict = all_docs_dict[predicted_annotator_idx] |
|
gold_predicted[(gold_annotator_idx, predicted_annotator_idx)] = gold.map( |
|
lambda doc: add_annotations_as_predictions(doc, other=predicted_dict[doc.id]) |
|
) |
|
|
|
spans_metric = F1Metric(layer="spans", labels="INFERRED", show_as_markdown=True) |
|
relations_metric = F1Metric(layer="relations", labels="INFERRED", show_as_markdown=True) |
|
|
|
metric_values = {} |
|
for gold_annotator_idx, predicted_annotator_idx in gold_predicted: |
|
print( |
|
f"calculate scores for annotations {gold_annotator_idx} -> {predicted_annotator_idx}" |
|
) |
|
for doc in gold_predicted[(gold_annotator_idx, predicted_annotator_idx)]: |
|
if span_aligner is not None: |
|
doc = span_aligner(doc) |
|
if merger is not None: |
|
doc = merger(doc) |
|
spans_metric(doc) |
|
relations_metric(doc) |
|
metric_id = f"gold:{gold_annotator_idx},predicted:{predicted_annotator_idx}" |
|
metric_values[metric_id] = { |
|
"spans": spans_metric.compute(reset=True), |
|
"relations": relations_metric.compute(reset=True), |
|
} |
|
|
|
result = pd.Series(unnest_dict(metric_values)) |
|
if show_results: |
|
metric_values_series_mean = result.unstack(0).mean(axis=1) |
|
metric_values_relations = metric_values_series_mean.xs("relations").unstack() |
|
metric_values_spans = metric_values_series_mean.xs("spans").unstack() |
|
|
|
print("\nspans:") |
|
print(metric_values_spans.round(decimals=3).to_markdown()) |
|
|
|
print("\nrelations:") |
|
print(metric_values_relations.round(decimals=3).to_markdown()) |
|
|
|
return result |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
""" |
|
example call: |
|
python calc_iaa_for_brat.py \ |
|
--annotation-dirs annotations/sciarg/v0.9/with_abstracts_rin annotations/sciarg/v0.9/with_abstracts_alisa \ |
|
--ignore-annotation-dir annotations/sciarg/v0.9/original |
|
""" |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Calculate inter-annotator agreement for spans and relations in means of F1 " |
|
"(exact match, i.e. offsets / arguments and labels must match) for two or more BRAT " |
|
"annotation directories." |
|
) |
|
parser.add_argument( |
|
"--annotation-dirs", |
|
type=str, |
|
required=True, |
|
nargs="+", |
|
help="List of annotation directories. At least two directories must be provided.", |
|
) |
|
parser.add_argument( |
|
"--ignore-annotation-dir", |
|
type=str, |
|
default=None, |
|
help="If set, ignore annotations loaded from this directory.", |
|
) |
|
parser.add_argument( |
|
"--combine-fragmented-spans-via-relation", |
|
type=str, |
|
default=None, |
|
help="If set, combine fragmented spans via this relation type.", |
|
) |
|
parser.add_argument( |
|
"--sort-arguments-of-relations", |
|
type=str, |
|
default=None, |
|
nargs="+", |
|
help="If set, sort the arguments of the relations with the given labels.", |
|
) |
|
parser.add_argument( |
|
"--align-spans", |
|
action="store_true", |
|
help="If set, align the spans of the predicted annotations to the gold annotations.", |
|
) |
|
parser.add_argument( |
|
"--per-file", |
|
action="store_true", |
|
help="If set, calculate IAA per file.", |
|
) |
|
args = parser.parse_args() |
|
|
|
metric_values_series = calc_brat_iaas( |
|
annotator_dirs=args.annotation_dirs, |
|
ignore_annotation_dir=args.ignore_annotation_dir, |
|
combine_fragmented_spans_via_relation=args.combine_fragmented_spans_via_relation, |
|
sort_arguments_of_relations=args.sort_arguments_of_relations, |
|
align_spans=args.align_spans, |
|
per_file=args.per_file, |
|
show_results=True, |
|
) |
|
|