import argparse import copy import json import os.path import random from collections import deque from pathlib import Path import cv2 as cv import numpy as np import torch import torch.backends.cudnn import torch.nn as nn from PIL import Image from torchvision.models.segmentation import deeplabv3_resnet50 from tqdm import tqdm from SoccerNet.Evaluation.utils_calibration import SoccerPitch def generate_class_synthesis(semantic_mask, radius): """ This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic class blobs. :param semantic_mask: a image containing the segmentation predictions :param radius: circle radius :return: a dictionary which associates with each class detected a list of points ( the circles centers) """ buckets = dict() kernel = np.ones((5, 5), np.uint8) semantic_mask = cv.erode(semantic_mask, kernel, iterations=1) for k, class_name in enumerate(SoccerPitch.lines_classes): mask = semantic_mask == k + 1 if mask.sum() > 0: disk_list = synthesize_mask(mask, radius) if len(disk_list): buckets[class_name] = disk_list return buckets def join_points(point_list, maxdist): """ Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates polylines by linking close points together if their distance is below the maxdist threshold. :param point_list: List of points of the same line class :param maxdist: minimal distance between two polylines. :return: a list of polylines """ polylines = [] if not len(point_list): return polylines head = point_list[0] tail = point_list[0] polyline = deque() polyline.append(point_list[0]) remaining_points = copy.deepcopy(point_list[1:]) while len(remaining_points) > 0: min_dist_tail = 1000 min_dist_head = 1000 best_head = -1 best_tail = -1 for j, point in enumerate(remaining_points): dist_tail = np.sqrt(np.sum(np.square(point - tail))) dist_head = np.sqrt(np.sum(np.square(point - head))) if dist_tail < min_dist_tail: min_dist_tail = dist_tail best_tail = j if dist_head < min_dist_head: min_dist_head = dist_head best_head = j if min_dist_head <= min_dist_tail and min_dist_head < maxdist: polyline.appendleft(remaining_points[best_head]) head = polyline[0] remaining_points.pop(best_head) elif min_dist_tail < min_dist_head and min_dist_tail < maxdist: polyline.append(remaining_points[best_tail]) tail = polyline[-1] remaining_points.pop(best_tail) else: polylines.append(list(polyline.copy())) head = remaining_points[0] tail = remaining_points[0] polyline = deque() polyline.append(head) remaining_points.pop(0) polylines.append(list(polyline)) return polylines def get_line_extremities(buckets, maxdist, width, height): """ Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities of the longest polyline that can be built on the class blobs, and normalize its coordinates by the image size. :param buckets: The dictionary associating line classes to the set of circle centers that covers best the class prediction blobs in the segmentation mask :param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic) :param width: image width :param height: image height :return: a dictionary associating to each class its extremities """ extremities = dict() for class_name, disks_list in buckets.items(): polyline_list = join_points(disks_list, maxdist) max_len = 0 longest_polyline = [] for polyline in polyline_list: if len(polyline) > max_len: max_len = len(polyline) longest_polyline = polyline extremities[class_name] = [ {'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height}, {'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height} ] return extremities def get_support_center(mask, start, disk_radius, min_support=0.1): """ Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and radius of disk_radius pixels. :param mask: Boolean mask :param start: A point located on a true pixel of the mask :param disk_radius: the radius of the circles :param min_support: proportion of the area under the circle area that should be True in order to get enough support :return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under the circle """ x = int(start[0]) y = int(start[1]) support_pixels = 1 result = [x, y] xstart = x - disk_radius if xstart < 0: xstart = 0 xend = x + disk_radius if xend > mask.shape[0]: xend = mask.shape[0] - 1 ystart = y - disk_radius if ystart < 0: ystart = 0 yend = y + disk_radius if yend > mask.shape[1]: yend = mask.shape[1] - 1 for i in range(xstart, xend + 1): for j in range(ystart, yend + 1): dist = np.sqrt(np.square(x - i) + np.square(y - j)) if dist < disk_radius and mask[i, j] > 0: support_pixels += 1 result[0] += i result[1] += j support = True if support_pixels < min_support * np.square(disk_radius) * np.pi: support = False result = np.array(result) result = np.true_divide(result, support_pixels) return support, result def synthesize_mask(semantic_mask, disk_radius): """ Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid fitting circles on alone pixels. :param semantic_mask: boolean mask :param disk_radius: radius of the circles :return: a list of disk centers, that have enough support """ mask = semantic_mask.copy().astype(np.uint8) points = np.transpose(np.nonzero(mask)) disks = [] while len(points): start = random.choice(points) dist = 10. success = True while dist > 1.: enough_support, center = get_support_center(mask, start, disk_radius) if not enough_support: bad_point = np.round(center).astype(np.int32) cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1) success = False dist = np.sqrt(np.sum(np.square(center - start))) start = center if success: disks.append(np.round(start).astype(np.int32)) cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1) points = np.transpose(np.nonzero(mask)) return disks class SegmentationNetwork: def __init__(self, model_file, mean_file, std_file, num_classes=29, width=640, height=360): file_path = Path(model_file).resolve() model = nn.DataParallel(deeplabv3_resnet50(pretrained=False, num_classes=num_classes)) self.init_weight(model, nn.init.kaiming_normal_, nn.BatchNorm2d, 1e-3, 0.1, mode='fan_in') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') checkpoint = torch.load(str(file_path), map_location=self.device) model.load_state_dict(checkpoint["model"]) model.eval() self.model = model.to(self.device) file_path = Path(mean_file).resolve() self.mean = np.load(str(file_path)) file_path = Path(std_file).resolve() self.std = np.load(str(file_path)) self.width = width self.height = height def init_weight(self, feature, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs): for name, m in feature.named_modules(): if isinstance(m, (nn.Conv2d, nn.Conv3d)): conv_init(m.weight, **kwargs) elif isinstance(m, norm_layer): m.eps = bn_eps m.momentum = bn_momentum nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def analyse_image(self, image): """ Process image and perform inference, returns mask of detected classes :param image: BGR image :return: predicted classes mask """ img = cv.resize(image, (self.width, self.height), interpolation=cv.INTER_LINEAR) img = np.asarray(img, np.float32) / 255. img = (img - self.mean) / self.std img = img.transpose((2, 0, 1)) img = torch.from_numpy(img).to(self.device).unsqueeze(0) cuda_result = self.model.forward(img.float()) output = cuda_result['out'].data[0].cpu().numpy() output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) return output if __name__ == "__main__": parser = argparse.ArgumentParser(description='Test') parser.add_argument('-s', '--soccernet', default="./annotations/", type=str, help='Path to the SoccerNet-V3 dataset folder') parser.add_argument('-p', '--prediction', default="./results_bis", required=False, type=str, help="Path to the prediction folder") parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data') parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory') parser.add_argument('--resolution_width', required=False, type=int, default=455, help='width resolution of the images') parser.add_argument('--resolution_height', required=False, type=int, default=256, help='height resolution of the images') parser.add_argument('--checkpoint_dir', default="resources") args = parser.parse_args() lines_palette = [0, 0, 0] for line_class in SoccerPitch.lines_classes: lines_palette.extend(SoccerPitch.palette[line_class]) calib_net = SegmentationNetwork( os.path.join(args.checkpoint_dir, "soccer_pitch_segmentation.pth"), os.path.join(args.checkpoint_dir, "mean.npy"), os.path.join(args.checkpoint_dir, "std.npy") ) dataset_dir = os.path.join(args.soccernet, args.split) if not os.path.exists(dataset_dir): print("Invalid dataset path !") exit(-1) frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f] with tqdm(enumerate(frames), total=len(frames), ncols=160) as t: for i, frame in t: output_prediction_folder = os.path.join(args.prediction, args.split) if not os.path.exists(output_prediction_folder): os.makedirs(output_prediction_folder) prediction = dict() count = 0 frame_path = os.path.join(dataset_dir, frame) frame_index = frame.split(".")[0] image = cv.imread(frame_path) semlines = calib_net.analyse_image(image) if args.masks: mask = Image.fromarray(semlines.astype(np.uint8)).convert('P') mask.putpalette(lines_palette) mask_file = os.path.join(output_prediction_folder, frame) mask.convert("RGB").save(mask_file) skeletons = generate_class_synthesis(semlines, 6) extremities = get_line_extremities(skeletons, 40, args.resolution_width, args.resolution_height) prediction = extremities count += 1 prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json") with open(prediction_file, "w") as f: json.dump(prediction, f, indent=4)