samewind / mmdet /models /backbones /st_nobackbone.py
scfive
Resolve README.md conflict and continue rebase
e8f2571
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmengine.logging import MMLogger
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import (constant_init, trunc_normal_,
trunc_normal_init)
from mmengine.runner.checkpoint import CheckpointLoader
from mmengine.utils import to_2tuple
from typing import Optional, Sequence, Tuple, Union
from mmdet.registry import MODELS
from mmdet.utils import OptConfigType, OptMultiConfig
from torch import Tensor, nn
from ..layers import PatchEmbed, PatchMerging,AdaptivePadding
def expand_tensor_along_second_dim(x, num):
assert x.size(1)<=num
# 计算需要重复的次数
repeat_times = num // x.size(1)
# 使用 repeat 函数对 x 张量进行复制
x = x.repeat(1, repeat_times, 1, 1)
# 如果 num 不是 x.size(1) 的整数倍,则进行切片操作
if num % x.size(1) != 0:
x = torch.cat([x, x[:, :num % x.size(1)]], dim=1)
return x
def extract_tensor_along_second_dim(x, m):
# 计算等间隔的索引
idx = torch.linspace(0, x.size(1) - 1, m).long().to(x.device)
# 使用 index_select 函数在第二个维度上进行抽取
x = torch.index_select(x, 1, idx)
return x
@MODELS.register_module()
class No_backbone_ST(BaseModule):
def __init__(self,
in_channels=3,
embed_dims=96,
strides=(1, 2, 2, 4),
patch_size=(1, 2, 2, 4),
patch_norm=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
pretrained=None,
num_levels =2,
init_cfg=None):
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be specified at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
self.init_cfg = init_cfg
else:
raise TypeError('pretrained must be a str or None')
super(No_backbone_ST, self).__init__(init_cfg=init_cfg)
assert strides[0] == patch_size[0], 'Use non-overlapping patch embed.'
self.embed_dims =embed_dims
self.in_channels = in_channels
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size[0],
stride=strides[0],
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
self.num_levels = num_levels
self.conv = nn.Conv2d(in_channels, embed_dims, kernel_size=1)
self.mlp = nn.Sequential(
nn.Linear(in_channels, embed_dims),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(embed_dims, embed_dims),
nn.LeakyReLU(negative_slope=0.2)
)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
# self.norm = build_norm_layer(norm_cfg, 128*128)[1]
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(No_backbone_ST, self).train(mode)
def forward(self, x):
# x, hw_shape = self.patch_embed(x)
# outs = []
# out = x.view(-1, *hw_shape, self.embed_dims).permute(0, 3, 1, 2).contiguous()
if self.in_channels < x.size(1):
x = extract_tensor_along_second_dim(x, self.in_channels)
outs = []
# out = expand_tensor_along_second_dim(x, self.embed_dims)
out = self.conv(x)
out = self.norm(out.flatten(2).transpose(1, 2))
# BN
# out = self.norm(out.flatten(2)).transpose(1, 2)
# y = x.reshape(x.size(0),x.size(1),-1).permute(0, 2, 1)
# out = self.mlp(y)
out = out.permute(0, 2, 1).reshape(x.size(0), self.embed_dims,x.size(2),x.size(3)).contiguous()
outs.append(out)
if self.num_levels > 1:
mean = outs[0].mean(dim=(2, 3), keepdim=True).detach()
outs.append(mean)
return outs