|
""" |
|
DataLoader used to train the segmentation network used for the prediction of extremities. |
|
""" |
|
|
|
import json |
|
import os |
|
import time |
|
from argparse import ArgumentParser |
|
|
|
import cv2 as cv |
|
import numpy as np |
|
from torch.utils.data import Dataset |
|
from tqdm import tqdm |
|
|
|
from SoccerNet.Evaluation.utils_calibration import SoccerPitch |
|
|
|
|
|
class SoccerNetDataset(Dataset): |
|
def __init__(self, |
|
datasetpath, |
|
split="test", |
|
width=640, |
|
height=360, |
|
mean="../resources/mean.npy", |
|
std="../resources/std.npy"): |
|
self.mean = np.load(mean) |
|
self.std = np.load(std) |
|
self.width = width |
|
self.height = height |
|
|
|
dataset_dir = os.path.join(datasetpath, split) |
|
if not os.path.exists(dataset_dir): |
|
print("Invalid dataset path !") |
|
exit(-1) |
|
|
|
frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f] |
|
|
|
self.data = [] |
|
self.n_samples = 0 |
|
for frame in frames: |
|
|
|
frame_index = frame.split(".")[0] |
|
annotation_file = os.path.join(dataset_dir, f"{frame_index}.json") |
|
if not os.path.exists(annotation_file): |
|
continue |
|
with open(annotation_file, "r") as f: |
|
groundtruth_lines = json.load(f) |
|
img_path = os.path.join(dataset_dir, frame) |
|
if groundtruth_lines: |
|
self.data.append({ |
|
"image_path": img_path, |
|
"annotations": groundtruth_lines, |
|
}) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
item = self.data[index] |
|
|
|
img = cv.imread(item["image_path"]) |
|
img = cv.resize(img, (self.width, self.height), interpolation=cv.INTER_LINEAR) |
|
|
|
mask = np.zeros(img.shape[:-1], dtype=np.uint8) |
|
img = np.asarray(img, np.float32) / 255. |
|
img -= self.mean |
|
img /= self.std |
|
img = img.transpose((2, 0, 1)) |
|
for class_number, class_ in enumerate(SoccerPitch.lines_classes): |
|
if class_ in item["annotations"].keys(): |
|
key = class_ |
|
line = item["annotations"][key] |
|
prev_point = line[0] |
|
for i in range(1, len(line)): |
|
next_point = line[i] |
|
cv.line(mask, |
|
(int(prev_point["x"] * mask.shape[1]), int(prev_point["y"] * mask.shape[0])), |
|
(int(next_point["x"] * mask.shape[1]), int(next_point["y"] * mask.shape[0])), |
|
class_number + 1, |
|
2) |
|
prev_point = next_point |
|
return img, mask |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = ArgumentParser(description='dataloader') |
|
|
|
parser.add_argument('--SoccerNet_path', default="./annotations/", type=str, |
|
help='Path to the SoccerNet-V3 dataset folder') |
|
parser.add_argument('--tiny', required=False, type=int, default=None, help='Select a subset of x games') |
|
parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data') |
|
parser.add_argument('--num_workers', required=False, type=int, default=4, |
|
help='number of workers for the dataloader') |
|
parser.add_argument('--resolution_width', required=False, type=int, default=1920, |
|
help='width resolution of the images') |
|
parser.add_argument('--resolution_height', required=False, type=int, default=1080, |
|
help='height resolution of the images') |
|
parser.add_argument('--preload_images', action='store_true', |
|
help="Preload the images when constructing the dataset") |
|
parser.add_argument('--zipped_images', action='store_true', help="Read images from zipped folder") |
|
|
|
args = parser.parse_args() |
|
|
|
start_time = time.time() |
|
soccernet = SoccerNetDataset(args.SoccerNet_path, split=args.split) |
|
with tqdm(enumerate(soccernet), total=len(soccernet), ncols=160) as t: |
|
for i, data in t: |
|
img = soccernet[i][0].astype(np.uint8).transpose((1, 2, 0)) |
|
print(img.shape) |
|
print(img.dtype) |
|
cv.imshow("Normalized image", img) |
|
cv.waitKey(0) |
|
cv.destroyAllWindows() |
|
print(data[1].shape) |
|
cv.imshow("Mask", soccernet[i][1].astype(np.uint8)) |
|
cv.waitKey(0) |
|
cv.destroyAllWindows() |
|
continue |
|
end_time = time.time() |
|
print(end_time - start_time) |
|
|