from typing import Union from pathlib import Path import torch from torchvision.models.segmentation import deeplabv3_resnet101 from SoccerNet.Evaluation.utils_calibration import SoccerPitch class InferenceSegmentationModel: def __init__(self, checkpoint: Union[str, Path], device) -> None: self.device = device self.model = deeplabv3_resnet101( num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True ) checkpoint_data = torch.load(checkpoint, map_location=self.device, weights_only=False) self.model.load_state_dict(checkpoint_data["model"], strict=False) self.model.to(self.device) self.model.eval() def inference(self, img_batch): return self.model(img_batch)["out"].argmax(1)