|
import pyrootutils |
|
|
|
root = pyrootutils.setup_root( |
|
search_from=__file__, |
|
indicator=[".project-root"], |
|
pythonpath=True, |
|
dotenv=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from collections.abc import Iterable, Sequence |
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
import hydra |
|
import pytorch_lightning as pl |
|
from omegaconf import DictConfig, OmegaConf |
|
from pie_datasets import DatasetDict |
|
from pie_modules.models import * |
|
from pie_modules.taskmodules import * |
|
from pytorch_ie import Document, Pipeline |
|
from pytorch_ie.models import * |
|
from pytorch_ie.taskmodules import * |
|
|
|
from src import utils |
|
from src.models import * |
|
from src.serializer.interface import DocumentSerializer |
|
from src.taskmodules import * |
|
|
|
log = utils.get_pylogger(__name__) |
|
|
|
|
|
def document_batch_iter( |
|
dataset: Union[Sequence[Document], Iterable[Document]], batch_size: int |
|
) -> Iterable[Sequence[Document]]: |
|
if isinstance(dataset, Sequence): |
|
for i in range(0, len(dataset), batch_size): |
|
yield dataset[i : i + batch_size] |
|
elif isinstance(dataset, Iterable): |
|
docs = [] |
|
for doc in dataset: |
|
docs.append(doc) |
|
if len(docs) == batch_size: |
|
yield docs |
|
docs = [] |
|
if docs: |
|
yield docs |
|
else: |
|
raise ValueError(f"Unsupported dataset type: {type(dataset)}") |
|
|
|
|
|
@utils.task_wrapper |
|
def predict(cfg: DictConfig) -> Tuple[dict, dict]: |
|
"""Contains minimal example of the prediction pipeline. Uses a pretrained model to annotate |
|
documents from a dataset and serializes them. |
|
|
|
Args: |
|
cfg (DictConfig): Configuration composed by Hydra. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
|
|
if cfg.get("seed"): |
|
pl.seed_everything(cfg.seed, workers=True) |
|
|
|
|
|
log.info(f"Instantiating dataset <{cfg.dataset._target_}>") |
|
dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial") |
|
|
|
|
|
|
|
|
|
pipeline: Optional[Pipeline] = None |
|
if cfg.get("pipeline") and cfg.pipeline.get("_target_"): |
|
log.info(f"Instantiating pipeline <{cfg.pipeline._target_}> from {cfg.model_name_or_path}") |
|
pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial") |
|
|
|
|
|
|
|
if cfg.ckpt_path is not None: |
|
log.info(f"Loading model weights from checkpoint: {cfg.ckpt_path}") |
|
pipeline.model = ( |
|
type(pipeline.model) |
|
.load_from_checkpoint(checkpoint_path=cfg.ckpt_path) |
|
.to(pipeline.device) |
|
.to(dtype=pipeline.model.dtype) |
|
) |
|
|
|
|
|
dataset = pipeline.taskmodule.convert_dataset(dataset) |
|
|
|
|
|
serializer: Optional[DocumentSerializer] = None |
|
if cfg.get("serializer") and cfg.serializer.get("_target_"): |
|
log.info(f"Instantiating serializer <{cfg.serializer._target_}>") |
|
serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial") |
|
|
|
|
|
dataset_predict = dataset[cfg.dataset_split] |
|
|
|
object_dict = { |
|
"cfg": cfg, |
|
"dataset": dataset, |
|
"pipeline": pipeline, |
|
"serializer": serializer, |
|
} |
|
|
|
result: Dict[str, Any] = utils.predict_and_serialize( |
|
pipeline=pipeline, |
|
serializer=serializer, |
|
dataset=dataset_predict, |
|
document_batch_size=cfg.get("document_batch_size", None), |
|
) |
|
|
|
|
|
if cfg.get("config_out_path"): |
|
config_out_dir = os.path.dirname(cfg.config_out_path) |
|
os.makedirs(config_out_dir, exist_ok=True) |
|
OmegaConf.save(config=cfg, f=cfg.config_out_path) |
|
result["config"] = cfg.config_out_path |
|
|
|
return result, object_dict |
|
|
|
|
|
@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="predict.yaml") |
|
def main(cfg: DictConfig) -> None: |
|
result_dict, _ = predict(cfg) |
|
return result_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
utils.replace_sys_args_with_values_from_files() |
|
utils.prepare_omegaconf() |
|
main() |
|
|