File size: 13,007 Bytes
bdb955e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 |
import argparse
import copy
import itertools
import json
import os.path
import random
from collections import deque
from pathlib import Path
from pytorch_lightning import seed_everything
seed_everything(seed=10, workers=True)
import cv2 as cv
import numpy as np
import torch
import torch.backends.cudnn
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torchvision.models.segmentation import deeplabv3_resnet101
from tqdm import tqdm
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
def generate_class_synthesis(semantic_mask, radius):
"""
This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic
class blobs.
:param semantic_mask: a image containing the segmentation predictions
:param radius: circle radius
:return: a dictionary which associates with each class detected a list of points ( the circles centers)
"""
buckets = dict()
kernel = np.ones((5, 5), np.uint8)
semantic_mask = cv.erode(semantic_mask, kernel, iterations=1)
for k, class_name in enumerate(SoccerPitch.lines_classes):
mask = semantic_mask == k + 1
if mask.sum() > 0:
disk_list = synthesize_mask(mask, radius)
if len(disk_list):
buckets[class_name] = disk_list
return buckets
def join_points(point_list, maxdist):
"""
Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates
polylines by linking close points together if their distance is below the maxdist threshold.
:param point_list: List of points of the same line class
:param maxdist: minimal distance between two polylines.
:return: a list of polylines
"""
polylines = []
if not len(point_list):
return polylines
head = point_list[0]
tail = point_list[0]
polyline = deque()
polyline.append(point_list[0])
remaining_points = copy.deepcopy(point_list[1:])
while len(remaining_points) > 0:
min_dist_tail = 1000
min_dist_head = 1000
best_head = -1
best_tail = -1
for j, point in enumerate(remaining_points):
dist_tail = np.sqrt(np.sum(np.square(point - tail)))
dist_head = np.sqrt(np.sum(np.square(point - head)))
if dist_tail < min_dist_tail:
min_dist_tail = dist_tail
best_tail = j
if dist_head < min_dist_head:
min_dist_head = dist_head
best_head = j
if min_dist_head <= min_dist_tail and min_dist_head < maxdist:
polyline.appendleft(remaining_points[best_head])
head = polyline[0]
remaining_points.pop(best_head)
elif min_dist_tail < min_dist_head and min_dist_tail < maxdist:
polyline.append(remaining_points[best_tail])
tail = polyline[-1]
remaining_points.pop(best_tail)
else:
polylines.append(list(polyline.copy()))
head = remaining_points[0]
tail = remaining_points[0]
polyline = deque()
polyline.append(head)
remaining_points.pop(0)
polylines.append(list(polyline))
return polylines
def get_line_extremities(buckets, maxdist, width, height, num_points_lines, num_points_circles):
"""
Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities
of the longest polyline that can be built on the class blobs, and normalize its coordinates
by the image size.
:param buckets: The dictionary associating line classes to the set of circle centers that covers best the class
prediction blobs in the segmentation mask
:param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic)
:param width: image width
:param height: image height
:return: a dictionary associating to each class its extremities
"""
extremities = dict()
for class_name, disks_list in buckets.items():
polyline_list = join_points(disks_list, maxdist)
max_len = 0
longest_polyline = []
for polyline in polyline_list:
if len(polyline) > max_len:
max_len = len(polyline)
longest_polyline = polyline
extremities[class_name] = [
{'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height},
{'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height},
]
num_points = num_points_lines
if "Circle" in class_name:
num_points = num_points_circles
if num_points > 2:
# equally spaced points along the longest polyline
# skip first and last as they already exist
for i in range(1, num_points - 1):
extremities[class_name].insert(
len(extremities[class_name]) - 1,
{'x': longest_polyline[i * int(len(longest_polyline) / num_points)][1] / width, 'y': longest_polyline[i * int(len(longest_polyline) / num_points)][0] / height}
)
return extremities
def get_support_center(mask, start, disk_radius, min_support=0.1):
"""
Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and
radius of disk_radius pixels.
:param mask: Boolean mask
:param start: A point located on a true pixel of the mask
:param disk_radius: the radius of the circles
:param min_support: proportion of the area under the circle area that should be True in order to get enough support
:return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under
the circle
"""
x = int(start[0])
y = int(start[1])
support_pixels = 1
result = [x, y]
xstart = x - disk_radius
if xstart < 0:
xstart = 0
xend = x + disk_radius
if xend > mask.shape[0]:
xend = mask.shape[0] - 1
ystart = y - disk_radius
if ystart < 0:
ystart = 0
yend = y + disk_radius
if yend > mask.shape[1]:
yend = mask.shape[1] - 1
for i in range(xstart, xend + 1):
for j in range(ystart, yend + 1):
dist = np.sqrt(np.square(x - i) + np.square(y - j))
if dist < disk_radius and mask[i, j] > 0:
support_pixels += 1
result[0] += i
result[1] += j
support = True
if support_pixels < min_support * np.square(disk_radius) * np.pi:
support = False
result = np.array(result)
result = np.true_divide(result, support_pixels)
return support, result
def synthesize_mask(semantic_mask, disk_radius):
"""
Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the
proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid
fitting circles on alone pixels.
:param semantic_mask: boolean mask
:param disk_radius: radius of the circles
:return: a list of disk centers, that have enough support
"""
mask = semantic_mask.copy().astype(np.uint8)
points = np.transpose(np.nonzero(mask))
disks = []
while len(points):
start = random.choice(points)
dist = 10.
success = True
while dist > 1.:
enough_support, center = get_support_center(mask, start, disk_radius)
if not enough_support:
bad_point = np.round(center).astype(np.int32)
cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1)
success = False
dist = np.sqrt(np.sum(np.square(center - start)))
start = center
if success:
disks.append(np.round(start).astype(np.int32))
cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1)
points = np.transpose(np.nonzero(mask))
return disks
class CustomNetwork:
def __init__(self, checkpoint):
print("Loading model" + checkpoint)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = deeplabv3_resnet101(num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True)
self.model.load_state_dict(torch.load(checkpoint)["model"], strict=False)
self.model.to(self.device)
self.model.eval()
print("using", self.device)
def forward(self, img):
trf = T.Compose(
[
T.Resize(256),
#T.CenterCrop(224),
T.ToTensor(),
T.Normalize(
mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225]
)
]
)
img = trf(img).unsqueeze(0).to(self.device)
result = self.model(img)["out"].detach().squeeze(0).argmax(0)
result = result.cpu().numpy().astype(np.uint8)
#print(result)
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test')
parser.add_argument('-s', '--soccernet', default="/nfs/data/soccernet/calibration/", type=str,
help='Path to the SoccerNet-V3 dataset folder')
parser.add_argument('-p', '--prediction', default="sn-calib-test_endpoints", required=False, type=str,
help="Path to the prediction folder")
parser.add_argument('--split', required=False, type=str, default="challenge", help='Select the split of data')
parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory')
parser.add_argument('--resolution_width', required=False, type=int, default=455,
help='width resolution of the images')
parser.add_argument('--resolution_height', required=False, type=int, default=256,
help='height resolution of the images')
parser.add_argument('--checkpoint', required=False, type=str, help="Path to the custom model checkpoint.")
parser.add_argument('--pp_radius', required=False, type=int, default=4,
help='Post processing: Radius of circles that cover each segment.')
parser.add_argument('--pp_maxdists', required=False, type=int, default=30,
help='Post processing: Maximum distance of circles that are allowed within one segment.')
parser.add_argument('--num_points_lines', required=False, type=int, default=2, choices=range(2,10),
help='Post processing: Number of keypoints that represent a line segment')
parser.add_argument('--num_points_circles', required=False, type=int, default=2, choices=range(2,10),
help='Post processing: Number of keypoints that represent a circle segment')
args = parser.parse_args()
lines_palette = [0, 0, 0]
for line_class in SoccerPitch.lines_classes:
lines_palette.extend(SoccerPitch.palette[line_class])
model = CustomNetwork(args.checkpoint)
dataset_dir = os.path.join(args.soccernet, args.split)
if not os.path.exists(dataset_dir):
print("Invalid dataset path !")
exit(-1)
radius = args.pp_radius
maxdists = args.pp_maxdists
frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
with tqdm(enumerate(frames), total=len(frames), ncols=160) as t:
for i, frame in t:
output_prediction_folder = os.path.join(str(args.prediction), f"np{args.num_points_lines}_nc{args.num_points_circles}_r{radius}_md{maxdists}", args.split)
if not os.path.exists(output_prediction_folder):
os.makedirs(output_prediction_folder)
prediction = dict()
count = 0
frame_path = os.path.join(dataset_dir, frame)
frame_index = frame.split(".")[0]
image = Image.open(frame_path)
semlines = model.forward(image)
#print(semlines.shape)
# print("\nsemlines", type(semlines), semlines.shape)
if args.masks:
mask = Image.fromarray(semlines.astype(np.uint8)).convert('P')
mask.putpalette(lines_palette)
mask_file = os.path.join(output_prediction_folder, frame)
mask.convert("RGB").save(mask_file)
skeletons = generate_class_synthesis(semlines, radius)
extremities = get_line_extremities(skeletons, maxdists, args.resolution_width, args.resolution_height, args.num_points_lines, args.num_points_circles)
prediction = extremities
count += 1
prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json")
with open(prediction_file, "w") as f:
json.dump(prediction, f, indent=4)
|