|
import argparse |
|
import copy |
|
import itertools |
|
import json |
|
import os.path |
|
import random |
|
from collections import deque |
|
from pathlib import Path |
|
|
|
from pytorch_lightning import seed_everything |
|
|
|
seed_everything(seed=10, workers=True) |
|
|
|
import cv2 as cv |
|
import numpy as np |
|
import torch |
|
import torch.backends.cudnn |
|
import torch.nn as nn |
|
import torchvision.transforms as T |
|
|
|
from PIL import Image |
|
from torchvision.models.segmentation import deeplabv3_resnet101 |
|
from tqdm import tqdm |
|
|
|
from SoccerNet.Evaluation.utils_calibration import SoccerPitch |
|
|
|
|
|
def generate_class_synthesis(semantic_mask, radius): |
|
""" |
|
This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic |
|
class blobs. |
|
:param semantic_mask: a image containing the segmentation predictions |
|
:param radius: circle radius |
|
:return: a dictionary which associates with each class detected a list of points ( the circles centers) |
|
""" |
|
buckets = dict() |
|
kernel = np.ones((5, 5), np.uint8) |
|
semantic_mask = cv.erode(semantic_mask, kernel, iterations=1) |
|
for k, class_name in enumerate(SoccerPitch.lines_classes): |
|
mask = semantic_mask == k + 1 |
|
if mask.sum() > 0: |
|
disk_list = synthesize_mask(mask, radius) |
|
if len(disk_list): |
|
buckets[class_name] = disk_list |
|
|
|
return buckets |
|
|
|
|
|
def join_points(point_list, maxdist): |
|
""" |
|
Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates |
|
polylines by linking close points together if their distance is below the maxdist threshold. |
|
:param point_list: List of points of the same line class |
|
:param maxdist: minimal distance between two polylines. |
|
:return: a list of polylines |
|
""" |
|
polylines = [] |
|
|
|
if not len(point_list): |
|
return polylines |
|
head = point_list[0] |
|
tail = point_list[0] |
|
polyline = deque() |
|
polyline.append(point_list[0]) |
|
remaining_points = copy.deepcopy(point_list[1:]) |
|
|
|
while len(remaining_points) > 0: |
|
min_dist_tail = 1000 |
|
min_dist_head = 1000 |
|
best_head = -1 |
|
best_tail = -1 |
|
for j, point in enumerate(remaining_points): |
|
dist_tail = np.sqrt(np.sum(np.square(point - tail))) |
|
dist_head = np.sqrt(np.sum(np.square(point - head))) |
|
if dist_tail < min_dist_tail: |
|
min_dist_tail = dist_tail |
|
best_tail = j |
|
if dist_head < min_dist_head: |
|
min_dist_head = dist_head |
|
best_head = j |
|
|
|
if min_dist_head <= min_dist_tail and min_dist_head < maxdist: |
|
polyline.appendleft(remaining_points[best_head]) |
|
head = polyline[0] |
|
remaining_points.pop(best_head) |
|
elif min_dist_tail < min_dist_head and min_dist_tail < maxdist: |
|
polyline.append(remaining_points[best_tail]) |
|
tail = polyline[-1] |
|
remaining_points.pop(best_tail) |
|
else: |
|
polylines.append(list(polyline.copy())) |
|
head = remaining_points[0] |
|
tail = remaining_points[0] |
|
polyline = deque() |
|
polyline.append(head) |
|
remaining_points.pop(0) |
|
polylines.append(list(polyline)) |
|
return polylines |
|
|
|
|
|
def get_line_extremities(buckets, maxdist, width, height, num_points_lines, num_points_circles): |
|
""" |
|
Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities |
|
of the longest polyline that can be built on the class blobs, and normalize its coordinates |
|
by the image size. |
|
:param buckets: The dictionary associating line classes to the set of circle centers that covers best the class |
|
prediction blobs in the segmentation mask |
|
:param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic) |
|
:param width: image width |
|
:param height: image height |
|
:return: a dictionary associating to each class its extremities |
|
""" |
|
extremities = dict() |
|
for class_name, disks_list in buckets.items(): |
|
polyline_list = join_points(disks_list, maxdist) |
|
max_len = 0 |
|
longest_polyline = [] |
|
for polyline in polyline_list: |
|
if len(polyline) > max_len: |
|
max_len = len(polyline) |
|
longest_polyline = polyline |
|
extremities[class_name] = [ |
|
{'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height}, |
|
{'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height}, |
|
|
|
] |
|
num_points = num_points_lines |
|
if "Circle" in class_name: |
|
num_points = num_points_circles |
|
if num_points > 2: |
|
|
|
|
|
for i in range(1, num_points - 1): |
|
extremities[class_name].insert( |
|
len(extremities[class_name]) - 1, |
|
{'x': longest_polyline[i * int(len(longest_polyline) / num_points)][1] / width, 'y': longest_polyline[i * int(len(longest_polyline) / num_points)][0] / height} |
|
) |
|
|
|
return extremities |
|
|
|
|
|
def get_support_center(mask, start, disk_radius, min_support=0.1): |
|
""" |
|
Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and |
|
radius of disk_radius pixels. |
|
:param mask: Boolean mask |
|
:param start: A point located on a true pixel of the mask |
|
:param disk_radius: the radius of the circles |
|
:param min_support: proportion of the area under the circle area that should be True in order to get enough support |
|
:return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under |
|
the circle |
|
""" |
|
x = int(start[0]) |
|
y = int(start[1]) |
|
support_pixels = 1 |
|
result = [x, y] |
|
xstart = x - disk_radius |
|
if xstart < 0: |
|
xstart = 0 |
|
xend = x + disk_radius |
|
if xend > mask.shape[0]: |
|
xend = mask.shape[0] - 1 |
|
|
|
ystart = y - disk_radius |
|
if ystart < 0: |
|
ystart = 0 |
|
yend = y + disk_radius |
|
if yend > mask.shape[1]: |
|
yend = mask.shape[1] - 1 |
|
|
|
for i in range(xstart, xend + 1): |
|
for j in range(ystart, yend + 1): |
|
dist = np.sqrt(np.square(x - i) + np.square(y - j)) |
|
if dist < disk_radius and mask[i, j] > 0: |
|
support_pixels += 1 |
|
result[0] += i |
|
result[1] += j |
|
support = True |
|
if support_pixels < min_support * np.square(disk_radius) * np.pi: |
|
support = False |
|
|
|
result = np.array(result) |
|
result = np.true_divide(result, support_pixels) |
|
|
|
return support, result |
|
|
|
|
|
def synthesize_mask(semantic_mask, disk_radius): |
|
""" |
|
Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the |
|
proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid |
|
fitting circles on alone pixels. |
|
:param semantic_mask: boolean mask |
|
:param disk_radius: radius of the circles |
|
:return: a list of disk centers, that have enough support |
|
""" |
|
mask = semantic_mask.copy().astype(np.uint8) |
|
points = np.transpose(np.nonzero(mask)) |
|
disks = [] |
|
while len(points): |
|
|
|
start = random.choice(points) |
|
dist = 10. |
|
success = True |
|
while dist > 1.: |
|
enough_support, center = get_support_center(mask, start, disk_radius) |
|
if not enough_support: |
|
bad_point = np.round(center).astype(np.int32) |
|
cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1) |
|
success = False |
|
dist = np.sqrt(np.sum(np.square(center - start))) |
|
start = center |
|
if success: |
|
disks.append(np.round(start).astype(np.int32)) |
|
cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1) |
|
points = np.transpose(np.nonzero(mask)) |
|
|
|
return disks |
|
|
|
class CustomNetwork: |
|
|
|
def __init__(self, checkpoint): |
|
print("Loading model" + checkpoint) |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model = deeplabv3_resnet101(num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True) |
|
self.model.load_state_dict(torch.load(checkpoint)["model"], strict=False) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
print("using", self.device) |
|
|
|
def forward(self, img): |
|
trf = T.Compose( |
|
[ |
|
T.Resize(256), |
|
|
|
T.ToTensor(), |
|
T.Normalize( |
|
mean = [0.485, 0.456, 0.406], |
|
std = [0.229, 0.224, 0.225] |
|
) |
|
] |
|
) |
|
img = trf(img).unsqueeze(0).to(self.device) |
|
result = self.model(img)["out"].detach().squeeze(0).argmax(0) |
|
result = result.cpu().numpy().astype(np.uint8) |
|
|
|
return result |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Test') |
|
|
|
parser.add_argument('-s', '--soccernet', default="/nfs/data/soccernet/calibration/", type=str, |
|
help='Path to the SoccerNet-V3 dataset folder') |
|
parser.add_argument('-p', '--prediction', default="sn-calib-test_endpoints", required=False, type=str, |
|
help="Path to the prediction folder") |
|
parser.add_argument('--split', required=False, type=str, default="challenge", help='Select the split of data') |
|
parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory') |
|
parser.add_argument('--resolution_width', required=False, type=int, default=455, |
|
help='width resolution of the images') |
|
parser.add_argument('--resolution_height', required=False, type=int, default=256, |
|
help='height resolution of the images') |
|
parser.add_argument('--checkpoint', required=False, type=str, help="Path to the custom model checkpoint.") |
|
parser.add_argument('--pp_radius', required=False, type=int, default=4, |
|
help='Post processing: Radius of circles that cover each segment.') |
|
parser.add_argument('--pp_maxdists', required=False, type=int, default=30, |
|
help='Post processing: Maximum distance of circles that are allowed within one segment.') |
|
parser.add_argument('--num_points_lines', required=False, type=int, default=2, choices=range(2,10), |
|
help='Post processing: Number of keypoints that represent a line segment') |
|
parser.add_argument('--num_points_circles', required=False, type=int, default=2, choices=range(2,10), |
|
help='Post processing: Number of keypoints that represent a circle segment') |
|
args = parser.parse_args() |
|
|
|
lines_palette = [0, 0, 0] |
|
for line_class in SoccerPitch.lines_classes: |
|
lines_palette.extend(SoccerPitch.palette[line_class]) |
|
|
|
model = CustomNetwork(args.checkpoint) |
|
|
|
dataset_dir = os.path.join(args.soccernet, args.split) |
|
if not os.path.exists(dataset_dir): |
|
print("Invalid dataset path !") |
|
exit(-1) |
|
|
|
radius = args.pp_radius |
|
maxdists = args.pp_maxdists |
|
|
|
frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f] |
|
with tqdm(enumerate(frames), total=len(frames), ncols=160) as t: |
|
for i, frame in t: |
|
|
|
output_prediction_folder = os.path.join(str(args.prediction), f"np{args.num_points_lines}_nc{args.num_points_circles}_r{radius}_md{maxdists}", args.split) |
|
if not os.path.exists(output_prediction_folder): |
|
os.makedirs(output_prediction_folder) |
|
prediction = dict() |
|
count = 0 |
|
|
|
frame_path = os.path.join(dataset_dir, frame) |
|
|
|
frame_index = frame.split(".")[0] |
|
|
|
image = Image.open(frame_path) |
|
|
|
semlines = model.forward(image) |
|
|
|
|
|
if args.masks: |
|
mask = Image.fromarray(semlines.astype(np.uint8)).convert('P') |
|
mask.putpalette(lines_palette) |
|
mask_file = os.path.join(output_prediction_folder, frame) |
|
mask.convert("RGB").save(mask_file) |
|
skeletons = generate_class_synthesis(semlines, radius) |
|
|
|
extremities = get_line_extremities(skeletons, maxdists, args.resolution_width, args.resolution_height, args.num_points_lines, args.num_points_circles) |
|
|
|
|
|
prediction = extremities |
|
count += 1 |
|
|
|
prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json") |
|
with open(prediction_file, "w") as f: |
|
json.dump(prediction, f, indent=4) |
|
|
|
|