"""This is a slightly modified version of https://github.com/ufoym/imbalanced-dataset-sampler.""" from typing import Callable, List, Optional import pandas as pd import torch import torch.utils.data class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): """Samples elements randomly from a given list of indices for imbalanced dataset. Arguments: indices: a list of indices num_samples: number of samples to draw callback_get_label: a callback-like function which takes one argument - the dataset """ def __init__( self, dataset, labels: Optional[List] = None, indices: Optional[List] = None, num_samples: Optional[int] = None, callback_get_label: Optional[Callable] = None, ): # if indices is not provided, all elements in the dataset will be considered self.indices = list(range(len(dataset))) if indices is None else indices # define custom callback self.callback_get_label = callback_get_label # if num_samples is not provided, draw `len(indices)` samples in each iteration self.num_samples = len(self.indices) if num_samples is None else num_samples # distribution of classes in the dataset df = pd.DataFrame() df["label"] = self._get_labels(dataset) if labels is None else labels df.index = self.indices df = df.sort_index() label_to_count = df["label"].value_counts() weights = 1.0 / label_to_count[df["label"]] self.weights = torch.DoubleTensor(weights.to_list()) def _get_labels(self, dataset): if self.callback_get_label: return self.callback_get_label(dataset) elif isinstance(dataset, torch.utils.data.TensorDataset): return dataset.tensors[1] elif isinstance(dataset, torch.utils.data.Subset): return dataset.dataset.imgs[:][1] elif isinstance(dataset, torch.utils.data.Dataset): return dataset.get_labels() else: raise NotImplementedError def __iter__(self): return ( self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True) ) def __len__(self): return self.num_samples