File size: 4,577 Bytes
e8f2571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import OrderedDict
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmengine.logging import MMLogger
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import (constant_init, trunc_normal_,
                                        trunc_normal_init)
from mmengine.runner.checkpoint import CheckpointLoader
from mmengine.utils import to_2tuple
from typing import Optional, Sequence, Tuple, Union
from mmdet.registry import MODELS
from mmdet.utils import OptConfigType, OptMultiConfig
from torch import Tensor, nn
from ..layers import PatchEmbed, PatchMerging,AdaptivePadding


def expand_tensor_along_second_dim(x, num):
    assert x.size(1)<=num
    # 计算需要重复的次数
    repeat_times = num // x.size(1)
    # 使用 repeat 函数对 x 张量进行复制
    x = x.repeat(1, repeat_times, 1, 1)
    # 如果 num 不是 x.size(1) 的整数倍,则进行切片操作
    if num % x.size(1) != 0:
        x = torch.cat([x, x[:, :num % x.size(1)]], dim=1)
    return x

def extract_tensor_along_second_dim(x, m):
    # 计算等间隔的索引
    idx = torch.linspace(0, x.size(1) - 1, m).long().to(x.device)
    # 使用 index_select 函数在第二个维度上进行抽取
    x = torch.index_select(x, 1, idx)

    return x


@MODELS.register_module()
class No_backbone_ST(BaseModule):
    def __init__(self,
                 in_channels=3,
                 embed_dims=96,
                 strides=(1, 2, 2, 4),
                 patch_size=(1, 2, 2, 4),
                 patch_norm=True,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 pretrained=None,
                 num_levels =2,
                 init_cfg=None):
        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be specified at the same time'
        if isinstance(pretrained, str):
            warnings.warn('DeprecationWarning: pretrained is deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
        elif pretrained is None:
            self.init_cfg = init_cfg
        else:
            raise TypeError('pretrained must be a str or None')

        super(No_backbone_ST, self).__init__(init_cfg=init_cfg)
        assert strides[0] == patch_size[0], 'Use non-overlapping patch embed.'
        self.embed_dims =embed_dims
        self.in_channels = in_channels

        self.patch_embed = PatchEmbed(
            in_channels=in_channels,
            embed_dims=embed_dims,
            conv_type='Conv2d',
            kernel_size=patch_size[0],
            stride=strides[0],
            norm_cfg=norm_cfg if patch_norm else None,
            init_cfg=None)
        self.num_levels = num_levels
        self.conv = nn.Conv2d(in_channels, embed_dims, kernel_size=1)
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, embed_dims),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(embed_dims, embed_dims),
            nn.LeakyReLU(negative_slope=0.2)
        )
        if norm_cfg is not None:
            self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
            # self.norm = build_norm_layer(norm_cfg, 128*128)[1]

    def train(self, mode=True):
        """Convert the model into training mode while keep layers freezed."""
        super(No_backbone_ST, self).train(mode)

    def forward(self, x):

        # x, hw_shape = self.patch_embed(x)
        # outs = []
        # out = x.view(-1, *hw_shape, self.embed_dims).permute(0, 3, 1, 2).contiguous()

        if self.in_channels < x.size(1):
            x = extract_tensor_along_second_dim(x, self.in_channels)
        outs = []
        # out = expand_tensor_along_second_dim(x, self.embed_dims)
        out = self.conv(x)
        out = self.norm(out.flatten(2).transpose(1, 2))
        # BN
        # out = self.norm(out.flatten(2)).transpose(1, 2)
        # y = x.reshape(x.size(0),x.size(1),-1).permute(0, 2, 1)
        # out = self.mlp(y)
        out = out.permute(0, 2, 1).reshape(x.size(0), self.embed_dims,x.size(2),x.size(3)).contiguous()
        outs.append(out)
        if self.num_levels > 1:
            mean = outs[0].mean(dim=(2, 3), keepdim=True).detach()
            outs.append(mean)
        return outs