|
|
|
from mmengine.optim.scheduler.lr_scheduler import LRSchedulerMixin |
|
from mmengine.optim.scheduler.momentum_scheduler import MomentumSchedulerMixin |
|
from mmengine.optim.scheduler.param_scheduler import INF, _ParamScheduler |
|
from torch.optim import Optimizer |
|
|
|
from mmdet.registry import PARAM_SCHEDULERS |
|
|
|
|
|
@PARAM_SCHEDULERS.register_module() |
|
class QuadraticWarmupParamScheduler(_ParamScheduler): |
|
r"""Warm up the parameter value of each parameter group by quadratic |
|
formula: |
|
|
|
.. math:: |
|
|
|
X_{t} = X_{t-1} + \frac{2t+1}{{(end-begin)}^{2}} \times X_{base} |
|
|
|
Args: |
|
optimizer (Optimizer): Wrapped optimizer. |
|
param_name (str): Name of the parameter to be adjusted, such as |
|
``lr``, ``momentum``. |
|
begin (int): Step at which to start updating the parameters. |
|
Defaults to 0. |
|
end (int): Step at which to stop updating the parameters. |
|
Defaults to INF. |
|
last_step (int): The index of last step. Used for resume without |
|
state dict. Defaults to -1. |
|
by_epoch (bool): Whether the scheduled parameters are updated by |
|
epochs. Defaults to True. |
|
verbose (bool): Whether to print the value for each update. |
|
Defaults to False. |
|
""" |
|
|
|
def __init__(self, |
|
optimizer: Optimizer, |
|
param_name: str, |
|
begin: int = 0, |
|
end: int = INF, |
|
last_step: int = -1, |
|
by_epoch: bool = True, |
|
verbose: bool = False): |
|
if end >= INF: |
|
raise ValueError('``end`` must be less than infinity,' |
|
'Please set ``end`` parameter of ' |
|
'``QuadraticWarmupScheduler`` as the ' |
|
'number of warmup end.') |
|
self.total_iters = end - begin |
|
super().__init__( |
|
optimizer=optimizer, |
|
param_name=param_name, |
|
begin=begin, |
|
end=end, |
|
last_step=last_step, |
|
by_epoch=by_epoch, |
|
verbose=verbose) |
|
|
|
@classmethod |
|
def build_iter_from_epoch(cls, |
|
*args, |
|
begin=0, |
|
end=INF, |
|
by_epoch=True, |
|
epoch_length=None, |
|
**kwargs): |
|
"""Build an iter-based instance of this scheduler from an epoch-based |
|
config.""" |
|
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ |
|
'be converted to iter-based.' |
|
assert epoch_length is not None and epoch_length > 0, \ |
|
f'`epoch_length` must be a positive integer, ' \ |
|
f'but got {epoch_length}.' |
|
by_epoch = False |
|
begin = begin * epoch_length |
|
if end != INF: |
|
end = end * epoch_length |
|
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) |
|
|
|
def _get_value(self): |
|
"""Compute value using chainable form of the scheduler.""" |
|
if self.last_step == 0: |
|
return [ |
|
base_value * (2 * self.last_step + 1) / self.total_iters**2 |
|
for base_value in self.base_values |
|
] |
|
|
|
return [ |
|
group[self.param_name] + base_value * |
|
(2 * self.last_step + 1) / self.total_iters**2 |
|
for base_value, group in zip(self.base_values, |
|
self.optimizer.param_groups) |
|
] |
|
|
|
|
|
@PARAM_SCHEDULERS.register_module() |
|
class QuadraticWarmupLR(LRSchedulerMixin, QuadraticWarmupParamScheduler): |
|
"""Warm up the learning rate of each parameter group by quadratic formula. |
|
|
|
Args: |
|
optimizer (Optimizer): Wrapped optimizer. |
|
begin (int): Step at which to start updating the parameters. |
|
Defaults to 0. |
|
end (int): Step at which to stop updating the parameters. |
|
Defaults to INF. |
|
last_step (int): The index of last step. Used for resume without |
|
state dict. Defaults to -1. |
|
by_epoch (bool): Whether the scheduled parameters are updated by |
|
epochs. Defaults to True. |
|
verbose (bool): Whether to print the value for each update. |
|
Defaults to False. |
|
""" |
|
|
|
|
|
@PARAM_SCHEDULERS.register_module() |
|
class QuadraticWarmupMomentum(MomentumSchedulerMixin, |
|
QuadraticWarmupParamScheduler): |
|
"""Warm up the momentum value of each parameter group by quadratic formula. |
|
|
|
Args: |
|
optimizer (Optimizer): Wrapped optimizer. |
|
begin (int): Step at which to start updating the parameters. |
|
Defaults to 0. |
|
end (int): Step at which to stop updating the parameters. |
|
Defaults to INF. |
|
last_step (int): The index of last step. Used for resume without |
|
state dict. Defaults to -1. |
|
by_epoch (bool): Whether the scheduled parameters are updated by |
|
epochs. Defaults to True. |
|
verbose (bool): Whether to print the value for each update. |
|
Defaults to False. |
|
""" |
|
|