|
import pickle |
|
import torch |
|
from torch.utils.data import Dataset |
|
import torchvision |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
import cv2 |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
import h5py |
|
from tqdm.auto import tqdm |
|
from collections import defaultdict |
|
from pathlib import Path |
|
import json |
|
import os |
|
from argparse import ArgumentParser |
|
|
|
lines_classes = [ |
|
"Big rect. left bottom", |
|
"Big rect. left main", |
|
"Big rect. left top", |
|
"Big rect. right bottom", |
|
"Big rect. right main", |
|
"Big rect. right top", |
|
"Circle central", |
|
"Circle left", |
|
"Circle right", |
|
|
|
|
|
|
|
|
|
|
|
|
|
"Goal unknown", |
|
"Line unknown", |
|
"Middle line", |
|
"Side line bottom", |
|
"Side line left", |
|
"Side line right", |
|
"Side line top", |
|
"Small rect. left bottom", |
|
"Small rect. left main", |
|
"Small rect. left top", |
|
"Small rect. right bottom", |
|
"Small rect. right main", |
|
"Small rect. right top", |
|
] |
|
|
|
|
|
palette = { |
|
"Big rect. left bottom": (127, 0, 0), |
|
"Big rect. left main": (102, 102, 102), |
|
"Big rect. left top": (0, 0, 127), |
|
"Big rect. right bottom": (86, 32, 39), |
|
"Big rect. right main": (48, 77, 0), |
|
"Big rect. right top": (14, 97, 100), |
|
"Circle central": (0, 0, 255), |
|
"Circle left": (255, 127, 0), |
|
"Circle right": (0, 255, 255), |
|
|
|
|
|
|
|
|
|
|
|
|
|
"Goal unknown": (0, 0, 0), |
|
"Line unknown": (0, 0, 0), |
|
"Middle line": (255, 255, 0), |
|
"Side line bottom": (255, 0, 255), |
|
"Side line left": (0, 255, 150), |
|
"Side line right": (0, 230, 0), |
|
"Side line top": (230, 0, 0), |
|
"Small rect. left bottom": (0, 150, 255), |
|
"Small rect. left main": (254, 173, 225), |
|
"Small rect. left top": (87, 72, 39), |
|
"Small rect. right bottom": (122, 0, 255), |
|
"Small rect. right main": (255, 255, 255), |
|
"Small rect. right top": (153, 23, 153), |
|
} |
|
|
|
|
|
def create_target_from_annotation(width, height, annotation, classes, linewidth=4): |
|
"""Draw one-hot encoded segments according to the annotation. |
|
Creates target that matches image size ([C+1]xHxW). |
|
""" |
|
annotation_abs = defaultdict(list) |
|
|
|
for k in annotation.keys(): |
|
if k not in lines_classes: |
|
continue |
|
start = annotation[k][0].copy() |
|
end = annotation[k][-1].copy() |
|
for annotation_point in annotation[k]: |
|
tup = annotation_point.copy() |
|
tup["x"] *= width |
|
tup["x"] = int(tup["x"]) |
|
tup["y"] *= height |
|
tup["y"] = int(tup["y"]) |
|
annotation_abs[k].append(tup) |
|
|
|
|
|
|
|
|
|
classes_segments = np.zeros(shape=(len(classes) + 1, height, width)) |
|
for cls, points in annotation_abs.items(): |
|
class_segments = np.zeros(shape=(height, width, 3)) |
|
for start, end in zip(points, points[1:]): |
|
startxy = (start["x"], start["y"]) |
|
endxy = [end["x"], end["y"]] |
|
class_segments = cv2.line( |
|
class_segments, startxy, endxy, (1, 1, 1), linewidth |
|
) |
|
classes_segments[classes.index(cls) + 1] = class_segments[:, :, 1] |
|
|
|
classes_segments = torch.Tensor(classes_segments) |
|
return classes_segments |
|
|
|
|
|
class ExtremitiesDataset(Dataset): |
|
def __init__( |
|
self, root, split, annotations, filter_cam=None, extremities_prefix="", classes=lines_classes, palette=palette |
|
): |
|
self.data_root = Path(root) |
|
self.split = split |
|
|
|
self.annotations_path = annotations |
|
|
|
if filter_cam is None: |
|
files = os.listdir(self.data_root / self.split) |
|
self.annotations = sorted([fn for fn in files if fn.endswith("json")]) |
|
self.images = sorted([fn for fn in files if fn.endswith("jpg")]) |
|
else: |
|
df = pd.read_json(self.data_root / self.split / "match_info_cam_gt.json").T |
|
df = df.loc[df.camera == filter_cam] |
|
assert len(df.index) > 0 |
|
df["image_file"] = df.index |
|
df = df.sort_values(by=["image_file"]) |
|
df["annotation_file"] = df["image_file"].apply( |
|
lambda s: extremities_prefix + s.split(".jpg")[0] + ".json" |
|
) |
|
self.annotations = df["annotation_file"].tolist() |
|
self.images = df["image_file"].tolist() |
|
|
|
self.classes = classes |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
impath = self.data_root / self.split / self.images[idx] |
|
annotation_path = self.annotations_path / self.annotations[idx] |
|
with open(annotation_path, "r") as f: |
|
annotation = json.load(f) |
|
|
|
img = Image.open(impath) |
|
trf = T.Compose( |
|
[ |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
img = trf(img) |
|
|
|
|
|
|
|
target = create_target_from_annotation(1280, 720, annotation, self.classes) |
|
target = target.long().argmax(dim=0).unsqueeze(0) |
|
|
|
target = target.bool().float() |
|
|
|
|
|
target = torchvision.transforms.Resize((180, 320))(target) |
|
|
|
|
|
target = (target * 255.0).to(torch.uint8) |
|
|
|
return img, target, impath.name |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
args = ArgumentParser() |
|
args.add_argument("--data_dir", type=Path) |
|
args.add_argument("--annotations", type=Path) |
|
args.add_argument("--output_dir", type=Path) |
|
args.add_argument("--extremities_prefix", type=str, default="") |
|
args = args.parse_args() |
|
|
|
data_dir = args.data_dir.parent |
|
split = args.data_dir.name |
|
output_dir = args.output_dir |
|
if not output_dir.exists(): |
|
raise FileNotFoundError |
|
|
|
dataset = ExtremitiesDataset(data_dir, split, args.annotations, filter_cam="Main camera center", extremities_prefix=args.extremities_prefix) |
|
|
|
|
|
|
|
|
|
|
|
image_src = [] |
|
edge_maps = np.zeros((len(dataset), 1, 180, 320), dtype=np.uint8) |
|
for i, (_, edge_map, img_id) in enumerate(tqdm(dataset)): |
|
edge_map = edge_map.numpy() |
|
|
|
edge_maps[i] = edge_map |
|
image_src.append(img_id) |
|
|
|
with h5py.File(output_dir / "seg_edge_maps.h5", "w") as f: |
|
f.create_dataset("edge_map", data=edge_maps) |
|
|
|
with open(output_dir / "seg_image_paths.pkl", "wb") as f: |
|
pickle.dump(image_src, f) |
|
|