|
|
|
from typing import Dict, List, Sequence, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule |
|
from mmcv.ops import DeformConv2d |
|
from mmengine.config import ConfigDict |
|
from mmengine.structures import InstanceData |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS, TASK_UTILS |
|
from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList |
|
from ..task_modules.prior_generators import MlvlPointGenerator |
|
from ..task_modules.samplers import PseudoSampler |
|
from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply, |
|
unmap) |
|
from .anchor_free_head import AnchorFreeHead |
|
|
|
|
|
@MODELS.register_module() |
|
class RepPointsHead(AnchorFreeHead): |
|
"""RepPoint head. |
|
|
|
Args: |
|
num_classes (int): Number of categories excluding the background |
|
category. |
|
in_channels (int): Number of channels in the input feature map. |
|
point_feat_channels (int): Number of channels of points features. |
|
num_points (int): Number of points. |
|
gradient_mul (float): The multiplier to gradients from |
|
points refinement and recognition. |
|
point_strides (Sequence[int]): points strides. |
|
point_base_scale (int): bbox scale for assigning labels. |
|
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. |
|
loss_bbox_init (:obj:`ConfigDict` or dict): Config of initial points |
|
loss. |
|
loss_bbox_refine (:obj:`ConfigDict` or dict): Config of points loss in |
|
refinement. |
|
use_grid_points (bool): If we use bounding box representation, the |
|
reppoints is represented as grid points on the bounding box. |
|
center_init (bool): Whether to use center point assignment. |
|
transform_method (str): The methods to transform RepPoints to bbox. |
|
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ |
|
dict]): Initialization config dict. |
|
""" |
|
|
|
def __init__(self, |
|
num_classes: int, |
|
in_channels: int, |
|
point_feat_channels: int = 256, |
|
num_points: int = 9, |
|
gradient_mul: float = 0.1, |
|
point_strides: Sequence[int] = [8, 16, 32, 64, 128], |
|
point_base_scale: int = 4, |
|
loss_cls: ConfigType = dict( |
|
type='FocalLoss', |
|
use_sigmoid=True, |
|
gamma=2.0, |
|
alpha=0.25, |
|
loss_weight=1.0), |
|
loss_bbox_init: ConfigType = dict( |
|
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5), |
|
loss_bbox_refine: ConfigType = dict( |
|
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), |
|
use_grid_points: bool = False, |
|
center_init: bool = True, |
|
transform_method: str = 'moment', |
|
moment_mul: float = 0.01, |
|
init_cfg: MultiConfig = dict( |
|
type='Normal', |
|
layer='Conv2d', |
|
std=0.01, |
|
override=dict( |
|
type='Normal', |
|
name='reppoints_cls_out', |
|
std=0.01, |
|
bias_prob=0.01)), |
|
**kwargs) -> None: |
|
self.num_points = num_points |
|
self.point_feat_channels = point_feat_channels |
|
self.use_grid_points = use_grid_points |
|
self.center_init = center_init |
|
|
|
|
|
self.dcn_kernel = int(np.sqrt(num_points)) |
|
self.dcn_pad = int((self.dcn_kernel - 1) / 2) |
|
assert self.dcn_kernel * self.dcn_kernel == num_points, \ |
|
'The points number should be a square number.' |
|
assert self.dcn_kernel % 2 == 1, \ |
|
'The points number should be an odd square number.' |
|
dcn_base = np.arange(-self.dcn_pad, |
|
self.dcn_pad + 1).astype(np.float64) |
|
dcn_base_y = np.repeat(dcn_base, self.dcn_kernel) |
|
dcn_base_x = np.tile(dcn_base, self.dcn_kernel) |
|
dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape( |
|
(-1)) |
|
self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1) |
|
|
|
super().__init__( |
|
num_classes=num_classes, |
|
in_channels=in_channels, |
|
loss_cls=loss_cls, |
|
init_cfg=init_cfg, |
|
**kwargs) |
|
|
|
self.gradient_mul = gradient_mul |
|
self.point_base_scale = point_base_scale |
|
self.point_strides = point_strides |
|
self.prior_generator = MlvlPointGenerator( |
|
self.point_strides, offset=0.) |
|
|
|
if self.train_cfg: |
|
self.init_assigner = TASK_UTILS.build( |
|
self.train_cfg['init']['assigner']) |
|
self.refine_assigner = TASK_UTILS.build( |
|
self.train_cfg['refine']['assigner']) |
|
|
|
if self.train_cfg.get('sampler', None) is not None: |
|
self.sampler = TASK_UTILS.build( |
|
self.train_cfg['sampler'], default_args=dict(context=self)) |
|
else: |
|
self.sampler = PseudoSampler(context=self) |
|
|
|
self.transform_method = transform_method |
|
if self.transform_method == 'moment': |
|
self.moment_transfer = nn.Parameter( |
|
data=torch.zeros(2), requires_grad=True) |
|
self.moment_mul = moment_mul |
|
|
|
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) |
|
if self.use_sigmoid_cls: |
|
self.cls_out_channels = self.num_classes |
|
else: |
|
self.cls_out_channels = self.num_classes + 1 |
|
self.loss_bbox_init = MODELS.build(loss_bbox_init) |
|
self.loss_bbox_refine = MODELS.build(loss_bbox_refine) |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize layers of the head.""" |
|
self.relu = nn.ReLU(inplace=True) |
|
self.cls_convs = nn.ModuleList() |
|
self.reg_convs = nn.ModuleList() |
|
for i in range(self.stacked_convs): |
|
chn = self.in_channels if i == 0 else self.feat_channels |
|
self.cls_convs.append( |
|
ConvModule( |
|
chn, |
|
self.feat_channels, |
|
3, |
|
stride=1, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg)) |
|
self.reg_convs.append( |
|
ConvModule( |
|
chn, |
|
self.feat_channels, |
|
3, |
|
stride=1, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg)) |
|
pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points |
|
self.reppoints_cls_conv = DeformConv2d(self.feat_channels, |
|
self.point_feat_channels, |
|
self.dcn_kernel, 1, |
|
self.dcn_pad) |
|
self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels, |
|
self.cls_out_channels, 1, 1, 0) |
|
self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels, |
|
self.point_feat_channels, 3, |
|
1, 1) |
|
self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels, |
|
pts_out_dim, 1, 1, 0) |
|
self.reppoints_pts_refine_conv = DeformConv2d(self.feat_channels, |
|
self.point_feat_channels, |
|
self.dcn_kernel, 1, |
|
self.dcn_pad) |
|
self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels, |
|
pts_out_dim, 1, 1, 0) |
|
|
|
def points2bbox(self, pts: Tensor, y_first: bool = True) -> Tensor: |
|
"""Converting the points set into bounding box. |
|
|
|
Args: |
|
pts (Tensor): the input points sets (fields), each points |
|
set (fields) is represented as 2n scalar. |
|
y_first (bool): if y_first=True, the point set is |
|
represented as [y1, x1, y2, x2 ... yn, xn], otherwise |
|
the point set is represented as |
|
[x1, y1, x2, y2 ... xn, yn]. Defaults to True. |
|
|
|
Returns: |
|
Tensor: each points set is converting to a bbox [x1, y1, x2, y2]. |
|
""" |
|
pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:]) |
|
pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, |
|
...] |
|
pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, |
|
...] |
|
if self.transform_method == 'minmax': |
|
bbox_left = pts_x.min(dim=1, keepdim=True)[0] |
|
bbox_right = pts_x.max(dim=1, keepdim=True)[0] |
|
bbox_up = pts_y.min(dim=1, keepdim=True)[0] |
|
bbox_bottom = pts_y.max(dim=1, keepdim=True)[0] |
|
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], |
|
dim=1) |
|
elif self.transform_method == 'partial_minmax': |
|
pts_y = pts_y[:, :4, ...] |
|
pts_x = pts_x[:, :4, ...] |
|
bbox_left = pts_x.min(dim=1, keepdim=True)[0] |
|
bbox_right = pts_x.max(dim=1, keepdim=True)[0] |
|
bbox_up = pts_y.min(dim=1, keepdim=True)[0] |
|
bbox_bottom = pts_y.max(dim=1, keepdim=True)[0] |
|
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], |
|
dim=1) |
|
elif self.transform_method == 'moment': |
|
pts_y_mean = pts_y.mean(dim=1, keepdim=True) |
|
pts_x_mean = pts_x.mean(dim=1, keepdim=True) |
|
pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True) |
|
pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True) |
|
moment_transfer = (self.moment_transfer * self.moment_mul) + ( |
|
self.moment_transfer.detach() * (1 - self.moment_mul)) |
|
moment_width_transfer = moment_transfer[0] |
|
moment_height_transfer = moment_transfer[1] |
|
half_width = pts_x_std * torch.exp(moment_width_transfer) |
|
half_height = pts_y_std * torch.exp(moment_height_transfer) |
|
bbox = torch.cat([ |
|
pts_x_mean - half_width, pts_y_mean - half_height, |
|
pts_x_mean + half_width, pts_y_mean + half_height |
|
], |
|
dim=1) |
|
else: |
|
raise NotImplementedError |
|
return bbox |
|
|
|
def gen_grid_from_reg(self, reg: Tensor, |
|
previous_boxes: Tensor) -> Tuple[Tensor]: |
|
"""Base on the previous bboxes and regression values, we compute the |
|
regressed bboxes and generate the grids on the bboxes. |
|
|
|
Args: |
|
reg (Tensor): the regression value to previous bboxes. |
|
previous_boxes (Tensor): previous bboxes. |
|
|
|
Returns: |
|
Tuple[Tensor]: generate grids on the regressed bboxes. |
|
""" |
|
b, _, h, w = reg.shape |
|
bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2. |
|
bwh = (previous_boxes[:, 2:, ...] - |
|
previous_boxes[:, :2, ...]).clamp(min=1e-6) |
|
grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp( |
|
reg[:, 2:, ...]) |
|
grid_wh = bwh * torch.exp(reg[:, 2:, ...]) |
|
grid_left = grid_topleft[:, [0], ...] |
|
grid_top = grid_topleft[:, [1], ...] |
|
grid_width = grid_wh[:, [0], ...] |
|
grid_height = grid_wh[:, [1], ...] |
|
intervel = torch.linspace(0., 1., self.dcn_kernel).view( |
|
1, self.dcn_kernel, 1, 1).type_as(reg) |
|
grid_x = grid_left + grid_width * intervel |
|
grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1) |
|
grid_x = grid_x.view(b, -1, h, w) |
|
grid_y = grid_top + grid_height * intervel |
|
grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1) |
|
grid_y = grid_y.view(b, -1, h, w) |
|
grid_yx = torch.stack([grid_y, grid_x], dim=2) |
|
grid_yx = grid_yx.view(b, -1, h, w) |
|
regressed_bbox = torch.cat([ |
|
grid_left, grid_top, grid_left + grid_width, grid_top + grid_height |
|
], 1) |
|
return grid_yx, regressed_bbox |
|
|
|
def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]: |
|
return multi_apply(self.forward_single, feats) |
|
|
|
def forward_single(self, x: Tensor) -> Tuple[Tensor]: |
|
"""Forward feature map of a single FPN level.""" |
|
dcn_base_offset = self.dcn_base_offset.type_as(x) |
|
|
|
|
|
|
|
if self.use_grid_points or not self.center_init: |
|
scale = self.point_base_scale / 2 |
|
points_init = dcn_base_offset / dcn_base_offset.max() * scale |
|
bbox_init = x.new_tensor([-scale, -scale, scale, |
|
scale]).view(1, 4, 1, 1) |
|
else: |
|
points_init = 0 |
|
cls_feat = x |
|
pts_feat = x |
|
for cls_conv in self.cls_convs: |
|
cls_feat = cls_conv(cls_feat) |
|
for reg_conv in self.reg_convs: |
|
pts_feat = reg_conv(pts_feat) |
|
|
|
pts_out_init = self.reppoints_pts_init_out( |
|
self.relu(self.reppoints_pts_init_conv(pts_feat))) |
|
if self.use_grid_points: |
|
pts_out_init, bbox_out_init = self.gen_grid_from_reg( |
|
pts_out_init, bbox_init.detach()) |
|
else: |
|
pts_out_init = pts_out_init + points_init |
|
|
|
pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach( |
|
) + self.gradient_mul * pts_out_init |
|
dcn_offset = pts_out_init_grad_mul - dcn_base_offset |
|
cls_out = self.reppoints_cls_out( |
|
self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset))) |
|
pts_out_refine = self.reppoints_pts_refine_out( |
|
self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset))) |
|
if self.use_grid_points: |
|
pts_out_refine, bbox_out_refine = self.gen_grid_from_reg( |
|
pts_out_refine, bbox_out_init.detach()) |
|
else: |
|
pts_out_refine = pts_out_refine + pts_out_init.detach() |
|
|
|
if self.training: |
|
return cls_out, pts_out_init, pts_out_refine |
|
else: |
|
return cls_out, self.points2bbox(pts_out_refine) |
|
|
|
def get_points(self, featmap_sizes: List[Tuple[int]], |
|
batch_img_metas: List[dict], device: str) -> tuple: |
|
"""Get points according to feature map sizes. |
|
|
|
Args: |
|
featmap_sizes (list[tuple]): Multi-level feature map sizes. |
|
batch_img_metas (list[dict]): Image meta info. |
|
|
|
Returns: |
|
tuple: points of each image, valid flags of each image |
|
""" |
|
num_imgs = len(batch_img_metas) |
|
|
|
|
|
|
|
multi_level_points = self.prior_generator.grid_priors( |
|
featmap_sizes, device=device, with_stride=True) |
|
points_list = [[point.clone() for point in multi_level_points] |
|
for _ in range(num_imgs)] |
|
|
|
|
|
valid_flag_list = [] |
|
for img_id, img_meta in enumerate(batch_img_metas): |
|
multi_level_flags = self.prior_generator.valid_flags( |
|
featmap_sizes, img_meta['pad_shape'], device=device) |
|
valid_flag_list.append(multi_level_flags) |
|
|
|
return points_list, valid_flag_list |
|
|
|
def centers_to_bboxes(self, point_list: List[Tensor]) -> List[Tensor]: |
|
"""Get bboxes according to center points. |
|
|
|
Only used in :class:`MaxIoUAssigner`. |
|
""" |
|
bbox_list = [] |
|
for i_img, point in enumerate(point_list): |
|
bbox = [] |
|
for i_lvl in range(len(self.point_strides)): |
|
scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5 |
|
bbox_shift = torch.Tensor([-scale, -scale, scale, |
|
scale]).view(1, 4).type_as(point[0]) |
|
bbox_center = torch.cat( |
|
[point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1) |
|
bbox.append(bbox_center + bbox_shift) |
|
bbox_list.append(bbox) |
|
return bbox_list |
|
|
|
def offset_to_pts(self, center_list: List[Tensor], |
|
pred_list: List[Tensor]) -> List[Tensor]: |
|
"""Change from point offset to point coordinate.""" |
|
pts_list = [] |
|
for i_lvl in range(len(self.point_strides)): |
|
pts_lvl = [] |
|
for i_img in range(len(center_list)): |
|
pts_center = center_list[i_img][i_lvl][:, :2].repeat( |
|
1, self.num_points) |
|
pts_shift = pred_list[i_lvl][i_img] |
|
yx_pts_shift = pts_shift.permute(1, 2, 0).view( |
|
-1, 2 * self.num_points) |
|
y_pts_shift = yx_pts_shift[..., 0::2] |
|
x_pts_shift = yx_pts_shift[..., 1::2] |
|
xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1) |
|
xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1) |
|
pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center |
|
pts_lvl.append(pts) |
|
pts_lvl = torch.stack(pts_lvl, 0) |
|
pts_list.append(pts_lvl) |
|
return pts_list |
|
|
|
def _get_targets_single(self, |
|
flat_proposals: Tensor, |
|
valid_flags: Tensor, |
|
gt_instances: InstanceData, |
|
gt_instances_ignore: InstanceData, |
|
stage: str = 'init', |
|
unmap_outputs: bool = True) -> tuple: |
|
"""Compute corresponding GT box and classification targets for |
|
proposals. |
|
|
|
Args: |
|
flat_proposals (Tensor): Multi level points of a image. |
|
valid_flags (Tensor): Multi level valid flags of a image. |
|
gt_instances (InstanceData): It usually includes ``bboxes`` and |
|
``labels`` attributes. |
|
gt_instances_ignore (InstanceData): It includes ``bboxes`` |
|
attribute data that is ignored during training and testing. |
|
stage (str): 'init' or 'refine'. Generate target for |
|
init stage or refine stage. Defaults to 'init'. |
|
unmap_outputs (bool): Whether to map outputs back to |
|
the original set of anchors. Defaults to True. |
|
|
|
Returns: |
|
tuple: |
|
|
|
- labels (Tensor): Labels of each level. |
|
- label_weights (Tensor): Label weights of each level. |
|
- bbox_targets (Tensor): BBox targets of each level. |
|
- bbox_weights (Tensor): BBox weights of each level. |
|
- pos_inds (Tensor): positive samples indexes. |
|
- neg_inds (Tensor): negative samples indexes. |
|
- sampling_result (:obj:`SamplingResult`): Sampling results. |
|
""" |
|
inside_flags = valid_flags |
|
if not inside_flags.any(): |
|
raise ValueError( |
|
'There is no valid proposal inside the image boundary. Please ' |
|
'check the image size.') |
|
|
|
proposals = flat_proposals[inside_flags, :] |
|
pred_instances = InstanceData(priors=proposals) |
|
|
|
if stage == 'init': |
|
assigner = self.init_assigner |
|
pos_weight = self.train_cfg['init']['pos_weight'] |
|
else: |
|
assigner = self.refine_assigner |
|
pos_weight = self.train_cfg['refine']['pos_weight'] |
|
|
|
assign_result = assigner.assign(pred_instances, gt_instances, |
|
gt_instances_ignore) |
|
sampling_result = self.sampler.sample(assign_result, pred_instances, |
|
gt_instances) |
|
|
|
num_valid_proposals = proposals.shape[0] |
|
bbox_gt = proposals.new_zeros([num_valid_proposals, 4]) |
|
pos_proposals = torch.zeros_like(proposals) |
|
proposals_weights = proposals.new_zeros([num_valid_proposals, 4]) |
|
labels = proposals.new_full((num_valid_proposals, ), |
|
self.num_classes, |
|
dtype=torch.long) |
|
label_weights = proposals.new_zeros( |
|
num_valid_proposals, dtype=torch.float) |
|
|
|
pos_inds = sampling_result.pos_inds |
|
neg_inds = sampling_result.neg_inds |
|
if len(pos_inds) > 0: |
|
bbox_gt[pos_inds, :] = sampling_result.pos_gt_bboxes |
|
pos_proposals[pos_inds, :] = proposals[pos_inds, :] |
|
proposals_weights[pos_inds, :] = 1.0 |
|
|
|
labels[pos_inds] = sampling_result.pos_gt_labels |
|
if pos_weight <= 0: |
|
label_weights[pos_inds] = 1.0 |
|
else: |
|
label_weights[pos_inds] = pos_weight |
|
if len(neg_inds) > 0: |
|
label_weights[neg_inds] = 1.0 |
|
|
|
|
|
if unmap_outputs: |
|
num_total_proposals = flat_proposals.size(0) |
|
labels = unmap( |
|
labels, |
|
num_total_proposals, |
|
inside_flags, |
|
fill=self.num_classes) |
|
label_weights = unmap(label_weights, num_total_proposals, |
|
inside_flags) |
|
bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) |
|
pos_proposals = unmap(pos_proposals, num_total_proposals, |
|
inside_flags) |
|
proposals_weights = unmap(proposals_weights, num_total_proposals, |
|
inside_flags) |
|
|
|
return (labels, label_weights, bbox_gt, pos_proposals, |
|
proposals_weights, pos_inds, neg_inds, sampling_result) |
|
|
|
def get_targets(self, |
|
proposals_list: List[Tensor], |
|
valid_flag_list: List[Tensor], |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
batch_gt_instances_ignore: OptInstanceList = None, |
|
stage: str = 'init', |
|
unmap_outputs: bool = True, |
|
return_sampling_results: bool = False) -> tuple: |
|
"""Compute corresponding GT box and classification targets for |
|
proposals. |
|
|
|
Args: |
|
proposals_list (list[Tensor]): Multi level points/bboxes of each |
|
image. |
|
valid_flag_list (list[Tensor]): Multi level valid flags of each |
|
image. |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
stage (str): 'init' or 'refine'. Generate target for init stage or |
|
refine stage. |
|
unmap_outputs (bool): Whether to map outputs back to the original |
|
set of anchors. |
|
return_sampling_results (bool): Whether to return the sampling |
|
results. Defaults to False. |
|
|
|
Returns: |
|
tuple: |
|
|
|
- labels_list (list[Tensor]): Labels of each level. |
|
- label_weights_list (list[Tensor]): Label weights of each |
|
level. |
|
- bbox_gt_list (list[Tensor]): Ground truth bbox of each level. |
|
- proposals_list (list[Tensor]): Proposals(points/bboxes) of |
|
each level. |
|
- proposal_weights_list (list[Tensor]): Proposal weights of |
|
each level. |
|
- avg_factor (int): Average factor that is used to average |
|
the loss. When using sampling method, avg_factor is usually |
|
the sum of positive and negative priors. When using |
|
`PseudoSampler`, `avg_factor` is usually equal to the number |
|
of positive priors. |
|
""" |
|
assert stage in ['init', 'refine'] |
|
num_imgs = len(batch_img_metas) |
|
assert len(proposals_list) == len(valid_flag_list) == num_imgs |
|
|
|
|
|
num_level_proposals = [points.size(0) for points in proposals_list[0]] |
|
|
|
|
|
for i in range(num_imgs): |
|
assert len(proposals_list[i]) == len(valid_flag_list[i]) |
|
proposals_list[i] = torch.cat(proposals_list[i]) |
|
valid_flag_list[i] = torch.cat(valid_flag_list[i]) |
|
|
|
if batch_gt_instances_ignore is None: |
|
batch_gt_instances_ignore = [None] * num_imgs |
|
|
|
(all_labels, all_label_weights, all_bbox_gt, all_proposals, |
|
all_proposal_weights, pos_inds_list, neg_inds_list, |
|
sampling_results_list) = multi_apply( |
|
self._get_targets_single, |
|
proposals_list, |
|
valid_flag_list, |
|
batch_gt_instances, |
|
batch_gt_instances_ignore, |
|
stage=stage, |
|
unmap_outputs=unmap_outputs) |
|
|
|
|
|
avg_refactor = sum( |
|
[results.avg_factor for results in sampling_results_list]) |
|
labels_list = images_to_levels(all_labels, num_level_proposals) |
|
label_weights_list = images_to_levels(all_label_weights, |
|
num_level_proposals) |
|
bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals) |
|
proposals_list = images_to_levels(all_proposals, num_level_proposals) |
|
proposal_weights_list = images_to_levels(all_proposal_weights, |
|
num_level_proposals) |
|
res = (labels_list, label_weights_list, bbox_gt_list, proposals_list, |
|
proposal_weights_list, avg_refactor) |
|
if return_sampling_results: |
|
res = res + (sampling_results_list, ) |
|
|
|
return res |
|
|
|
def loss_by_feat_single(self, cls_score: Tensor, pts_pred_init: Tensor, |
|
pts_pred_refine: Tensor, labels: Tensor, |
|
label_weights, bbox_gt_init: Tensor, |
|
bbox_weights_init: Tensor, bbox_gt_refine: Tensor, |
|
bbox_weights_refine: Tensor, stride: int, |
|
avg_factor_init: int, |
|
avg_factor_refine: int) -> Tuple[Tensor]: |
|
"""Calculate the loss of a single scale level based on the features |
|
extracted by the detection head. |
|
|
|
Args: |
|
cls_score (Tensor): Box scores for each scale level |
|
Has shape (N, num_classes, h_i, w_i). |
|
pts_pred_init (Tensor): Points of shape |
|
(batch_size, h_i * w_i, num_points * 2). |
|
pts_pred_refine (Tensor): Points refined of shape |
|
(batch_size, h_i * w_i, num_points * 2). |
|
labels (Tensor): Ground truth class indices with shape |
|
(batch_size, h_i * w_i). |
|
label_weights (Tensor): Label weights of shape |
|
(batch_size, h_i * w_i). |
|
bbox_gt_init (Tensor): BBox regression targets in the init stage |
|
of shape (batch_size, h_i * w_i, 4). |
|
bbox_weights_init (Tensor): BBox regression loss weights in the |
|
init stage of shape (batch_size, h_i * w_i, 4). |
|
bbox_gt_refine (Tensor): BBox regression targets in the refine |
|
stage of shape (batch_size, h_i * w_i, 4). |
|
bbox_weights_refine (Tensor): BBox regression loss weights in the |
|
refine stage of shape (batch_size, h_i * w_i, 4). |
|
stride (int): Point stride. |
|
avg_factor_init (int): Average factor that is used to average |
|
the loss in the init stage. |
|
avg_factor_refine (int): Average factor that is used to average |
|
the loss in the refine stage. |
|
|
|
Returns: |
|
Tuple[Tensor]: loss components. |
|
""" |
|
|
|
labels = labels.reshape(-1) |
|
label_weights = label_weights.reshape(-1) |
|
cls_score = cls_score.permute(0, 2, 3, |
|
1).reshape(-1, self.cls_out_channels) |
|
cls_score = cls_score.contiguous() |
|
loss_cls = self.loss_cls( |
|
cls_score, labels, label_weights, avg_factor=avg_factor_refine) |
|
|
|
|
|
bbox_gt_init = bbox_gt_init.reshape(-1, 4) |
|
bbox_weights_init = bbox_weights_init.reshape(-1, 4) |
|
bbox_pred_init = self.points2bbox( |
|
pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False) |
|
bbox_gt_refine = bbox_gt_refine.reshape(-1, 4) |
|
bbox_weights_refine = bbox_weights_refine.reshape(-1, 4) |
|
bbox_pred_refine = self.points2bbox( |
|
pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False) |
|
normalize_term = self.point_base_scale * stride |
|
loss_pts_init = self.loss_bbox_init( |
|
bbox_pred_init / normalize_term, |
|
bbox_gt_init / normalize_term, |
|
bbox_weights_init, |
|
avg_factor=avg_factor_init) |
|
loss_pts_refine = self.loss_bbox_refine( |
|
bbox_pred_refine / normalize_term, |
|
bbox_gt_refine / normalize_term, |
|
bbox_weights_refine, |
|
avg_factor=avg_factor_refine) |
|
return loss_cls, loss_pts_init, loss_pts_refine |
|
|
|
def loss_by_feat( |
|
self, |
|
cls_scores: List[Tensor], |
|
pts_preds_init: List[Tensor], |
|
pts_preds_refine: List[Tensor], |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
batch_gt_instances_ignore: OptInstanceList = None |
|
) -> Dict[str, Tensor]: |
|
"""Calculate the loss based on the features extracted by the detection |
|
head. |
|
|
|
Args: |
|
cls_scores (list[Tensor]): Box scores for each scale level, |
|
each is a 4D-tensor, of shape (batch_size, num_classes, h, w). |
|
pts_preds_init (list[Tensor]): Points for each scale level, each is |
|
a 3D-tensor, of shape (batch_size, h_i * w_i, num_points * 2). |
|
pts_preds_refine (list[Tensor]): Points refined for each scale |
|
level, each is a 3D-tensor, of shape |
|
(batch_size, h_i * w_i, num_points * 2). |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
|
device = cls_scores[0].device |
|
|
|
|
|
center_list, valid_flag_list = self.get_points(featmap_sizes, |
|
batch_img_metas, device) |
|
pts_coordinate_preds_init = self.offset_to_pts(center_list, |
|
pts_preds_init) |
|
if self.train_cfg['init']['assigner']['type'] == 'PointAssigner': |
|
|
|
candidate_list = center_list |
|
else: |
|
|
|
|
|
bbox_list = self.centers_to_bboxes(center_list) |
|
candidate_list = bbox_list |
|
cls_reg_targets_init = self.get_targets( |
|
proposals_list=candidate_list, |
|
valid_flag_list=valid_flag_list, |
|
batch_gt_instances=batch_gt_instances, |
|
batch_img_metas=batch_img_metas, |
|
batch_gt_instances_ignore=batch_gt_instances_ignore, |
|
stage='init', |
|
return_sampling_results=False) |
|
(*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init, |
|
avg_factor_init) = cls_reg_targets_init |
|
|
|
|
|
center_list, valid_flag_list = self.get_points(featmap_sizes, |
|
batch_img_metas, device) |
|
pts_coordinate_preds_refine = self.offset_to_pts( |
|
center_list, pts_preds_refine) |
|
bbox_list = [] |
|
for i_img, center in enumerate(center_list): |
|
bbox = [] |
|
for i_lvl in range(len(pts_preds_refine)): |
|
bbox_preds_init = self.points2bbox( |
|
pts_preds_init[i_lvl].detach()) |
|
bbox_shift = bbox_preds_init * self.point_strides[i_lvl] |
|
bbox_center = torch.cat( |
|
[center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1) |
|
bbox.append(bbox_center + |
|
bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4)) |
|
bbox_list.append(bbox) |
|
cls_reg_targets_refine = self.get_targets( |
|
proposals_list=bbox_list, |
|
valid_flag_list=valid_flag_list, |
|
batch_gt_instances=batch_gt_instances, |
|
batch_img_metas=batch_img_metas, |
|
batch_gt_instances_ignore=batch_gt_instances_ignore, |
|
stage='refine', |
|
return_sampling_results=False) |
|
(labels_list, label_weights_list, bbox_gt_list_refine, |
|
candidate_list_refine, bbox_weights_list_refine, |
|
avg_factor_refine) = cls_reg_targets_refine |
|
|
|
|
|
losses_cls, losses_pts_init, losses_pts_refine = multi_apply( |
|
self.loss_by_feat_single, |
|
cls_scores, |
|
pts_coordinate_preds_init, |
|
pts_coordinate_preds_refine, |
|
labels_list, |
|
label_weights_list, |
|
bbox_gt_list_init, |
|
bbox_weights_list_init, |
|
bbox_gt_list_refine, |
|
bbox_weights_list_refine, |
|
self.point_strides, |
|
avg_factor_init=avg_factor_init, |
|
avg_factor_refine=avg_factor_refine) |
|
loss_dict_all = { |
|
'loss_cls': losses_cls, |
|
'loss_pts_init': losses_pts_init, |
|
'loss_pts_refine': losses_pts_refine |
|
} |
|
return loss_dict_all |
|
|
|
|
|
def _predict_by_feat_single(self, |
|
cls_score_list: List[Tensor], |
|
bbox_pred_list: List[Tensor], |
|
score_factor_list: List[Tensor], |
|
mlvl_priors: List[Tensor], |
|
img_meta: dict, |
|
cfg: ConfigDict, |
|
rescale: bool = False, |
|
with_nms: bool = True) -> InstanceData: |
|
"""Transform outputs of a single image into bbox predictions. |
|
|
|
Args: |
|
cls_score_list (list[Tensor]): Box scores from all scale |
|
levels of a single image, each item has shape |
|
(num_priors * num_classes, H, W). |
|
bbox_pred_list (list[Tensor]): Box energies / deltas from |
|
all scale levels of a single image, each item has shape |
|
(num_priors * 4, H, W). |
|
score_factor_list (list[Tensor]): Score factor from all scale |
|
levels of a single image. RepPoints head does not need |
|
this value. |
|
mlvl_priors (list[Tensor]): Each element in the list is |
|
the priors of a single level in feature pyramid, has shape |
|
(num_priors, 2). |
|
img_meta (dict): Image meta info. |
|
cfg (:obj:`ConfigDict`): Test / postprocessing configuration, |
|
if None, test_cfg would be used. |
|
rescale (bool): If True, return boxes in original image space. |
|
Defaults to False. |
|
with_nms (bool): If True, do nms before return boxes. |
|
Defaults to True. |
|
|
|
Returns: |
|
:obj:`InstanceData`: Detection results of each image |
|
after the post process. |
|
Each item usually contains following keys. |
|
|
|
- scores (Tensor): Classification scores, has a shape |
|
(num_instance, ) |
|
- labels (Tensor): Labels of bboxes, has a shape |
|
(num_instances, ). |
|
- bboxes (Tensor): Has a shape (num_instances, 4), |
|
the last dimension 4 arrange as (x1, y1, x2, y2). |
|
""" |
|
cfg = self.test_cfg if cfg is None else cfg |
|
assert len(cls_score_list) == len(bbox_pred_list) |
|
img_shape = img_meta['img_shape'] |
|
nms_pre = cfg.get('nms_pre', -1) |
|
|
|
mlvl_bboxes = [] |
|
mlvl_scores = [] |
|
mlvl_labels = [] |
|
for level_idx, (cls_score, bbox_pred, priors) in enumerate( |
|
zip(cls_score_list, bbox_pred_list, mlvl_priors)): |
|
assert cls_score.size()[-2:] == bbox_pred.size()[-2:] |
|
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) |
|
|
|
cls_score = cls_score.permute(1, 2, |
|
0).reshape(-1, self.cls_out_channels) |
|
if self.use_sigmoid_cls: |
|
scores = cls_score.sigmoid() |
|
else: |
|
scores = cls_score.softmax(-1)[:, :-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
results = filter_scores_and_topk( |
|
scores, cfg.score_thr, nms_pre, |
|
dict(bbox_pred=bbox_pred, priors=priors)) |
|
scores, labels, _, filtered_results = results |
|
|
|
bbox_pred = filtered_results['bbox_pred'] |
|
priors = filtered_results['priors'] |
|
|
|
bboxes = self._bbox_decode(priors, bbox_pred, |
|
self.point_strides[level_idx], |
|
img_shape) |
|
|
|
mlvl_bboxes.append(bboxes) |
|
mlvl_scores.append(scores) |
|
mlvl_labels.append(labels) |
|
|
|
results = InstanceData() |
|
results.bboxes = torch.cat(mlvl_bboxes) |
|
results.scores = torch.cat(mlvl_scores) |
|
results.labels = torch.cat(mlvl_labels) |
|
|
|
return self._bbox_post_process( |
|
results=results, |
|
cfg=cfg, |
|
rescale=rescale, |
|
with_nms=with_nms, |
|
img_meta=img_meta) |
|
|
|
def _bbox_decode(self, points: Tensor, bbox_pred: Tensor, stride: int, |
|
max_shape: Tuple[int, int]) -> Tensor: |
|
"""Decode the prediction to bounding box. |
|
|
|
Args: |
|
points (Tensor): shape (h_i * w_i, 2). |
|
bbox_pred (Tensor): shape (h_i * w_i, 4). |
|
stride (int): Stride for bbox_pred in different level. |
|
max_shape (Tuple[int, int]): image shape. |
|
|
|
Returns: |
|
Tensor: Bounding boxes decoded. |
|
""" |
|
bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1) |
|
bboxes = bbox_pred * stride + bbox_pos_center |
|
x1 = bboxes[:, 0].clamp(min=0, max=max_shape[1]) |
|
y1 = bboxes[:, 1].clamp(min=0, max=max_shape[0]) |
|
x2 = bboxes[:, 2].clamp(min=0, max=max_shape[1]) |
|
y2 = bboxes[:, 3].clamp(min=0, max=max_shape[0]) |
|
decoded_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) |
|
return decoded_bboxes |
|
|