from typing import Callable, Type, Union from pie_datasets import Dataset, DatasetDict from pytorch_ie import Document from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target # TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and # batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged def apply_func_to_splits( dataset: DatasetDict, function: Union[str, Callable], result_document_type: Type[Document], **kwargs ): resolved_func = resolve_target(function) resolved_document_type = resolve_optional_document_type(document_type=result_document_type) result_dict = dict() split: Dataset for split_name, split in dataset.items(): converted_dataset = split.map( function=resolved_func, batched=True, batch_size=len(split), result_document_type=resolved_document_type, **kwargs ) result_dict[split_name] = converted_dataset return DatasetDict(result_dict)