File size: 4,635 Bytes
bdb955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
DataLoader used to train the segmentation network used for the prediction of extremities.
"""

import json
import os
import time
from argparse import ArgumentParser

import cv2 as cv
import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm

from SoccerNet.Evaluation.utils_calibration import SoccerPitch


class SoccerNetDataset(Dataset):
    def __init__(self,
                 datasetpath,
                 split="test",
                 width=640,
                 height=360,
                 mean="../resources/mean.npy",
                 std="../resources/std.npy"):
        self.mean = np.load(mean)
        self.std = np.load(std)
        self.width = width
        self.height = height

        dataset_dir = os.path.join(datasetpath, split)
        if not os.path.exists(dataset_dir):
            print("Invalid dataset path !")
            exit(-1)

        frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]

        self.data = []
        self.n_samples = 0
        for frame in frames:

            frame_index = frame.split(".")[0]
            annotation_file = os.path.join(dataset_dir, f"{frame_index}.json")
            if not os.path.exists(annotation_file):
                continue
            with open(annotation_file, "r") as f:
                groundtruth_lines = json.load(f)
            img_path = os.path.join(dataset_dir, frame)
            if groundtruth_lines:
                self.data.append({
                    "image_path": img_path,
                    "annotations": groundtruth_lines,
                })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]

        img = cv.imread(item["image_path"])
        img = cv.resize(img, (self.width, self.height), interpolation=cv.INTER_LINEAR)

        mask = np.zeros(img.shape[:-1], dtype=np.uint8)
        img = np.asarray(img, np.float32) / 255.
        img -= self.mean
        img /= self.std
        img = img.transpose((2, 0, 1))
        for class_number, class_ in enumerate(SoccerPitch.lines_classes):
            if class_ in item["annotations"].keys():
                key = class_
                line = item["annotations"][key]
                prev_point = line[0]
                for i in range(1, len(line)):
                    next_point = line[i]
                    cv.line(mask,
                            (int(prev_point["x"] * mask.shape[1]), int(prev_point["y"] * mask.shape[0])),
                            (int(next_point["x"] * mask.shape[1]), int(next_point["y"] * mask.shape[0])),
                            class_number + 1,
                            2)
                    prev_point = next_point
        return img, mask


if __name__ == "__main__":

    # Load the arguments
    parser = ArgumentParser(description='dataloader')

    parser.add_argument('--SoccerNet_path', default="./annotations/", type=str,
                        help='Path to the SoccerNet-V3 dataset folder')
    parser.add_argument('--tiny', required=False, type=int, default=None, help='Select a subset of x games')
    parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
    parser.add_argument('--num_workers', required=False, type=int, default=4,
                        help='number of workers for the dataloader')
    parser.add_argument('--resolution_width', required=False, type=int, default=1920,
                        help='width resolution of the images')
    parser.add_argument('--resolution_height', required=False, type=int, default=1080,
                        help='height resolution of the images')
    parser.add_argument('--preload_images', action='store_true',
                        help="Preload the images when constructing the dataset")
    parser.add_argument('--zipped_images', action='store_true', help="Read images from zipped folder")

    args = parser.parse_args()

    start_time = time.time()
    soccernet = SoccerNetDataset(args.SoccerNet_path, split=args.split)
    with tqdm(enumerate(soccernet), total=len(soccernet), ncols=160) as t:
        for i, data in t:
            img = soccernet[i][0].astype(np.uint8).transpose((1, 2, 0))
            print(img.shape)
            print(img.dtype)
            cv.imshow("Normalized image", img)
            cv.waitKey(0)
            cv.destroyAllWindows()
            print(data[1].shape)
            cv.imshow("Mask", soccernet[i][1].astype(np.uint8))
            cv.waitKey(0)
            cv.destroyAllWindows()
            continue
    end_time = time.time()
    print(end_time - start_time)