Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
from pathlib import Path | |
from typing import Callable, Dict, List | |
import torch | |
import torch.nn as nn | |
from torch.cuda.amp.grad_scaler import GradScaler | |
from torch.optim import Optimizer | |
from torch.optim.lr_scheduler import LRScheduler | |
from torch.utils.data import DataLoader, Dataset | |
from torch.utils.tensorboard import SummaryWriter | |
__all__ = [ | |
"BaseConfig", | |
] | |
class BaseConfig(object): | |
# TODO property | |
def __init__(self) -> None: | |
super().__init__() | |
self.task: str = None | |
# instance / function | |
self._model: nn.Module = None | |
self._postprocessor: nn.Module = None | |
self._criterion: nn.Module = None | |
self._optimizer: Optimizer = None | |
self._lr_scheduler: LRScheduler = None | |
self._lr_warmup_scheduler: LRScheduler = None | |
self._train_dataloader: DataLoader = None | |
self._val_dataloader: DataLoader = None | |
self._ema: nn.Module = None | |
self._scaler: GradScaler = None | |
self._train_dataset: Dataset = None | |
self._val_dataset: Dataset = None | |
self._collate_fn: Callable = None | |
self._evaluator: Callable[[nn.Module, DataLoader, str],] = None | |
self._writer: SummaryWriter = None | |
# dataset | |
self.num_workers: int = 0 | |
self.batch_size: int = None | |
self._train_batch_size: int = None | |
self._val_batch_size: int = None | |
self._train_shuffle: bool = None | |
self._val_shuffle: bool = None | |
# runtime | |
self.resume: str = None | |
self.tuning: str = None | |
self.epochs: int = None | |
self.last_epoch: int = -1 | |
self.use_amp: bool = False | |
self.use_ema: bool = False | |
self.ema_decay: float = 0.9999 | |
self.ema_warmups: int = 2000 | |
self.sync_bn: bool = False | |
self.clip_max_norm: float = 0.0 | |
self.find_unused_parameters: bool = None | |
self.seed: int = None | |
self.print_freq: int = None | |
self.checkpoint_freq: int = 1 | |
self.output_dir: str = None | |
self.summary_dir: str = None | |
self.device: str = "" | |
def model(self) -> nn.Module: | |
return self._model | |
def model(self, m): | |
assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class" | |
self._model = m | |
def postprocessor(self) -> nn.Module: | |
return self._postprocessor | |
def postprocessor(self, m): | |
assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class" | |
self._postprocessor = m | |
def criterion(self) -> nn.Module: | |
return self._criterion | |
def criterion(self, m): | |
assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class" | |
self._criterion = m | |
def optimizer(self) -> Optimizer: | |
return self._optimizer | |
def optimizer(self, m): | |
assert isinstance( | |
m, Optimizer | |
), f"{type(m)} != optim.Optimizer, please check your model class" | |
self._optimizer = m | |
def lr_scheduler(self) -> LRScheduler: | |
return self._lr_scheduler | |
def lr_scheduler(self, m): | |
assert isinstance( | |
m, LRScheduler | |
), f"{type(m)} != LRScheduler, please check your model class" | |
self._lr_scheduler = m | |
def lr_warmup_scheduler(self) -> LRScheduler: | |
return self._lr_warmup_scheduler | |
def lr_warmup_scheduler(self, m): | |
self._lr_warmup_scheduler = m | |
def train_dataloader(self) -> DataLoader: | |
if self._train_dataloader is None and self.train_dataset is not None: | |
loader = DataLoader( | |
self.train_dataset, | |
batch_size=self.train_batch_size, | |
num_workers=self.num_workers, | |
collate_fn=self.collate_fn, | |
shuffle=self.train_shuffle, | |
) | |
loader.shuffle = self.train_shuffle | |
self._train_dataloader = loader | |
return self._train_dataloader | |
def train_dataloader(self, loader): | |
self._train_dataloader = loader | |
def val_dataloader(self) -> DataLoader: | |
if self._val_dataloader is None and self.val_dataset is not None: | |
loader = DataLoader( | |
self.val_dataset, | |
batch_size=self.val_batch_size, | |
num_workers=self.num_workers, | |
drop_last=False, | |
collate_fn=self.collate_fn, | |
shuffle=self.val_shuffle, | |
persistent_workers=True, | |
) | |
loader.shuffle = self.val_shuffle | |
self._val_dataloader = loader | |
return self._val_dataloader | |
def val_dataloader(self, loader): | |
self._val_dataloader = loader | |
def ema(self) -> nn.Module: | |
if self._ema is None and self.use_ema and self.model is not None: | |
from ..optim import ModelEMA | |
self._ema = ModelEMA(self.model, self.ema_decay, self.ema_warmups) | |
return self._ema | |
def ema(self, obj): | |
self._ema = obj | |
def scaler(self) -> GradScaler: | |
if self._scaler is None and self.use_amp and torch.cuda.is_available(): | |
self._scaler = GradScaler() | |
return self._scaler | |
def scaler(self, obj: GradScaler): | |
self._scaler = obj | |
def val_shuffle(self) -> bool: | |
if self._val_shuffle is None: | |
print("warning: set default val_shuffle=False") | |
return False | |
return self._val_shuffle | |
def val_shuffle(self, shuffle): | |
assert isinstance(shuffle, bool), "shuffle must be bool" | |
self._val_shuffle = shuffle | |
def train_shuffle(self) -> bool: | |
if self._train_shuffle is None: | |
print("warning: set default train_shuffle=True") | |
return True | |
return self._train_shuffle | |
def train_shuffle(self, shuffle): | |
assert isinstance(shuffle, bool), "shuffle must be bool" | |
self._train_shuffle = shuffle | |
def train_batch_size(self) -> int: | |
if self._train_batch_size is None and isinstance(self.batch_size, int): | |
print(f"warning: set train_batch_size=batch_size={self.batch_size}") | |
return self.batch_size | |
return self._train_batch_size | |
def train_batch_size(self, batch_size): | |
assert isinstance(batch_size, int), "batch_size must be int" | |
self._train_batch_size = batch_size | |
def val_batch_size(self) -> int: | |
if self._val_batch_size is None: | |
print(f"warning: set val_batch_size=batch_size={self.batch_size}") | |
return self.batch_size | |
return self._val_batch_size | |
def val_batch_size(self, batch_size): | |
assert isinstance(batch_size, int), "batch_size must be int" | |
self._val_batch_size = batch_size | |
def train_dataset(self) -> Dataset: | |
return self._train_dataset | |
def train_dataset(self, dataset): | |
assert isinstance(dataset, Dataset), f"{type(dataset)} must be Dataset" | |
self._train_dataset = dataset | |
def val_dataset(self) -> Dataset: | |
return self._val_dataset | |
def val_dataset(self, dataset): | |
assert isinstance(dataset, Dataset), f"{type(dataset)} must be Dataset" | |
self._val_dataset = dataset | |
def collate_fn(self) -> Callable: | |
return self._collate_fn | |
def collate_fn(self, fn): | |
assert isinstance(fn, Callable), f"{type(fn)} must be Callable" | |
self._collate_fn = fn | |
def evaluator(self) -> Callable: | |
return self._evaluator | |
def evaluator(self, fn): | |
assert isinstance(fn, Callable), f"{type(fn)} must be Callable" | |
self._evaluator = fn | |
def writer(self) -> SummaryWriter: | |
if self._writer is None: | |
if self.summary_dir: | |
self._writer = SummaryWriter(self.summary_dir) | |
elif self.output_dir: | |
self._writer = SummaryWriter(Path(self.output_dir) / "summary") | |
return self._writer | |
def writer(self, m): | |
assert isinstance(m, SummaryWriter), f"{type(m)} must be SummaryWriter" | |
self._writer = m | |
def __repr__(self): | |
s = "" | |
for k, v in self.__dict__.items(): | |
if not k.startswith("_"): | |
s += f"{k}: {v}\n" | |
return s | |