File size: 6,229 Bytes
e7eaeed 3133b5e e7eaeed 3133b5e e7eaeed 3133b5e e7eaeed 3133b5e e7eaeed 3133b5e e7eaeed 3133b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import logging
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
from pytorch_ie.core import Document
from pytorch_ie.core.taskmodule import (
IterableTaskEncodingDataset,
TaskEncoding,
TaskEncodingDataset,
TaskModule,
)
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Sampler
from typing_extensions import TypeAlias
from .components.sampler import ImbalancedDatasetSampler
DocumentType = TypeVar("DocumentType", bound=Document)
InputEncoding = TypeVar("InputEncoding")
TargetEncoding = TypeVar("TargetEncoding")
DatasetType: TypeAlias = Union[
TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
]
logger = logging.getLogger(__name__)
class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
"""A simple LightningDataModule for PIE document datasets.
A DataModule implements 5 key methods:
- prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
- setup (things to do on every accelerator in distributed mode)
- train_dataloader (the training dataloader)
- val_dataloader (the validation dataloader(s))
- test_dataloader (the test dataloader(s))
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
"""
def __init__(
self,
taskmodule: TaskModule[DocumentType, InputEncoding, TargetEncoding, Any, Any, Any],
dataset: Dict[str, Sequence[DocumentType]],
data_config_path: Optional[str] = None,
train_split: Optional[str] = "train",
val_split: Optional[str] = "validation",
test_split: Optional[str] = "test",
show_progress_for_encode: bool = False,
train_sampler: Optional[str] = None,
dont_shuffle_train: bool = False,
**dataloader_kwargs,
):
super().__init__()
self.taskmodule = taskmodule
self.config_path = data_config_path
self.dataset = dataset
self.train_split = train_split
self.val_split = val_split
self.test_split = test_split
self.show_progress_for_encode = show_progress_for_encode
self.train_sampler_name = train_sampler
self.dataloader_kwargs = dataloader_kwargs
self.dont_shuffle_train = dont_shuffle_train
self._data: Dict[str, DatasetType] = {}
@property
def num_train(self) -> int:
if self.train_split is None:
raise ValueError("no train_split assigned")
data_train = self._data.get(self.train_split, None)
if data_train is None:
raise ValueError("can not get train size if setup() was not yet called")
if isinstance(data_train, IterableTaskEncodingDataset):
raise TypeError("IterableTaskEncodingDataset has no length")
return len(data_train)
def setup(self, stage: str):
if stage == "fit":
split_names = [self.train_split, self.val_split]
elif stage == "validate":
split_names = [self.val_split]
elif stage == "test":
split_names = [self.test_split]
else:
raise NotImplementedError(f"not implemented for stage={stage} ")
for split in split_names:
if split is None or split not in self.dataset:
continue
task_encoding_dataset = self.taskmodule.encode(
self.dataset[split],
encode_target=True,
as_dataset=True,
show_progress=self.show_progress_for_encode,
)
if not isinstance(
task_encoding_dataset,
(TaskEncodingDataset, IterableTaskEncodingDataset),
):
raise TypeError(
f"taskmodule.encode did not return a (Iterable)TaskEncodingDataset, but: {type(task_encoding_dataset)}"
)
self._data[split] = task_encoding_dataset
def data_split(self, split: Optional[str] = None) -> DatasetType:
if split is None or split not in self._data:
raise ValueError(f"data for split={split} not available")
return self._data[split]
def get_train_sampler(
self,
sampler_name: str,
dataset: DatasetType,
) -> Sampler:
if sampler_name == "imbalanced_dataset":
# for now, this work only with targets that have a single entry
return ImbalancedDatasetSampler(
dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
)
else:
raise ValueError(f"unknown sampler name: {sampler_name}")
def train_dataloader(self):
ds = self.data_split(self.train_split)
if self.train_sampler_name is not None:
sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
else:
sampler = None
# don't shuffle streamed datasets or if we use a sampler or if we explicitly set dont_shuffle_train
shuffle = not self.dont_shuffle_train and not (
isinstance(ds, IterableTaskEncodingDataset) or sampler is not None
)
if not shuffle:
logger.warning("not shuffling train dataloader")
return DataLoader(
dataset=ds,
sampler=sampler,
collate_fn=self.taskmodule.collate,
shuffle=shuffle,
**self.dataloader_kwargs,
)
def val_dataloader(self):
return DataLoader(
dataset=self.data_split(self.val_split),
collate_fn=self.taskmodule.collate,
shuffle=False,
**self.dataloader_kwargs,
)
def test_dataloader(self):
return DataLoader(
dataset=self.data_split(self.test_split),
collate_fn=self.taskmodule.collate,
shuffle=False,
**self.dataloader_kwargs,
)
|