File size: 3,197 Bytes
3133b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import logging
from importlib.util import find_spec
from typing import List, Optional, Union
from omegaconf import DictConfig, OmegaConf
from pie_modules.models.interface import RequiresTaskmoduleConfig
from pytorch_ie import PyTorchIEModel, TaskModule
from pytorch_lightning.loggers import Logger
from pytorch_lightning.utilities import rank_zero_only
def get_pylogger(name=__name__) -> logging.Logger:
"""Initializes multi-GPU-friendly python command line logger."""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
log = get_pylogger(__name__)
@rank_zero_only
def log_hyperparameters(
logger: Optional[List[Logger]] = None,
config: Optional[Union[dict, DictConfig]] = None,
model: Optional[PyTorchIEModel] = None,
taskmodule: Optional[TaskModule] = None,
key_prefix: str = "_",
**kwargs,
) -> None:
"""Controls which config parts are saved by lightning loggers.
Additional saves:
- Number of model parameters
"""
hparams = {}
if not logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return
# this is just for backwards compatibility: usually, the taskmodule_config should be passed to
# the model and, thus, be logged there automatically
if model is not None and not isinstance(model, RequiresTaskmoduleConfig):
if taskmodule is None:
raise ValueError(
"If model is not an instance of RequiresTaskmoduleConfig, taskmodule must be passed!"
)
# here we use the taskmodule/model config how it is after preparation/initialization
hparams["taskmodule_config"] = taskmodule.config
if model is not None:
# save number of model parameters
hparams[f"{key_prefix}num_params/total"] = sum(p.numel() for p in model.parameters())
hparams[f"{key_prefix}num_params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams[f"{key_prefix}num_params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
if config is not None:
hparams[f"{key_prefix}config"] = (
OmegaConf.to_container(config, resolve=True) if OmegaConf.is_config(config) else config
)
# add additional hparams
for k, v in kwargs.items():
hparams[f"{key_prefix}{k}"] = v
# send hparams to all loggers
for current_logger in logger:
current_logger.log_hyperparams(hparams)
def close_loggers() -> None:
"""Makes sure all loggers closed properly (prevents logging failure during multirun)."""
log.info("Closing loggers...")
if find_spec("wandb"): # if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
|