File size: 6,229 Bytes
e7eaeed
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7eaeed
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7eaeed
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
e7eaeed
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7eaeed
 
 
 
 
 
3133b5e
 
 
 
e7eaeed
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import logging
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union

from pytorch_ie.core import Document
from pytorch_ie.core.taskmodule import (
    IterableTaskEncodingDataset,
    TaskEncoding,
    TaskEncodingDataset,
    TaskModule,
)
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Sampler
from typing_extensions import TypeAlias

from .components.sampler import ImbalancedDatasetSampler

DocumentType = TypeVar("DocumentType", bound=Document)
InputEncoding = TypeVar("InputEncoding")
TargetEncoding = TypeVar("TargetEncoding")
DatasetType: TypeAlias = Union[
    TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
    IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
]

logger = logging.getLogger(__name__)


class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
    """A simple LightningDataModule for PIE document datasets.

    A DataModule implements 5 key methods:
        - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
        - setup (things to do on every accelerator in distributed mode)
        - train_dataloader (the training dataloader)
        - val_dataloader (the validation dataloader(s))
        - test_dataloader (the test dataloader(s))

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
    """

    def __init__(
        self,
        taskmodule: TaskModule[DocumentType, InputEncoding, TargetEncoding, Any, Any, Any],
        dataset: Dict[str, Sequence[DocumentType]],
        data_config_path: Optional[str] = None,
        train_split: Optional[str] = "train",
        val_split: Optional[str] = "validation",
        test_split: Optional[str] = "test",
        show_progress_for_encode: bool = False,
        train_sampler: Optional[str] = None,
        dont_shuffle_train: bool = False,
        **dataloader_kwargs,
    ):
        super().__init__()

        self.taskmodule = taskmodule
        self.config_path = data_config_path
        self.dataset = dataset
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.show_progress_for_encode = show_progress_for_encode
        self.train_sampler_name = train_sampler
        self.dataloader_kwargs = dataloader_kwargs
        self.dont_shuffle_train = dont_shuffle_train

        self._data: Dict[str, DatasetType] = {}

    @property
    def num_train(self) -> int:
        if self.train_split is None:
            raise ValueError("no train_split assigned")
        data_train = self._data.get(self.train_split, None)
        if data_train is None:
            raise ValueError("can not get train size if setup() was not yet called")
        if isinstance(data_train, IterableTaskEncodingDataset):
            raise TypeError("IterableTaskEncodingDataset has no length")
        return len(data_train)

    def setup(self, stage: str):
        if stage == "fit":
            split_names = [self.train_split, self.val_split]
        elif stage == "validate":
            split_names = [self.val_split]
        elif stage == "test":
            split_names = [self.test_split]
        else:
            raise NotImplementedError(f"not implemented for stage={stage} ")

        for split in split_names:
            if split is None or split not in self.dataset:
                continue
            task_encoding_dataset = self.taskmodule.encode(
                self.dataset[split],
                encode_target=True,
                as_dataset=True,
                show_progress=self.show_progress_for_encode,
            )
            if not isinstance(
                task_encoding_dataset,
                (TaskEncodingDataset, IterableTaskEncodingDataset),
            ):
                raise TypeError(
                    f"taskmodule.encode did not return a (Iterable)TaskEncodingDataset, but: {type(task_encoding_dataset)}"
                )
            self._data[split] = task_encoding_dataset

    def data_split(self, split: Optional[str] = None) -> DatasetType:
        if split is None or split not in self._data:
            raise ValueError(f"data for split={split} not available")
        return self._data[split]

    def get_train_sampler(
        self,
        sampler_name: str,
        dataset: DatasetType,
    ) -> Sampler:
        if sampler_name == "imbalanced_dataset":
            # for now, this work only with targets that have a single entry
            return ImbalancedDatasetSampler(
                dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
            )
        else:
            raise ValueError(f"unknown sampler name: {sampler_name}")

    def train_dataloader(self):
        ds = self.data_split(self.train_split)
        if self.train_sampler_name is not None:
            sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
        else:
            sampler = None
        # don't shuffle streamed datasets or if we use a sampler or if we explicitly set dont_shuffle_train
        shuffle = not self.dont_shuffle_train and not (
            isinstance(ds, IterableTaskEncodingDataset) or sampler is not None
        )
        if not shuffle:
            logger.warning("not shuffling train dataloader")
        return DataLoader(
            dataset=ds,
            sampler=sampler,
            collate_fn=self.taskmodule.collate,
            shuffle=shuffle,
            **self.dataloader_kwargs,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.data_split(self.val_split),
            collate_fn=self.taskmodule.collate,
            shuffle=False,
            **self.dataloader_kwargs,
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_split(self.test_split),
            collate_fn=self.taskmodule.collate,
            shuffle=False,
            **self.dataloader_kwargs,
        )