|
|
|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmengine.config import ConfigDict |
|
from torch import Tensor |
|
|
|
from mmdet.models.task_modules import SamplingResult |
|
from mmdet.registry import MODELS |
|
from mmdet.utils import ConfigType, InstanceList, OptConfigType, reduce_mean |
|
from .fcn_mask_head import FCNMaskHead |
|
|
|
|
|
@MODELS.register_module() |
|
class DynamicMaskHead(FCNMaskHead): |
|
r"""Dynamic Mask Head for |
|
`Instances as Queries <http://arxiv.org/abs/2105.01928>`_ |
|
|
|
Args: |
|
num_convs (int): Number of convolution layer. |
|
Defaults to 4. |
|
roi_feat_size (int): The output size of RoI extractor, |
|
Defaults to 14. |
|
in_channels (int): Input feature channels. |
|
Defaults to 256. |
|
conv_kernel_size (int): Kernel size of convolution layers. |
|
Defaults to 3. |
|
conv_out_channels (int): Output channels of convolution layers. |
|
Defaults to 256. |
|
num_classes (int): Number of classes. |
|
Defaults to 80 |
|
class_agnostic (int): Whether generate class agnostic prediction. |
|
Defaults to False. |
|
dropout (float): Probability of drop the channel. |
|
Defaults to 0.0 |
|
upsample_cfg (:obj:`ConfigDict` or dict): The config for |
|
upsample layer. |
|
conv_cfg (:obj:`ConfigDict` or dict, optional): The convolution |
|
layer config. |
|
norm_cfg (:obj:`ConfigDict` or dict, optional): The norm layer config. |
|
dynamic_conv_cfg (:obj:`ConfigDict` or dict): The dynamic convolution |
|
layer config. |
|
loss_mask (:obj:`ConfigDict` or dict): The config for mask loss. |
|
""" |
|
|
|
def __init__(self, |
|
num_convs: int = 4, |
|
roi_feat_size: int = 14, |
|
in_channels: int = 256, |
|
conv_kernel_size: int = 3, |
|
conv_out_channels: int = 256, |
|
num_classes: int = 80, |
|
class_agnostic: bool = False, |
|
upsample_cfg: ConfigType = dict( |
|
type='deconv', scale_factor=2), |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: OptConfigType = None, |
|
dynamic_conv_cfg: ConfigType = dict( |
|
type='DynamicConv', |
|
in_channels=256, |
|
feat_channels=64, |
|
out_channels=256, |
|
input_feat_shape=14, |
|
with_proj=False, |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
norm_cfg=dict(type='LN')), |
|
loss_mask: ConfigType = dict( |
|
type='DiceLoss', loss_weight=8.0), |
|
**kwargs) -> None: |
|
super().__init__( |
|
num_convs=num_convs, |
|
roi_feat_size=roi_feat_size, |
|
in_channels=in_channels, |
|
conv_kernel_size=conv_kernel_size, |
|
conv_out_channels=conv_out_channels, |
|
num_classes=num_classes, |
|
class_agnostic=class_agnostic, |
|
upsample_cfg=upsample_cfg, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
loss_mask=loss_mask, |
|
**kwargs) |
|
assert class_agnostic is False, \ |
|
'DynamicMaskHead only support class_agnostic=False' |
|
self.fp16_enabled = False |
|
|
|
self.instance_interactive_conv = MODELS.build(dynamic_conv_cfg) |
|
|
|
def init_weights(self) -> None: |
|
"""Use xavier initialization for all weight parameter and set |
|
classification head bias as a specific value when use focal loss.""" |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
nn.init.constant_(self.conv_logits.bias, 0.) |
|
|
|
def forward(self, roi_feat: Tensor, proposal_feat: Tensor) -> Tensor: |
|
"""Forward function of DynamicMaskHead. |
|
|
|
Args: |
|
roi_feat (Tensor): Roi-pooling features with shape |
|
(batch_size*num_proposals, feature_dimensions, |
|
pooling_h , pooling_w). |
|
proposal_feat (Tensor): Intermediate feature get from |
|
diihead in last stage, has shape |
|
(batch_size*num_proposals, feature_dimensions) |
|
|
|
Returns: |
|
mask_preds (Tensor): Predicted foreground masks with shape |
|
(batch_size*num_proposals, num_classes, pooling_h*2, pooling_w*2). |
|
""" |
|
|
|
proposal_feat = proposal_feat.reshape(-1, self.in_channels) |
|
proposal_feat_iic = self.instance_interactive_conv( |
|
proposal_feat, roi_feat) |
|
|
|
x = proposal_feat_iic.permute(0, 2, 1).reshape(roi_feat.size()) |
|
|
|
for conv in self.convs: |
|
x = conv(x) |
|
if self.upsample is not None: |
|
x = self.upsample(x) |
|
if self.upsample_method == 'deconv': |
|
x = self.relu(x) |
|
mask_preds = self.conv_logits(x) |
|
return mask_preds |
|
|
|
def loss_and_target(self, mask_preds: Tensor, |
|
sampling_results: List[SamplingResult], |
|
batch_gt_instances: InstanceList, |
|
rcnn_train_cfg: ConfigDict) -> dict: |
|
"""Calculate the loss based on the features extracted by the mask head. |
|
|
|
Args: |
|
mask_preds (Tensor): Predicted foreground masks, has shape |
|
(num_pos, num_classes, h, w). |
|
sampling_results (List[obj:SamplingResult]): Assign results of |
|
all images in a batch after sampling. |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes``, ``labels``, and |
|
``masks`` attributes. |
|
rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. |
|
|
|
Returns: |
|
dict: A dictionary of loss and targets components. |
|
""" |
|
mask_targets = self.get_targets( |
|
sampling_results=sampling_results, |
|
batch_gt_instances=batch_gt_instances, |
|
rcnn_train_cfg=rcnn_train_cfg) |
|
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) |
|
|
|
num_pos = pos_labels.new_ones(pos_labels.size()).float().sum() |
|
avg_factor = torch.clamp(reduce_mean(num_pos), min=1.).item() |
|
loss = dict() |
|
if mask_preds.size(0) == 0: |
|
loss_mask = mask_preds.sum() |
|
else: |
|
loss_mask = self.loss_mask( |
|
mask_preds[torch.arange(num_pos).long(), pos_labels, |
|
...].sigmoid(), |
|
mask_targets, |
|
avg_factor=avg_factor) |
|
loss['loss_mask'] = loss_mask |
|
return dict(loss_mask=loss, mask_targets=mask_targets) |
|
|