RamziBm's picture
init
bdb955e
raw
history blame
775 Bytes
from typing import Union
from pathlib import Path
import torch
from torchvision.models.segmentation import deeplabv3_resnet101
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
class InferenceSegmentationModel:
def __init__(self, checkpoint: Union[str, Path], device) -> None:
self.device = device
self.model = deeplabv3_resnet101(
num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True
)
checkpoint_data = torch.load(checkpoint, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint_data["model"], strict=False)
self.model.to(self.device)
self.model.eval()
def inference(self, img_batch):
return self.model(img_batch)["out"].argmax(1)