File size: 12,634 Bytes
3133b5e ced4316 3133b5e ced4316 3133b5e ced4316 3133b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
dotenv=True,
)
# ------------------------------------------------------------------------------------ #
# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
# that helps to make the environment more robust and convenient
#
# the main advantages are:
# - allows you to keep all entry files in "src/" without installing project as a package
# - makes paths and scripts always work no matter where is your current work dir
# - automatically loads environment variables from ".env" file if exists
#
# how it works:
# - the line above recursively searches for either ".git" or "pyproject.toml" in present
# and parent dirs, to determine the project root dir
# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
# any place without installing project as a package
# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
# to make all paths always relative to the project root
# - loads environment variables from ".env" file in root dir (if `dotenv=True`)
#
# you can remove `pyrootutils.setup_root(...)` if you:
# 1. either install project as a package or move each entry file to the project root dir
# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
# 3. always run entry files from the project root dir
#
# https://github.com/ashleve/pyrootutils
# ------------------------------------------------------------------------------------ #
import os.path
from typing import Any, Dict, List, Optional, Tuple
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from pie_datasets import DatasetDict
from pie_modules.models import * # noqa: F403
from pie_modules.models import SimpleGenerativeModel
from pie_modules.models.interface import RequiresTaskmoduleConfig
from pie_modules.taskmodules import * # noqa: F403
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
from pytorch_ie.core import PyTorchIEModel, TaskModule
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
from pytorch_ie.taskmodules import * # noqa: F403
from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import Logger
from src import utils
from src.datamodules import PieDataModule
from src.models import * # noqa: F403
from src.taskmodules import * # noqa: F403
log = utils.get_pylogger(__name__)
def get_metric_value(metric_dict: dict, metric_name: str) -> Optional[float]:
"""Safely retrieves value of the metric logged in LightningModule."""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator which applies extra utilities
before and after the call.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
# Init pytorch-ie taskmodule
log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
# Init pytorch-ie dataset
log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
dataset: DatasetDict = hydra.utils.instantiate(
cfg.dataset,
_convert_="partial",
)
# auto-convert the dataset if the taskmodule specifies a document type
dataset = taskmodule.convert_dataset(dataset)
# Init pytorch-ie datamodule
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
datamodule: PieDataModule = hydra.utils.instantiate(
cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial"
)
# Use the train dataset split to prepare the taskmodule
taskmodule.prepare(dataset[datamodule.train_split])
# Init the pytorch-ie model
log.info(f"Instantiating model <{cfg.model._target_}>")
# get additional model arguments
additional_model_kwargs: Dict[str, Any] = {}
model_cls = hydra.utils.get_class(cfg.model["_target_"])
# NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE!
# SEE EXAMPLES BELOW.
if issubclass(model_cls, RequiresNumClasses):
additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id)
if issubclass(model_cls, RequiresModelNameOrPath):
if "model_name_or_path" not in cfg.model:
raise Exception(
f"Please specify model_name_or_path in the model config for {model_cls.__name__}."
)
if isinstance(taskmodule, ChangesTokenizerVocabSize):
additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer)
pooler_config = cfg["model"].get("pooler")
if pooler_config is not None:
if isinstance(pooler_config, str):
pooler_config = {"type": pooler_config}
pooler_config = dict(pooler_config)
if pooler_config["type"] in ["start_tokens", "mention_pooling"]:
# NOTE: This is very hacky, we should create a new interface class, e.g. RequiresPoolerNumIndices
if hasattr(taskmodule, "argument_role2idx"):
pooler_config["num_indices"] = len(taskmodule.argument_role2idx)
else:
pooler_config["num_indices"] = 1
elif pooler_config["type"] == "cls_token":
pass
else:
raise Exception(
f"unknown pooler type: {pooler_config['type']}. Please adjust the train.py script for that type."
)
additional_model_kwargs["pooler"] = pooler_config
if issubclass(model_cls, RequiresTaskmoduleConfig):
additional_model_kwargs["taskmodule_config"] = taskmodule.config
if model_cls == SimpleGenerativeModel:
# There may be already some base_model_config entries in the model config. Also need to convert the
# base_model_config to a dict, because it is a OmegaConf object which does not accept additional entries.
base_model_config = (
dict(cfg.model.base_model_config) if "base_model_config" in cfg.model else {}
)
if isinstance(taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
base_model_config.update(
dict(
bos_token_id=taskmodule.bos_id,
eos_token_id=taskmodule.eos_id,
pad_token_id=taskmodule.eos_id,
target_token_ids=taskmodule.target_token_ids,
embedding_weight_mapping=taskmodule.label_embedding_weight_mapping,
)
)
additional_model_kwargs["base_model_config"] = base_model_config
# initialize the model
model: PyTorchIEModel = hydra.utils.instantiate(
cfg.model, _convert_="partial", **additional_model_kwargs
)
log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks")
log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_dict_entries(cfg, key="logger")
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
object_dict = {
"cfg": cfg,
"dataset": dataset,
"taskmodule": taskmodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
if cfg.model_save_dir is not None:
log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
taskmodule.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub)
else:
log.warning("the taskmodule is not saved because no save_dir is specified")
if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
train_metrics = trainer.callback_metrics
best_ckpt_path = trainer.checkpoint_callback.best_model_path
best_epoch = None
if best_ckpt_path != "":
log.info(f"Best ckpt path: {best_ckpt_path}")
best_checkpoint_file = os.path.basename(best_ckpt_path)
utils.log_hyperparameters(
logger=logger,
best_checkpoint=best_checkpoint_file,
checkpoint_dir=trainer.checkpoint_callback.dirpath,
)
# get epoch from best_checkpoint_file (e.g. "epoch_078.ckpt")
try:
best_epoch = int(os.path.splitext(best_checkpoint_file)[0].split("_")[-1])
except Exception as e:
log.warning(
f'Could not retrieve epoch from best checkpoint file name: "{e}". '
f"Expected format: " + '"epoch_{best_epoch}.ckpt"'
)
if not cfg.trainer.get("fast_dev_run"):
if cfg.model_save_dir is not None:
if best_ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for saving...")
else:
model = type(model).load_from_checkpoint(best_ckpt_path)
log.info(f"Save model to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
model.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub)
else:
log.warning("the model is not saved because no save_dir is specified")
if cfg.get("validate"):
log.info("Starting validation!")
if best_ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for validation...")
trainer.validate(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
elif cfg.get("train"):
log.warning(
"Validation after training is skipped! That means, the finally reported validation scores are "
"the values from the *last* checkpoint, not from the *best* checkpoint (which is saved)!"
)
if cfg.get("test"):
log.info("Starting testing!")
if best_ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
test_metrics = trainer.callback_metrics
test_metrics["best_epoch"] = best_epoch
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
# add model_save_dir to the result so that it gets dumped to job_return_value.json
# if we use hydra_callbacks.SaveJobReturnValueCallback
if cfg.get("model_save_dir") is not None:
metric_dict["model_save_dir"] = cfg.model_save_dir
return metric_dict, object_dict
@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
# train the model
metric_dict, _ = train(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
if cfg.get("optimized_metric") is not None:
metric_value = get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)
# return optimized metric
return metric_value
else:
return metric_dict
if __name__ == "__main__":
utils.replace_sys_args_with_values_from_files()
utils.prepare_omegaconf()
main()
|