ScientificArgumentRecommender / src /data /construct_sciarg_abstracts_remaining_gold_retrieval.py
ArneBinder's picture
update from https://github.com/ArneBinder/pie-document-level/pull/397
ced4316 verified
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
# dotenv=True,
)
import argparse
import logging
import os
from collections import defaultdict
from typing import List, Optional, Sequence, Tuple, TypeVar
import pandas as pd
from pie_datasets import load_dataset
from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
from pytorch_ie.annotations import LabeledMultiSpan
from pytorch_ie.documents import (
TextBasedDocument,
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)
from src.document.processing import replace_substrings_in_text_with_spaces
logger = logging.getLogger(__name__)
def multi_span_is_in_span(multi_span: LabeledMultiSpan, range_span: Tuple[int, int]) -> bool:
start, end = range_span
starts, ends = zip(*multi_span.slices)
return start <= min(starts) and max(ends) <= end
def filter_multi_spans(
multi_spans: Sequence[LabeledMultiSpan], filter_span: Tuple[int, int]
) -> List[LabeledMultiSpan]:
return [
span
for span in multi_spans
if multi_span_is_in_span(multi_span=span, range_span=filter_span)
]
def shift_multi_span_slices(
slices: Sequence[Tuple[int, int]], shift: int
) -> List[Tuple[int, int]]:
return [(start + shift, end + shift) for start, end in slices]
def construct_gold_retrievals(
doc: TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
symmetric_relations: Optional[List[str]] = None,
relation_label_whitelist: Optional[List[str]] = None,
) -> Optional[pd.DataFrame]:
abstract_annotations = [
span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract"
]
if len(abstract_annotations) != 1:
logger.warning(
f"Expected exactly one abstract annotation, found {len(abstract_annotations)}"
)
return None
abstract_annotation = abstract_annotations[0]
span_abstract = (abstract_annotation.start, abstract_annotation.end)
span_remaining = (abstract_annotation.end, len(doc.text))
labeled_multi_spans = list(doc.labeled_multi_spans)
spans_in_abstract = set(
span for span in labeled_multi_spans if multi_span_is_in_span(span, span_abstract)
)
spans_in_remaining = set(
span for span in labeled_multi_spans if multi_span_is_in_span(span, span_remaining)
)
spans_not_covered = set(labeled_multi_spans) - spans_in_abstract - spans_in_remaining
if len(spans_not_covered) > 0:
logger.warning(
f"Found {len(spans_not_covered)} spans not covered by abstract or remaining text"
)
rel_arg_and_label2other = defaultdict(list)
for rel in doc.binary_relations:
rel_arg_and_label2other[rel.head].append((rel.tail, rel.label))
if symmetric_relations is not None and rel.label in symmetric_relations:
label_reversed = rel.label
else:
label_reversed = f"{rel.label}_reversed"
rel_arg_and_label2other[rel.tail].append((rel.head, label_reversed))
result_rows = []
for rel in doc.binary_relations:
# we check all semantically_same relations that point from (head) remaining to abstract (tail) ...
if rel.label == "semantically_same":
if rel.head in spans_in_abstract and rel.tail in spans_in_remaining:
# ... and if the head is
# candidate_query_span = rel.tail
candidate_spans_with_label = rel_arg_and_label2other[rel.tail]
for candidate_span, rel_label in candidate_spans_with_label:
if (
relation_label_whitelist is not None
and rel_label not in relation_label_whitelist
):
continue
result_row = {
"doc_id": f"{doc.id}.remaining.{span_remaining[0]}.txt",
"query_doc_id": f"{doc.id}.abstract.{span_abstract[0]}_{span_abstract[1]}.txt",
"span": shift_multi_span_slices(candidate_span.slices, -span_remaining[0]),
"query_span": shift_multi_span_slices(rel.head.slices, -span_abstract[0]),
"ref_span": shift_multi_span_slices(rel.tail.slices, -span_remaining[0]),
"type": rel_label,
"label": candidate_span.label,
"ref_label": rel.tail.label,
}
result_rows.append(result_row)
if len(result_rows) > 0:
return pd.DataFrame(result_rows)
else:
return None
D_text = TypeVar("D_text", bound=TextBasedDocument)
def clean_doc(doc: D_text) -> D_text:
# remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing
# pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are
# removed at completely.
doc = replace_substrings_in_text_with_spaces(
doc,
substrings=[
"</H2>",
"<H3>",
"</Document>",
"<H1>",
"<H2>",
"</H3>",
"</H1>",
"<Abstract>",
"</Abstract>",
],
)
return doc
def main(
data_dir: str,
out_path: str,
doc_id_whitelist: Optional[List[str]] = None,
symmetric_relations: Optional[List[str]] = None,
relation_label_whitelist: Optional[List[str]] = None,
) -> None:
logger.info(f"Loading dataset from {data_dir}")
sciarg_with_abstracts = load_dataset(
"pie/sciarg",
revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210",
base_dataset_kwargs={"data_dir": data_dir, "split_paths": None},
name="resolve_parts_of_same",
split="train",
)
if issubclass(sciarg_with_abstracts.document_type, BratDocument):
ds_converted = sciarg_with_abstracts.to_document_type(
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
)
elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans):
ds_converted = sciarg_with_abstracts.to_document_type(
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
)
else:
raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}")
ds_clean = ds_converted.map(clean_doc)
if doc_id_whitelist is not None:
num_before = len(ds_clean)
ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist]
logger.info(
f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist"
)
results_per_doc = [
construct_gold_retrievals(
doc,
symmetric_relations=symmetric_relations,
relation_label_whitelist=relation_label_whitelist,
)
for doc in ds_clean
]
results_per_doc_not_empty = [doc for doc in results_per_doc if doc is not None]
if len(results_per_doc_not_empty) > 0:
results = pd.concat(results_per_doc_not_empty, ignore_index=True)
# sort to make the output deterministic
results = results.sort_values(
by=results.columns.tolist(), ignore_index=True, key=lambda s: s.apply(str)
)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
logger.info(f"Saving result ({len(results)}) to {out_path}")
results.to_json(out_path)
else:
logger.warning("No results found")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Create gold retrievals for SciArg-abstracts-remaining in the same format as the retrieval results"
)
parser.add_argument(
"--data_dir",
type=str,
default="data/annotations/sciarg-with-abstracts-and-cross-section-rels",
help="Path to the sciarg data directory",
)
parser.add_argument(
"--out_path",
type=str,
default="data/retrieval_results/sciarg-with-abstracts-and-cross-section-rels/gold.json",
help="Path to save the results",
)
parser.add_argument(
"--symmetric_relations",
type=str,
nargs="+",
default=None,
help="Relations that are symmetric, i.e., if A is related to B, then B is related to A",
)
parser.add_argument(
"--relation_label_whitelist",
type=str,
nargs="+",
default=None,
help="Only consider relations with these labels",
)
logging.basicConfig(level=logging.INFO)
kwargs = vars(parser.parse_args())
main(**kwargs)
logger.info("Done")