hamishivi's picture
commit
17ff0d8 verified
raw
history blame
3.48 kB
import os
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from collections import defaultdict # noqa: E402
from datasets import load_dataset # noqa: E402
from torch.utils.data import DataLoader, IterableDataset # noqa: E402
from tqdm.auto import tqdm # noqa: E402
from transformers import AutoTokenizer # noqa: E402
ds = load_dataset("sdlm/data/dolma/dolma_dataset.py", streaming=True)
text_column_name = "text"
ds = ds.select_columns([text_column_name, "source"])
ds["train"] = ds["train"].shuffle(seed=42, buffer_size=10_000)
tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-7B-v0.1",
revision="26bca36bde8333b5d7f72e9ed20ccda6a618af24",
use_fast=True,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
def tokenize_function(examples):
"""
from sdlm/data/data_utils.py (`tokenize_data_new`)
"""
# Remove empty lines
examples[text_column_name] = [
line
for line in examples[text_column_name]
if len(line) > 0 and not line.isspace()
]
return tokenizer(
examples[text_column_name],
# hard coded
padding="max_length",
truncation=True,
# hard coded
max_length=512,
return_special_tokens_mask=True,
)
tokenized_datasets = ds.map(
tokenize_function,
batched=True,
remove_columns=[text_column_name],
)
def simple_collate_fn(xs):
"""simple collate fn that collects key-values from dict"""
result = defaultdict(list)
for x in xs:
for key, value in x.items():
result[key].append(value)
return result
def source_collat_fn(xs):
result = simple_collate_fn(xs)
return result["source"]
def tokenize_collate_fn(xs):
result = simple_collate_fn(xs)
return tokenize_function(result)
# from https://github.com/huggingface/datasets/issues/6279
# related https://discuss.huggingface.co/t/slow-dataloader-with-big-batch-size/57224
class Dataset2Iterable(IterableDataset):
"""
Wrapper to use a HF dataset as pytorch IterableDataset to speed up data loading.
"""
def __init__(self, dataset, batch_size=1, shuffle=True):
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
if self.shuffle:
self.dataset.shuffle()
return self.dataset.iter(batch_size=self.batch_size)
# returns source information
source_dataloader = DataLoader(
ds["train"],
batch_size=8,
num_workers=64,
collate_fn=source_collat_fn,
persistent_workers=True,
prefetch_factor=2,
)
# returns tokens; current method via ds.map (very slow)
# also freezes if num_workers is too big ( > 1)
token_dataloader_v1 = DataLoader(
tokenized_datasets["train"],
batch_size=8,
num_workers=32,
)
# returns tokens; grab text and tokenize in collate_fn on the fly
token_dataloader_v2 = DataLoader(
ds["train"],
batch_size=8,
num_workers=32,
collate_fn=tokenize_collate_fn,
)
token_dataloader_v3 = DataLoader(
Dataset2Iterable(tokenized_datasets["train"]),
batch_size=8,
num_workers=0, # required
)
# change params to test
stop_iter = 1_000
dataloader_to_test = token_dataloader_v3
for i, x in enumerate(tqdm(dataloader_to_test)):
if i == stop_iter:
break
# just check iteration speed
tqdm.write(str(i))
# check content (for source)
# print(i, x)