ArneBinder's picture
https://github.com/ArneBinder/pie-document-level/pull/312
3133b5e verified
import logging
from importlib.util import find_spec
from typing import List, Optional, Union
from omegaconf import DictConfig, OmegaConf
from pie_modules.models.interface import RequiresTaskmoduleConfig
from pytorch_ie import PyTorchIEModel, TaskModule
from pytorch_lightning.loggers import Logger
from pytorch_lightning.utilities import rank_zero_only
def get_pylogger(name=__name__) -> logging.Logger:
"""Initializes multi-GPU-friendly python command line logger."""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
log = get_pylogger(__name__)
@rank_zero_only
def log_hyperparameters(
logger: Optional[List[Logger]] = None,
config: Optional[Union[dict, DictConfig]] = None,
model: Optional[PyTorchIEModel] = None,
taskmodule: Optional[TaskModule] = None,
key_prefix: str = "_",
**kwargs,
) -> None:
"""Controls which config parts are saved by lightning loggers.
Additional saves:
- Number of model parameters
"""
hparams = {}
if not logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return
# this is just for backwards compatibility: usually, the taskmodule_config should be passed to
# the model and, thus, be logged there automatically
if model is not None and not isinstance(model, RequiresTaskmoduleConfig):
if taskmodule is None:
raise ValueError(
"If model is not an instance of RequiresTaskmoduleConfig, taskmodule must be passed!"
)
# here we use the taskmodule/model config how it is after preparation/initialization
hparams["taskmodule_config"] = taskmodule.config
if model is not None:
# save number of model parameters
hparams[f"{key_prefix}num_params/total"] = sum(p.numel() for p in model.parameters())
hparams[f"{key_prefix}num_params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams[f"{key_prefix}num_params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
if config is not None:
hparams[f"{key_prefix}config"] = (
OmegaConf.to_container(config, resolve=True) if OmegaConf.is_config(config) else config
)
# add additional hparams
for k, v in kwargs.items():
hparams[f"{key_prefix}{k}"] = v
# send hparams to all loggers
for current_logger in logger:
current_logger.log_hyperparams(hparams)
def close_loggers() -> None:
"""Makes sure all loggers closed properly (prevents logging failure during multirun)."""
log.info("Closing loggers...")
if find_spec("wandb"): # if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()