from typing import Callable, Type, Union | |
from pie_datasets import Dataset, DatasetDict | |
from pytorch_ie import Document | |
from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target | |
# TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and | |
# batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged | |
def apply_func_to_splits( | |
dataset: DatasetDict, | |
function: Union[str, Callable], | |
result_document_type: Type[Document], | |
**kwargs | |
): | |
resolved_func = resolve_target(function) | |
resolved_document_type = resolve_optional_document_type(document_type=result_document_type) | |
result_dict = dict() | |
split: Dataset | |
for split_name, split in dataset.items(): | |
converted_dataset = split.map( | |
function=resolved_func, | |
batched=True, | |
batch_size=len(split), | |
result_document_type=resolved_document_type, | |
**kwargs | |
) | |
result_dict[split_name] = converted_dataset | |
return DatasetDict(result_dict) | |