|
|
|
|
|
|
|
from typing import Dict, Optional, Tuple,List, Union |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
import torch.nn.functional as F |
|
from torch.nn.init import normal_ |
|
from mmdet.registry import MODELS |
|
from mmdet.structures import OptSampleList, SampleList |
|
from mmdet.utils import OptConfigType |
|
|
|
|
|
from ..layers import SinePositionalEncoding |
|
from ..layers.transformer.dino_layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder, |
|
DinoTransformerDecoder) |
|
from .deformable_detr import DeformableDETR, MultiScaleDeformableAttention |
|
|
|
|
|
|
|
|
|
@MODELS.register_module() |
|
class DINO(DeformableDETR): |
|
r"""Implementation of `DINO: DETR with Improved DeNoising Anchor Boxes |
|
for End-to-End Object Detection <https://arxiv.org/abs/2203.03605>`_ |
|
|
|
Code is modified from the `official github repo |
|
<https://github.com/IDEA-Research/DINO>`_. |
|
|
|
Args: |
|
dn_cfg (:obj:`ConfigDict` or dict, optional): Config of denoising |
|
query generator. Defaults to `None`. |
|
""" |
|
|
|
def __init__(self, *args, dn_cfg: OptConfigType = None, |
|
candidate_bboxes_size: float = 0.05, |
|
scale_gt_bboxes_size: float = 0, |
|
htd_2s: int = False, |
|
**kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
assert self.as_two_stage, 'as_two_stage must be True for DINO' |
|
assert self.with_box_refine, 'with_box_refine must be True for DINO' |
|
|
|
if dn_cfg is not None: |
|
assert 'num_classes' not in dn_cfg and \ |
|
'num_queries' not in dn_cfg and \ |
|
'hidden_dim' not in dn_cfg, \ |
|
'The three keyword args `num_classes`, `embed_dims`, and ' \ |
|
'`num_matching_queries` are set in `detector.__init__()`, ' \ |
|
'users should not set them in `dn_cfg` config.' |
|
dn_cfg['num_classes'] = self.bbox_head.num_classes |
|
dn_cfg['embed_dims'] = self.embed_dims |
|
dn_cfg['num_matching_queries'] = self.num_queries |
|
self.dn_query_generator = CdnQueryGenerator(**dn_cfg) |
|
self.scale_gt_bboxes_size = scale_gt_bboxes_size |
|
self.candidate_bboxes_size = candidate_bboxes_size |
|
self.htd_2s = htd_2s |
|
def _init_layers(self) -> None: |
|
"""Initialize layers except for backbone, neck and bbox_head.""" |
|
self.positional_encoding = SinePositionalEncoding( |
|
**self.positional_encoding) |
|
self.encoder = DeformableDetrTransformerEncoder(**self.encoder) |
|
self.decoder = DinoTransformerDecoder(**self.decoder) |
|
self.embed_dims = self.encoder.embed_dims |
|
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) |
|
|
|
|
|
|
|
|
|
|
|
num_feats = self.positional_encoding.num_feats |
|
assert num_feats * 2 == self.embed_dims, \ |
|
f'embed_dims should be exactly 2 times of num_feats. ' \ |
|
f'Found {self.embed_dims} and {num_feats}.' |
|
|
|
self.level_embed = nn.Parameter( |
|
torch.Tensor(self.num_feature_levels, self.embed_dims)) |
|
self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) |
|
self.memory_trans_norm = nn.LayerNorm(self.embed_dims) |
|
|
|
def init_weights(self) -> None: |
|
"""Initialize weights for Transformer and other components.""" |
|
super(DeformableDETR, self).init_weights() |
|
for coder in self.encoder, self.decoder: |
|
for p in coder.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
for m in self.modules(): |
|
if isinstance(m, MultiScaleDeformableAttention): |
|
m.init_weights() |
|
nn.init.xavier_uniform_(self.memory_trans_fc.weight) |
|
nn.init.xavier_uniform_(self.query_embedding.weight) |
|
normal_(self.level_embed) |
|
|
|
|
|
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: |
|
"""Extract features. |
|
|
|
Args: |
|
batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W). |
|
|
|
Returns: |
|
tuple[Tensor]: Tuple of feature maps from neck. Each feature map |
|
has shape (bs, dim, H, W). |
|
""" |
|
x = self.backbone(batch_inputs) |
|
if self.with_neck: |
|
x = self.neck(x) |
|
return x |
|
|
|
def loss(self, batch_inputs: Tensor, |
|
batch_data_samples: SampleList) -> Union[dict, list]: |
|
"""Calculate losses from a batch of inputs and data samples. |
|
|
|
Args: |
|
batch_inputs (Tensor): Input images of shape (bs, dim, H, W). |
|
These should usually be mean centered and std scaled. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
|
|
|
Returns: |
|
dict: A dictionary of loss components |
|
""" |
|
|
|
if self.scale_gt_bboxes_size>0: |
|
batch_data_samples = self.rescale_gt_bboxes(batch_data_samples, self.scale_gt_bboxes_size) |
|
|
|
img_feats = self.extract_feat(batch_inputs) |
|
head_inputs_dict = self.forward_transformer(img_feats, |
|
batch_data_samples) |
|
losses = self.bbox_head.loss( |
|
**head_inputs_dict, batch_data_samples=batch_data_samples) |
|
|
|
return losses |
|
|
|
|
|
def predict(self, |
|
batch_inputs: Tensor, |
|
batch_data_samples: SampleList, |
|
rescale: bool = True) -> SampleList: |
|
"""Predict results from a batch of inputs and data samples with post- |
|
processing. |
|
|
|
Args: |
|
batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W). |
|
batch_data_samples (List[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
|
rescale (bool): Whether to rescale the results. |
|
Defaults to True. |
|
|
|
Returns: |
|
list[:obj:`DetDataSample`]: Detection results of the input images. |
|
Each DetDataSample usually contain 'pred_instances'. And the |
|
`pred_instances` usually contains following keys. |
|
|
|
- scores (Tensor): Classification scores, has a shape |
|
(num_instance, ) |
|
- labels (Tensor): Labels of bboxes, has a shape |
|
(num_instances, ). |
|
- bboxes (Tensor): Has a shape (num_instances, 4), |
|
the last dimension 4 arrange as (x1, y1, x2, y2). |
|
""" |
|
img_feats = self.extract_feat(batch_inputs) |
|
head_inputs_dict = self.forward_transformer(img_feats, |
|
batch_data_samples) |
|
results_list = self.bbox_head.predict( |
|
**head_inputs_dict, |
|
rescale=rescale, |
|
batch_data_samples=batch_data_samples) |
|
batch_data_samples = self.add_pred_to_datasample( |
|
batch_data_samples, results_list) |
|
return batch_data_samples |
|
|
|
def _forward(self, |
|
batch_inputs: Tensor, |
|
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: |
|
"""Network forward process. Usually includes backbone, neck and head |
|
forward without any post-processing. |
|
|
|
Args: |
|
batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W). |
|
batch_data_samples (List[:obj:`DetDataSample`], optional): The |
|
batch data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
|
Defaults to None. |
|
|
|
Returns: |
|
tuple[Tensor]: A tuple of features from ``bbox_head`` forward. |
|
""" |
|
img_feats = self.extract_feat(batch_inputs) |
|
head_inputs_dict = self.forward_transformer(img_feats, |
|
batch_data_samples) |
|
results = self.bbox_head.forward(**head_inputs_dict) |
|
return results |
|
|
|
def forward_transformer( |
|
self, |
|
img_feats: Tuple[Tensor], |
|
batch_data_samples: OptSampleList = None, |
|
) -> Dict: |
|
"""Forward process of Transformer. |
|
|
|
The forward procedure of the transformer is defined as: |
|
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' |
|
More details can be found at `TransformerDetector.forward_transformer` |
|
in `mmdet/detector/base_detr.py`. |
|
The difference is that the ground truth in `batch_data_samples` is |
|
required for the `pre_decoder` to prepare the query of DINO. |
|
Additionally, DINO inherits the `pre_transformer` method and the |
|
`forward_encoder` method of DeformableDETR. More details about the |
|
two methods can be found in `mmdet/detector/deformable_detr.py`. |
|
|
|
Args: |
|
img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each |
|
feature map has shape (bs, dim, H, W). |
|
batch_data_samples (list[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
|
Defaults to None. |
|
|
|
Returns: |
|
dict: The dictionary of bbox_head function inputs, which always |
|
includes the `hidden_states` of the decoder output and may contain |
|
`references` including the initial and intermediate references. |
|
""" |
|
encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( |
|
img_feats, batch_data_samples) |
|
|
|
encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict) |
|
|
|
tmp_dec_in, head_inputs_dict = self.pre_decoder( |
|
**encoder_outputs_dict, batch_data_samples=batch_data_samples) |
|
decoder_inputs_dict.update(tmp_dec_in) |
|
|
|
decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) |
|
head_inputs_dict.update(decoder_outputs_dict) |
|
return head_inputs_dict |
|
|
|
def pre_transformer( |
|
self, |
|
mlvl_feats: Tuple[Tensor], |
|
batch_data_samples: OptSampleList = None) -> Tuple[Dict]: |
|
"""Process image features before feeding them to the transformer. |
|
|
|
The forward procedure of the transformer is defined as: |
|
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' |
|
More details can be found at `TransformerDetector.forward_transformer` |
|
in `mmdet/detector/base_detr.py`. |
|
|
|
Args: |
|
mlvl_feats (tuple[Tensor]): Multi-level features that may have |
|
different resolutions, output from neck. Each feature has |
|
shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. |
|
batch_data_samples (list[:obj:`DetDataSample`], optional): The |
|
batch data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
|
Defaults to None. |
|
|
|
Returns: |
|
tuple[dict]: The first dict contains the inputs of encoder and the |
|
second dict contains the inputs of decoder. |
|
|
|
- encoder_inputs_dict (dict): The keyword args dictionary of |
|
`self.forward_encoder()`, which includes 'feat', 'feat_mask', |
|
and 'feat_pos'. |
|
- decoder_inputs_dict (dict): The keyword args dictionary of |
|
`self.forward_decoder()`, which includes 'memory_mask'. |
|
""" |
|
batch_size = mlvl_feats[0].size(0) |
|
|
|
|
|
assert batch_data_samples is not None |
|
batch_input_shape = batch_data_samples[0].batch_input_shape |
|
img_shape_list = [sample.img_shape for sample in batch_data_samples] |
|
input_img_h, input_img_w = batch_input_shape |
|
masks = mlvl_feats[0].new_ones((batch_size, input_img_h, input_img_w)) |
|
for img_id in range(batch_size): |
|
img_h, img_w = img_shape_list[img_id] |
|
masks[img_id, :img_h, :img_w] = 0 |
|
|
|
|
|
|
|
mlvl_masks = [] |
|
mlvl_pos_embeds = [] |
|
for feat in mlvl_feats: |
|
mlvl_masks.append( |
|
F.interpolate(masks[None], |
|
size=feat.shape[-2:]).to(torch.bool).squeeze(0)) |
|
mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1])) |
|
|
|
feat_flatten = [] |
|
lvl_pos_embed_flatten = [] |
|
mask_flatten = [] |
|
spatial_shapes = [] |
|
for lvl, (feat, mask, pos_embed) in enumerate( |
|
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): |
|
batch_size, c, h, w = feat.shape |
|
|
|
feat = feat.view(batch_size, c, -1).permute(0, 2, 1) |
|
pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) |
|
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) |
|
|
|
mask = mask.flatten(1) |
|
spatial_shape = (h, w) |
|
|
|
feat_flatten.append(feat) |
|
lvl_pos_embed_flatten.append(lvl_pos_embed) |
|
mask_flatten.append(mask) |
|
spatial_shapes.append(spatial_shape) |
|
|
|
|
|
feat_flatten = torch.cat(feat_flatten, 1) |
|
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) |
|
|
|
mask_flatten = torch.cat(mask_flatten, 1) |
|
|
|
spatial_shapes = torch.as_tensor( |
|
spatial_shapes, |
|
dtype=torch.long, |
|
device=feat_flatten.device) |
|
level_start_index = torch.cat(( |
|
spatial_shapes.new_zeros((1, )), |
|
spatial_shapes.prod(1).cumsum(0)[:-1])) |
|
valid_ratios = torch.stack( |
|
[self.get_valid_ratio(m) for m in mlvl_masks], 1) |
|
|
|
encoder_inputs_dict = dict( |
|
feat=feat_flatten, |
|
feat_mask=mask_flatten, |
|
feat_pos=lvl_pos_embed_flatten, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
valid_ratios=valid_ratios) |
|
decoder_inputs_dict = dict( |
|
memory_mask=mask_flatten, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
valid_ratios=valid_ratios) |
|
return encoder_inputs_dict, decoder_inputs_dict |
|
|
|
def forward_encoder(self, feat: Tensor, feat_mask: Tensor, |
|
feat_pos: Tensor, spatial_shapes: Tensor, |
|
level_start_index: Tensor, |
|
valid_ratios: Tensor) -> Dict: |
|
"""Forward with Transformer encoder. |
|
|
|
The forward procedure of the transformer is defined as: |
|
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' |
|
More details can be found at `TransformerDetector.forward_transformer` |
|
in `mmdet/detector/base_detr.py`. |
|
|
|
Args: |
|
feat (Tensor): Sequential features, has shape (bs, num_feat_points, |
|
dim). |
|
feat_mask (Tensor): ByteTensor, the padding mask of the features, |
|
has shape (bs, num_feat_points). |
|
feat_pos (Tensor): The positional embeddings of the features, has |
|
shape (bs, num_feat_points, dim). |
|
spatial_shapes (Tensor): Spatial shapes of features in all levels, |
|
has shape (num_levels, 2), last dimension represents (h, w). |
|
level_start_index (Tensor): The start index of each level. |
|
A tensor has shape (num_levels, ) and can be represented |
|
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. |
|
valid_ratios (Tensor): The ratios of the valid width and the valid |
|
height relative to the width and the height of features in all |
|
levels, has shape (bs, num_levels, 2). |
|
|
|
Returns: |
|
dict: The dictionary of encoder outputs, which includes the |
|
`memory` of the encoder output. |
|
""" |
|
memory = self.encoder( |
|
query=feat, |
|
query_pos=feat_pos, |
|
key_padding_mask=feat_mask, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
valid_ratios=valid_ratios) |
|
encoder_outputs_dict = dict( |
|
memory=memory, |
|
memory_mask=feat_mask, |
|
spatial_shapes=spatial_shapes) |
|
return encoder_outputs_dict |
|
|
|
def pre_decoder( |
|
self, |
|
memory: Tensor, |
|
memory_mask: Tensor, |
|
spatial_shapes: Tensor, |
|
batch_data_samples: OptSampleList = None, |
|
) -> Tuple[Dict]: |
|
"""Prepare intermediate variables before entering Transformer decoder, |
|
such as `query`, `query_pos`, and `reference_points`. |
|
|
|
Args: |
|
memory (Tensor): The output embeddings of the Transformer encoder, |
|
has shape (bs, num_feat_points, dim). |
|
memory_mask (Tensor): ByteTensor, the padding mask of the memory, |
|
has shape (bs, num_feat_points). Will only be used when |
|
`as_two_stage` is `True`. |
|
spatial_shapes (Tensor): Spatial shapes of features in all levels. |
|
With shape (num_levels, 2), last dimension represents (h, w). |
|
Will only be used when `as_two_stage` is `True`. |
|
batch_data_samples (list[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
|
Defaults to None. |
|
|
|
Returns: |
|
tuple[dict]: The decoder_inputs_dict and head_inputs_dict. |
|
|
|
- decoder_inputs_dict (dict): The keyword dictionary args of |
|
`self.forward_decoder()`, which includes 'query', 'memory', |
|
`reference_points`, and `dn_mask`. The reference points of |
|
decoder input here are 4D boxes, although it has `points` |
|
in its name. |
|
- head_inputs_dict (dict): The keyword dictionary args of the |
|
bbox_head functions, which includes `topk_score`, `topk_coords`, |
|
and `dn_meta` when `self.training` is `True`, else is empty. |
|
""" |
|
bs, _, c = memory.shape |
|
cls_out_features = self.bbox_head.cls_branches[ |
|
self.decoder.num_layers].out_features |
|
|
|
output_memory, output_proposals = self.gen_encoder_output_proposals( |
|
memory, memory_mask, spatial_shapes) |
|
|
|
output_memory = output_memory[:,:-1,:] |
|
output_proposals = output_proposals[:,:-1,:] |
|
|
|
enc_outputs_class = self.bbox_head.cls_branches[ |
|
self.decoder.num_layers]( |
|
output_memory) |
|
enc_outputs_coord_unact = self.bbox_head.reg_branches[ |
|
self.decoder.num_layers](output_memory) + output_proposals |
|
|
|
|
|
|
|
|
|
|
|
topk_indices = torch.topk( |
|
enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1] |
|
topk_score = torch.gather( |
|
enc_outputs_class, 1, |
|
topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) |
|
topk_coords_unact = torch.gather( |
|
enc_outputs_coord_unact, 1, |
|
topk_indices.unsqueeze(-1).repeat(1, 1, 4)) |
|
topk_coords = topk_coords_unact.sigmoid() |
|
topk_coords_unact = topk_coords_unact.detach() |
|
|
|
query = self.query_embedding.weight[:, None, :] |
|
query = query.repeat(1, bs, 1).transpose(0, 1) |
|
if self.training: |
|
dn_label_query, dn_bbox_query, dn_mask, dn_meta = \ |
|
self.dn_query_generator(batch_data_samples) |
|
query = torch.cat([dn_label_query, query], dim=1) |
|
reference_points = torch.cat([dn_bbox_query, topk_coords_unact], |
|
dim=1) |
|
else: |
|
reference_points = topk_coords_unact |
|
dn_mask, dn_meta = None, None |
|
reference_points = reference_points.sigmoid() |
|
|
|
decoder_inputs_dict = dict( |
|
query=query, |
|
memory=memory, |
|
reference_points=reference_points, |
|
dn_mask=dn_mask) |
|
|
|
|
|
|
|
head_inputs_dict = dict( |
|
enc_outputs_class=topk_score, |
|
enc_outputs_coord=topk_coords, |
|
dn_meta=dn_meta) if self.training else dict() |
|
return decoder_inputs_dict, head_inputs_dict |
|
|
|
def forward_decoder(self, |
|
query: Tensor, |
|
memory: Tensor, |
|
memory_mask: Tensor, |
|
reference_points: Tensor, |
|
spatial_shapes: Tensor, |
|
level_start_index: Tensor, |
|
valid_ratios: Tensor, |
|
dn_mask: Optional[Tensor] = None) -> Dict: |
|
"""Forward with Transformer decoder. |
|
|
|
The forward procedure of the transformer is defined as: |
|
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' |
|
More details can be found at `TransformerDetector.forward_transformer` |
|
in `mmdet/detector/base_detr.py`. |
|
|
|
Args: |
|
query (Tensor): The queries of decoder inputs, has shape |
|
(bs, num_queries_total, dim), where `num_queries_total` is the |
|
sum of `num_denoising_queries` and `num_matching_queries` when |
|
`self.training` is `True`, else `num_matching_queries`. |
|
memory (Tensor): The output embeddings of the Transformer encoder, |
|
has shape (bs, num_feat_points, dim). |
|
memory_mask (Tensor): ByteTensor, the padding mask of the memory, |
|
has shape (bs, num_feat_points). |
|
reference_points (Tensor): The initial reference, has shape |
|
(bs, num_queries_total, 4) with the last dimension arranged as |
|
(cx, cy, w, h). |
|
spatial_shapes (Tensor): Spatial shapes of features in all levels, |
|
has shape (num_levels, 2), last dimension represents (h, w). |
|
level_start_index (Tensor): The start index of each level. |
|
A tensor has shape (num_levels, ) and can be represented |
|
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. |
|
valid_ratios (Tensor): The ratios of the valid width and the valid |
|
height relative to the width and the height of features in all |
|
levels, has shape (bs, num_levels, 2). |
|
dn_mask (Tensor, optional): The attention mask to prevent |
|
information leakage from different denoising groups and |
|
matching parts, will be used as `self_attn_mask` of the |
|
`self.decoder`, has shape (num_queries_total, |
|
num_queries_total). |
|
It is `None` when `self.training` is `False`. |
|
|
|
Returns: |
|
dict: The dictionary of decoder outputs, which includes the |
|
`hidden_states` of the decoder output and `references` including |
|
the initial and intermediate reference_points. |
|
""" |
|
inter_states, references = self.decoder( |
|
query=query, |
|
value=memory, |
|
key_padding_mask=memory_mask, |
|
self_attn_mask=dn_mask, |
|
reference_points=reference_points, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
valid_ratios=valid_ratios, |
|
reg_branches=self.bbox_head.reg_branches) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(query) == self.num_queries: |
|
|
|
|
|
|
|
|
|
inter_states[0] += \ |
|
self.dn_query_generator.label_embedding.weight[0, 0] * 0.0 |
|
|
|
decoder_outputs_dict = dict( |
|
hidden_states=inter_states, references=list(references)) |
|
return decoder_outputs_dict |
|
|
|
|
|
@staticmethod |
|
def get_valid_ratio(mask: Tensor) -> Tensor: |
|
"""Get the valid radios of feature map in a level. |
|
|
|
.. code:: text |
|
|
|
|---> valid_W <---| |
|
---+-----------------+-----+--- |
|
A | | | A |
|
| | | | | |
|
| | | | | |
|
valid_H | | | | |
|
| | | | H |
|
| | | | | |
|
V | | | | |
|
---+-----------------+ | | |
|
| | V |
|
+-----------------------+--- |
|
|---------> W <---------| |
|
|
|
The valid_ratios are defined as: |
|
r_h = valid_H / H, r_w = valid_W / W |
|
They are the factors to re-normalize the relative coordinates of the |
|
image to the relative coordinates of the current level feature map. |
|
|
|
Args: |
|
mask (Tensor): Binary mask of a feature map, has shape (bs, H, W). |
|
|
|
Returns: |
|
Tensor: valid ratios [r_w, r_h] of a feature map, has shape (1, 2). |
|
""" |
|
_, H, W = mask.shape |
|
valid_H = torch.sum(~mask[:, :, 0], 1) |
|
valid_W = torch.sum(~mask[:, 0, :], 1) |
|
valid_ratio_h = valid_H.float() / H |
|
valid_ratio_w = valid_W.float() / W |
|
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) |
|
return valid_ratio |
|
|
|
def gen_encoder_output_proposals( |
|
self, memory: Tensor, memory_mask: Tensor, |
|
spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]: |
|
"""Generate proposals from encoded memory. The function will only be |
|
used when `as_two_stage` is `True`. |
|
|
|
Args: |
|
memory (Tensor): The output embeddings of the Transformer encoder, |
|
has shape (bs, num_feat_points, dim). |
|
memory_mask (Tensor): ByteTensor, the padding mask of the memory, |
|
has shape (bs, num_feat_points). |
|
spatial_shapes (Tensor): Spatial shapes of features in all levels, |
|
has shape (num_levels, 2), last dimension represents (h, w). |
|
|
|
Returns: |
|
tuple: A tuple of transformed memory and proposals. |
|
|
|
- output_memory (Tensor): The transformed memory for obtaining |
|
top-k proposals, has shape (bs, num_feat_points, dim). |
|
- output_proposals (Tensor): The inverse-normalized proposal, has |
|
shape (batch_size, num_keys, 4) with the last dimension arranged |
|
as (cx, cy, w, h). |
|
""" |
|
|
|
bs = memory.size(0) |
|
proposals = [] |
|
|
|
_cur = 0 |
|
for lvl, (H, W) in enumerate(spatial_shapes): |
|
mask_flatten_ = memory_mask[:, |
|
_cur:(_cur + H * W)].view(bs, H, W, 1) |
|
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1) |
|
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1) |
|
|
|
grid_y, grid_x = torch.meshgrid( |
|
torch.linspace( |
|
0, H - 1, H, dtype=torch.float32, device=memory.device), |
|
torch.linspace( |
|
0, W - 1, W, dtype=torch.float32, device=memory.device)) |
|
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) |
|
|
|
scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2) |
|
grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale |
|
wh = torch.ones_like(grid) * self.candidate_bboxes_size * (2.0 ** lvl) |
|
|
|
proposal = torch.cat((grid, wh), -1).view(bs, -1, 4) |
|
proposals.append(proposal) |
|
_cur += (H * W) |
|
output_proposals = torch.cat(proposals, 1) |
|
output_proposals_valid = ((output_proposals > 0.01) & |
|
(output_proposals < 0.99)).all( |
|
-1, keepdim=True) |
|
|
|
if self.htd_2s: |
|
output_proposals_valid = ((output_proposals > 0.0001) & |
|
(output_proposals < 0.9999)).all( |
|
-1, keepdim=True) |
|
|
|
output_proposals = torch.log(output_proposals / (1 - output_proposals)) |
|
output_proposals = output_proposals.masked_fill( |
|
memory_mask.unsqueeze(-1), float('inf')) |
|
output_proposals = output_proposals.masked_fill( |
|
~output_proposals_valid, float('inf')) |
|
output_memory = memory |
|
output_memory = output_memory.masked_fill( |
|
memory_mask.unsqueeze(-1), float(0)) |
|
output_memory = output_memory.masked_fill(~output_proposals_valid, |
|
float(0)) |
|
output_memory = self.memory_trans_fc(output_memory) |
|
output_memory = self.memory_trans_norm(output_memory) |
|
|
|
return output_memory, output_proposals |
|
|
|
|
|
@staticmethod |
|
def rescale_gt_bboxes(batch_data_samples:OptSampleList, scale_gt_bboxes_size:float = 0.25) -> OptSampleList: |
|
for i_sample in range(len(batch_data_samples)): |
|
gt_bboxes = batch_data_samples[i_sample].gt_instances.bboxes |
|
gt_bboxes[:, :2] = gt_bboxes[:, :2] +scale_gt_bboxes_size |
|
gt_bboxes[:, 2:] = gt_bboxes[:, 2:] - scale_gt_bboxes_size |
|
|
|
return batch_data_samples |
|
|
|
|
|
|