D-FINE / src /optim /ema.py
developer0hye's picture
Upload 76 files
e85fecb verified
raw
history blame
3.89 kB
"""
D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
---------------------------------------------------------------------------------
Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright (c) 2023 lyuwenyu. All Rights Reserved.
"""
import math
from copy import deepcopy
import torch
import torch.nn as nn
from ..core import register
from ..misc import dist_utils
__all__ = ["ModelEMA"]
@register()
class ModelEMA(object):
"""
Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(
self, model: nn.Module, decay: float = 0.9999, warmups: int = 1000, start: int = 0
):
super().__init__()
self.module = deepcopy(dist_utils.de_parallel(model)).eval()
# if next(model.parameters()).device.type != 'cpu':
# self.module.half() # FP16 EMA
self.decay = decay
self.warmups = warmups
self.before_start = 0
self.start = start
self.updates = 0 # number of EMA updates
if warmups == 0:
self.decay_fn = lambda x: decay
else:
self.decay_fn = lambda x: decay * (
1 - math.exp(-x / warmups)
) # decay exponential ramp (to help early epochs)
for p in self.module.parameters():
p.requires_grad_(False)
def update(self, model: nn.Module):
if self.before_start < self.start:
self.before_start += 1
return
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay_fn(self.updates)
msd = dist_utils.de_parallel(model).state_dict()
for k, v in self.module.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs)
return self
def state_dict(
self,
):
return dict(module=self.module.state_dict(), updates=self.updates)
def load_state_dict(self, state, strict=True):
self.module.load_state_dict(state["module"], strict=strict)
if "updates" in state:
self.updates = state["updates"]
def forwad(
self,
):
raise RuntimeError("ema...")
def extra_repr(self) -> str:
return f"decay={self.decay}, warmups={self.warmups}"
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
"""Maintains moving averages of model parameters using an exponential decay.
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA.
"""
def __init__(self, model, decay, device="cpu", use_buffers=True):
self.decay_fn = lambda x: decay * (1 - math.exp(-x / 2000))
def ema_avg(avg_model_param, model_param, num_averaged):
decay = self.decay_fn(num_averaged)
return decay * avg_model_param + (1 - decay) * model_param
super().__init__(model, device, ema_avg, use_buffers=use_buffers)