|
|
|
import copy |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule, Scale |
|
from mmengine.config import ConfigDict |
|
from mmengine.model import BaseModule, kaiming_init |
|
from mmengine.structures import InstanceData |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from mmdet.structures.bbox import cat_boxes |
|
from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, |
|
OptInstanceList, reduce_mean) |
|
from ..task_modules.prior_generators import MlvlPointGenerator |
|
from ..utils import (aligned_bilinear, filter_scores_and_topk, multi_apply, |
|
relative_coordinate_maps, select_single_mlvl) |
|
from ..utils.misc import empty_instances |
|
from .base_mask_head import BaseMaskHead |
|
from .fcos_head import FCOSHead |
|
|
|
INF = 1e8 |
|
|
|
|
|
@MODELS.register_module() |
|
class CondInstBboxHead(FCOSHead): |
|
"""CondInst box head used in https://arxiv.org/abs/1904.02689. |
|
|
|
Note that CondInst Bbox Head is a extension of FCOS head. |
|
Two differences are described as follows: |
|
|
|
1. CondInst box head predicts a set of params for each instance. |
|
2. CondInst box head return the pos_gt_inds and pos_inds. |
|
|
|
Args: |
|
num_params (int): Number of params for instance segmentation. |
|
""" |
|
|
|
def __init__(self, *args, num_params: int = 169, **kwargs) -> None: |
|
self.num_params = num_params |
|
super().__init__(*args, **kwargs) |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize layers of the head.""" |
|
super()._init_layers() |
|
self.controller = nn.Conv2d( |
|
self.feat_channels, self.num_params, 3, padding=1) |
|
|
|
def forward_single(self, x: Tensor, scale: Scale, |
|
stride: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
|
"""Forward features of a single scale level. |
|
|
|
Args: |
|
x (Tensor): FPN feature maps of the specified stride. |
|
scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize |
|
the bbox prediction. |
|
stride (int): The corresponding stride for feature maps, only |
|
used to normalize the bbox prediction when self.norm_on_bbox |
|
is True. |
|
|
|
Returns: |
|
tuple: scores for each class, bbox predictions, centerness |
|
predictions and param predictions of input feature maps. |
|
""" |
|
cls_score, bbox_pred, cls_feat, reg_feat = \ |
|
super(FCOSHead, self).forward_single(x) |
|
if self.centerness_on_reg: |
|
centerness = self.conv_centerness(reg_feat) |
|
else: |
|
centerness = self.conv_centerness(cls_feat) |
|
|
|
|
|
bbox_pred = scale(bbox_pred).float() |
|
if self.norm_on_bbox: |
|
|
|
|
|
|
|
bbox_pred = bbox_pred.clamp(min=0) |
|
if not self.training: |
|
bbox_pred *= stride |
|
else: |
|
bbox_pred = bbox_pred.exp() |
|
param_pred = self.controller(reg_feat) |
|
return cls_score, bbox_pred, centerness, param_pred |
|
|
|
def loss_by_feat( |
|
self, |
|
cls_scores: List[Tensor], |
|
bbox_preds: List[Tensor], |
|
centernesses: List[Tensor], |
|
param_preds: List[Tensor], |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
batch_gt_instances_ignore: OptInstanceList = None |
|
) -> Dict[str, Tensor]: |
|
"""Calculate the loss based on the features extracted by the detection |
|
head. |
|
|
|
Args: |
|
cls_scores (list[Tensor]): Box scores for each scale level, |
|
each is a 4D-tensor, the channel number is |
|
num_points * num_classes. |
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
|
level, each is a 4D-tensor, the channel number is |
|
num_points * 4. |
|
centernesses (list[Tensor]): centerness for each scale level, each |
|
is a 4D-tensor, the channel number is num_points * 1. |
|
param_preds (List[Tensor]): param_pred for each scale level, each |
|
is a 4D-tensor, the channel number is num_params. |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
assert len(cls_scores) == len(bbox_preds) == len(centernesses) |
|
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
|
|
|
all_level_points_strides = self.prior_generator.grid_priors( |
|
featmap_sizes, |
|
dtype=bbox_preds[0].dtype, |
|
device=bbox_preds[0].device, |
|
with_stride=True) |
|
all_level_points = [i[:, :2] for i in all_level_points_strides] |
|
all_level_strides = [i[:, 2] for i in all_level_points_strides] |
|
labels, bbox_targets, pos_inds_list, pos_gt_inds_list = \ |
|
self.get_targets(all_level_points, batch_gt_instances) |
|
|
|
num_imgs = cls_scores[0].size(0) |
|
|
|
flatten_cls_scores = [ |
|
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) |
|
for cls_score in cls_scores |
|
] |
|
flatten_bbox_preds = [ |
|
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) |
|
for bbox_pred in bbox_preds |
|
] |
|
flatten_centerness = [ |
|
centerness.permute(0, 2, 3, 1).reshape(-1) |
|
for centerness in centernesses |
|
] |
|
flatten_cls_scores = torch.cat(flatten_cls_scores) |
|
flatten_bbox_preds = torch.cat(flatten_bbox_preds) |
|
flatten_centerness = torch.cat(flatten_centerness) |
|
flatten_labels = torch.cat(labels) |
|
flatten_bbox_targets = torch.cat(bbox_targets) |
|
|
|
flatten_points = torch.cat( |
|
[points.repeat(num_imgs, 1) for points in all_level_points]) |
|
|
|
|
|
bg_class_ind = self.num_classes |
|
pos_inds = ((flatten_labels >= 0) |
|
& (flatten_labels < bg_class_ind)).nonzero().reshape(-1) |
|
num_pos = torch.tensor( |
|
len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) |
|
num_pos = max(reduce_mean(num_pos), 1.0) |
|
loss_cls = self.loss_cls( |
|
flatten_cls_scores, flatten_labels, avg_factor=num_pos) |
|
|
|
pos_bbox_preds = flatten_bbox_preds[pos_inds] |
|
pos_centerness = flatten_centerness[pos_inds] |
|
pos_bbox_targets = flatten_bbox_targets[pos_inds] |
|
pos_centerness_targets = self.centerness_target(pos_bbox_targets) |
|
|
|
centerness_denorm = max( |
|
reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) |
|
|
|
if len(pos_inds) > 0: |
|
pos_points = flatten_points[pos_inds] |
|
pos_decoded_bbox_preds = self.bbox_coder.decode( |
|
pos_points, pos_bbox_preds) |
|
pos_decoded_target_preds = self.bbox_coder.decode( |
|
pos_points, pos_bbox_targets) |
|
loss_bbox = self.loss_bbox( |
|
pos_decoded_bbox_preds, |
|
pos_decoded_target_preds, |
|
weight=pos_centerness_targets, |
|
avg_factor=centerness_denorm) |
|
loss_centerness = self.loss_centerness( |
|
pos_centerness, pos_centerness_targets, avg_factor=num_pos) |
|
else: |
|
loss_bbox = pos_bbox_preds.sum() |
|
loss_centerness = pos_centerness.sum() |
|
|
|
self._raw_positive_infos.update(cls_scores=cls_scores) |
|
self._raw_positive_infos.update(centernesses=centernesses) |
|
self._raw_positive_infos.update(param_preds=param_preds) |
|
self._raw_positive_infos.update(all_level_points=all_level_points) |
|
self._raw_positive_infos.update(all_level_strides=all_level_strides) |
|
self._raw_positive_infos.update(pos_gt_inds_list=pos_gt_inds_list) |
|
self._raw_positive_infos.update(pos_inds_list=pos_inds_list) |
|
|
|
return dict( |
|
loss_cls=loss_cls, |
|
loss_bbox=loss_bbox, |
|
loss_centerness=loss_centerness) |
|
|
|
def get_targets( |
|
self, points: List[Tensor], batch_gt_instances: InstanceList |
|
) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: |
|
"""Compute regression, classification and centerness targets for points |
|
in multiple images. |
|
|
|
Args: |
|
points (list[Tensor]): Points of each fpn level, each has shape |
|
(num_points, 2). |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
|
|
Returns: |
|
tuple: Targets of each level. |
|
|
|
- concat_lvl_labels (list[Tensor]): Labels of each level. |
|
- concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ |
|
level. |
|
- pos_inds_list (list[Tensor]): pos_inds of each image. |
|
- pos_gt_inds_list (List[Tensor]): pos_gt_inds of each image. |
|
""" |
|
assert len(points) == len(self.regress_ranges) |
|
num_levels = len(points) |
|
|
|
expanded_regress_ranges = [ |
|
points[i].new_tensor(self.regress_ranges[i])[None].expand_as( |
|
points[i]) for i in range(num_levels) |
|
] |
|
|
|
concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) |
|
concat_points = torch.cat(points, dim=0) |
|
|
|
|
|
num_points = [center.size(0) for center in points] |
|
|
|
|
|
labels_list, bbox_targets_list, pos_inds_list, pos_gt_inds_list = \ |
|
multi_apply( |
|
self._get_targets_single, |
|
batch_gt_instances, |
|
points=concat_points, |
|
regress_ranges=concat_regress_ranges, |
|
num_points_per_lvl=num_points) |
|
|
|
|
|
labels_list = [labels.split(num_points, 0) for labels in labels_list] |
|
bbox_targets_list = [ |
|
bbox_targets.split(num_points, 0) |
|
for bbox_targets in bbox_targets_list |
|
] |
|
|
|
|
|
concat_lvl_labels = [] |
|
concat_lvl_bbox_targets = [] |
|
for i in range(num_levels): |
|
concat_lvl_labels.append( |
|
torch.cat([labels[i] for labels in labels_list])) |
|
bbox_targets = torch.cat( |
|
[bbox_targets[i] for bbox_targets in bbox_targets_list]) |
|
if self.norm_on_bbox: |
|
bbox_targets = bbox_targets / self.strides[i] |
|
concat_lvl_bbox_targets.append(bbox_targets) |
|
return (concat_lvl_labels, concat_lvl_bbox_targets, pos_inds_list, |
|
pos_gt_inds_list) |
|
|
|
def _get_targets_single( |
|
self, gt_instances: InstanceData, points: Tensor, |
|
regress_ranges: Tensor, num_points_per_lvl: List[int] |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
|
"""Compute regression and classification targets for a single image.""" |
|
num_points = points.size(0) |
|
num_gts = len(gt_instances) |
|
gt_bboxes = gt_instances.bboxes |
|
gt_labels = gt_instances.labels |
|
gt_masks = gt_instances.get('masks', None) |
|
|
|
if num_gts == 0: |
|
return gt_labels.new_full((num_points,), self.num_classes), \ |
|
gt_bboxes.new_zeros((num_points, 4)), \ |
|
gt_bboxes.new_zeros((0,), dtype=torch.int64), \ |
|
gt_bboxes.new_zeros((0,), dtype=torch.int64) |
|
|
|
areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( |
|
gt_bboxes[:, 3] - gt_bboxes[:, 1]) |
|
|
|
|
|
areas = areas[None].repeat(num_points, 1) |
|
regress_ranges = regress_ranges[:, None, :].expand( |
|
num_points, num_gts, 2) |
|
gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) |
|
xs, ys = points[:, 0], points[:, 1] |
|
xs = xs[:, None].expand(num_points, num_gts) |
|
ys = ys[:, None].expand(num_points, num_gts) |
|
|
|
left = xs - gt_bboxes[..., 0] |
|
right = gt_bboxes[..., 2] - xs |
|
top = ys - gt_bboxes[..., 1] |
|
bottom = gt_bboxes[..., 3] - ys |
|
bbox_targets = torch.stack((left, top, right, bottom), -1) |
|
|
|
if self.center_sampling: |
|
|
|
radius = self.center_sample_radius |
|
|
|
|
|
if gt_masks is None: |
|
center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2 |
|
center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2 |
|
else: |
|
h, w = gt_masks.height, gt_masks.width |
|
masks = gt_masks.to_tensor( |
|
dtype=torch.bool, device=gt_bboxes.device) |
|
yys = torch.arange( |
|
0, h, dtype=torch.float32, device=masks.device) |
|
xxs = torch.arange( |
|
0, w, dtype=torch.float32, device=masks.device) |
|
|
|
|
|
m00 = masks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6) |
|
m10 = (masks * xxs).sum(dim=-1).sum(dim=-1) |
|
m01 = (masks * yys[:, None]).sum(dim=-1).sum(dim=-1) |
|
center_xs = m10 / m00 |
|
center_ys = m01 / m00 |
|
|
|
center_xs = center_xs[None].expand(num_points, num_gts) |
|
center_ys = center_ys[None].expand(num_points, num_gts) |
|
center_gts = torch.zeros_like(gt_bboxes) |
|
stride = center_xs.new_zeros(center_xs.shape) |
|
|
|
|
|
lvl_begin = 0 |
|
for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl): |
|
lvl_end = lvl_begin + num_points_lvl |
|
stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius |
|
lvl_begin = lvl_end |
|
|
|
x_mins = center_xs - stride |
|
y_mins = center_ys - stride |
|
x_maxs = center_xs + stride |
|
y_maxs = center_ys + stride |
|
center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0], |
|
x_mins, gt_bboxes[..., 0]) |
|
center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1], |
|
y_mins, gt_bboxes[..., 1]) |
|
center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2], |
|
gt_bboxes[..., 2], x_maxs) |
|
center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3], |
|
gt_bboxes[..., 3], y_maxs) |
|
|
|
cb_dist_left = xs - center_gts[..., 0] |
|
cb_dist_right = center_gts[..., 2] - xs |
|
cb_dist_top = ys - center_gts[..., 1] |
|
cb_dist_bottom = center_gts[..., 3] - ys |
|
center_bbox = torch.stack( |
|
(cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1) |
|
inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 |
|
else: |
|
|
|
inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 |
|
|
|
|
|
max_regress_distance = bbox_targets.max(-1)[0] |
|
inside_regress_range = ( |
|
(max_regress_distance >= regress_ranges[..., 0]) |
|
& (max_regress_distance <= regress_ranges[..., 1])) |
|
|
|
|
|
|
|
areas[inside_gt_bbox_mask == 0] = INF |
|
areas[inside_regress_range == 0] = INF |
|
min_area, min_area_inds = areas.min(dim=1) |
|
|
|
labels = gt_labels[min_area_inds] |
|
labels[min_area == INF] = self.num_classes |
|
bbox_targets = bbox_targets[range(num_points), min_area_inds] |
|
|
|
|
|
bg_class_ind = self.num_classes |
|
pos_inds = ((labels >= 0) |
|
& (labels < bg_class_ind)).nonzero().reshape(-1) |
|
pos_gt_inds = min_area_inds[labels < self.num_classes] |
|
return labels, bbox_targets, pos_inds, pos_gt_inds |
|
|
|
def get_positive_infos(self) -> InstanceList: |
|
"""Get positive information from sampling results. |
|
|
|
Returns: |
|
list[:obj:`InstanceData`]: Positive information of each image, |
|
usually including positive bboxes, positive labels, positive |
|
priors, etc. |
|
""" |
|
assert len(self._raw_positive_infos) > 0 |
|
|
|
pos_gt_inds_list = self._raw_positive_infos['pos_gt_inds_list'] |
|
pos_inds_list = self._raw_positive_infos['pos_inds_list'] |
|
num_imgs = len(pos_gt_inds_list) |
|
|
|
cls_score_list = [] |
|
centerness_list = [] |
|
param_pred_list = [] |
|
point_list = [] |
|
stride_list = [] |
|
for cls_score_per_lvl, centerness_per_lvl, param_pred_per_lvl,\ |
|
point_per_lvl, stride_per_lvl in \ |
|
zip(self._raw_positive_infos['cls_scores'], |
|
self._raw_positive_infos['centernesses'], |
|
self._raw_positive_infos['param_preds'], |
|
self._raw_positive_infos['all_level_points'], |
|
self._raw_positive_infos['all_level_strides']): |
|
cls_score_per_lvl = \ |
|
cls_score_per_lvl.permute( |
|
0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) |
|
centerness_per_lvl = \ |
|
centerness_per_lvl.permute( |
|
0, 2, 3, 1).reshape(num_imgs, -1, 1) |
|
param_pred_per_lvl = \ |
|
param_pred_per_lvl.permute( |
|
0, 2, 3, 1).reshape(num_imgs, -1, self.num_params) |
|
point_per_lvl = point_per_lvl.unsqueeze(0).repeat(num_imgs, 1, 1) |
|
stride_per_lvl = stride_per_lvl.unsqueeze(0).repeat(num_imgs, 1) |
|
|
|
cls_score_list.append(cls_score_per_lvl) |
|
centerness_list.append(centerness_per_lvl) |
|
param_pred_list.append(param_pred_per_lvl) |
|
point_list.append(point_per_lvl) |
|
stride_list.append(stride_per_lvl) |
|
cls_scores = torch.cat(cls_score_list, dim=1) |
|
centernesses = torch.cat(centerness_list, dim=1) |
|
param_preds = torch.cat(param_pred_list, dim=1) |
|
all_points = torch.cat(point_list, dim=1) |
|
all_strides = torch.cat(stride_list, dim=1) |
|
|
|
positive_infos = [] |
|
for i, (pos_gt_inds, |
|
pos_inds) in enumerate(zip(pos_gt_inds_list, pos_inds_list)): |
|
pos_info = InstanceData() |
|
pos_info.points = all_points[i][pos_inds] |
|
pos_info.strides = all_strides[i][pos_inds] |
|
pos_info.scores = cls_scores[i][pos_inds] |
|
pos_info.centernesses = centernesses[i][pos_inds] |
|
pos_info.param_preds = param_preds[i][pos_inds] |
|
pos_info.pos_assigned_gt_inds = pos_gt_inds |
|
pos_info.pos_inds = pos_inds |
|
positive_infos.append(pos_info) |
|
return positive_infos |
|
|
|
def predict_by_feat(self, |
|
cls_scores: List[Tensor], |
|
bbox_preds: List[Tensor], |
|
score_factors: Optional[List[Tensor]] = None, |
|
param_preds: Optional[List[Tensor]] = None, |
|
batch_img_metas: Optional[List[dict]] = None, |
|
cfg: Optional[ConfigDict] = None, |
|
rescale: bool = False, |
|
with_nms: bool = True) -> InstanceList: |
|
"""Transform a batch of output features extracted from the head into |
|
bbox results. |
|
|
|
Note: When score_factors is not None, the cls_scores are |
|
usually multiplied by it then obtain the real score used in NMS, |
|
such as CenterNess in FCOS, IoU branch in ATSS. |
|
|
|
Args: |
|
cls_scores (list[Tensor]): Classification scores for all |
|
scale levels, each is a 4D-tensor, has shape |
|
(batch_size, num_priors * num_classes, H, W). |
|
bbox_preds (list[Tensor]): Box energies / deltas for all |
|
scale levels, each is a 4D-tensor, has shape |
|
(batch_size, num_priors * 4, H, W). |
|
score_factors (list[Tensor], optional): Score factor for |
|
all scale level, each is a 4D-tensor, has shape |
|
(batch_size, num_priors * 1, H, W). Defaults to None. |
|
param_preds (list[Tensor], optional): Params for all scale |
|
level, each is a 4D-tensor, has shape |
|
(batch_size, num_priors * num_params, H, W) |
|
batch_img_metas (list[dict], Optional): Batch image meta info. |
|
Defaults to None. |
|
cfg (ConfigDict, optional): Test / postprocessing |
|
configuration, if None, test_cfg would be used. |
|
Defaults to None. |
|
rescale (bool): If True, return boxes in original image space. |
|
Defaults to False. |
|
with_nms (bool): If True, do nms before return boxes. |
|
Defaults to True. |
|
|
|
Returns: |
|
list[:obj:`InstanceData`]: Object detection results of each image |
|
after the post process. Each item usually contains following keys. |
|
|
|
- scores (Tensor): Classification scores, has a shape |
|
(num_instance, ) |
|
- labels (Tensor): Labels of bboxes, has a shape |
|
(num_instances, ). |
|
- bboxes (Tensor): Has a shape (num_instances, 4), |
|
the last dimension 4 arrange as (x1, y1, x2, y2). |
|
""" |
|
assert len(cls_scores) == len(bbox_preds) |
|
|
|
if score_factors is None: |
|
|
|
with_score_factors = False |
|
else: |
|
|
|
with_score_factors = True |
|
assert len(cls_scores) == len(score_factors) |
|
|
|
num_levels = len(cls_scores) |
|
|
|
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] |
|
all_level_points_strides = self.prior_generator.grid_priors( |
|
featmap_sizes, |
|
dtype=bbox_preds[0].dtype, |
|
device=bbox_preds[0].device, |
|
with_stride=True) |
|
all_level_points = [i[:, :2] for i in all_level_points_strides] |
|
all_level_strides = [i[:, 2] for i in all_level_points_strides] |
|
|
|
result_list = [] |
|
|
|
for img_id in range(len(batch_img_metas)): |
|
img_meta = batch_img_metas[img_id] |
|
cls_score_list = select_single_mlvl( |
|
cls_scores, img_id, detach=True) |
|
bbox_pred_list = select_single_mlvl( |
|
bbox_preds, img_id, detach=True) |
|
if with_score_factors: |
|
score_factor_list = select_single_mlvl( |
|
score_factors, img_id, detach=True) |
|
else: |
|
score_factor_list = [None for _ in range(num_levels)] |
|
param_pred_list = select_single_mlvl( |
|
param_preds, img_id, detach=True) |
|
|
|
results = self._predict_by_feat_single( |
|
cls_score_list=cls_score_list, |
|
bbox_pred_list=bbox_pred_list, |
|
score_factor_list=score_factor_list, |
|
param_pred_list=param_pred_list, |
|
mlvl_points=all_level_points, |
|
mlvl_strides=all_level_strides, |
|
img_meta=img_meta, |
|
cfg=cfg, |
|
rescale=rescale, |
|
with_nms=with_nms) |
|
result_list.append(results) |
|
return result_list |
|
|
|
def _predict_by_feat_single(self, |
|
cls_score_list: List[Tensor], |
|
bbox_pred_list: List[Tensor], |
|
score_factor_list: List[Tensor], |
|
param_pred_list: List[Tensor], |
|
mlvl_points: List[Tensor], |
|
mlvl_strides: List[Tensor], |
|
img_meta: dict, |
|
cfg: ConfigDict, |
|
rescale: bool = False, |
|
with_nms: bool = True) -> InstanceData: |
|
"""Transform a single image's features extracted from the head into |
|
bbox results. |
|
|
|
Args: |
|
cls_score_list (list[Tensor]): Box scores from all scale |
|
levels of a single image, each item has shape |
|
(num_priors * num_classes, H, W). |
|
bbox_pred_list (list[Tensor]): Box energies / deltas from |
|
all scale levels of a single image, each item has shape |
|
(num_priors * 4, H, W). |
|
score_factor_list (list[Tensor]): Score factor from all scale |
|
levels of a single image, each item has shape |
|
(num_priors * 1, H, W). |
|
param_pred_list (List[Tensor]): Param predition from all scale |
|
levels of a single image, each item has shape |
|
(num_priors * num_params, H, W). |
|
mlvl_points (list[Tensor]): Each element in the list is |
|
the priors of a single level in feature pyramid. |
|
It has shape (num_priors, 2) |
|
mlvl_strides (List[Tensor]): Each element in the list is |
|
the stride of a single level in feature pyramid. |
|
It has shape (num_priors, 1) |
|
img_meta (dict): Image meta info. |
|
cfg (mmengine.Config): Test / postprocessing configuration, |
|
if None, test_cfg would be used. |
|
rescale (bool): If True, return boxes in original image space. |
|
Defaults to False. |
|
with_nms (bool): If True, do nms before return boxes. |
|
Defaults to True. |
|
|
|
Returns: |
|
:obj:`InstanceData`: Detection results of each image |
|
after the post process. |
|
Each item usually contains following keys. |
|
|
|
- scores (Tensor): Classification scores, has a shape |
|
(num_instance, ) |
|
- labels (Tensor): Labels of bboxes, has a shape |
|
(num_instances, ). |
|
- bboxes (Tensor): Has a shape (num_instances, 4), |
|
the last dimension 4 arrange as (x1, y1, x2, y2). |
|
""" |
|
if score_factor_list[0] is None: |
|
|
|
with_score_factors = False |
|
else: |
|
|
|
with_score_factors = True |
|
|
|
cfg = self.test_cfg if cfg is None else cfg |
|
cfg = copy.deepcopy(cfg) |
|
img_shape = img_meta['img_shape'] |
|
nms_pre = cfg.get('nms_pre', -1) |
|
|
|
mlvl_bbox_preds = [] |
|
mlvl_param_preds = [] |
|
mlvl_valid_points = [] |
|
mlvl_valid_strides = [] |
|
mlvl_scores = [] |
|
mlvl_labels = [] |
|
if with_score_factors: |
|
mlvl_score_factors = [] |
|
else: |
|
mlvl_score_factors = None |
|
for level_idx, (cls_score, bbox_pred, score_factor, |
|
param_pred, points, strides) in \ |
|
enumerate(zip(cls_score_list, bbox_pred_list, |
|
score_factor_list, param_pred_list, |
|
mlvl_points, mlvl_strides)): |
|
|
|
assert cls_score.size()[-2:] == bbox_pred.size()[-2:] |
|
|
|
dim = self.bbox_coder.encode_size |
|
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) |
|
if with_score_factors: |
|
score_factor = score_factor.permute(1, 2, |
|
0).reshape(-1).sigmoid() |
|
cls_score = cls_score.permute(1, 2, |
|
0).reshape(-1, self.cls_out_channels) |
|
if self.use_sigmoid_cls: |
|
scores = cls_score.sigmoid() |
|
else: |
|
|
|
|
|
|
|
scores = cls_score.softmax(-1)[:, :-1] |
|
|
|
param_pred = param_pred.permute(1, 2, |
|
0).reshape(-1, self.num_params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
score_thr = cfg.get('score_thr', 0) |
|
|
|
results = filter_scores_and_topk( |
|
scores, score_thr, nms_pre, |
|
dict( |
|
bbox_pred=bbox_pred, |
|
param_pred=param_pred, |
|
points=points, |
|
strides=strides)) |
|
scores, labels, keep_idxs, filtered_results = results |
|
|
|
bbox_pred = filtered_results['bbox_pred'] |
|
param_pred = filtered_results['param_pred'] |
|
points = filtered_results['points'] |
|
strides = filtered_results['strides'] |
|
|
|
if with_score_factors: |
|
score_factor = score_factor[keep_idxs] |
|
|
|
mlvl_bbox_preds.append(bbox_pred) |
|
mlvl_param_preds.append(param_pred) |
|
mlvl_valid_points.append(points) |
|
mlvl_valid_strides.append(strides) |
|
mlvl_scores.append(scores) |
|
mlvl_labels.append(labels) |
|
|
|
if with_score_factors: |
|
mlvl_score_factors.append(score_factor) |
|
|
|
bbox_pred = torch.cat(mlvl_bbox_preds) |
|
priors = cat_boxes(mlvl_valid_points) |
|
bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) |
|
|
|
results = InstanceData() |
|
results.bboxes = bboxes |
|
results.scores = torch.cat(mlvl_scores) |
|
results.labels = torch.cat(mlvl_labels) |
|
results.param_preds = torch.cat(mlvl_param_preds) |
|
results.points = torch.cat(mlvl_valid_points) |
|
results.strides = torch.cat(mlvl_valid_strides) |
|
if with_score_factors: |
|
results.score_factors = torch.cat(mlvl_score_factors) |
|
|
|
return self._bbox_post_process( |
|
results=results, |
|
cfg=cfg, |
|
rescale=rescale, |
|
with_nms=with_nms, |
|
img_meta=img_meta) |
|
|
|
|
|
class MaskFeatModule(BaseModule): |
|
"""CondInst mask feature map branch used in \ |
|
https://arxiv.org/abs/1904.02689. |
|
|
|
Args: |
|
in_channels (int): Number of channels in the input feature map. |
|
feat_channels (int): Number of hidden channels of the mask feature |
|
map branch. |
|
start_level (int): The starting feature map level from RPN that |
|
will be used to predict the mask feature map. |
|
end_level (int): The ending feature map level from rpn that |
|
will be used to predict the mask feature map. |
|
out_channels (int): Number of output channels of the mask feature |
|
map branch. This is the channel count of the mask |
|
feature map that to be dynamically convolved with the predicted |
|
kernel. |
|
mask_stride (int): Downsample factor of the mask feature map output. |
|
Defaults to 4. |
|
num_stacked_convs (int): Number of convs in mask feature branch. |
|
conv_cfg (dict): Config dict for convolution layer. Default: None. |
|
norm_cfg (dict): Config dict for normalization layer. Default: None. |
|
init_cfg (dict or list[dict], optional): Initialization config dict. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels: int, |
|
feat_channels: int, |
|
start_level: int, |
|
end_level: int, |
|
out_channels: int, |
|
mask_stride: int = 4, |
|
num_stacked_convs: int = 4, |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: OptConfigType = None, |
|
init_cfg: MultiConfig = [ |
|
dict(type='Normal', layer='Conv2d', std=0.01) |
|
], |
|
**kwargs) -> None: |
|
super().__init__(init_cfg=init_cfg) |
|
self.in_channels = in_channels |
|
self.feat_channels = feat_channels |
|
self.start_level = start_level |
|
self.end_level = end_level |
|
self.mask_stride = mask_stride |
|
self.num_stacked_convs = num_stacked_convs |
|
assert start_level >= 0 and end_level >= start_level |
|
self.out_channels = out_channels |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self._init_layers() |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize layers of the head.""" |
|
self.convs_all_levels = nn.ModuleList() |
|
for i in range(self.start_level, self.end_level + 1): |
|
convs_per_level = nn.Sequential() |
|
convs_per_level.add_module( |
|
f'conv{i}', |
|
ConvModule( |
|
self.in_channels, |
|
self.feat_channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
inplace=False, |
|
bias=False)) |
|
self.convs_all_levels.append(convs_per_level) |
|
|
|
conv_branch = [] |
|
for _ in range(self.num_stacked_convs): |
|
conv_branch.append( |
|
ConvModule( |
|
self.feat_channels, |
|
self.feat_channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
bias=False)) |
|
self.conv_branch = nn.Sequential(*conv_branch) |
|
|
|
self.conv_pred = nn.Conv2d( |
|
self.feat_channels, self.out_channels, 1, stride=1) |
|
|
|
def init_weights(self) -> None: |
|
"""Initialize weights of the head.""" |
|
super().init_weights() |
|
kaiming_init(self.convs_all_levels, a=1, distribution='uniform') |
|
kaiming_init(self.conv_branch, a=1, distribution='uniform') |
|
kaiming_init(self.conv_pred, a=1, distribution='uniform') |
|
|
|
def forward(self, x: Tuple[Tensor]) -> Tensor: |
|
"""Forward features from the upstream network. |
|
|
|
Args: |
|
x (tuple[Tensor]): Features from the upstream network, each is |
|
a 4D-tensor. |
|
|
|
Returns: |
|
Tensor: The predicted mask feature map. |
|
""" |
|
inputs = x[self.start_level:self.end_level + 1] |
|
assert len(inputs) == (self.end_level - self.start_level + 1) |
|
feature_add_all_level = self.convs_all_levels[0](inputs[0]) |
|
target_h, target_w = feature_add_all_level.size()[2:] |
|
for i in range(1, len(inputs)): |
|
input_p = inputs[i] |
|
x_p = self.convs_all_levels[i](input_p) |
|
h, w = x_p.size()[2:] |
|
factor_h = target_h // h |
|
factor_w = target_w // w |
|
assert factor_h == factor_w |
|
feature_per_level = aligned_bilinear(x_p, factor_h) |
|
feature_add_all_level = feature_add_all_level + \ |
|
feature_per_level |
|
|
|
feature_add_all_level = self.conv_branch(feature_add_all_level) |
|
feature_pred = self.conv_pred(feature_add_all_level) |
|
return feature_pred |
|
|
|
|
|
@MODELS.register_module() |
|
class CondInstMaskHead(BaseMaskHead): |
|
"""CondInst mask head used in https://arxiv.org/abs/1904.02689. |
|
|
|
This head outputs the mask for CondInst. |
|
|
|
Args: |
|
mask_feature_head (dict): Config of CondInstMaskFeatHead. |
|
num_layers (int): Number of dynamic conv layers. |
|
feat_channels (int): Number of channels in the dynamic conv. |
|
mask_out_stride (int): The stride of the mask feat. |
|
size_of_interest (int): The size of the region used in rel coord. |
|
max_masks_to_train (int): Maximum number of masks to train for |
|
each image. |
|
loss_segm (:obj:`ConfigDict` or dict, optional): Config of |
|
segmentation loss. |
|
train_cfg (:obj:`ConfigDict` or dict, optional): Training config |
|
of head. |
|
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of |
|
head. |
|
""" |
|
|
|
def __init__(self, |
|
mask_feature_head: ConfigType, |
|
num_layers: int = 3, |
|
feat_channels: int = 8, |
|
mask_out_stride: int = 4, |
|
size_of_interest: int = 8, |
|
max_masks_to_train: int = -1, |
|
topk_masks_per_img: int = -1, |
|
loss_mask: ConfigType = None, |
|
train_cfg: OptConfigType = None, |
|
test_cfg: OptConfigType = None) -> None: |
|
super().__init__() |
|
self.mask_feature_head = MaskFeatModule(**mask_feature_head) |
|
self.mask_feat_stride = self.mask_feature_head.mask_stride |
|
self.in_channels = self.mask_feature_head.out_channels |
|
self.num_layers = num_layers |
|
self.feat_channels = feat_channels |
|
self.size_of_interest = size_of_interest |
|
self.mask_out_stride = mask_out_stride |
|
self.max_masks_to_train = max_masks_to_train |
|
self.topk_masks_per_img = topk_masks_per_img |
|
self.prior_generator = MlvlPointGenerator([self.mask_feat_stride]) |
|
|
|
self.train_cfg = train_cfg |
|
self.test_cfg = test_cfg |
|
self.loss_mask = MODELS.build(loss_mask) |
|
self._init_layers() |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize layers of the head.""" |
|
weight_nums, bias_nums = [], [] |
|
for i in range(self.num_layers): |
|
if i == 0: |
|
weight_nums.append((self.in_channels + 2) * self.feat_channels) |
|
bias_nums.append(self.feat_channels) |
|
elif i == self.num_layers - 1: |
|
weight_nums.append(self.feat_channels * 1) |
|
bias_nums.append(1) |
|
else: |
|
weight_nums.append(self.feat_channels * self.feat_channels) |
|
bias_nums.append(self.feat_channels) |
|
|
|
self.weight_nums = weight_nums |
|
self.bias_nums = bias_nums |
|
self.num_params = sum(weight_nums) + sum(bias_nums) |
|
|
|
def parse_dynamic_params( |
|
self, params: Tensor) -> Tuple[List[Tensor], List[Tensor]]: |
|
"""parse the dynamic params for dynamic conv.""" |
|
num_insts = params.size(0) |
|
params_splits = list( |
|
torch.split_with_sizes( |
|
params, self.weight_nums + self.bias_nums, dim=1)) |
|
weight_splits = params_splits[:self.num_layers] |
|
bias_splits = params_splits[self.num_layers:] |
|
for i in range(self.num_layers): |
|
if i < self.num_layers - 1: |
|
weight_splits[i] = weight_splits[i].reshape( |
|
num_insts * self.in_channels, -1, 1, 1) |
|
bias_splits[i] = bias_splits[i].reshape(num_insts * |
|
self.in_channels) |
|
else: |
|
|
|
weight_splits[i] = weight_splits[i].reshape( |
|
num_insts * 1, -1, 1, 1) |
|
bias_splits[i] = bias_splits[i].reshape(num_insts) |
|
|
|
return weight_splits, bias_splits |
|
|
|
def dynamic_conv_forward(self, features: Tensor, weights: List[Tensor], |
|
biases: List[Tensor], num_insts: int) -> Tensor: |
|
"""dynamic forward, each layer follow a relu.""" |
|
n_layers = len(weights) |
|
x = features |
|
for i, (w, b) in enumerate(zip(weights, biases)): |
|
x = F.conv2d(x, w, bias=b, stride=1, padding=0, groups=num_insts) |
|
if i < n_layers - 1: |
|
x = F.relu(x) |
|
return x |
|
|
|
def forward(self, x: tuple, positive_infos: InstanceList) -> tuple: |
|
"""Forward feature from the upstream network to get prototypes and |
|
linearly combine the prototypes, using masks coefficients, into |
|
instance masks. Finally, crop the instance masks with given bboxes. |
|
|
|
Args: |
|
x (Tuple[Tensor]): Feature from the upstream network, which is |
|
a 4D-tensor. |
|
positive_infos (List[:obj:``InstanceData``]): Positive information |
|
that calculate from detect head. |
|
|
|
Returns: |
|
tuple: Predicted instance segmentation masks |
|
""" |
|
mask_feats = self.mask_feature_head(x) |
|
return multi_apply(self.forward_single, mask_feats, positive_infos) |
|
|
|
def forward_single(self, mask_feat: Tensor, |
|
positive_info: InstanceData) -> Tensor: |
|
"""Forward features of a each image.""" |
|
pos_param_preds = positive_info.get('param_preds') |
|
pos_points = positive_info.get('points') |
|
pos_strides = positive_info.get('strides') |
|
|
|
num_inst = pos_param_preds.shape[0] |
|
mask_feat = mask_feat[None].repeat(num_inst, 1, 1, 1) |
|
_, _, H, W = mask_feat.size() |
|
if num_inst == 0: |
|
return (pos_param_preds.new_zeros((0, 1, H, W)), ) |
|
|
|
locations = self.prior_generator.single_level_grid_priors( |
|
mask_feat.size()[2:], 0, device=mask_feat.device) |
|
|
|
rel_coords = relative_coordinate_maps(locations, pos_points, |
|
pos_strides, |
|
self.size_of_interest, |
|
mask_feat.size()[2:]) |
|
mask_head_inputs = torch.cat([rel_coords, mask_feat], dim=1) |
|
mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) |
|
|
|
weights, biases = self.parse_dynamic_params(pos_param_preds) |
|
mask_preds = self.dynamic_conv_forward(mask_head_inputs, weights, |
|
biases, num_inst) |
|
mask_preds = mask_preds.reshape(-1, H, W) |
|
mask_preds = aligned_bilinear( |
|
mask_preds.unsqueeze(0), |
|
int(self.mask_feat_stride / self.mask_out_stride)).squeeze(0) |
|
|
|
return (mask_preds, ) |
|
|
|
def loss_by_feat(self, mask_preds: List[Tensor], |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], positive_infos: InstanceList, |
|
**kwargs) -> dict: |
|
"""Calculate the loss based on the features extracted by the mask head. |
|
|
|
Args: |
|
mask_preds (list[Tensor]): List of predicted masks, each has |
|
shape (num_classes, H, W). |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes``, ``masks``, |
|
and ``labels`` attributes. |
|
batch_img_metas (list[dict]): Meta information of multiple images. |
|
positive_infos (List[:obj:``InstanceData``]): Information of |
|
positive samples of each image that are assigned in detection |
|
head. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
assert positive_infos is not None, \ |
|
'positive_infos should not be None in `CondInstMaskHead`' |
|
losses = dict() |
|
|
|
loss_mask = 0. |
|
num_imgs = len(mask_preds) |
|
total_pos = 0 |
|
|
|
for idx in range(num_imgs): |
|
(mask_pred, pos_mask_targets, num_pos) = \ |
|
self._get_targets_single( |
|
mask_preds[idx], batch_gt_instances[idx], |
|
positive_infos[idx]) |
|
|
|
total_pos += num_pos |
|
if num_pos == 0 or pos_mask_targets is None: |
|
loss = mask_pred.new_zeros(1).mean() |
|
else: |
|
loss = self.loss_mask( |
|
mask_pred, pos_mask_targets, |
|
reduction_override='none').sum() |
|
loss_mask += loss |
|
|
|
if total_pos == 0: |
|
total_pos += 1 |
|
loss_mask = loss_mask / total_pos |
|
losses.update(loss_mask=loss_mask) |
|
return losses |
|
|
|
def _get_targets_single(self, mask_preds: Tensor, |
|
gt_instances: InstanceData, |
|
positive_info: InstanceData): |
|
"""Compute targets for predictions of single image. |
|
|
|
Args: |
|
mask_preds (Tensor): Predicted prototypes with shape |
|
(num_classes, H, W). |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It should includes ``bboxes``, ``labels``, |
|
and ``masks`` attributes. |
|
positive_info (:obj:`InstanceData`): Information of positive |
|
samples that are assigned in detection head. It usually |
|
contains following keys. |
|
|
|
- pos_assigned_gt_inds (Tensor): Assigner GT indexes of |
|
positive proposals, has shape (num_pos, ) |
|
- pos_inds (Tensor): Positive index of image, has |
|
shape (num_pos, ). |
|
- param_pred (Tensor): Positive param preditions |
|
with shape (num_pos, num_params). |
|
|
|
Returns: |
|
tuple: Usually returns a tuple containing learning targets. |
|
|
|
- mask_preds (Tensor): Positive predicted mask with shape |
|
(num_pos, mask_h, mask_w). |
|
- pos_mask_targets (Tensor): Positive mask targets with shape |
|
(num_pos, mask_h, mask_w). |
|
- num_pos (int): Positive numbers. |
|
""" |
|
gt_bboxes = gt_instances.bboxes |
|
device = gt_bboxes.device |
|
gt_masks = gt_instances.masks.to_tensor( |
|
dtype=torch.bool, device=device).float() |
|
|
|
|
|
pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds') |
|
scores = positive_info.get('scores') |
|
centernesses = positive_info.get('centernesses') |
|
num_pos = pos_assigned_gt_inds.size(0) |
|
|
|
if gt_masks.size(0) == 0 or num_pos == 0: |
|
return mask_preds, None, 0 |
|
|
|
|
|
|
|
if (self.max_masks_to_train != -1) and \ |
|
(num_pos > self.max_masks_to_train): |
|
perm = torch.randperm(num_pos) |
|
select = perm[:self.max_masks_to_train] |
|
mask_preds = mask_preds[select] |
|
pos_assigned_gt_inds = pos_assigned_gt_inds[select] |
|
num_pos = self.max_masks_to_train |
|
elif self.topk_masks_per_img != -1: |
|
unique_gt_inds = pos_assigned_gt_inds.unique() |
|
num_inst_per_gt = max( |
|
int(self.topk_masks_per_img / len(unique_gt_inds)), 1) |
|
|
|
keep_mask_preds = [] |
|
keep_pos_assigned_gt_inds = [] |
|
for gt_ind in unique_gt_inds: |
|
per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind) |
|
mask_preds_per_inst = mask_preds[per_inst_pos_inds] |
|
gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds] |
|
if sum(per_inst_pos_inds) > num_inst_per_gt: |
|
per_inst_scores = scores[per_inst_pos_inds].sigmoid().max( |
|
dim=1)[0] |
|
per_inst_centerness = centernesses[ |
|
per_inst_pos_inds].sigmoid().reshape(-1, ) |
|
select = (per_inst_scores * per_inst_centerness).topk( |
|
k=num_inst_per_gt, dim=0)[1] |
|
mask_preds_per_inst = mask_preds_per_inst[select] |
|
gt_inds_per_inst = gt_inds_per_inst[select] |
|
keep_mask_preds.append(mask_preds_per_inst) |
|
keep_pos_assigned_gt_inds.append(gt_inds_per_inst) |
|
mask_preds = torch.cat(keep_mask_preds) |
|
pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds) |
|
num_pos = pos_assigned_gt_inds.size(0) |
|
|
|
|
|
start = int(self.mask_out_stride // 2) |
|
gt_masks = gt_masks[:, start::self.mask_out_stride, |
|
start::self.mask_out_stride] |
|
gt_masks = gt_masks.gt(0.5).float() |
|
pos_mask_targets = gt_masks[pos_assigned_gt_inds] |
|
|
|
return (mask_preds, pos_mask_targets, num_pos) |
|
|
|
def predict_by_feat(self, |
|
mask_preds: List[Tensor], |
|
results_list: InstanceList, |
|
batch_img_metas: List[dict], |
|
rescale: bool = True, |
|
**kwargs) -> InstanceList: |
|
"""Transform a batch of output features extracted from the head into |
|
mask results. |
|
|
|
Args: |
|
mask_preds (list[Tensor]): Predicted prototypes with shape |
|
(num_classes, H, W). |
|
results_list (List[:obj:``InstanceData``]): BBoxHead results. |
|
batch_img_metas (list[dict]): Meta information of all images. |
|
rescale (bool, optional): Whether to rescale the results. |
|
Defaults to False. |
|
|
|
Returns: |
|
list[:obj:`InstanceData`]: Processed results of multiple |
|
images.Each :obj:`InstanceData` usually contains |
|
following keys. |
|
|
|
- scores (Tensor): Classification scores, has shape |
|
(num_instance,). |
|
- labels (Tensor): Has shape (num_instances,). |
|
- masks (Tensor): Processed mask results, has |
|
shape (num_instances, h, w). |
|
""" |
|
assert len(mask_preds) == len(results_list) == len(batch_img_metas) |
|
|
|
for img_id in range(len(batch_img_metas)): |
|
img_meta = batch_img_metas[img_id] |
|
results = results_list[img_id] |
|
bboxes = results.bboxes |
|
mask_pred = mask_preds[img_id] |
|
if bboxes.shape[0] == 0 or mask_pred.shape[0] == 0: |
|
results_list[img_id] = empty_instances( |
|
[img_meta], |
|
bboxes.device, |
|
task_type='mask', |
|
instance_results=[results])[0] |
|
else: |
|
im_mask = self._predict_by_feat_single( |
|
mask_preds=mask_pred, |
|
bboxes=bboxes, |
|
img_meta=img_meta, |
|
rescale=rescale) |
|
results.masks = im_mask |
|
return results_list |
|
|
|
def _predict_by_feat_single(self, |
|
mask_preds: Tensor, |
|
bboxes: Tensor, |
|
img_meta: dict, |
|
rescale: bool, |
|
cfg: OptConfigType = None): |
|
"""Transform a single image's features extracted from the head into |
|
mask results. |
|
|
|
Args: |
|
mask_preds (Tensor): Predicted prototypes, has shape [H, W, N]. |
|
img_meta (dict): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
rescale (bool): If rescale is False, then returned masks will |
|
fit the scale of imgs[0]. |
|
cfg (dict, optional): Config used in test phase. |
|
Defaults to None. |
|
|
|
Returns: |
|
:obj:`InstanceData`: Processed results of single image. |
|
it usually contains following keys. |
|
|
|
- scores (Tensor): Classification scores, has shape |
|
(num_instance,). |
|
- labels (Tensor): Has shape (num_instances,). |
|
- masks (Tensor): Processed mask results, has |
|
shape (num_instances, h, w). |
|
""" |
|
cfg = self.test_cfg if cfg is None else cfg |
|
scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( |
|
(1, 2)) |
|
img_h, img_w = img_meta['img_shape'][:2] |
|
ori_h, ori_w = img_meta['ori_shape'][:2] |
|
|
|
mask_preds = mask_preds.sigmoid().unsqueeze(0) |
|
mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride) |
|
mask_preds = mask_preds[:, :, :img_h, :img_w] |
|
if rescale: |
|
scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( |
|
(1, 2)) |
|
bboxes /= scale_factor |
|
|
|
masks = F.interpolate( |
|
mask_preds, (ori_h, ori_w), |
|
mode='bilinear', |
|
align_corners=False).squeeze(0) > cfg.mask_thr |
|
else: |
|
masks = mask_preds.squeeze(0) > cfg.mask_thr |
|
|
|
return masks |
|
|