|
|
|
from typing import Dict |
|
|
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from ..layers import (ConditionalDetrTransformerDecoder, |
|
DetrTransformerEncoder, SinePositionalEncoding) |
|
from .detr import DETR |
|
|
|
|
|
@MODELS.register_module() |
|
class ConditionalDETR(DETR): |
|
r"""Implementation of `Conditional DETR for Fast Training Convergence. |
|
|
|
<https://arxiv.org/abs/2108.06152>`_. |
|
|
|
Code is modified from the `official github repo |
|
<https://github.com/Atten4Vis/ConditionalDETR>`_. |
|
""" |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize layers except for backbone, neck and bbox_head.""" |
|
self.positional_encoding = SinePositionalEncoding( |
|
**self.positional_encoding) |
|
self.encoder = DetrTransformerEncoder(**self.encoder) |
|
self.decoder = ConditionalDetrTransformerDecoder(**self.decoder) |
|
self.embed_dims = self.encoder.embed_dims |
|
|
|
|
|
|
|
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) |
|
|
|
num_feats = self.positional_encoding.num_feats |
|
assert num_feats * 2 == self.embed_dims, \ |
|
f'embed_dims should be exactly 2 times of num_feats. ' \ |
|
f'Found {self.embed_dims} and {num_feats}.' |
|
|
|
def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, |
|
memory_mask: Tensor, memory_pos: Tensor) -> Dict: |
|
"""Forward with Transformer decoder. |
|
|
|
Args: |
|
query (Tensor): The queries of decoder inputs, has shape |
|
(bs, num_queries, dim). |
|
query_pos (Tensor): The positional queries of decoder inputs, |
|
has shape (bs, num_queries, dim). |
|
memory (Tensor): The output embeddings of the Transformer encoder, |
|
has shape (bs, num_feat_points, dim). |
|
memory_mask (Tensor): ByteTensor, the padding mask of the memory, |
|
has shape (bs, num_feat_points). |
|
memory_pos (Tensor): The positional embeddings of memory, has |
|
shape (bs, num_feat_points, dim). |
|
|
|
Returns: |
|
dict: The dictionary of decoder outputs, which includes the |
|
`hidden_states` and `references` of the decoder output. |
|
|
|
- hidden_states (Tensor): Has shape |
|
(num_decoder_layers, bs, num_queries, dim) |
|
- references (Tensor): Has shape |
|
(bs, num_queries, 2) |
|
""" |
|
|
|
hidden_states, references = self.decoder( |
|
query=query, |
|
key=memory, |
|
query_pos=query_pos, |
|
key_pos=memory_pos, |
|
key_padding_mask=memory_mask) |
|
head_inputs_dict = dict( |
|
hidden_states=hidden_states, references=references) |
|
return head_inputs_dict |
|
|