Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
import copy | |
import re | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from ._config import BaseConfig | |
from .workspace import create | |
from .yaml_utils import load_config, merge_config, merge_dict | |
class YAMLConfig(BaseConfig): | |
def __init__(self, cfg_path: str, **kwargs) -> None: | |
super().__init__() | |
cfg = load_config(cfg_path) | |
cfg = merge_dict(cfg, kwargs) | |
self.yaml_cfg = copy.deepcopy(cfg) | |
for k in super().__dict__: | |
if not k.startswith("_") and k in cfg: | |
self.__dict__[k] = cfg[k] | |
def global_cfg(self): | |
return merge_config(self.yaml_cfg, inplace=False, overwrite=False) | |
def model(self) -> torch.nn.Module: | |
if self._model is None and "model" in self.yaml_cfg: | |
self._model = create(self.yaml_cfg["model"], self.global_cfg) | |
return super().model | |
def postprocessor(self) -> torch.nn.Module: | |
if self._postprocessor is None and "postprocessor" in self.yaml_cfg: | |
self._postprocessor = create(self.yaml_cfg["postprocessor"], self.global_cfg) | |
return super().postprocessor | |
def criterion(self) -> torch.nn.Module: | |
if self._criterion is None and "criterion" in self.yaml_cfg: | |
self._criterion = create(self.yaml_cfg["criterion"], self.global_cfg) | |
return super().criterion | |
def optimizer(self) -> optim.Optimizer: | |
if self._optimizer is None and "optimizer" in self.yaml_cfg: | |
params = self.get_optim_params(self.yaml_cfg["optimizer"], self.model) | |
self._optimizer = create("optimizer", self.global_cfg, params=params) | |
return super().optimizer | |
def lr_scheduler(self) -> optim.lr_scheduler.LRScheduler: | |
if self._lr_scheduler is None and "lr_scheduler" in self.yaml_cfg: | |
self._lr_scheduler = create("lr_scheduler", self.global_cfg, optimizer=self.optimizer) | |
print(f"Initial lr: {self._lr_scheduler.get_last_lr()}") | |
return super().lr_scheduler | |
def lr_warmup_scheduler(self) -> optim.lr_scheduler.LRScheduler: | |
if self._lr_warmup_scheduler is None and "lr_warmup_scheduler" in self.yaml_cfg: | |
self._lr_warmup_scheduler = create( | |
"lr_warmup_scheduler", self.global_cfg, lr_scheduler=self.lr_scheduler | |
) | |
return super().lr_warmup_scheduler | |
def train_dataloader(self) -> DataLoader: | |
if self._train_dataloader is None and "train_dataloader" in self.yaml_cfg: | |
self._train_dataloader = self.build_dataloader("train_dataloader") | |
return super().train_dataloader | |
def val_dataloader(self) -> DataLoader: | |
if self._val_dataloader is None and "val_dataloader" in self.yaml_cfg: | |
self._val_dataloader = self.build_dataloader("val_dataloader") | |
return super().val_dataloader | |
def ema(self) -> torch.nn.Module: | |
if self._ema is None and self.yaml_cfg.get("use_ema", False): | |
self._ema = create("ema", self.global_cfg, model=self.model) | |
return super().ema | |
def scaler(self): | |
if self._scaler is None and self.yaml_cfg.get("use_amp", False): | |
self._scaler = create("scaler", self.global_cfg) | |
return super().scaler | |
def evaluator(self): | |
if self._evaluator is None and "evaluator" in self.yaml_cfg: | |
if self.yaml_cfg["evaluator"]["type"] == "CocoEvaluator": | |
from ..data import get_coco_api_from_dataset | |
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset) | |
self._evaluator = create("evaluator", self.global_cfg, coco_gt=base_ds) | |
else: | |
raise NotImplementedError(f"{self.yaml_cfg['evaluator']['type']}") | |
return super().evaluator | |
def use_wandb(self) -> bool: | |
return self.yaml_cfg.get("use_wandb", False) | |
def get_optim_params(cfg: dict, model: nn.Module): | |
""" | |
E.g.: | |
^(?=.*a)(?=.*b).*$ means including a and b | |
^(?=.*(?:a|b)).*$ means including a or b | |
^(?=.*a)(?!.*b).*$ means including a, but not b | |
""" | |
assert "type" in cfg, "" | |
cfg = copy.deepcopy(cfg) | |
if "params" not in cfg: | |
return model.parameters() | |
assert isinstance(cfg["params"], list), "" | |
param_groups = [] | |
visited = [] | |
for pg in cfg["params"]: | |
pattern = pg["params"] | |
params = { | |
k: v | |
for k, v in model.named_parameters() | |
if v.requires_grad and len(re.findall(pattern, k)) > 0 | |
} | |
pg["params"] = params.values() | |
param_groups.append(pg) | |
visited.extend(list(params.keys())) | |
# print(params.keys()) | |
names = [k for k, v in model.named_parameters() if v.requires_grad] | |
if len(visited) < len(names): | |
unseen = set(names) - set(visited) | |
params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen} | |
param_groups.append({"params": params.values()}) | |
visited.extend(list(params.keys())) | |
# print(params.keys()) | |
assert len(visited) == len(names), "" | |
return param_groups | |
def get_rank_batch_size(cfg): | |
"""compute batch size for per rank if total_batch_size is provided.""" | |
assert ("total_batch_size" in cfg or "batch_size" in cfg) and not ( | |
"total_batch_size" in cfg and "batch_size" in cfg | |
), "`batch_size` or `total_batch_size` should be choosed one" | |
total_batch_size = cfg.get("total_batch_size", None) | |
if total_batch_size is None: | |
bs = cfg.get("batch_size") | |
else: | |
from ..misc import dist_utils | |
assert ( | |
total_batch_size % dist_utils.get_world_size() == 0 | |
), "total_batch_size should be divisible by world size" | |
bs = total_batch_size // dist_utils.get_world_size() | |
return bs | |
def build_dataloader(self, name: str): | |
bs = self.get_rank_batch_size(self.yaml_cfg[name]) | |
global_cfg = self.global_cfg | |
if "total_batch_size" in global_cfg[name]: | |
# pop unexpected key for dataloader init | |
_ = global_cfg[name].pop("total_batch_size") | |
print(f"building {name} with batch_size={bs}...") | |
loader = create(name, global_cfg, batch_size=bs) | |
loader.shuffle = self.yaml_cfg[name].get("shuffle", False) | |
return loader | |