""" Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ import random from functools import partial import torch import torch.nn.functional as F import torch.utils.data as data import torchvision import torchvision.transforms.v2 as VT from torch.utils.data import default_collate from torchvision.transforms.v2 import InterpolationMode from torchvision.transforms.v2 import functional as VF from ..core import register torchvision.disable_beta_transforms_warning() __all__ = [ "DataLoader", "BaseCollateFunction", "BatchImageCollateFunction", "batch_image_collate_fn", ] @register() class DataLoader(data.DataLoader): __inject__ = ["dataset", "collate_fn"] def __repr__(self) -> str: format_string = self.__class__.__name__ + "(" for n in ["dataset", "batch_size", "num_workers", "drop_last", "collate_fn"]: format_string += "\n" format_string += " {0}: {1}".format(n, getattr(self, n)) format_string += "\n)" return format_string def set_epoch(self, epoch): self._epoch = epoch self.dataset.set_epoch(epoch) self.collate_fn.set_epoch(epoch) @property def epoch(self): return self._epoch if hasattr(self, "_epoch") else -1 @property def shuffle(self): return self._shuffle @shuffle.setter def shuffle(self, shuffle): assert isinstance(shuffle, bool), "shuffle must be a boolean" self._shuffle = shuffle @register() def batch_image_collate_fn(items): """only batch image""" return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items] class BaseCollateFunction(object): def set_epoch(self, epoch): self._epoch = epoch @property def epoch(self): return self._epoch if hasattr(self, "_epoch") else -1 def __call__(self, items): raise NotImplementedError("") def generate_scales(base_size, base_size_repeat): scale_repeat = (base_size - int(base_size * 0.75 / 32) * 32) // 32 scales = [int(base_size * 0.75 / 32) * 32 + i * 32 for i in range(scale_repeat)] scales += [base_size] * base_size_repeat scales += [int(base_size * 1.25 / 32) * 32 - i * 32 for i in range(scale_repeat)] return scales @register() class BatchImageCollateFunction(BaseCollateFunction): def __init__( self, stop_epoch=None, ema_restart_decay=0.9999, base_size=640, base_size_repeat=None, ) -> None: super().__init__() self.base_size = base_size self.scales = ( generate_scales(base_size, base_size_repeat) if base_size_repeat is not None else None ) self.stop_epoch = stop_epoch if stop_epoch is not None else 100000000 self.ema_restart_decay = ema_restart_decay # self.interpolation = interpolation def __call__(self, items): images = torch.cat([x[0][None] for x in items], dim=0) targets = [x[1] for x in items] if self.scales is not None and self.epoch < self.stop_epoch: # sz = random.choice(self.scales) # sz = [sz] if isinstance(sz, int) else list(sz) # VF.resize(inpt, sz, interpolation=self.interpolation) sz = random.choice(self.scales) images = F.interpolate(images, size=sz) if "masks" in targets[0]: for tg in targets: tg["masks"] = F.interpolate(tg["masks"], size=sz, mode="nearest") raise NotImplementedError("") return images, targets