RamziBm's picture
init
bdb955e
raw
history blame
13 kB
import argparse
import copy
import itertools
import json
import os.path
import random
from collections import deque
from pathlib import Path
from pytorch_lightning import seed_everything
seed_everything(seed=10, workers=True)
import cv2 as cv
import numpy as np
import torch
import torch.backends.cudnn
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torchvision.models.segmentation import deeplabv3_resnet101
from tqdm import tqdm
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
def generate_class_synthesis(semantic_mask, radius):
"""
This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic
class blobs.
:param semantic_mask: a image containing the segmentation predictions
:param radius: circle radius
:return: a dictionary which associates with each class detected a list of points ( the circles centers)
"""
buckets = dict()
kernel = np.ones((5, 5), np.uint8)
semantic_mask = cv.erode(semantic_mask, kernel, iterations=1)
for k, class_name in enumerate(SoccerPitch.lines_classes):
mask = semantic_mask == k + 1
if mask.sum() > 0:
disk_list = synthesize_mask(mask, radius)
if len(disk_list):
buckets[class_name] = disk_list
return buckets
def join_points(point_list, maxdist):
"""
Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates
polylines by linking close points together if their distance is below the maxdist threshold.
:param point_list: List of points of the same line class
:param maxdist: minimal distance between two polylines.
:return: a list of polylines
"""
polylines = []
if not len(point_list):
return polylines
head = point_list[0]
tail = point_list[0]
polyline = deque()
polyline.append(point_list[0])
remaining_points = copy.deepcopy(point_list[1:])
while len(remaining_points) > 0:
min_dist_tail = 1000
min_dist_head = 1000
best_head = -1
best_tail = -1
for j, point in enumerate(remaining_points):
dist_tail = np.sqrt(np.sum(np.square(point - tail)))
dist_head = np.sqrt(np.sum(np.square(point - head)))
if dist_tail < min_dist_tail:
min_dist_tail = dist_tail
best_tail = j
if dist_head < min_dist_head:
min_dist_head = dist_head
best_head = j
if min_dist_head <= min_dist_tail and min_dist_head < maxdist:
polyline.appendleft(remaining_points[best_head])
head = polyline[0]
remaining_points.pop(best_head)
elif min_dist_tail < min_dist_head and min_dist_tail < maxdist:
polyline.append(remaining_points[best_tail])
tail = polyline[-1]
remaining_points.pop(best_tail)
else:
polylines.append(list(polyline.copy()))
head = remaining_points[0]
tail = remaining_points[0]
polyline = deque()
polyline.append(head)
remaining_points.pop(0)
polylines.append(list(polyline))
return polylines
def get_line_extremities(buckets, maxdist, width, height, num_points_lines, num_points_circles):
"""
Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities
of the longest polyline that can be built on the class blobs, and normalize its coordinates
by the image size.
:param buckets: The dictionary associating line classes to the set of circle centers that covers best the class
prediction blobs in the segmentation mask
:param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic)
:param width: image width
:param height: image height
:return: a dictionary associating to each class its extremities
"""
extremities = dict()
for class_name, disks_list in buckets.items():
polyline_list = join_points(disks_list, maxdist)
max_len = 0
longest_polyline = []
for polyline in polyline_list:
if len(polyline) > max_len:
max_len = len(polyline)
longest_polyline = polyline
extremities[class_name] = [
{'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height},
{'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height},
]
num_points = num_points_lines
if "Circle" in class_name:
num_points = num_points_circles
if num_points > 2:
# equally spaced points along the longest polyline
# skip first and last as they already exist
for i in range(1, num_points - 1):
extremities[class_name].insert(
len(extremities[class_name]) - 1,
{'x': longest_polyline[i * int(len(longest_polyline) / num_points)][1] / width, 'y': longest_polyline[i * int(len(longest_polyline) / num_points)][0] / height}
)
return extremities
def get_support_center(mask, start, disk_radius, min_support=0.1):
"""
Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and
radius of disk_radius pixels.
:param mask: Boolean mask
:param start: A point located on a true pixel of the mask
:param disk_radius: the radius of the circles
:param min_support: proportion of the area under the circle area that should be True in order to get enough support
:return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under
the circle
"""
x = int(start[0])
y = int(start[1])
support_pixels = 1
result = [x, y]
xstart = x - disk_radius
if xstart < 0:
xstart = 0
xend = x + disk_radius
if xend > mask.shape[0]:
xend = mask.shape[0] - 1
ystart = y - disk_radius
if ystart < 0:
ystart = 0
yend = y + disk_radius
if yend > mask.shape[1]:
yend = mask.shape[1] - 1
for i in range(xstart, xend + 1):
for j in range(ystart, yend + 1):
dist = np.sqrt(np.square(x - i) + np.square(y - j))
if dist < disk_radius and mask[i, j] > 0:
support_pixels += 1
result[0] += i
result[1] += j
support = True
if support_pixels < min_support * np.square(disk_radius) * np.pi:
support = False
result = np.array(result)
result = np.true_divide(result, support_pixels)
return support, result
def synthesize_mask(semantic_mask, disk_radius):
"""
Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the
proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid
fitting circles on alone pixels.
:param semantic_mask: boolean mask
:param disk_radius: radius of the circles
:return: a list of disk centers, that have enough support
"""
mask = semantic_mask.copy().astype(np.uint8)
points = np.transpose(np.nonzero(mask))
disks = []
while len(points):
start = random.choice(points)
dist = 10.
success = True
while dist > 1.:
enough_support, center = get_support_center(mask, start, disk_radius)
if not enough_support:
bad_point = np.round(center).astype(np.int32)
cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1)
success = False
dist = np.sqrt(np.sum(np.square(center - start)))
start = center
if success:
disks.append(np.round(start).astype(np.int32))
cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1)
points = np.transpose(np.nonzero(mask))
return disks
class CustomNetwork:
def __init__(self, checkpoint):
print("Loading model" + checkpoint)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = deeplabv3_resnet101(num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True)
self.model.load_state_dict(torch.load(checkpoint)["model"], strict=False)
self.model.to(self.device)
self.model.eval()
print("using", self.device)
def forward(self, img):
trf = T.Compose(
[
T.Resize(256),
#T.CenterCrop(224),
T.ToTensor(),
T.Normalize(
mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225]
)
]
)
img = trf(img).unsqueeze(0).to(self.device)
result = self.model(img)["out"].detach().squeeze(0).argmax(0)
result = result.cpu().numpy().astype(np.uint8)
#print(result)
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test')
parser.add_argument('-s', '--soccernet', default="/nfs/data/soccernet/calibration/", type=str,
help='Path to the SoccerNet-V3 dataset folder')
parser.add_argument('-p', '--prediction', default="sn-calib-test_endpoints", required=False, type=str,
help="Path to the prediction folder")
parser.add_argument('--split', required=False, type=str, default="challenge", help='Select the split of data')
parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory')
parser.add_argument('--resolution_width', required=False, type=int, default=455,
help='width resolution of the images')
parser.add_argument('--resolution_height', required=False, type=int, default=256,
help='height resolution of the images')
parser.add_argument('--checkpoint', required=False, type=str, help="Path to the custom model checkpoint.")
parser.add_argument('--pp_radius', required=False, type=int, default=4,
help='Post processing: Radius of circles that cover each segment.')
parser.add_argument('--pp_maxdists', required=False, type=int, default=30,
help='Post processing: Maximum distance of circles that are allowed within one segment.')
parser.add_argument('--num_points_lines', required=False, type=int, default=2, choices=range(2,10),
help='Post processing: Number of keypoints that represent a line segment')
parser.add_argument('--num_points_circles', required=False, type=int, default=2, choices=range(2,10),
help='Post processing: Number of keypoints that represent a circle segment')
args = parser.parse_args()
lines_palette = [0, 0, 0]
for line_class in SoccerPitch.lines_classes:
lines_palette.extend(SoccerPitch.palette[line_class])
model = CustomNetwork(args.checkpoint)
dataset_dir = os.path.join(args.soccernet, args.split)
if not os.path.exists(dataset_dir):
print("Invalid dataset path !")
exit(-1)
radius = args.pp_radius
maxdists = args.pp_maxdists
frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
with tqdm(enumerate(frames), total=len(frames), ncols=160) as t:
for i, frame in t:
output_prediction_folder = os.path.join(str(args.prediction), f"np{args.num_points_lines}_nc{args.num_points_circles}_r{radius}_md{maxdists}", args.split)
if not os.path.exists(output_prediction_folder):
os.makedirs(output_prediction_folder)
prediction = dict()
count = 0
frame_path = os.path.join(dataset_dir, frame)
frame_index = frame.split(".")[0]
image = Image.open(frame_path)
semlines = model.forward(image)
#print(semlines.shape)
# print("\nsemlines", type(semlines), semlines.shape)
if args.masks:
mask = Image.fromarray(semlines.astype(np.uint8)).convert('P')
mask.putpalette(lines_palette)
mask_file = os.path.join(output_prediction_folder, frame)
mask.convert("RGB").save(mask_file)
skeletons = generate_class_synthesis(semlines, radius)
extremities = get_line_extremities(skeletons, maxdists, args.resolution_width, args.resolution_height, args.num_points_lines, args.num_points_circles)
prediction = extremities
count += 1
prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json")
with open(prediction_file, "w") as f:
json.dump(prediction, f, indent=4)