import pyrootutils root = pyrootutils.setup_root( search_from=__file__, indicator=[".project-root"], pythonpath=True, dotenv=True, ) # ------------------------------------------------------------------------------------ # # `pyrootutils.setup_root(...)` is an optional line at the top of each entry file # that helps to make the environment more robust and convenient # # the main advantages are: # - allows you to keep all entry files in "src/" without installing project as a package # - makes paths and scripts always work no matter where is your current work dir # - automatically loads environment variables from ".env" file if exists # # how it works: # - the line above recursively searches for either ".git" or "pyproject.toml" in present # and parent dirs, to determine the project root dir # - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from # any place without installing project as a package # - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" # to make all paths always relative to the project root # - loads environment variables from ".env" file in root dir (if `dotenv=True`) # # you can remove `pyrootutils.setup_root(...)` if you: # 1. either install project as a package or move each entry file to the project root dir # 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml" # 3. always run entry files from the project root dir # # https://github.com/ashleve/pyrootutils # ------------------------------------------------------------------------------------ # import os.path from typing import Any, Dict, List, Optional, Tuple import hydra import pytorch_lightning as pl from omegaconf import DictConfig, OmegaConf from pie_datasets import DatasetDict from pie_modules.models import * # noqa: F403 from pie_modules.models import SimpleGenerativeModel from pie_modules.models.interface import RequiresTaskmoduleConfig from pie_modules.taskmodules import * # noqa: F403 from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE from pytorch_ie import Pipeline from pytorch_ie.core import PyTorchIEModel, TaskModule from pytorch_ie.models import * # noqa: F403 from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses from pytorch_ie.taskmodules import * # noqa: F403 from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import Logger from src import utils from src.datamodules import PieDataModule from src.models import * # noqa: F403 from src.serializer.interface import DocumentSerializer from src.taskmodules import * # noqa: F403 log = utils.get_pylogger(__name__) def get_metric_value(metric_dict: dict, metric_name: str) -> Optional[float]: """Safely retrieves value of the metric logged in LightningModule.""" if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...") return None if metric_name not in metric_dict: raise Exception( f"Metric value not found! \n" "Make sure metric name logged in LightningModule is correct!\n" "Make sure `optimized_metric` name in `hparams_search` config is correct!" ) metric_value = metric_dict[metric_name].item() log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") return metric_value def flatten_nested_dict(d: Dict[str, Any], parent_key: str = "", sep: str = ".") -> Dict[str, Any]: """Flatten a nested dictionary. Args: d (Dict[str, Any]): The dictionary to flatten. parent_key (str): The parent key. sep (str): The separator. Returns: Dict[str, Any]: The flattened dictionary. """ items: List[Tuple[str, Any]] = [] for k, v in d.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k if isinstance(v, dict): items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items) @utils.task_wrapper def train(cfg: DictConfig) -> Tuple[dict, dict]: """Trains the model. Can additionally evaluate on a testset, using best weights obtained during training. This method is wrapped in optional @task_wrapper decorator which applies extra utilities before and after the call. Args: cfg (DictConfig): Configuration composed by Hydra. Returns: Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. """ # set seed for random number generators in pytorch, numpy and python.random if cfg.get("seed"): pl.seed_everything(cfg.seed, workers=True) # Init pytorch-ie taskmodule log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>") taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial") # Init pytorch-ie dataset log.info(f"Instantiating dataset <{cfg.dataset._target_}>") dataset: DatasetDict = hydra.utils.instantiate( cfg.dataset, _convert_="partial", ) # auto-convert the dataset if the taskmodule specifies a document type dataset = taskmodule.convert_dataset(dataset) # Init pytorch-ie datamodule log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") datamodule: PieDataModule = hydra.utils.instantiate( cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial" ) # Use the train dataset split to prepare the taskmodule taskmodule.prepare(dataset[datamodule.train_split]) # Init the pytorch-ie model log.info(f"Instantiating model <{cfg.model._target_}>") # get additional model arguments additional_model_kwargs: Dict[str, Any] = {} model_cls = hydra.utils.get_class(cfg.model["_target_"]) # NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE! # SEE EXAMPLES BELOW. if issubclass(model_cls, RequiresNumClasses): additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id) if issubclass(model_cls, RequiresModelNameOrPath): if "model_name_or_path" not in cfg.model: raise Exception( f"Please specify model_name_or_path in the model config for {model_cls.__name__}." ) if isinstance(taskmodule, ChangesTokenizerVocabSize): additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer) pooler_config = cfg["model"].get("pooler") if pooler_config is not None: if isinstance(pooler_config, str): pooler_config = {"type": pooler_config} pooler_config = dict(pooler_config) if pooler_config["type"] in ["start_tokens", "mention_pooling"]: # NOTE: This is very hacky, we should create a new interface class, e.g. RequiresPoolerNumIndices if hasattr(taskmodule, "argument_role2idx"): pooler_config["num_indices"] = len(taskmodule.argument_role2idx) else: pooler_config["num_indices"] = 1 elif pooler_config["type"] == "cls_token": pass else: raise Exception( f"unknown pooler type: {pooler_config['type']}. Please adjust the train.py script for that type." ) additional_model_kwargs["pooler"] = pooler_config if issubclass(model_cls, RequiresTaskmoduleConfig): additional_model_kwargs["taskmodule_config"] = taskmodule.config if model_cls == SimpleGenerativeModel: # There may be already some base_model_config entries in the model config. Also need to convert the # base_model_config to a dict, because it is a OmegaConf object which does not accept additional entries. base_model_config = ( dict(cfg.model.base_model_config) if "base_model_config" in cfg.model else {} ) if isinstance(taskmodule, PointerNetworkTaskModuleForEnd2EndRE): base_model_config.update( dict( bos_token_id=taskmodule.bos_id, eos_token_id=taskmodule.eos_id, pad_token_id=taskmodule.eos_id, target_token_ids=taskmodule.target_token_ids, embedding_weight_mapping=taskmodule.label_embedding_weight_mapping, ) ) additional_model_kwargs["base_model_config"] = base_model_config if issubclass(model_cls, SimpleSequenceClassificationModelWithInputTypeIds): # noqa: F405 # add the number of input type ids to the model: # 2 for B- and I-labels for each entity type, 1 for O labels, 1 for padding additional_model_kwargs["num_token_type_ids"] = len(taskmodule.entity_labels) * 2 + 1 + 1 # initialize the model model: PyTorchIEModel = hydra.utils.instantiate( cfg.model, _convert_="partial", **additional_model_kwargs ) log.info("Instantiating callbacks...") callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks") log.info("Instantiating loggers...") logger: List[Logger] = utils.instantiate_dict_entries(cfg, key="logger") log.info(f"Instantiating trainer <{cfg.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) object_dict = { "cfg": cfg, "dataset": dataset, "taskmodule": taskmodule, "model": model, "callbacks": callbacks, "logger": logger, "trainer": trainer, } if logger: log.info("Logging hyperparameters!") utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg) if cfg.paths.model_save_dir is not None: log.info(f"Save taskmodule to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]") taskmodule.save_pretrained( save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub ) else: log.warning("the taskmodule is not saved because no save_dir is specified") if cfg.get("train"): log.info("Starting training!") trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) train_metrics = trainer.callback_metrics best_ckpt_path = trainer.checkpoint_callback.best_model_path best_epoch = None if best_ckpt_path != "": log.info(f"Best ckpt path: {best_ckpt_path}") best_checkpoint_file = os.path.basename(best_ckpt_path) utils.log_hyperparameters( logger=logger, best_checkpoint=best_checkpoint_file, checkpoint_dir=trainer.checkpoint_callback.dirpath, ) # get epoch from best_checkpoint_file (e.g. "epoch_078.ckpt") try: best_epoch = int(os.path.splitext(best_checkpoint_file)[0].split("_")[-1]) except Exception as e: log.warning( f'Could not retrieve epoch from best checkpoint file name: "{e}". ' f"Expected format: " + '"epoch_{best_epoch}.ckpt"' ) if not cfg.trainer.get("fast_dev_run") or cfg.get("predict", False): if cfg.paths.model_save_dir is not None: if best_ckpt_path == "": log.warning("Best ckpt not found! Using current weights for saving...") else: model = type(model).load_from_checkpoint(best_ckpt_path) log.info(f"Save model to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]") model.save_pretrained( save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub ) else: log.warning("the model is not saved because no save_dir is specified") if cfg.get("validate"): log.info("Starting validation!") if best_ckpt_path == "": log.warning("Best ckpt not found! Using current weights for validation...") trainer.validate(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None) elif cfg.get("train"): log.warning( "Validation after training is skipped! That means, the finally reported validation scores are " "the values from the *last* checkpoint, not from the *best* checkpoint (which is saved)!" ) if cfg.get("test"): log.info("Starting testing!") if best_ckpt_path == "": log.warning("Best ckpt not found! Using current weights for testing...") trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None) test_metrics = trainer.callback_metrics test_metrics["best_epoch"] = best_epoch # merge train and test metrics metric_dict = {**train_metrics, **test_metrics} # add model_save_dir to the result so that it gets dumped to job_return_value.json # if we use hydra_callbacks.SaveJobReturnValueCallback if cfg.paths.get("model_save_dir") is not None: metric_dict["model_save_dir"] = cfg.paths.model_save_dir if cfg.get("predict"): # Init the inference pipeline pipeline: Optional[Pipeline] = None if cfg.get("pipeline") and cfg.pipeline.get("_target_"): log.info(f"Instantiating inference pipeline <{cfg.pipeline._target_}>") pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial") # Init the serializer serializer: Optional[DocumentSerializer] = None if cfg.get("serializer") and cfg.serializer.get("_target_"): log.info(f"Instantiating serializer <{cfg.serializer._target_}>") serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial") # predict and serialize predict_metrics: Dict[str, Any] = utils.predict_and_serialize( pipeline=pipeline, serializer=serializer, dataset=dataset[cfg.dataset_split], document_batch_size=cfg.get("document_batch_size", None), ) # flatten the predict_metrics dict predict_metrics_flat = flatten_nested_dict(predict_metrics, sep="/") metric_dict.update(predict_metrics_flat) if cfg.get("delete_model_dir"): import shutil log.info(f"Deleting model directory {cfg.paths.model_save_dir}") shutil.rmtree(cfg.paths.model_save_dir) return metric_dict, object_dict @hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="train.yaml") def main(cfg: DictConfig) -> Optional[float]: # train the model metric_dict, _ = train(cfg) # safely retrieve metric value for hydra-based hyperparameter optimization if cfg.get("optimized_metric") is not None: metric_value = get_metric_value( metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") ) # return optimized metric return metric_value else: return metric_dict if __name__ == "__main__": utils.replace_sys_args_with_values_from_files() utils.prepare_omegaconf() OmegaConf.register_new_resolver("eval", eval) main()