Spaces:
Sleeping
Sleeping
import os | |
os.environ["HF_DATASETS_OFFLINE"] = "1" | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
from collections import defaultdict # noqa: E402 | |
from datasets import load_dataset # noqa: E402 | |
from torch.utils.data import DataLoader, IterableDataset # noqa: E402 | |
from tqdm.auto import tqdm # noqa: E402 | |
from transformers import AutoTokenizer # noqa: E402 | |
ds = load_dataset("sdlm/data/dolma/dolma_dataset.py", streaming=True) | |
text_column_name = "text" | |
ds = ds.select_columns([text_column_name, "source"]) | |
ds["train"] = ds["train"].shuffle(seed=42, buffer_size=10_000) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"mistralai/Mistral-7B-v0.1", | |
revision="26bca36bde8333b5d7f72e9ed20ccda6a618af24", | |
use_fast=True, | |
) | |
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
def tokenize_function(examples): | |
""" | |
from sdlm/data/data_utils.py (`tokenize_data_new`) | |
""" | |
# Remove empty lines | |
examples[text_column_name] = [ | |
line | |
for line in examples[text_column_name] | |
if len(line) > 0 and not line.isspace() | |
] | |
return tokenizer( | |
examples[text_column_name], | |
# hard coded | |
padding="max_length", | |
truncation=True, | |
# hard coded | |
max_length=512, | |
return_special_tokens_mask=True, | |
) | |
tokenized_datasets = ds.map( | |
tokenize_function, | |
batched=True, | |
remove_columns=[text_column_name], | |
) | |
def simple_collate_fn(xs): | |
"""simple collate fn that collects key-values from dict""" | |
result = defaultdict(list) | |
for x in xs: | |
for key, value in x.items(): | |
result[key].append(value) | |
return result | |
def source_collat_fn(xs): | |
result = simple_collate_fn(xs) | |
return result["source"] | |
def tokenize_collate_fn(xs): | |
result = simple_collate_fn(xs) | |
return tokenize_function(result) | |
# from https://github.com/huggingface/datasets/issues/6279 | |
# related https://discuss.huggingface.co/t/slow-dataloader-with-big-batch-size/57224 | |
class Dataset2Iterable(IterableDataset): | |
""" | |
Wrapper to use a HF dataset as pytorch IterableDataset to speed up data loading. | |
""" | |
def __init__(self, dataset, batch_size=1, shuffle=True): | |
super().__init__() | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
def __iter__(self): | |
if self.shuffle: | |
self.dataset.shuffle() | |
return self.dataset.iter(batch_size=self.batch_size) | |
# returns source information | |
source_dataloader = DataLoader( | |
ds["train"], | |
batch_size=8, | |
num_workers=64, | |
collate_fn=source_collat_fn, | |
persistent_workers=True, | |
prefetch_factor=2, | |
) | |
# returns tokens; current method via ds.map (very slow) | |
# also freezes if num_workers is too big ( > 1) | |
token_dataloader_v1 = DataLoader( | |
tokenized_datasets["train"], | |
batch_size=8, | |
num_workers=32, | |
) | |
# returns tokens; grab text and tokenize in collate_fn on the fly | |
token_dataloader_v2 = DataLoader( | |
ds["train"], | |
batch_size=8, | |
num_workers=32, | |
collate_fn=tokenize_collate_fn, | |
) | |
token_dataloader_v3 = DataLoader( | |
Dataset2Iterable(tokenized_datasets["train"]), | |
batch_size=8, | |
num_workers=0, # required | |
) | |
# change params to test | |
stop_iter = 1_000 | |
dataloader_to_test = token_dataloader_v3 | |
for i, x in enumerate(tqdm(dataloader_to_test)): | |
if i == stop_iter: | |
break | |
# just check iteration speed | |
tqdm.write(str(i)) | |
# check content (for source) | |
# print(i, x) | |