# Copyright (c) OpenMMLab. All rights reserved. # written by lzx import copy import os.path as osp from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from mmengine.fileio import get_local_path from mmdet.registry import DATASETS from mmdet.datasets.api_wrappers import COCO from mmdet.datasets.base_det_dataset import BaseDetDataset from mmdet.datasets.coco import CocoDataset from mmengine.utils import is_abs @DATASETS.register_module() class HSIDataset(CocoDataset): """Dataset for COCO.""" METAINFO = { 'classes': ('CB', 'MP', 'VO', 'ZO', 'TO', 'FG', 'GS', 'IP', 'IS', 'NP', 'LO', 'NO', 'NC', 'NF', 'K_N', 'K_O', 'P_P', 'P_O', 'V_Y_W', 'C_Y_W','BlueTrap','BrownTrap', 'Airport', 'Brown', 'DarkGreen', 'PeaGreen', 'FauxVineyardGreen'), # palette is a list of color tuples, which is used for visualization. 'palette': [(220, 20, 60), (119, 11, 32), (0, 0, 230), (106, 0, 228), (0, 60, 100), (0, 0, 70), (250, 170, 30), (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),(0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),(110, 76, 0), (174, 57, 255), (199, 100, 0),[0, 0, 255],(199, 100, 0), ] } COCOAPI = COCO # ann_id is unique in coco dataset. ANN_ID_UNIQUE = True def __init__(self, *args, seg_prefix: Optional[str] = None, abu_prefix: Optional[str] = None, **kwargs) -> None: self.seg_prefix = seg_prefix self.abu_prefix = abu_prefix super().__init__(*args, **kwargs) def _join_prefix(self): """Join ``self.data_root`` with ``self.data_prefix`` and ``self.ann_file``. Examples: >>> # self.data_prefix contains relative paths >>> self.data_root = 'a/b/c' >>> self.data_prefix = dict(img='d/e/') >>> self.ann_file = 'f' >>> self._join_prefix() >>> self.data_prefix dict(img='a/b/c/d/e') >>> self.ann_file 'a/b/c/f' >>> # self.data_prefix contains absolute paths >>> self.data_root = 'a/b/c' >>> self.data_prefix = dict(img='/d/e/') >>> self.ann_file = 'f' >>> self._join_prefix() >>> self.data_prefix dict(img='/d/e') >>> self.ann_file 'a/b/c/f' """ # Automatically join annotation file path with `self.root` if # `self.ann_file` is not an absolute path. if not is_abs(self.ann_file) and self.ann_file: self.ann_file = osp.join(self.data_root, self.ann_file) # Automatically join data directory with `self.root` if path value in # `self.data_prefix` is not an absolute path. for data_key, prefix in self.data_prefix.items(): if isinstance(prefix, str): if not is_abs(prefix): self.data_prefix[data_key] = osp.join( self.data_root, prefix) else: self.data_prefix[data_key] = prefix else: raise TypeError('prefix should be a string, but got ' f'{type(prefix)}') if self.seg_prefix is not None: for data_key, prefix in self.seg_prefix.items(): if isinstance(prefix, str): if not is_abs(prefix): self.seg_prefix[data_key] = osp.join( self.data_root, prefix) else: self.seg_prefix[data_key] = prefix else: raise TypeError('prefix should be a string, but got ' f'{type(prefix)}') if self.abu_prefix is not None: for data_key, prefix in self.abu_prefix.items(): if isinstance(prefix, str): if not is_abs(prefix): self.abu_prefix[data_key] = osp.join( self.data_root, prefix) else: self.abu_prefix[data_key] = prefix else: raise TypeError('prefix should be a string, but got ' f'{type(prefix)}') def load_data_list(self) -> List[dict]: """Load annotations from an annotation file named as ``self.ann_file`` Returns: List[dict]: A list of annotation. """ # noqa: E501 with get_local_path( self.ann_file, backend_args=self.backend_args) as local_path: self.coco = self.COCOAPI(local_path) # The order of returned `cat_ids` will not # change with the order of the `classes` self.cat_ids = self.coco.get_cat_ids( cat_names=self.metainfo['classes']) self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) img_ids = self.coco.get_img_ids() data_list = [] total_ann_ids = [] for img_id in img_ids: raw_img_info = self.coco.load_imgs([img_id])[0] raw_img_info['img_id'] = img_id ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) raw_ann_info = self.coco.load_anns(ann_ids) total_ann_ids.extend(ann_ids) parsed_data_info = self.parse_data_info({ 'raw_ann_info': raw_ann_info, 'raw_img_info': raw_img_info }) data_list.append(parsed_data_info) if self.ANN_ID_UNIQUE: assert len(set(total_ann_ids)) == len( total_ann_ids ), f"Annotation ids in '{self.ann_file}' are not unique!" del self.coco return data_list def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: """Parse raw annotation to target format. Args: raw_data_info (dict): Raw data information load from ``ann_file`` Returns: Union[dict, List[dict]]: Parsed annotation. """ img_info = raw_data_info['raw_img_info'] ann_info = raw_data_info['raw_ann_info'] data_info = {} # TODO: need to change data_prefix['img'] to data_prefix['img_path'] img_path = osp.join(self.data_prefix['img'], img_info['file_name']) if self.data_prefix.get('seg', None): seg_map_path = osp.join( self.data_prefix['seg'], img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix) else: seg_map_path = None # if self.seg_prefix is not None: if self.seg_prefix is not None: seg_path = osp.join(self.seg_prefix['img'], img_info['file_name']).replace('.npy', '.png') else: seg_path = None if self.abu_prefix is not None: abu_path = osp.join(self.abu_prefix['img'], img_info['file_name']).replace('.npy', '.mat') else: abu_path = None data_info['img_path'] = img_path data_info['img_id'] = img_info['img_id'] data_info['seg_map_path'] = seg_map_path data_info['seg_path'] = seg_path data_info['abu_path'] = abu_path data_info['height'] = img_info['height'] data_info['width'] = img_info['width'] instances = [] for i, ann in enumerate(ann_info): instance = {} if ann.get('ignore', False): continue x1, y1, w, h = ann['bbox'] inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) if inter_w * inter_h == 0: continue if ann['area'] <= 0 or w < 1 or h < 1: continue if ann['category_id'] not in self.cat_ids: continue bbox = [x1, y1, x1 + w, y1 + h] if ann.get('iscrowd', False): instance['ignore_flag'] = 1 else: instance['ignore_flag'] = 0 instance['bbox'] = bbox instance['bbox_label'] = self.cat2label[ann['category_id']] if ann.get('segmentation', None): instance['mask'] = ann['segmentation'] instances.append(instance) data_info['instances'] = instances return data_info def filter_data(self) -> List[dict]: """Filter annotations according to filter_cfg. Returns: List[dict]: Filtered results. """ if self.test_mode: return self.data_list if self.filter_cfg is None: return self.data_list filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) min_size = self.filter_cfg.get('min_size', 0) # obtain images that contain annotation ids_with_ann = set(data_info['img_id'] for data_info in self.data_list) # obtain images that contain annotations of the required categories ids_in_cat = set() for i, class_id in enumerate(self.cat_ids): ids_in_cat |= set(self.cat_img_map[class_id]) # merge the image id sets of the two conditions and use the merged set # to filter out images if self.filter_empty_gt=True ids_in_cat &= ids_with_ann valid_data_infos = [] for i, data_info in enumerate(self.data_list): img_id = data_info['img_id'] width = data_info['width'] height = data_info['height'] if filter_empty_gt and img_id not in ids_in_cat: continue if min(width, height) >= min_size: valid_data_infos.append(data_info) return valid_data_infos # @DATASETS.register_module() # class HSIDataset16(HSIDataset): # """Dataset for COCO.""" # # METAINFO = { # 'classes': # ( 'TO', 'FG', 'GS', 'IP', 'IS', 'NP', 'LO', 'NO', 'NC', 'NF', 'K_N', 'K_O', 'P_P', 'P_O', 'V_Y_W', 'C_Y_W'), # # palette is a list of color tuples, which is used for visualization. # 'palette': # [ # (0, 60, 100), (0, 0, 70), (250, 170, 30), # (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), # (165, 42, 42), (255, 77, 255),(0, 226, 252), (182, 182, 255), # (0, 82, 0), (120, 166, 157),(110, 76, 0), (174, 57, 255), # (199, 100, 0),] # } # COCOAPI = COCO # # ann_id is unique in coco dataset. # ANN_ID_UNIQUE = True