""" Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ from torch.optim.lr_scheduler import LRScheduler from ..core import register class Warmup(object): def __init__( self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1 ) -> None: self.lr_scheduler = lr_scheduler self.warmup_end_values = [pg["lr"] for pg in lr_scheduler.optimizer.param_groups] self.last_step = last_step self.warmup_duration = warmup_duration self.step() def state_dict(self): return {k: v for k, v in self.__dict__.items() if k != "lr_scheduler"} def load_state_dict(self, state_dict): self.__dict__.update(state_dict) def get_warmup_factor(self, step, **kwargs): raise NotImplementedError def step( self, ): self.last_step += 1 if self.last_step >= self.warmup_duration: return factor = self.get_warmup_factor(self.last_step) for i, pg in enumerate(self.lr_scheduler.optimizer.param_groups): pg["lr"] = factor * self.warmup_end_values[i] def finished( self, ): if self.last_step >= self.warmup_duration: return True return False @register() class LinearWarmup(Warmup): def __init__( self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1 ) -> None: super().__init__(lr_scheduler, warmup_duration, last_step) def get_warmup_factor(self, step): return min(1.0, (step + 1) / self.warmup_duration)