import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
dotenv=True,
)
import argparse
import logging
import os
from typing import List, Optional, TypeVar
from pie_datasets import load_dataset
from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
from pytorch_ie.documents import (
TextBasedDocument,
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
TextDocumentWithLabeledPartitions,
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)
from src.document.processing import replace_substrings_in_text_with_spaces
logger = logging.getLogger(__name__)
def save_abstract_and_remaining_text(
doc: TextDocumentWithLabeledPartitions, base_path: str
) -> None:
abstract_annotations = [
span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract"
]
if len(abstract_annotations) != 1:
logger.warning(
f"Expected exactly one abstract annotation, found {len(abstract_annotations)}"
)
return
abstract_annotation = abstract_annotations[0]
text_abstract = doc.text[abstract_annotation.start : abstract_annotation.end]
text_remaining = doc.text[abstract_annotation.end :]
with open(
f"{base_path}.abstract.{abstract_annotation.start}_{abstract_annotation.end}.txt", "w"
) as f:
f.write(text_abstract)
with open(f"{base_path}.remaining.{abstract_annotation.end}.txt", "w") as f:
f.write(text_remaining)
D_text = TypeVar("D_text", bound=TextBasedDocument)
def clean_doc(doc: D_text) -> D_text:
# remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing
# pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are
# removed at completely.
doc = replace_substrings_in_text_with_spaces(
doc,
substrings=[
"",
"
",
"",
"",
"",
"
",
"",
"",
"",
],
)
return doc
def main(out_dir: str, doc_id_whitelist: Optional[List[str]] = None) -> None:
logger.info("Loading dataset from pie/sciarg")
sciarg_with_abstracts = load_dataset(
"pie/sciarg",
revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210",
split="train",
)
if issubclass(sciarg_with_abstracts.document_type, BratDocument):
ds_converted = sciarg_with_abstracts.to_document_type(
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
)
elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans):
ds_converted = sciarg_with_abstracts.to_document_type(
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
)
else:
raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}")
ds_clean = ds_converted.map(clean_doc)
if doc_id_whitelist is not None:
num_before = len(ds_clean)
ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist]
logger.info(
f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist"
)
os.makedirs(out_dir, exist_ok=True)
logger.info(f"Saving dataset to {out_dir}")
for doc in ds_clean:
save_abstract_and_remaining_text(doc, os.path.join(out_dir, doc.id))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Split SciArg dataset into abstract and remaining text"
)
parser.add_argument(
"--out_dir",
type=str,
default="data/datasets/sciarg/abstracts_and_remaining_text",
help="Path to save the split data",
)
parser.add_argument(
"--doc_id_whitelist",
type=str,
nargs="+",
default=["A32", "A33", "A34", "A35", "A36", "A37", "A38", "A39", "A40"],
help="List of document ids to include in the split",
)
logging.basicConfig(level=logging.INFO)
kwargs = vars(parser.parse_args())
# allow for "all" to include all documents
if len(kwargs["doc_id_whitelist"]) == 1 and kwargs["doc_id_whitelist"][0].lower() == "all":
kwargs["doc_id_whitelist"] = None
main(**kwargs)
logger.info("Done")