|
import pyrootutils |
|
|
|
root = pyrootutils.setup_root( |
|
search_from=__file__, |
|
indicator=[".project-root"], |
|
pythonpath=True, |
|
dotenv=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Tuple |
|
|
|
import hydra |
|
import pytorch_lightning as pl |
|
from omegaconf import DictConfig |
|
from pie_datasets import DatasetDict |
|
from pytorch_ie.core import DocumentMetric |
|
from pytorch_ie.metrics import * |
|
|
|
from src import utils |
|
from src.metrics import * |
|
|
|
log = utils.get_pylogger(__name__) |
|
|
|
|
|
@utils.task_wrapper |
|
def evaluate_documents(cfg: DictConfig) -> Tuple[dict, dict]: |
|
"""Evaluates serialized PIE documents. |
|
|
|
This method is wrapped in optional @task_wrapper decorator which applies extra utilities |
|
before and after the call. |
|
Args: |
|
cfg (DictConfig): Configuration composed by Hydra. |
|
Returns: |
|
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. |
|
""" |
|
|
|
|
|
if cfg.get("seed"): |
|
pl.seed_everything(cfg.seed, workers=True) |
|
|
|
|
|
log.info(f"Instantiating dataset <{cfg.dataset._target_}>") |
|
dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial") |
|
|
|
|
|
log.info(f"Instantiating metric <{cfg.metric._target_}>") |
|
metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial") |
|
|
|
|
|
dataset = metric.convert_dataset(dataset) |
|
|
|
|
|
loggers = utils.instantiate_dict_entries(cfg, "logger") |
|
|
|
object_dict = { |
|
"cfg": cfg, |
|
"dataset": dataset, |
|
"metric": metric, |
|
"logger": loggers, |
|
} |
|
|
|
if loggers: |
|
log.info("Logging hyperparameters!") |
|
|
|
for logger in loggers: |
|
logger.log_hyperparams(cfg) |
|
|
|
splits = cfg.get("splits", None) |
|
if splits is None: |
|
documents = dataset |
|
else: |
|
documents = type(dataset)({k: v for k, v in dataset.items() if k in splits}) |
|
|
|
metric_dict = metric(documents) |
|
|
|
return metric_dict, object_dict |
|
|
|
|
|
@hydra.main( |
|
version_base="1.2", config_path=str(root / "configs"), config_name="evaluate_documents.yaml" |
|
) |
|
def main(cfg: DictConfig) -> Any: |
|
metric_dict, _ = evaluate_documents(cfg) |
|
return metric_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
utils.replace_sys_args_with_values_from_files() |
|
utils.prepare_omegaconf() |
|
main() |
|
|