|
|
|
|
|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from mmdet.structures.bbox import bbox2roi |
|
from mmdet.utils import ConfigType, InstanceList |
|
from ..task_modules.samplers import SamplingResult |
|
from ..utils import empty_instances |
|
from .standard_roi_head import StandardRoIHead |
|
|
|
|
|
@MODELS.register_module() |
|
class PointRendRoIHead(StandardRoIHead): |
|
"""`PointRend <https://arxiv.org/abs/1912.08193>`_.""" |
|
|
|
def __init__(self, point_head: ConfigType, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
assert self.with_bbox and self.with_mask |
|
self.init_point_head(point_head) |
|
|
|
def init_point_head(self, point_head: ConfigType) -> None: |
|
"""Initialize ``point_head``""" |
|
self.point_head = MODELS.build(point_head) |
|
|
|
def mask_loss(self, x: Tuple[Tensor], |
|
sampling_results: List[SamplingResult], bbox_feats: Tensor, |
|
batch_gt_instances: InstanceList) -> dict: |
|
"""Run forward function and calculate loss for mask head and point head |
|
in training.""" |
|
mask_results = super().mask_loss( |
|
x=x, |
|
sampling_results=sampling_results, |
|
bbox_feats=bbox_feats, |
|
batch_gt_instances=batch_gt_instances) |
|
|
|
mask_point_results = self._mask_point_loss( |
|
x=x, |
|
sampling_results=sampling_results, |
|
mask_preds=mask_results['mask_preds'], |
|
batch_gt_instances=batch_gt_instances) |
|
mask_results['loss_mask'].update( |
|
loss_point=mask_point_results['loss_point']) |
|
|
|
return mask_results |
|
|
|
def _mask_point_loss(self, x: Tuple[Tensor], |
|
sampling_results: List[SamplingResult], |
|
mask_preds: Tensor, |
|
batch_gt_instances: InstanceList) -> dict: |
|
"""Run forward function and calculate loss for point head in |
|
training.""" |
|
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) |
|
rel_roi_points = self.point_head.get_roi_rel_points_train( |
|
mask_preds, pos_labels, cfg=self.train_cfg) |
|
rois = bbox2roi([res.pos_bboxes for res in sampling_results]) |
|
|
|
fine_grained_point_feats = self._get_fine_grained_point_feats( |
|
x, rois, rel_roi_points) |
|
coarse_point_feats = point_sample(mask_preds, rel_roi_points) |
|
mask_point_pred = self.point_head(fine_grained_point_feats, |
|
coarse_point_feats) |
|
|
|
loss_and_target = self.point_head.loss_and_target( |
|
point_pred=mask_point_pred, |
|
rel_roi_points=rel_roi_points, |
|
sampling_results=sampling_results, |
|
batch_gt_instances=batch_gt_instances, |
|
cfg=self.train_cfg) |
|
|
|
return loss_and_target |
|
|
|
def _mask_point_forward_test(self, x: Tuple[Tensor], rois: Tensor, |
|
label_preds: Tensor, |
|
mask_preds: Tensor) -> Tensor: |
|
"""Mask refining process with point head in testing. |
|
|
|
Args: |
|
x (tuple[Tensor]): Feature maps of all scale level. |
|
rois (Tensor): shape (num_rois, 5). |
|
label_preds (Tensor): The predication class for each rois. |
|
mask_preds (Tensor): The predication coarse masks of |
|
shape (num_rois, num_classes, small_size, small_size). |
|
|
|
Returns: |
|
Tensor: The refined masks of shape (num_rois, num_classes, |
|
large_size, large_size). |
|
""" |
|
refined_mask_pred = mask_preds.clone() |
|
for subdivision_step in range(self.test_cfg.subdivision_steps): |
|
refined_mask_pred = F.interpolate( |
|
refined_mask_pred, |
|
scale_factor=self.test_cfg.scale_factor, |
|
mode='bilinear', |
|
align_corners=False) |
|
|
|
|
|
num_rois, channels, mask_height, mask_width = \ |
|
refined_mask_pred.shape |
|
if (self.test_cfg.subdivision_num_points >= |
|
self.test_cfg.scale_factor**2 * mask_height * mask_width |
|
and |
|
subdivision_step < self.test_cfg.subdivision_steps - 1): |
|
continue |
|
point_indices, rel_roi_points = \ |
|
self.point_head.get_roi_rel_points_test( |
|
refined_mask_pred, label_preds, cfg=self.test_cfg) |
|
|
|
fine_grained_point_feats = self._get_fine_grained_point_feats( |
|
x=x, rois=rois, rel_roi_points=rel_roi_points) |
|
coarse_point_feats = point_sample(mask_preds, rel_roi_points) |
|
mask_point_pred = self.point_head(fine_grained_point_feats, |
|
coarse_point_feats) |
|
|
|
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) |
|
refined_mask_pred = refined_mask_pred.reshape( |
|
num_rois, channels, mask_height * mask_width) |
|
refined_mask_pred = refined_mask_pred.scatter_( |
|
2, point_indices, mask_point_pred) |
|
refined_mask_pred = refined_mask_pred.view(num_rois, channels, |
|
mask_height, mask_width) |
|
|
|
return refined_mask_pred |
|
|
|
def _get_fine_grained_point_feats(self, x: Tuple[Tensor], rois: Tensor, |
|
rel_roi_points: Tensor) -> Tensor: |
|
"""Sample fine grained feats from each level feature map and |
|
concatenate them together. |
|
|
|
Args: |
|
x (tuple[Tensor]): Feature maps of all scale level. |
|
rois (Tensor): shape (num_rois, 5). |
|
rel_roi_points (Tensor): A tensor of shape (num_rois, num_points, |
|
2) that contains [0, 1] x [0, 1] normalized coordinates of the |
|
most uncertain points from the [mask_height, mask_width] grid. |
|
|
|
Returns: |
|
Tensor: The fine grained features for each points, |
|
has shape (num_rois, feats_channels, num_points). |
|
""" |
|
assert rois.shape[0] > 0, 'RoI is a empty tensor.' |
|
num_imgs = x[0].shape[0] |
|
fine_grained_feats = [] |
|
for idx in range(self.mask_roi_extractor.num_inputs): |
|
feats = x[idx] |
|
spatial_scale = 1. / float( |
|
self.mask_roi_extractor.featmap_strides[idx]) |
|
point_feats = [] |
|
for batch_ind in range(num_imgs): |
|
|
|
feat = feats[batch_ind].unsqueeze(0) |
|
inds = (rois[:, 0].long() == batch_ind) |
|
if inds.any(): |
|
rel_img_points = rel_roi_point_to_rel_img_point( |
|
rois=rois[inds], |
|
rel_roi_points=rel_roi_points[inds], |
|
img=feat.shape[2:], |
|
spatial_scale=spatial_scale).unsqueeze(0) |
|
point_feat = point_sample(feat, rel_img_points) |
|
point_feat = point_feat.squeeze(0).transpose(0, 1) |
|
point_feats.append(point_feat) |
|
fine_grained_feats.append(torch.cat(point_feats, dim=0)) |
|
return torch.cat(fine_grained_feats, dim=1) |
|
|
|
def predict_mask(self, |
|
x: Tuple[Tensor], |
|
batch_img_metas: List[dict], |
|
results_list: InstanceList, |
|
rescale: bool = False) -> InstanceList: |
|
"""Perform forward propagation of the mask head and predict detection |
|
results on the features of the upstream network. |
|
|
|
Args: |
|
x (tuple[Tensor]): Feature maps of all scale level. |
|
batch_img_metas (list[dict]): List of image information. |
|
results_list (list[:obj:`InstanceData`]): Detection results of |
|
each image. |
|
rescale (bool): If True, return boxes in original image space. |
|
Defaults to False. |
|
|
|
Returns: |
|
list[:obj:`InstanceData`]: Detection results of each image |
|
after the post process. |
|
Each item usually contains following keys. |
|
|
|
- scores (Tensor): Classification scores, has a shape |
|
(num_instance, ) |
|
- labels (Tensor): Labels of bboxes, has a shape |
|
(num_instances, ). |
|
- bboxes (Tensor): Has a shape (num_instances, 4), |
|
the last dimension 4 arrange as (x1, y1, x2, y2). |
|
- masks (Tensor): Has a shape (num_instances, H, W). |
|
""" |
|
|
|
bboxes = [res.bboxes for res in results_list] |
|
mask_rois = bbox2roi(bboxes) |
|
if mask_rois.shape[0] == 0: |
|
results_list = empty_instances( |
|
batch_img_metas, |
|
mask_rois.device, |
|
task_type='mask', |
|
instance_results=results_list, |
|
mask_thr_binary=self.test_cfg.mask_thr_binary) |
|
return results_list |
|
|
|
mask_results = self._mask_forward(x, mask_rois) |
|
mask_preds = mask_results['mask_preds'] |
|
|
|
num_mask_rois_per_img = [len(res) for res in results_list] |
|
mask_preds = mask_preds.split(num_mask_rois_per_img, 0) |
|
|
|
|
|
mask_rois = mask_rois.split(num_mask_rois_per_img, 0) |
|
mask_preds_refined = [] |
|
for i in range(len(batch_img_metas)): |
|
labels = results_list[i].labels |
|
x_i = [xx[[i]] for xx in x] |
|
mask_rois_i = mask_rois[i] |
|
mask_rois_i[:, 0] = 0 |
|
mask_pred_i = self._mask_point_forward_test( |
|
x_i, mask_rois_i, labels, mask_preds[i]) |
|
mask_preds_refined.append(mask_pred_i) |
|
|
|
|
|
results_list = self.mask_head.predict_by_feat( |
|
mask_preds=mask_preds_refined, |
|
results_list=results_list, |
|
batch_img_metas=batch_img_metas, |
|
rcnn_test_cfg=self.test_cfg, |
|
rescale=rescale) |
|
return results_list |
|
|