|
import pyrootutils |
|
|
|
root = pyrootutils.setup_root( |
|
search_from=__file__, |
|
indicator=[".project-root"], |
|
pythonpath=True, |
|
dotenv=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import timeit |
|
from collections.abc import Iterable, Sequence |
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
import hydra |
|
import pytorch_lightning as pl |
|
from omegaconf import DictConfig, OmegaConf |
|
from pie_datasets import Dataset, DatasetDict |
|
from pie_modules.models import * |
|
from pie_modules.taskmodules import * |
|
from pytorch_ie import Document, Pipeline |
|
from pytorch_ie.models import * |
|
from pytorch_ie.taskmodules import * |
|
|
|
from src import utils |
|
from src.models import * |
|
from src.serializer.interface import DocumentSerializer |
|
from src.taskmodules import * |
|
|
|
log = utils.get_pylogger(__name__) |
|
|
|
|
|
def document_batch_iter( |
|
dataset: Union[Sequence[Document], Iterable[Document]], batch_size: int |
|
) -> Iterable[Sequence[Document]]: |
|
if isinstance(dataset, Sequence): |
|
for i in range(0, len(dataset), batch_size): |
|
yield dataset[i : i + batch_size] |
|
elif isinstance(dataset, Iterable): |
|
docs = [] |
|
for doc in dataset: |
|
docs.append(doc) |
|
if len(docs) == batch_size: |
|
yield docs |
|
docs = [] |
|
if docs: |
|
yield docs |
|
else: |
|
raise ValueError(f"Unsupported dataset type: {type(dataset)}") |
|
|
|
|
|
@utils.task_wrapper |
|
def predict(cfg: DictConfig) -> Tuple[dict, dict]: |
|
"""Contains minimal example of the prediction pipeline. Uses a pretrained model to annotate |
|
documents from a dataset and serializes them. |
|
|
|
Args: |
|
cfg (DictConfig): Configuration composed by Hydra. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
|
|
if cfg.get("seed"): |
|
pl.seed_everything(cfg.seed, workers=True) |
|
|
|
|
|
log.info(f"Instantiating dataset <{cfg.dataset._target_}>") |
|
dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial") |
|
|
|
|
|
|
|
|
|
pipeline: Optional[Pipeline] = None |
|
if cfg.get("pipeline") and cfg.pipeline.get("_target_"): |
|
log.info(f"Instantiating pipeline <{cfg.pipeline._target_}> from {cfg.model_name_or_path}") |
|
pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial") |
|
|
|
|
|
|
|
if cfg.ckpt_path is not None: |
|
pipeline.model = pipeline.model.load_from_checkpoint(checkpoint_path=cfg.ckpt_path).to( |
|
pipeline.device |
|
) |
|
|
|
|
|
dataset = pipeline.taskmodule.convert_dataset(dataset) |
|
|
|
|
|
serializer: Optional[DocumentSerializer] = None |
|
if cfg.get("serializer") and cfg.serializer.get("_target_"): |
|
log.info(f"Instantiating serializer <{cfg.serializer._target_}>") |
|
serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial") |
|
|
|
|
|
dataset_predict = dataset[cfg.dataset_split] |
|
|
|
object_dict = { |
|
"cfg": cfg, |
|
"dataset": dataset, |
|
"pipeline": pipeline, |
|
"serializer": serializer, |
|
} |
|
result: Dict[str, Any] = {} |
|
if pipeline is not None: |
|
log.info("Starting inference!") |
|
prediction_time = 0.0 |
|
else: |
|
log.warning("No prediction pipeline is defined, skip inference!") |
|
prediction_time = None |
|
document_batch_size = cfg.get("document_batch_size", None) |
|
for docs_batch in ( |
|
document_batch_iter(dataset_predict, document_batch_size) |
|
if document_batch_size |
|
else [dataset_predict] |
|
): |
|
if pipeline is not None: |
|
t_start = timeit.default_timer() |
|
docs_batch = pipeline(docs_batch, inplace=False) |
|
prediction_time += timeit.default_timer() - t_start |
|
|
|
|
|
if serializer is not None: |
|
|
|
|
|
serializer_result = serializer(docs_batch) |
|
if "serializer" in result and result["serializer"] != serializer_result: |
|
log.warning( |
|
f"serializer result changed from {result['serializer']} to {serializer_result}" |
|
" during prediction. Only the last result is returned." |
|
) |
|
result["serializer"] = serializer_result |
|
|
|
if prediction_time is not None: |
|
result["prediction_time"] = prediction_time |
|
|
|
|
|
if cfg.get("config_out_path"): |
|
config_out_dir = os.path.dirname(cfg.config_out_path) |
|
os.makedirs(config_out_dir, exist_ok=True) |
|
OmegaConf.save(config=cfg, f=cfg.config_out_path) |
|
result["config"] = cfg.config_out_path |
|
|
|
return result, object_dict |
|
|
|
|
|
@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="predict.yaml") |
|
def main(cfg: DictConfig) -> None: |
|
result_dict, _ = predict(cfg) |
|
return result_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
utils.replace_sys_args_with_values_from_files() |
|
utils.prepare_omegaconf() |
|
main() |
|
|