|
import logging |
|
from importlib.util import find_spec |
|
from typing import List, Optional, Union |
|
|
|
from omegaconf import DictConfig, OmegaConf |
|
from pie_modules.models.interface import RequiresTaskmoduleConfig |
|
from pytorch_ie import PyTorchIEModel, TaskModule |
|
from pytorch_lightning.loggers import Logger |
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
|
|
|
def get_pylogger(name=__name__) -> logging.Logger: |
|
"""Initializes multi-GPU-friendly python command line logger.""" |
|
|
|
logger = logging.getLogger(name) |
|
|
|
|
|
|
|
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") |
|
for level in logging_levels: |
|
setattr(logger, level, rank_zero_only(getattr(logger, level))) |
|
|
|
return logger |
|
|
|
|
|
log = get_pylogger(__name__) |
|
|
|
|
|
@rank_zero_only |
|
def log_hyperparameters( |
|
logger: Optional[List[Logger]] = None, |
|
config: Optional[Union[dict, DictConfig]] = None, |
|
model: Optional[PyTorchIEModel] = None, |
|
taskmodule: Optional[TaskModule] = None, |
|
key_prefix: str = "_", |
|
**kwargs, |
|
) -> None: |
|
"""Controls which config parts are saved by lightning loggers. |
|
|
|
Additional saves: |
|
- Number of model parameters |
|
""" |
|
|
|
hparams = {} |
|
|
|
if not logger: |
|
log.warning("Logger not found! Skipping hyperparameter logging...") |
|
return |
|
|
|
|
|
|
|
if model is not None and not isinstance(model, RequiresTaskmoduleConfig): |
|
if taskmodule is None: |
|
raise ValueError( |
|
"If model is not an instance of RequiresTaskmoduleConfig, taskmodule must be passed!" |
|
) |
|
|
|
hparams["taskmodule_config"] = taskmodule.config |
|
|
|
if model is not None: |
|
|
|
hparams[f"{key_prefix}num_params/total"] = sum(p.numel() for p in model.parameters()) |
|
hparams[f"{key_prefix}num_params/trainable"] = sum( |
|
p.numel() for p in model.parameters() if p.requires_grad |
|
) |
|
hparams[f"{key_prefix}num_params/non_trainable"] = sum( |
|
p.numel() for p in model.parameters() if not p.requires_grad |
|
) |
|
|
|
if config is not None: |
|
hparams[f"{key_prefix}config"] = ( |
|
OmegaConf.to_container(config, resolve=True) if OmegaConf.is_config(config) else config |
|
) |
|
|
|
|
|
for k, v in kwargs.items(): |
|
hparams[f"{key_prefix}{k}"] = v |
|
|
|
|
|
for current_logger in logger: |
|
current_logger.log_hyperparams(hparams) |
|
|
|
|
|
def close_loggers() -> None: |
|
"""Makes sure all loggers closed properly (prevents logging failure during multirun).""" |
|
|
|
log.info("Closing loggers...") |
|
|
|
if find_spec("wandb"): |
|
import wandb |
|
|
|
if wandb.run: |
|
log.info("Closing wandb!") |
|
wandb.finish() |
|
|