File size: 4,473 Bytes
ced4316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import pyrootutils

root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".project-root"],
    pythonpath=True,
    dotenv=True,
)

import argparse
import logging
import os
from typing import List, Optional, TypeVar

from pie_datasets import load_dataset
from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
from pytorch_ie.documents import (
    TextBasedDocument,
    TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
    TextDocumentWithLabeledPartitions,
    TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)

from src.document.processing import replace_substrings_in_text_with_spaces

logger = logging.getLogger(__name__)


def save_abstract_and_remaining_text(
    doc: TextDocumentWithLabeledPartitions, base_path: str
) -> None:
    abstract_annotations = [
        span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract"
    ]
    if len(abstract_annotations) != 1:
        logger.warning(
            f"Expected exactly one abstract annotation, found {len(abstract_annotations)}"
        )
        return
    abstract_annotation = abstract_annotations[0]
    text_abstract = doc.text[abstract_annotation.start : abstract_annotation.end]
    text_remaining = doc.text[abstract_annotation.end :]
    with open(
        f"{base_path}.abstract.{abstract_annotation.start}_{abstract_annotation.end}.txt", "w"
    ) as f:
        f.write(text_abstract)
    with open(f"{base_path}.remaining.{abstract_annotation.end}.txt", "w") as f:
        f.write(text_remaining)


D_text = TypeVar("D_text", bound=TextBasedDocument)


def clean_doc(doc: D_text) -> D_text:
    # remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing
    # pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are
    # removed at completely.
    doc = replace_substrings_in_text_with_spaces(
        doc,
        substrings=[
            "</H2>",
            "<H3>",
            "</Document>",
            "<H1>",
            "<H2>",
            "</H3>",
            "</H1>",
            "<Abstract>",
            "</Abstract>",
        ],
    )
    return doc


def main(out_dir: str, doc_id_whitelist: Optional[List[str]] = None) -> None:
    logger.info("Loading dataset from pie/sciarg")
    sciarg_with_abstracts = load_dataset(
        "pie/sciarg",
        revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210",
        split="train",
    )
    if issubclass(sciarg_with_abstracts.document_type, BratDocument):
        ds_converted = sciarg_with_abstracts.to_document_type(
            TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
        )
    elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans):
        ds_converted = sciarg_with_abstracts.to_document_type(
            TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
        )
    else:
        raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}")

    ds_clean = ds_converted.map(clean_doc)
    if doc_id_whitelist is not None:
        num_before = len(ds_clean)
        ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist]
        logger.info(
            f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist"
        )

    os.makedirs(out_dir, exist_ok=True)
    logger.info(f"Saving dataset to {out_dir}")
    for doc in ds_clean:
        save_abstract_and_remaining_text(doc, os.path.join(out_dir, doc.id))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Split SciArg dataset into abstract and remaining text"
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        default="data/datasets/sciarg/abstracts_and_remaining_text",
        help="Path to save the split data",
    )
    parser.add_argument(
        "--doc_id_whitelist",
        type=str,
        nargs="+",
        default=["A32", "A33", "A34", "A35", "A36", "A37", "A38", "A39", "A40"],
        help="List of document ids to include in the split",
    )

    logging.basicConfig(level=logging.INFO)

    kwargs = vars(parser.parse_args())
    # allow for "all" to include all documents
    if len(kwargs["doc_id_whitelist"]) == 1 and kwargs["doc_id_whitelist"][0].lower() == "all":
        kwargs["doc_id_whitelist"] = None
    main(**kwargs)
    logger.info("Done")