|
import pyrootutils |
|
|
|
root = pyrootutils.setup_root( |
|
search_from=__file__, |
|
indicator=[".project-root"], |
|
pythonpath=True, |
|
|
|
) |
|
|
|
import argparse |
|
import logging |
|
import os |
|
from collections import defaultdict |
|
from typing import List, Optional, Sequence, Tuple, TypeVar |
|
|
|
import pandas as pd |
|
from pie_datasets import load_dataset |
|
from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans |
|
from pytorch_ie.annotations import LabeledMultiSpan |
|
from pytorch_ie.documents import ( |
|
TextBasedDocument, |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, |
|
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
) |
|
|
|
from src.document.processing import replace_substrings_in_text_with_spaces |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def multi_span_is_in_span(multi_span: LabeledMultiSpan, range_span: Tuple[int, int]) -> bool: |
|
start, end = range_span |
|
starts, ends = zip(*multi_span.slices) |
|
return start <= min(starts) and max(ends) <= end |
|
|
|
|
|
def filter_multi_spans( |
|
multi_spans: Sequence[LabeledMultiSpan], filter_span: Tuple[int, int] |
|
) -> List[LabeledMultiSpan]: |
|
return [ |
|
span |
|
for span in multi_spans |
|
if multi_span_is_in_span(multi_span=span, range_span=filter_span) |
|
] |
|
|
|
|
|
def shift_multi_span_slices( |
|
slices: Sequence[Tuple[int, int]], shift: int |
|
) -> List[Tuple[int, int]]: |
|
return [(start + shift, end + shift) for start, end in slices] |
|
|
|
|
|
def construct_gold_retrievals( |
|
doc: TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, |
|
symmetric_relations: Optional[List[str]] = None, |
|
relation_label_whitelist: Optional[List[str]] = None, |
|
) -> Optional[pd.DataFrame]: |
|
abstract_annotations = [ |
|
span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract" |
|
] |
|
if len(abstract_annotations) != 1: |
|
logger.warning( |
|
f"Expected exactly one abstract annotation, found {len(abstract_annotations)}" |
|
) |
|
return None |
|
abstract_annotation = abstract_annotations[0] |
|
span_abstract = (abstract_annotation.start, abstract_annotation.end) |
|
span_remaining = (abstract_annotation.end, len(doc.text)) |
|
labeled_multi_spans = list(doc.labeled_multi_spans) |
|
spans_in_abstract = set( |
|
span for span in labeled_multi_spans if multi_span_is_in_span(span, span_abstract) |
|
) |
|
spans_in_remaining = set( |
|
span for span in labeled_multi_spans if multi_span_is_in_span(span, span_remaining) |
|
) |
|
spans_not_covered = set(labeled_multi_spans) - spans_in_abstract - spans_in_remaining |
|
if len(spans_not_covered) > 0: |
|
logger.warning( |
|
f"Found {len(spans_not_covered)} spans not covered by abstract or remaining text" |
|
) |
|
|
|
rel_arg_and_label2other = defaultdict(list) |
|
for rel in doc.binary_relations: |
|
rel_arg_and_label2other[rel.head].append((rel.tail, rel.label)) |
|
if symmetric_relations is not None and rel.label in symmetric_relations: |
|
label_reversed = rel.label |
|
else: |
|
label_reversed = f"{rel.label}_reversed" |
|
rel_arg_and_label2other[rel.tail].append((rel.head, label_reversed)) |
|
|
|
result_rows = [] |
|
for rel in doc.binary_relations: |
|
|
|
if rel.label == "semantically_same": |
|
if rel.head in spans_in_abstract and rel.tail in spans_in_remaining: |
|
|
|
|
|
candidate_spans_with_label = rel_arg_and_label2other[rel.tail] |
|
for candidate_span, rel_label in candidate_spans_with_label: |
|
if ( |
|
relation_label_whitelist is not None |
|
and rel_label not in relation_label_whitelist |
|
): |
|
continue |
|
result_row = { |
|
"doc_id": f"{doc.id}.remaining.{span_remaining[0]}.txt", |
|
"query_doc_id": f"{doc.id}.abstract.{span_abstract[0]}_{span_abstract[1]}.txt", |
|
"span": shift_multi_span_slices(candidate_span.slices, -span_remaining[0]), |
|
"query_span": shift_multi_span_slices(rel.head.slices, -span_abstract[0]), |
|
"ref_span": shift_multi_span_slices(rel.tail.slices, -span_remaining[0]), |
|
"type": rel_label, |
|
"label": candidate_span.label, |
|
"ref_label": rel.tail.label, |
|
} |
|
result_rows.append(result_row) |
|
|
|
if len(result_rows) > 0: |
|
return pd.DataFrame(result_rows) |
|
else: |
|
return None |
|
|
|
|
|
D_text = TypeVar("D_text", bound=TextBasedDocument) |
|
|
|
|
|
def clean_doc(doc: D_text) -> D_text: |
|
|
|
|
|
|
|
doc = replace_substrings_in_text_with_spaces( |
|
doc, |
|
substrings=[ |
|
"</H2>", |
|
"<H3>", |
|
"</Document>", |
|
"<H1>", |
|
"<H2>", |
|
"</H3>", |
|
"</H1>", |
|
"<Abstract>", |
|
"</Abstract>", |
|
], |
|
) |
|
return doc |
|
|
|
|
|
def main( |
|
data_dir: str, |
|
out_path: str, |
|
doc_id_whitelist: Optional[List[str]] = None, |
|
symmetric_relations: Optional[List[str]] = None, |
|
relation_label_whitelist: Optional[List[str]] = None, |
|
) -> None: |
|
logger.info(f"Loading dataset from {data_dir}") |
|
sciarg_with_abstracts = load_dataset( |
|
"pie/sciarg", |
|
revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210", |
|
base_dataset_kwargs={"data_dir": data_dir, "split_paths": None}, |
|
name="resolve_parts_of_same", |
|
split="train", |
|
) |
|
if issubclass(sciarg_with_abstracts.document_type, BratDocument): |
|
ds_converted = sciarg_with_abstracts.to_document_type( |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions |
|
) |
|
elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans): |
|
ds_converted = sciarg_with_abstracts.to_document_type( |
|
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
) |
|
else: |
|
raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}") |
|
|
|
ds_clean = ds_converted.map(clean_doc) |
|
if doc_id_whitelist is not None: |
|
num_before = len(ds_clean) |
|
ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist] |
|
logger.info( |
|
f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist" |
|
) |
|
|
|
results_per_doc = [ |
|
construct_gold_retrievals( |
|
doc, |
|
symmetric_relations=symmetric_relations, |
|
relation_label_whitelist=relation_label_whitelist, |
|
) |
|
for doc in ds_clean |
|
] |
|
results_per_doc_not_empty = [doc for doc in results_per_doc if doc is not None] |
|
if len(results_per_doc_not_empty) > 0: |
|
results = pd.concat(results_per_doc_not_empty, ignore_index=True) |
|
|
|
results = results.sort_values( |
|
by=results.columns.tolist(), ignore_index=True, key=lambda s: s.apply(str) |
|
) |
|
os.makedirs(os.path.dirname(out_path), exist_ok=True) |
|
logger.info(f"Saving result ({len(results)}) to {out_path}") |
|
results.to_json(out_path) |
|
else: |
|
logger.warning("No results found") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Create gold retrievals for SciArg-abstracts-remaining in the same format as the retrieval results" |
|
) |
|
parser.add_argument( |
|
"--data_dir", |
|
type=str, |
|
default="data/annotations/sciarg-with-abstracts-and-cross-section-rels", |
|
help="Path to the sciarg data directory", |
|
) |
|
parser.add_argument( |
|
"--out_path", |
|
type=str, |
|
default="data/retrieval_results/sciarg-with-abstracts-and-cross-section-rels/gold.json", |
|
help="Path to save the results", |
|
) |
|
parser.add_argument( |
|
"--symmetric_relations", |
|
type=str, |
|
nargs="+", |
|
default=None, |
|
help="Relations that are symmetric, i.e., if A is related to B, then B is related to A", |
|
) |
|
parser.add_argument( |
|
"--relation_label_whitelist", |
|
type=str, |
|
nargs="+", |
|
default=None, |
|
help="Only consider relations with these labels", |
|
) |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
kwargs = vars(parser.parse_args()) |
|
main(**kwargs) |
|
logger.info("Done") |
|
|