|
|
|
import json |
|
import logging |
|
import os.path as osp |
|
import warnings |
|
from typing import List, Union |
|
|
|
import mmcv |
|
from mmengine.dist import get_rank |
|
from mmengine.fileio import dump, get, get_text, load |
|
from mmengine.logging import print_log |
|
from mmengine.utils import ProgressBar |
|
|
|
from mmdet.registry import DATASETS |
|
from .base_det_dataset import BaseDetDataset |
|
|
|
|
|
@DATASETS.register_module() |
|
class CrowdHumanDataset(BaseDetDataset): |
|
r"""Dataset for CrowdHuman. |
|
|
|
Args: |
|
data_root (str): The root directory for |
|
``data_prefix`` and ``ann_file``. |
|
ann_file (str): Annotation file path. |
|
extra_ann_file (str | optional):The path of extra image metas |
|
for CrowdHuman. It can be created by CrowdHumanDataset |
|
automatically or by tools/misc/get_crowdhuman_id_hw.py |
|
manually. Defaults to None. |
|
""" |
|
|
|
METAINFO = { |
|
'classes': ('person', ), |
|
|
|
'palette': [(220, 20, 60)] |
|
} |
|
|
|
def __init__(self, data_root, ann_file, extra_ann_file=None, **kwargs): |
|
|
|
|
|
|
|
if extra_ann_file is not None: |
|
self.extra_ann_exist = True |
|
self.extra_anns = load(extra_ann_file) |
|
else: |
|
ann_file_name = osp.basename(ann_file) |
|
if 'train' in ann_file_name: |
|
self.extra_ann_file = osp.join(data_root, 'id_hw_train.json') |
|
elif 'val' in ann_file_name: |
|
self.extra_ann_file = osp.join(data_root, 'id_hw_val.json') |
|
self.extra_ann_exist = False |
|
if not osp.isfile(self.extra_ann_file): |
|
print_log( |
|
'extra_ann_file does not exist, prepare to collect ' |
|
'image height and width...', |
|
level=logging.INFO) |
|
self.extra_anns = {} |
|
else: |
|
self.extra_ann_exist = True |
|
self.extra_anns = load(self.extra_ann_file) |
|
super().__init__(data_root=data_root, ann_file=ann_file, **kwargs) |
|
|
|
def load_data_list(self) -> List[dict]: |
|
"""Load annotations from an annotation file named as ``self.ann_file`` |
|
|
|
Returns: |
|
List[dict]: A list of annotation. |
|
""" |
|
anno_strs = get_text( |
|
self.ann_file, backend_args=self.backend_args).strip().split('\n') |
|
print_log('loading CrowdHuman annotation...', level=logging.INFO) |
|
data_list = [] |
|
prog_bar = ProgressBar(len(anno_strs)) |
|
for i, anno_str in enumerate(anno_strs): |
|
anno_dict = json.loads(anno_str) |
|
parsed_data_info = self.parse_data_info(anno_dict) |
|
data_list.append(parsed_data_info) |
|
prog_bar.update() |
|
if not self.extra_ann_exist and get_rank() == 0: |
|
|
|
try: |
|
dump(self.extra_anns, self.extra_ann_file, file_format='json') |
|
except: |
|
warnings.warn( |
|
'Cache files can not be saved automatically! To speed up' |
|
'loading the dataset, please manually generate the cache' |
|
' file by file tools/misc/get_crowdhuman_id_hw.py') |
|
|
|
print_log( |
|
f'\nsave extra_ann_file in {self.data_root}', |
|
level=logging.INFO) |
|
|
|
del self.extra_anns |
|
print_log('\nDone', level=logging.INFO) |
|
return data_list |
|
|
|
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: |
|
"""Parse raw annotation to target format. |
|
|
|
Args: |
|
raw_data_info (dict): Raw data information load from ``ann_file`` |
|
|
|
Returns: |
|
Union[dict, List[dict]]: Parsed annotation. |
|
""" |
|
data_info = {} |
|
img_path = osp.join(self.data_prefix['img'], |
|
f"{raw_data_info['ID']}.jpg") |
|
data_info['img_path'] = img_path |
|
data_info['img_id'] = raw_data_info['ID'] |
|
|
|
if not self.extra_ann_exist: |
|
img_bytes = get(img_path, backend_args=self.backend_args) |
|
img = mmcv.imfrombytes(img_bytes, backend='cv2') |
|
data_info['height'], data_info['width'] = img.shape[:2] |
|
self.extra_anns[raw_data_info['ID']] = img.shape[:2] |
|
del img, img_bytes |
|
else: |
|
data_info['height'], data_info['width'] = self.extra_anns[ |
|
raw_data_info['ID']] |
|
|
|
instances = [] |
|
for i, ann in enumerate(raw_data_info['gtboxes']): |
|
instance = {} |
|
if ann['tag'] not in self.metainfo['classes']: |
|
instance['bbox_label'] = -1 |
|
instance['ignore_flag'] = 1 |
|
else: |
|
instance['bbox_label'] = self.metainfo['classes'].index( |
|
ann['tag']) |
|
instance['ignore_flag'] = 0 |
|
if 'extra' in ann: |
|
if 'ignore' in ann['extra']: |
|
if ann['extra']['ignore'] != 0: |
|
instance['bbox_label'] = -1 |
|
instance['ignore_flag'] = 1 |
|
|
|
x1, y1, w, h = ann['fbox'] |
|
bbox = [x1, y1, x1 + w, y1 + h] |
|
instance['bbox'] = bbox |
|
|
|
|
|
|
|
|
|
|
|
instance['fbox'] = bbox |
|
hbox = ann['hbox'] |
|
instance['hbox'] = [ |
|
hbox[0], hbox[1], hbox[0] + hbox[2], hbox[1] + hbox[3] |
|
] |
|
vbox = ann['vbox'] |
|
instance['vbox'] = [ |
|
vbox[0], vbox[1], vbox[0] + vbox[2], vbox[1] + vbox[3] |
|
] |
|
|
|
instances.append(instance) |
|
|
|
data_info['instances'] = instances |
|
return data_info |
|
|