RamziBm's picture
init
bdb955e
import argparse
import copy
import json
import os.path
import random
from collections import deque
from pathlib import Path
import cv2 as cv
import numpy as np
import torch
import torch.backends.cudnn
import torch.nn as nn
from PIL import Image
from torchvision.models.segmentation import deeplabv3_resnet50
from tqdm import tqdm
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
def generate_class_synthesis(semantic_mask, radius):
"""
This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic
class blobs.
:param semantic_mask: a image containing the segmentation predictions
:param radius: circle radius
:return: a dictionary which associates with each class detected a list of points ( the circles centers)
"""
buckets = dict()
kernel = np.ones((5, 5), np.uint8)
semantic_mask = cv.erode(semantic_mask, kernel, iterations=1)
for k, class_name in enumerate(SoccerPitch.lines_classes):
mask = semantic_mask == k + 1
if mask.sum() > 0:
disk_list = synthesize_mask(mask, radius)
if len(disk_list):
buckets[class_name] = disk_list
return buckets
def join_points(point_list, maxdist):
"""
Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates
polylines by linking close points together if their distance is below the maxdist threshold.
:param point_list: List of points of the same line class
:param maxdist: minimal distance between two polylines.
:return: a list of polylines
"""
polylines = []
if not len(point_list):
return polylines
head = point_list[0]
tail = point_list[0]
polyline = deque()
polyline.append(point_list[0])
remaining_points = copy.deepcopy(point_list[1:])
while len(remaining_points) > 0:
min_dist_tail = 1000
min_dist_head = 1000
best_head = -1
best_tail = -1
for j, point in enumerate(remaining_points):
dist_tail = np.sqrt(np.sum(np.square(point - tail)))
dist_head = np.sqrt(np.sum(np.square(point - head)))
if dist_tail < min_dist_tail:
min_dist_tail = dist_tail
best_tail = j
if dist_head < min_dist_head:
min_dist_head = dist_head
best_head = j
if min_dist_head <= min_dist_tail and min_dist_head < maxdist:
polyline.appendleft(remaining_points[best_head])
head = polyline[0]
remaining_points.pop(best_head)
elif min_dist_tail < min_dist_head and min_dist_tail < maxdist:
polyline.append(remaining_points[best_tail])
tail = polyline[-1]
remaining_points.pop(best_tail)
else:
polylines.append(list(polyline.copy()))
head = remaining_points[0]
tail = remaining_points[0]
polyline = deque()
polyline.append(head)
remaining_points.pop(0)
polylines.append(list(polyline))
return polylines
def get_line_extremities(buckets, maxdist, width, height):
"""
Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities
of the longest polyline that can be built on the class blobs, and normalize its coordinates
by the image size.
:param buckets: The dictionary associating line classes to the set of circle centers that covers best the class
prediction blobs in the segmentation mask
:param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic)
:param width: image width
:param height: image height
:return: a dictionary associating to each class its extremities
"""
extremities = dict()
for class_name, disks_list in buckets.items():
polyline_list = join_points(disks_list, maxdist)
max_len = 0
longest_polyline = []
for polyline in polyline_list:
if len(polyline) > max_len:
max_len = len(polyline)
longest_polyline = polyline
extremities[class_name] = [
{'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height},
{'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height}
]
return extremities
def get_support_center(mask, start, disk_radius, min_support=0.1):
"""
Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and
radius of disk_radius pixels.
:param mask: Boolean mask
:param start: A point located on a true pixel of the mask
:param disk_radius: the radius of the circles
:param min_support: proportion of the area under the circle area that should be True in order to get enough support
:return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under
the circle
"""
x = int(start[0])
y = int(start[1])
support_pixels = 1
result = [x, y]
xstart = x - disk_radius
if xstart < 0:
xstart = 0
xend = x + disk_radius
if xend > mask.shape[0]:
xend = mask.shape[0] - 1
ystart = y - disk_radius
if ystart < 0:
ystart = 0
yend = y + disk_radius
if yend > mask.shape[1]:
yend = mask.shape[1] - 1
for i in range(xstart, xend + 1):
for j in range(ystart, yend + 1):
dist = np.sqrt(np.square(x - i) + np.square(y - j))
if dist < disk_radius and mask[i, j] > 0:
support_pixels += 1
result[0] += i
result[1] += j
support = True
if support_pixels < min_support * np.square(disk_radius) * np.pi:
support = False
result = np.array(result)
result = np.true_divide(result, support_pixels)
return support, result
def synthesize_mask(semantic_mask, disk_radius):
"""
Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the
proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid
fitting circles on alone pixels.
:param semantic_mask: boolean mask
:param disk_radius: radius of the circles
:return: a list of disk centers, that have enough support
"""
mask = semantic_mask.copy().astype(np.uint8)
points = np.transpose(np.nonzero(mask))
disks = []
while len(points):
start = random.choice(points)
dist = 10.
success = True
while dist > 1.:
enough_support, center = get_support_center(mask, start, disk_radius)
if not enough_support:
bad_point = np.round(center).astype(np.int32)
cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1)
success = False
dist = np.sqrt(np.sum(np.square(center - start)))
start = center
if success:
disks.append(np.round(start).astype(np.int32))
cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1)
points = np.transpose(np.nonzero(mask))
return disks
class SegmentationNetwork:
def __init__(self, model_file, mean_file, std_file, num_classes=29, width=640, height=360):
file_path = Path(model_file).resolve()
model = nn.DataParallel(deeplabv3_resnet50(pretrained=False, num_classes=num_classes))
self.init_weight(model, nn.init.kaiming_normal_,
nn.BatchNorm2d, 1e-3, 0.1,
mode='fan_in')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(str(file_path), map_location=self.device)
model.load_state_dict(checkpoint["model"])
model.eval()
self.model = model.to(self.device)
file_path = Path(mean_file).resolve()
self.mean = np.load(str(file_path))
file_path = Path(std_file).resolve()
self.std = np.load(str(file_path))
self.width = width
self.height = height
def init_weight(self, feature, conv_init, norm_layer, bn_eps, bn_momentum,
**kwargs):
for name, m in feature.named_modules():
if isinstance(m, (nn.Conv2d, nn.Conv3d)):
conv_init(m.weight, **kwargs)
elif isinstance(m, norm_layer):
m.eps = bn_eps
m.momentum = bn_momentum
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def analyse_image(self, image):
"""
Process image and perform inference, returns mask of detected classes
:param image: BGR image
:return: predicted classes mask
"""
img = cv.resize(image, (self.width, self.height), interpolation=cv.INTER_LINEAR)
img = np.asarray(img, np.float32) / 255.
img = (img - self.mean) / self.std
img = img.transpose((2, 0, 1))
img = torch.from_numpy(img).to(self.device).unsqueeze(0)
cuda_result = self.model.forward(img.float())
output = cuda_result['out'].data[0].cpu().numpy()
output = output.transpose(1, 2, 0)
output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test')
parser.add_argument('-s', '--soccernet', default="./annotations/", type=str,
help='Path to the SoccerNet-V3 dataset folder')
parser.add_argument('-p', '--prediction', default="./results_bis", required=False, type=str,
help="Path to the prediction folder")
parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory')
parser.add_argument('--resolution_width', required=False, type=int, default=455,
help='width resolution of the images')
parser.add_argument('--resolution_height', required=False, type=int, default=256,
help='height resolution of the images')
parser.add_argument('--checkpoint_dir', default="resources")
args = parser.parse_args()
lines_palette = [0, 0, 0]
for line_class in SoccerPitch.lines_classes:
lines_palette.extend(SoccerPitch.palette[line_class])
calib_net = SegmentationNetwork(
os.path.join(args.checkpoint_dir, "soccer_pitch_segmentation.pth"),
os.path.join(args.checkpoint_dir, "mean.npy"),
os.path.join(args.checkpoint_dir, "std.npy")
)
dataset_dir = os.path.join(args.soccernet, args.split)
if not os.path.exists(dataset_dir):
print("Invalid dataset path !")
exit(-1)
frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
with tqdm(enumerate(frames), total=len(frames), ncols=160) as t:
for i, frame in t:
output_prediction_folder = os.path.join(args.prediction, args.split)
if not os.path.exists(output_prediction_folder):
os.makedirs(output_prediction_folder)
prediction = dict()
count = 0
frame_path = os.path.join(dataset_dir, frame)
frame_index = frame.split(".")[0]
image = cv.imread(frame_path)
semlines = calib_net.analyse_image(image)
if args.masks:
mask = Image.fromarray(semlines.astype(np.uint8)).convert('P')
mask.putpalette(lines_palette)
mask_file = os.path.join(output_prediction_folder, frame)
mask.convert("RGB").save(mask_file)
skeletons = generate_class_synthesis(semlines, 6)
extremities = get_line_extremities(skeletons, 40, args.resolution_width, args.resolution_height)
prediction = extremities
count += 1
prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json")
with open(prediction_file, "w") as f:
json.dump(prediction, f, indent=4)