|
|
|
from typing import List, Optional, Union |
|
|
|
from mmcv.cnn import ConvModule |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from .fcn_mask_head import FCNMaskHead |
|
|
|
|
|
@MODELS.register_module() |
|
class HTCMaskHead(FCNMaskHead): |
|
"""Mask head for HTC. |
|
|
|
Args: |
|
with_conv_res (bool): Whether add conv layer for ``res_feat``. |
|
Defaults to True. |
|
""" |
|
|
|
def __init__(self, with_conv_res: bool = True, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.with_conv_res = with_conv_res |
|
if self.with_conv_res: |
|
self.conv_res = ConvModule( |
|
self.conv_out_channels, |
|
self.conv_out_channels, |
|
1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg) |
|
|
|
def forward(self, |
|
x: Tensor, |
|
res_feat: Optional[Tensor] = None, |
|
return_logits: bool = True, |
|
return_feat: bool = True) -> Union[Tensor, List[Tensor]]: |
|
""" |
|
Args: |
|
x (Tensor): Feature map. |
|
res_feat (Tensor, optional): Feature for residual connection. |
|
Defaults to None. |
|
return_logits (bool): Whether return mask logits. Defaults to True. |
|
return_feat (bool): Whether return feature map. Defaults to True. |
|
|
|
Returns: |
|
Union[Tensor, List[Tensor]]: The return result is one of three |
|
results: res_feat, logits, or [logits, res_feat]. |
|
""" |
|
assert not (not return_logits and not return_feat) |
|
if res_feat is not None: |
|
assert self.with_conv_res |
|
res_feat = self.conv_res(res_feat) |
|
x = x + res_feat |
|
for conv in self.convs: |
|
x = conv(x) |
|
res_feat = x |
|
outs = [] |
|
if return_logits: |
|
x = self.upsample(x) |
|
if self.upsample_method == 'deconv': |
|
x = self.relu(x) |
|
mask_preds = self.conv_logits(x) |
|
outs.append(mask_preds) |
|
if return_feat: |
|
outs.append(res_feat) |
|
return outs if len(outs) > 1 else outs[0] |
|
|