from collections.abc import Iterable import pyrootutils from pytorch_ie import Document root = pyrootutils.setup_root( search_from=__file__, indicator=[".project-root"], pythonpath=True, dotenv=True, ) import argparse from functools import partial from typing import Callable, List, Optional, Union import pandas as pd from pie_datasets import Dataset, load_dataset from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans from pie_modules.document.processing import RelationArgumentSorter, SpansViaRelationMerger from pytorch_ie.metrics import F1Metric from src.document.processing import align_predicted_span_annotations def add_annotations_as_predictions(document: BratDocument, other: BratDocument) -> BratDocument: document = document.copy() other = other.copy() document.spans.predictions.extend(other.spans.clear()) gold2gold_span_mapping = {span: span for span in document.spans} predicted2maybe_gold_span = {} for span in document.spans.predictions: predicted2maybe_gold_span[span] = gold2gold_span_mapping.get(span, span) predicted_relations = [ rel.copy( head=predicted2maybe_gold_span[rel.head], tail=predicted2maybe_gold_span[rel.tail] ) for rel in other.relations.clear() ] document.relations.predictions.extend(predicted_relations) return document def remove_annotations_existing_in_other( document: BratDocumentWithMergedSpans, other: BratDocumentWithMergedSpans ) -> BratDocumentWithMergedSpans: result = document.copy(with_annotations=False) document = document.copy() other = other.copy() spans = set(document.spans.clear()) - set(other.spans.clear()) relations = set(document.relations.clear()) - set(other.relations.clear()) result.spans.extend(spans) result.relations.extend(relations) return result def unnest_dict(d): result = {} for key, value in d.items(): if isinstance(value, dict): unnested = unnest_dict(value) for k, v in unnested.items(): result[(key,) + k] = v else: result[(key,)] = value return result def calc_brat_iaas( annotator_dirs: List[str], ignore_annotation_dir: Optional[str] = None, combine_fragmented_spans_via_relation: Optional[str] = None, sort_arguments_of_relations: Optional[List[str]] = None, align_spans: bool = False, show_results: bool = False, per_file: bool = False, ) -> Union[pd.Series, List[pd.Series]]: if len(annotator_dirs) < 2: raise ValueError("At least two annotation dirs must be provided") span_aligner = None if align_spans: span_aligner = partial(align_predicted_span_annotations, span_layer="spans") if combine_fragmented_spans_via_relation is not None: print(f"Combine fragmented spans via {combine_fragmented_spans_via_relation} relations") merger = SpansViaRelationMerger( relation_layer="relations", link_relation_label=combine_fragmented_spans_via_relation, create_multi_spans=True, result_document_type=BratDocument, result_field_mapping={"spans": "spans", "relations": "relations"}, combine_scores_method="product", ) else: merger = None if sort_arguments_of_relations is not None and len(sort_arguments_of_relations) > 0: print(f"Sort arguments of relations with labels {sort_arguments_of_relations}") relation_argument_sorter = RelationArgumentSorter( relation_layer="relations", label_whitelist=sort_arguments_of_relations, # ["parts_of_same", "semantically_same", "contradicts"], ) else: relation_argument_sorter = None all_docs = [ load_dataset( "pie/brat", name="merge_fragmented_spans", base_dataset_kwargs=dict(data_dir=annotation_dir), split="train", ).map(lambda doc: doc.deduplicate_annotations()) for annotation_dir in annotator_dirs ] if ignore_annotation_dir is not None: print(f"Ignoring annotations loaded from {ignore_annotation_dir}") ignore_annotation_docs = load_dataset( "pie/brat", name="merge_fragmented_spans", base_dataset_kwargs=dict(data_dir=ignore_annotation_dir), split="train", ) ignore_annotation_docs_dict = {doc.id: doc for doc in ignore_annotation_docs} all_docs = [ docs.map( lambda doc: remove_annotations_existing_in_other( doc, other=ignore_annotation_docs_dict[doc.id] ) ) for docs in all_docs ] if relation_argument_sorter is not None: all_docs = [docs.map(relation_argument_sorter) for docs in all_docs] if per_file: results_per_doc = [] for docs_tuple in zip(*all_docs): if show_results: print(f"\ncalculate scores for document id={docs_tuple[0].id} ...") docs = [Dataset.from_documents([doc]) for doc in docs_tuple] result_per_doc = calc_brat_iaas_for_docs( docs, span_aligner=span_aligner, merger=merger, show_results=show_results ) results_per_doc.append(result_per_doc) return results_per_doc else: return calc_brat_iaas_for_docs( all_docs, span_aligner=span_aligner, merger=merger, show_results=show_results ) def calc_brat_iaas_for_docs( all_docs: List[Dataset], span_aligner: Optional[Callable] = None, merger: Optional[Callable] = None, show_results: bool = False, ) -> pd.Series: num_annotators = len(all_docs) all_docs_dict = [{doc.id: doc for doc in docs} for docs in all_docs] gold_predicted = {} for gold_annotator_idx in range(num_annotators): gold = all_docs[gold_annotator_idx] for predicted_annotator_idx in range(num_annotators): if gold_annotator_idx == predicted_annotator_idx: continue predicted_dict = all_docs_dict[predicted_annotator_idx] gold_predicted[(gold_annotator_idx, predicted_annotator_idx)] = gold.map( lambda doc: add_annotations_as_predictions(doc, other=predicted_dict[doc.id]) ) spans_metric = F1Metric(layer="spans", labels="INFERRED", show_as_markdown=True) relations_metric = F1Metric(layer="relations", labels="INFERRED", show_as_markdown=True) metric_values = {} for gold_annotator_idx, predicted_annotator_idx in gold_predicted: print( f"calculate scores for annotations {gold_annotator_idx} -> {predicted_annotator_idx}" ) for doc in gold_predicted[(gold_annotator_idx, predicted_annotator_idx)]: if span_aligner is not None: doc = span_aligner(doc) if merger is not None: doc = merger(doc) spans_metric(doc) relations_metric(doc) metric_id = f"gold:{gold_annotator_idx},predicted:{predicted_annotator_idx}" metric_values[metric_id] = { "spans": spans_metric.compute(reset=True), "relations": relations_metric.compute(reset=True), } result = pd.Series(unnest_dict(metric_values)) if show_results: metric_values_series_mean = result.unstack(0).mean(axis=1) metric_values_relations = metric_values_series_mean.xs("relations").unstack() metric_values_spans = metric_values_series_mean.xs("spans").unstack() print("\nspans:") print(metric_values_spans.round(decimals=3).to_markdown()) print("\nrelations:") print(metric_values_relations.round(decimals=3).to_markdown()) return result if __name__ == "__main__": """ example call: python calc_iaa_for_brat.py \ --annotation-dirs annotations/sciarg/v0.9/with_abstracts_rin annotations/sciarg/v0.9/with_abstracts_alisa \ --ignore-annotation-dir annotations/sciarg/v0.9/original """ parser = argparse.ArgumentParser( description="Calculate inter-annotator agreement for spans and relations in means of F1 " "(exact match, i.e. offsets / arguments and labels must match) for two or more BRAT " "annotation directories." ) parser.add_argument( "--annotation-dirs", type=str, required=True, nargs="+", help="List of annotation directories. At least two directories must be provided.", ) parser.add_argument( "--ignore-annotation-dir", type=str, default=None, help="If set, ignore annotations loaded from this directory.", ) parser.add_argument( "--combine-fragmented-spans-via-relation", type=str, default=None, help="If set, combine fragmented spans via this relation type.", ) parser.add_argument( "--sort-arguments-of-relations", type=str, default=None, nargs="+", help="If set, sort the arguments of the relations with the given labels.", ) parser.add_argument( "--align-spans", action="store_true", help="If set, align the spans of the predicted annotations to the gold annotations.", ) parser.add_argument( "--per-file", action="store_true", help="If set, calculate IAA per file.", ) args = parser.parse_args() metric_values_series = calc_brat_iaas( annotator_dirs=args.annotation_dirs, ignore_annotation_dir=args.ignore_annotation_dir, combine_fragmented_spans_via_relation=args.combine_fragmented_spans_via_relation, sort_arguments_of_relations=args.sort_arguments_of_relations, align_spans=args.align_spans, per_file=args.per_file, show_results=True, )