|
|
|
from typing import Dict, Tuple |
|
|
|
from mmengine.model import uniform_init |
|
from torch import Tensor, nn |
|
|
|
from mmdet.registry import MODELS |
|
from ..layers import SinePositionalEncoding |
|
from ..layers.transformer import (DABDetrTransformerDecoder, |
|
DABDetrTransformerEncoder, inverse_sigmoid) |
|
from .detr import DETR |
|
|
|
|
|
@MODELS.register_module() |
|
class DABDETR(DETR): |
|
r"""Implementation of `DAB-DETR: |
|
Dynamic Anchor Boxes are Better Queries for DETR. |
|
|
|
<https://arxiv.org/abs/2201.12329>`_. |
|
|
|
Code is modified from the `official github repo |
|
<https://github.com/IDEA-Research/DAB-DETR>`_. |
|
|
|
Args: |
|
with_random_refpoints (bool): Whether to randomly initialize query |
|
embeddings and not update them during training. |
|
Defaults to False. |
|
num_patterns (int): Inspired by Anchor-DETR. Defaults to 0. |
|
""" |
|
|
|
def __init__(self, |
|
*args, |
|
with_random_refpoints: bool = False, |
|
num_patterns: int = 0, |
|
**kwargs) -> None: |
|
self.with_random_refpoints = with_random_refpoints |
|
assert isinstance(num_patterns, int), \ |
|
f'num_patterns should be int but {num_patterns}.' |
|
self.num_patterns = num_patterns |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize layers except for backbone, neck and bbox_head.""" |
|
self.positional_encoding = SinePositionalEncoding( |
|
**self.positional_encoding) |
|
self.encoder = DABDetrTransformerEncoder(**self.encoder) |
|
self.decoder = DABDetrTransformerDecoder(**self.decoder) |
|
self.embed_dims = self.encoder.embed_dims |
|
self.query_dim = self.decoder.query_dim |
|
self.query_embedding = nn.Embedding(self.num_queries, self.query_dim) |
|
if self.num_patterns > 0: |
|
self.patterns = nn.Embedding(self.num_patterns, self.embed_dims) |
|
|
|
num_feats = self.positional_encoding.num_feats |
|
assert num_feats * 2 == self.embed_dims, \ |
|
f'embed_dims should be exactly 2 times of num_feats. ' \ |
|
f'Found {self.embed_dims} and {num_feats}.' |
|
|
|
def init_weights(self) -> None: |
|
"""Initialize weights for Transformer and other components.""" |
|
super(DABDETR, self).init_weights() |
|
if self.with_random_refpoints: |
|
uniform_init(self.query_embedding) |
|
self.query_embedding.weight.data[:, :2] = \ |
|
inverse_sigmoid(self.query_embedding.weight.data[:, :2]) |
|
self.query_embedding.weight.data[:, :2].requires_grad = False |
|
|
|
def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]: |
|
"""Prepare intermediate variables before entering Transformer decoder, |
|
such as `query`, `query_pos`. |
|
|
|
Args: |
|
memory (Tensor): The output embeddings of the Transformer encoder, |
|
has shape (bs, num_feat_points, dim). |
|
|
|
Returns: |
|
tuple[dict, dict]: The first dict contains the inputs of decoder |
|
and the second dict contains the inputs of the bbox_head function. |
|
|
|
- decoder_inputs_dict (dict): The keyword args dictionary of |
|
`self.forward_decoder()`, which includes 'query', 'query_pos', |
|
'memory' and 'reg_branches'. |
|
- head_inputs_dict (dict): The keyword args dictionary of the |
|
bbox_head functions, which is usually empty, or includes |
|
`enc_outputs_class` and `enc_outputs_class` when the detector |
|
support 'two stage' or 'query selection' strategies. |
|
""" |
|
batch_size = memory.size(0) |
|
query_pos = self.query_embedding.weight |
|
query_pos = query_pos.unsqueeze(0).repeat(batch_size, 1, 1) |
|
if self.num_patterns == 0: |
|
query = query_pos.new_zeros(batch_size, self.num_queries, |
|
self.embed_dims) |
|
else: |
|
query = self.patterns.weight[:, None, None, :]\ |
|
.repeat(1, self.num_queries, batch_size, 1)\ |
|
.view(-1, batch_size, self.embed_dims)\ |
|
.permute(1, 0, 2) |
|
query_pos = query_pos.repeat(1, self.num_patterns, 1) |
|
|
|
decoder_inputs_dict = dict( |
|
query_pos=query_pos, query=query, memory=memory) |
|
head_inputs_dict = dict() |
|
return decoder_inputs_dict, head_inputs_dict |
|
|
|
def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, |
|
memory_mask: Tensor, memory_pos: Tensor) -> Dict: |
|
"""Forward with Transformer decoder. |
|
|
|
Args: |
|
query (Tensor): The queries of decoder inputs, has shape |
|
(bs, num_queries, dim). |
|
query_pos (Tensor): The positional queries of decoder inputs, |
|
has shape (bs, num_queries, dim). |
|
memory (Tensor): The output embeddings of the Transformer encoder, |
|
has shape (bs, num_feat_points, dim). |
|
memory_mask (Tensor): ByteTensor, the padding mask of the memory, |
|
has shape (bs, num_feat_points). |
|
memory_pos (Tensor): The positional embeddings of memory, has |
|
shape (bs, num_feat_points, dim). |
|
|
|
Returns: |
|
dict: The dictionary of decoder outputs, which includes the |
|
`hidden_states` and `references` of the decoder output. |
|
""" |
|
|
|
hidden_states, references = self.decoder( |
|
query=query, |
|
key=memory, |
|
query_pos=query_pos, |
|
key_pos=memory_pos, |
|
key_padding_mask=memory_mask, |
|
reg_branches=self.bbox_head. |
|
fc_reg |
|
) |
|
head_inputs_dict = dict( |
|
hidden_states=hidden_states, references=references) |
|
return head_inputs_dict |
|
|