D-FINE / src /core /yaml_config.py
developer0hye's picture
Upload 76 files
e85fecb verified
raw
history blame
7.09 kB
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import copy
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from ._config import BaseConfig
from .workspace import create
from .yaml_utils import load_config, merge_config, merge_dict
class YAMLConfig(BaseConfig):
def __init__(self, cfg_path: str, **kwargs) -> None:
super().__init__()
cfg = load_config(cfg_path)
cfg = merge_dict(cfg, kwargs)
self.yaml_cfg = copy.deepcopy(cfg)
for k in super().__dict__:
if not k.startswith("_") and k in cfg:
self.__dict__[k] = cfg[k]
@property
def global_cfg(self):
return merge_config(self.yaml_cfg, inplace=False, overwrite=False)
@property
def model(self) -> torch.nn.Module:
if self._model is None and "model" in self.yaml_cfg:
self._model = create(self.yaml_cfg["model"], self.global_cfg)
return super().model
@property
def postprocessor(self) -> torch.nn.Module:
if self._postprocessor is None and "postprocessor" in self.yaml_cfg:
self._postprocessor = create(self.yaml_cfg["postprocessor"], self.global_cfg)
return super().postprocessor
@property
def criterion(self) -> torch.nn.Module:
if self._criterion is None and "criterion" in self.yaml_cfg:
self._criterion = create(self.yaml_cfg["criterion"], self.global_cfg)
return super().criterion
@property
def optimizer(self) -> optim.Optimizer:
if self._optimizer is None and "optimizer" in self.yaml_cfg:
params = self.get_optim_params(self.yaml_cfg["optimizer"], self.model)
self._optimizer = create("optimizer", self.global_cfg, params=params)
return super().optimizer
@property
def lr_scheduler(self) -> optim.lr_scheduler.LRScheduler:
if self._lr_scheduler is None and "lr_scheduler" in self.yaml_cfg:
self._lr_scheduler = create("lr_scheduler", self.global_cfg, optimizer=self.optimizer)
print(f"Initial lr: {self._lr_scheduler.get_last_lr()}")
return super().lr_scheduler
@property
def lr_warmup_scheduler(self) -> optim.lr_scheduler.LRScheduler:
if self._lr_warmup_scheduler is None and "lr_warmup_scheduler" in self.yaml_cfg:
self._lr_warmup_scheduler = create(
"lr_warmup_scheduler", self.global_cfg, lr_scheduler=self.lr_scheduler
)
return super().lr_warmup_scheduler
@property
def train_dataloader(self) -> DataLoader:
if self._train_dataloader is None and "train_dataloader" in self.yaml_cfg:
self._train_dataloader = self.build_dataloader("train_dataloader")
return super().train_dataloader
@property
def val_dataloader(self) -> DataLoader:
if self._val_dataloader is None and "val_dataloader" in self.yaml_cfg:
self._val_dataloader = self.build_dataloader("val_dataloader")
return super().val_dataloader
@property
def ema(self) -> torch.nn.Module:
if self._ema is None and self.yaml_cfg.get("use_ema", False):
self._ema = create("ema", self.global_cfg, model=self.model)
return super().ema
@property
def scaler(self):
if self._scaler is None and self.yaml_cfg.get("use_amp", False):
self._scaler = create("scaler", self.global_cfg)
return super().scaler
@property
def evaluator(self):
if self._evaluator is None and "evaluator" in self.yaml_cfg:
if self.yaml_cfg["evaluator"]["type"] == "CocoEvaluator":
from ..data import get_coco_api_from_dataset
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
self._evaluator = create("evaluator", self.global_cfg, coco_gt=base_ds)
else:
raise NotImplementedError(f"{self.yaml_cfg['evaluator']['type']}")
return super().evaluator
@property
def use_wandb(self) -> bool:
return self.yaml_cfg.get("use_wandb", False)
@staticmethod
def get_optim_params(cfg: dict, model: nn.Module):
"""
E.g.:
^(?=.*a)(?=.*b).*$ means including a and b
^(?=.*(?:a|b)).*$ means including a or b
^(?=.*a)(?!.*b).*$ means including a, but not b
"""
assert "type" in cfg, ""
cfg = copy.deepcopy(cfg)
if "params" not in cfg:
return model.parameters()
assert isinstance(cfg["params"], list), ""
param_groups = []
visited = []
for pg in cfg["params"]:
pattern = pg["params"]
params = {
k: v
for k, v in model.named_parameters()
if v.requires_grad and len(re.findall(pattern, k)) > 0
}
pg["params"] = params.values()
param_groups.append(pg)
visited.extend(list(params.keys()))
# print(params.keys())
names = [k for k, v in model.named_parameters() if v.requires_grad]
if len(visited) < len(names):
unseen = set(names) - set(visited)
params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
param_groups.append({"params": params.values()})
visited.extend(list(params.keys()))
# print(params.keys())
assert len(visited) == len(names), ""
return param_groups
@staticmethod
def get_rank_batch_size(cfg):
"""compute batch size for per rank if total_batch_size is provided."""
assert ("total_batch_size" in cfg or "batch_size" in cfg) and not (
"total_batch_size" in cfg and "batch_size" in cfg
), "`batch_size` or `total_batch_size` should be choosed one"
total_batch_size = cfg.get("total_batch_size", None)
if total_batch_size is None:
bs = cfg.get("batch_size")
else:
from ..misc import dist_utils
assert (
total_batch_size % dist_utils.get_world_size() == 0
), "total_batch_size should be divisible by world size"
bs = total_batch_size // dist_utils.get_world_size()
return bs
def build_dataloader(self, name: str):
bs = self.get_rank_batch_size(self.yaml_cfg[name])
global_cfg = self.global_cfg
if "total_batch_size" in global_cfg[name]:
# pop unexpected key for dataloader init
_ = global_cfg[name].pop("total_batch_size")
print(f"building {name} with batch_size={bs}...")
loader = create(name, global_cfg, batch_size=bs)
loader.shuffle = self.yaml_cfg[name].get("shuffle", False)
return loader