ArneBinder's picture
https://github.com/ArneBinder/pie-document-level/pull/312
3133b5e verified
raw
history blame contribute delete
1.07 kB
from typing import Callable, Type, Union
from pie_datasets import Dataset, DatasetDict
from pytorch_ie import Document
from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target
# TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and
# batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged
def apply_func_to_splits(
dataset: DatasetDict,
function: Union[str, Callable],
result_document_type: Type[Document],
**kwargs
):
resolved_func = resolve_target(function)
resolved_document_type = resolve_optional_document_type(document_type=result_document_type)
result_dict = dict()
split: Dataset
for split_name, split in dataset.items():
converted_dataset = split.map(
function=resolved_func,
batched=True,
batch_size=len(split),
result_document_type=resolved_document_type,
**kwargs
)
result_dict[split_name] = converted_dataset
return DatasetDict(result_dict)