import pyrootutils root = pyrootutils.setup_root( search_from=__file__, indicator=[".project-root"], pythonpath=True, dotenv=True, ) import argparse import logging import os from typing import List, Optional, TypeVar from pie_datasets import load_dataset from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans from pytorch_ie.documents import ( TextBasedDocument, TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, TextDocumentWithLabeledPartitions, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, ) from src.document.processing import replace_substrings_in_text_with_spaces logger = logging.getLogger(__name__) def save_abstract_and_remaining_text( doc: TextDocumentWithLabeledPartitions, base_path: str ) -> None: abstract_annotations = [ span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract" ] if len(abstract_annotations) != 1: logger.warning( f"Expected exactly one abstract annotation, found {len(abstract_annotations)}" ) return abstract_annotation = abstract_annotations[0] text_abstract = doc.text[abstract_annotation.start : abstract_annotation.end] text_remaining = doc.text[abstract_annotation.end :] with open( f"{base_path}.abstract.{abstract_annotation.start}_{abstract_annotation.end}.txt", "w" ) as f: f.write(text_abstract) with open(f"{base_path}.remaining.{abstract_annotation.end}.txt", "w") as f: f.write(text_remaining) D_text = TypeVar("D_text", bound=TextBasedDocument) def clean_doc(doc: D_text) -> D_text: # remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing # pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are # removed at completely. doc = replace_substrings_in_text_with_spaces( doc, substrings=[ "", "

", "", "

", "

", "

", "", "", "", ], ) return doc def main(out_dir: str, doc_id_whitelist: Optional[List[str]] = None) -> None: logger.info("Loading dataset from pie/sciarg") sciarg_with_abstracts = load_dataset( "pie/sciarg", revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210", split="train", ) if issubclass(sciarg_with_abstracts.document_type, BratDocument): ds_converted = sciarg_with_abstracts.to_document_type( TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions ) elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans): ds_converted = sciarg_with_abstracts.to_document_type( TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions ) else: raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}") ds_clean = ds_converted.map(clean_doc) if doc_id_whitelist is not None: num_before = len(ds_clean) ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist] logger.info( f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist" ) os.makedirs(out_dir, exist_ok=True) logger.info(f"Saving dataset to {out_dir}") for doc in ds_clean: save_abstract_and_remaining_text(doc, os.path.join(out_dir, doc.id)) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Split SciArg dataset into abstract and remaining text" ) parser.add_argument( "--out_dir", type=str, default="data/datasets/sciarg/abstracts_and_remaining_text", help="Path to save the split data", ) parser.add_argument( "--doc_id_whitelist", type=str, nargs="+", default=["A32", "A33", "A34", "A35", "A36", "A37", "A38", "A39", "A40"], help="List of document ids to include in the split", ) logging.basicConfig(level=logging.INFO) kwargs = vars(parser.parse_args()) # allow for "all" to include all documents if len(kwargs["doc_id_whitelist"]) == 1 and kwargs["doc_id_whitelist"][0].lower() == "all": kwargs["doc_id_whitelist"] = None main(**kwargs) logger.info("Done")