File size: 4,473 Bytes
ced4316 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
dotenv=True,
)
import argparse
import logging
import os
from typing import List, Optional, TypeVar
from pie_datasets import load_dataset
from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
from pytorch_ie.documents import (
TextBasedDocument,
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
TextDocumentWithLabeledPartitions,
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)
from src.document.processing import replace_substrings_in_text_with_spaces
logger = logging.getLogger(__name__)
def save_abstract_and_remaining_text(
doc: TextDocumentWithLabeledPartitions, base_path: str
) -> None:
abstract_annotations = [
span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract"
]
if len(abstract_annotations) != 1:
logger.warning(
f"Expected exactly one abstract annotation, found {len(abstract_annotations)}"
)
return
abstract_annotation = abstract_annotations[0]
text_abstract = doc.text[abstract_annotation.start : abstract_annotation.end]
text_remaining = doc.text[abstract_annotation.end :]
with open(
f"{base_path}.abstract.{abstract_annotation.start}_{abstract_annotation.end}.txt", "w"
) as f:
f.write(text_abstract)
with open(f"{base_path}.remaining.{abstract_annotation.end}.txt", "w") as f:
f.write(text_remaining)
D_text = TypeVar("D_text", bound=TextBasedDocument)
def clean_doc(doc: D_text) -> D_text:
# remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing
# pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are
# removed at completely.
doc = replace_substrings_in_text_with_spaces(
doc,
substrings=[
"</H2>",
"<H3>",
"</Document>",
"<H1>",
"<H2>",
"</H3>",
"</H1>",
"<Abstract>",
"</Abstract>",
],
)
return doc
def main(out_dir: str, doc_id_whitelist: Optional[List[str]] = None) -> None:
logger.info("Loading dataset from pie/sciarg")
sciarg_with_abstracts = load_dataset(
"pie/sciarg",
revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210",
split="train",
)
if issubclass(sciarg_with_abstracts.document_type, BratDocument):
ds_converted = sciarg_with_abstracts.to_document_type(
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
)
elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans):
ds_converted = sciarg_with_abstracts.to_document_type(
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
)
else:
raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}")
ds_clean = ds_converted.map(clean_doc)
if doc_id_whitelist is not None:
num_before = len(ds_clean)
ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist]
logger.info(
f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist"
)
os.makedirs(out_dir, exist_ok=True)
logger.info(f"Saving dataset to {out_dir}")
for doc in ds_clean:
save_abstract_and_remaining_text(doc, os.path.join(out_dir, doc.id))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Split SciArg dataset into abstract and remaining text"
)
parser.add_argument(
"--out_dir",
type=str,
default="data/datasets/sciarg/abstracts_and_remaining_text",
help="Path to save the split data",
)
parser.add_argument(
"--doc_id_whitelist",
type=str,
nargs="+",
default=["A32", "A33", "A34", "A35", "A36", "A37", "A38", "A39", "A40"],
help="List of document ids to include in the split",
)
logging.basicConfig(level=logging.INFO)
kwargs = vars(parser.parse_args())
# allow for "all" to include all documents
if len(kwargs["doc_id_whitelist"]) == 1 and kwargs["doc_id_whitelist"][0].lower() == "all":
kwargs["doc_id_whitelist"] = None
main(**kwargs)
logger.info("Done")
|