Spaces:
Running
Running
Upload 45 files
Browse files- linea/configs/linea/include/dataset.py +10 -0
- linea/configs/linea/include/linea.py +62 -0
- linea/configs/linea/include/optimizer.py +9 -0
- linea/configs/linea/linea_hgnetv2_l.py +56 -0
- linea/configs/linea/linea_hgnetv2_m.py +63 -0
- linea/configs/linea/linea_hgnetv2_n.py +63 -0
- linea/configs/linea/linea_hgnetv2_s.py +64 -0
- linea/models/__init__.py +8 -0
- linea/models/__pycache__/__init__.cpython-311.pyc +0 -0
- linea/models/__pycache__/registry.cpython-311.pyc +0 -0
- linea/models/linea/__init__.py +11 -0
- linea/models/linea/__pycache__/__init__.cpython-311.pyc +0 -0
- linea/models/linea/__pycache__/attention_mechanism.cpython-311.pyc +0 -0
- linea/models/linea/__pycache__/criterion.cpython-311.pyc +0 -0
- linea/models/linea/__pycache__/decoder.cpython-311.pyc +0 -0
- linea/models/linea/__pycache__/dn_components.cpython-311.pyc +0 -0
- linea/models/linea/__pycache__/hgnetv2.cpython-311.pyc +0 -0
- linea/models/linea/__pycache__/hybrid_encoder_asymmetric_conv.cpython-311.pyc +0 -0
- linea/models/linea/__pycache__/linea.cpython-311.pyc +0 -0
- linea/models/linea/__pycache__/linea_utils.cpython-311.pyc +0 -0
- linea/models/linea/__pycache__/matcher.cpython-311.pyc +0 -0
- linea/models/linea/__pycache__/utils.cpython-311.pyc +0 -0
- linea/models/linea/attention_mechanism.py +593 -0
- linea/models/linea/criterion.py +517 -0
- linea/models/linea/decoder.py +551 -0
- linea/models/linea/dn_components.py +178 -0
- linea/models/linea/hgnetv2.py +595 -0
- linea/models/linea/hybrid_encoder.py +471 -0
- linea/models/linea/hybrid_encoder_asymmetric_conv.py +549 -0
- linea/models/linea/linea.py +156 -0
- linea/models/linea/linea_utils.py +165 -0
- linea/models/linea/matcher.py +180 -0
- linea/models/linea/new_dn_components.py +163 -0
- linea/models/linea/position_encoding.py +150 -0
- linea/models/linea/utils.py +139 -0
- linea/models/registry.py +58 -0
- linea/requirements.txt +9 -0
- linea/util/__init__.py +1 -0
- linea/util/__pycache__/__init__.cpython-311.pyc +0 -0
- linea/util/__pycache__/misc.cpython-311.pyc +0 -0
- linea/util/__pycache__/slconfig.cpython-311.pyc +0 -0
- linea/util/get_param_dicts.py +35 -0
- linea/util/misc.py +275 -0
- linea/util/profiler.py +21 -0
- linea/util/slconfig.py +440 -0
linea/configs/linea/include/dataset.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data_aug_scales = [(640, 640)]
|
2 |
+
data_aug_max_size = 1333
|
3 |
+
data_aug_scales2_resize = [400, 500, 600]
|
4 |
+
data_aug_scales2_crop = [384, 600]
|
5 |
+
|
6 |
+
|
7 |
+
data_aug_scale_overlap = None
|
8 |
+
batch_size_train = 8
|
9 |
+
batch_size_val = 64
|
10 |
+
|
linea/configs/linea/include/linea.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model
|
2 |
+
modelname = 'LINEA'
|
3 |
+
eval_spatial_size = (640, 640)
|
4 |
+
eval_idx = 5 # 6 decoder layers
|
5 |
+
num_classes = 2
|
6 |
+
|
7 |
+
## backbone
|
8 |
+
pretrained = True
|
9 |
+
use_checkpoint = False
|
10 |
+
return_interm_indices = [1, 2, 3]
|
11 |
+
freeze_norm = True
|
12 |
+
freeze_stem_only = True
|
13 |
+
|
14 |
+
## encoder
|
15 |
+
hybrid_encoder = 'hybrid_encoder_asymmetric_conv'
|
16 |
+
in_channels_encoder = [512, 1024, 2048]
|
17 |
+
pe_temperatureH = 20
|
18 |
+
pe_temperatureW = 20
|
19 |
+
|
20 |
+
## encoder
|
21 |
+
transformer_activation = 'relu'
|
22 |
+
batch_norm_type = 'FrozenBatchNorm2d'
|
23 |
+
masks = False
|
24 |
+
aux_loss = True
|
25 |
+
|
26 |
+
## decoder
|
27 |
+
num_queries = 1100
|
28 |
+
query_dim = 4
|
29 |
+
num_feature_levels = 3
|
30 |
+
dec_n_points = [4, 1, 1]
|
31 |
+
dropout = 0.0
|
32 |
+
pre_norm = False
|
33 |
+
|
34 |
+
# denoise
|
35 |
+
use_dn = True
|
36 |
+
dn_number = 300
|
37 |
+
dn_line_noise_scale = 1.0
|
38 |
+
dn_label_noise_ratio = 0.5
|
39 |
+
embed_init_tgt = True
|
40 |
+
dn_labelbook_size = 2
|
41 |
+
match_unstable_error = True
|
42 |
+
|
43 |
+
# matcher
|
44 |
+
set_cost_class = 2.0
|
45 |
+
set_cost_lines = 5.0
|
46 |
+
|
47 |
+
# criterion
|
48 |
+
criterionname = 'LINEACRITERION'
|
49 |
+
criterion_type = 'default'
|
50 |
+
weight_dict = {'loss_logits': 1, 'loss_line': 5}
|
51 |
+
losses = ['labels', 'lines']
|
52 |
+
focal_alpha = 0.1
|
53 |
+
|
54 |
+
matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher
|
55 |
+
nms_iou_threshold = -1
|
56 |
+
|
57 |
+
# for ema
|
58 |
+
use_ema = False
|
59 |
+
ema_decay = 0.9997
|
60 |
+
ema_epoch = 0
|
61 |
+
|
62 |
+
|
linea/configs/linea/include/optimizer.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lr = 0.00025
|
2 |
+
weight_decay = 0.000125
|
3 |
+
betas = [0.9, 0.999]
|
4 |
+
|
5 |
+
epochs = 12
|
6 |
+
lr_drop_list = [11]
|
7 |
+
clip_max_norm = 0.1
|
8 |
+
|
9 |
+
save_checkpoint_interval = 1
|
linea/configs/linea/linea_hgnetv2_l.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'./include/dataset.py',
|
3 |
+
'./include/optimizer.py',
|
4 |
+
'./include/linea.py'
|
5 |
+
]
|
6 |
+
|
7 |
+
output_dir = 'output/line_hgnetv2_l'
|
8 |
+
|
9 |
+
# backbone
|
10 |
+
backbone = 'HGNetv2_B4'
|
11 |
+
param_dict_type = backbone.lower()
|
12 |
+
use_lab = False
|
13 |
+
|
14 |
+
|
15 |
+
# transformer
|
16 |
+
feat_strides = [8, 16, 32]
|
17 |
+
hidden_dim = 256
|
18 |
+
dim_feedforward = 1024
|
19 |
+
nheads = 8
|
20 |
+
use_lmap = False
|
21 |
+
|
22 |
+
## encoder
|
23 |
+
hybrid_encoder = 'hybrid_encoder_asymmetric_conv'
|
24 |
+
in_channels_encoder = [512, 1024, 2048]
|
25 |
+
pe_temperatureH = 20
|
26 |
+
pe_temperatureW = 20
|
27 |
+
expansion = 0.5
|
28 |
+
depth_mult = 1.0
|
29 |
+
|
30 |
+
## decoder
|
31 |
+
feat_channels_decoder = [256, 256, 256]
|
32 |
+
dec_layers = 6
|
33 |
+
num_queries = 1100
|
34 |
+
num_select = 300
|
35 |
+
reg_max = 16
|
36 |
+
reg_scale = 4
|
37 |
+
|
38 |
+
# criterion
|
39 |
+
weight_dict = {'loss_logits': 4, 'loss_line': 5}
|
40 |
+
use_warmup = False
|
41 |
+
|
42 |
+
# optimizer params
|
43 |
+
model_parameters = [
|
44 |
+
{
|
45 |
+
'params': '^(?=.*backbone)(?!.*norm|bn).*$',
|
46 |
+
'lr': 0.0000125
|
47 |
+
},
|
48 |
+
|
49 |
+
{
|
50 |
+
'params': '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn)).*$',
|
51 |
+
'weight_decay': 0.
|
52 |
+
}
|
53 |
+
]
|
54 |
+
lr = 0.00025
|
55 |
+
betas = [0.9, 0.999]
|
56 |
+
weight_decay = 0.000125
|
linea/configs/linea/linea_hgnetv2_m.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'./include/dataset.py',
|
3 |
+
'./include/optimizer.py',
|
4 |
+
'./include/linea.py'
|
5 |
+
]
|
6 |
+
|
7 |
+
output_dir = 'output/line_hgnetv2_m'
|
8 |
+
|
9 |
+
# backbone
|
10 |
+
backbone = 'HGNetv2_B2'
|
11 |
+
use_lab = True
|
12 |
+
freeze_norm = False
|
13 |
+
freeze_stem_only = True
|
14 |
+
|
15 |
+
# transformer
|
16 |
+
feat_strides = [8, 16, 32]
|
17 |
+
hidden_dim = 256
|
18 |
+
dim_feedforward = 512
|
19 |
+
nheads = 8
|
20 |
+
use_lmap = False
|
21 |
+
|
22 |
+
## encoder
|
23 |
+
hybrid_encoder = 'hybrid_encoder_asymmetric_conv'
|
24 |
+
in_channels_encoder = [384, 768, 1536]
|
25 |
+
pe_temperatureH = 20
|
26 |
+
pe_temperatureW = 20
|
27 |
+
expansion = 0.34
|
28 |
+
depth_mult = 1.0
|
29 |
+
|
30 |
+
## decoder
|
31 |
+
feat_channels_decoder = [hidden_dim, hidden_dim, hidden_dim]
|
32 |
+
dec_layers = 4
|
33 |
+
num_queries = 1100
|
34 |
+
num_select = 300
|
35 |
+
reg_max = 16
|
36 |
+
reg_scale = 4
|
37 |
+
eval_idx = 3
|
38 |
+
|
39 |
+
# criterion
|
40 |
+
epochs = 24
|
41 |
+
lr_drop_list = [20]
|
42 |
+
weight_dict = {'loss_logits': 2, 'loss_line': 5}
|
43 |
+
use_warmup = False
|
44 |
+
|
45 |
+
# optimizer params
|
46 |
+
model_parameters = [
|
47 |
+
{
|
48 |
+
'params': '^(?=.*backbone)(?!.*norm|bn).*$',
|
49 |
+
'lr': 0.00002
|
50 |
+
},
|
51 |
+
{
|
52 |
+
'params': '^(?=.*backbone)(?=.*norm|bn).*$',
|
53 |
+
'lr': 0.00002,
|
54 |
+
'weight_decay': 0.
|
55 |
+
},
|
56 |
+
{
|
57 |
+
'params': '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn|bias)).*$',
|
58 |
+
'weight_decay': 0.
|
59 |
+
}
|
60 |
+
]
|
61 |
+
lr = 0.0002
|
62 |
+
betas = [0.9, 0.999]
|
63 |
+
weight_decay = 0.0001
|
linea/configs/linea/linea_hgnetv2_n.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'./include/dataset.py',
|
3 |
+
'./include/optimizer.py',
|
4 |
+
'./include/linea.py'
|
5 |
+
]
|
6 |
+
|
7 |
+
output_dir = 'output/line_hgnetv2_n'
|
8 |
+
|
9 |
+
# backbone
|
10 |
+
backbone = 'HGNetv2_B0'
|
11 |
+
use_lab = True
|
12 |
+
freeze_norm = False
|
13 |
+
freeze_stem_only = True
|
14 |
+
|
15 |
+
# transformer
|
16 |
+
feat_strides = [8, 16, 32]
|
17 |
+
hidden_dim = 128
|
18 |
+
dim_feedforward = 512
|
19 |
+
nheads = 8
|
20 |
+
use_lmap = False
|
21 |
+
|
22 |
+
## encoder
|
23 |
+
hybrid_encoder = 'hybrid_encoder_asymmetric_conv'
|
24 |
+
in_channels_encoder = [256, 512, 1024]
|
25 |
+
pe_temperatureH = 20
|
26 |
+
pe_temperatureW = 20
|
27 |
+
expansion = 0.34
|
28 |
+
depth_mult = 0.5
|
29 |
+
|
30 |
+
## decoder
|
31 |
+
feat_channels_decoder = [hidden_dim, hidden_dim, hidden_dim]
|
32 |
+
dec_layers = 3
|
33 |
+
num_queries = 1100
|
34 |
+
num_select = 300
|
35 |
+
reg_max = 16
|
36 |
+
reg_scale = 4
|
37 |
+
eval_idx = 2
|
38 |
+
|
39 |
+
# criterion
|
40 |
+
epochs = 72
|
41 |
+
lr_drop_list = [60]
|
42 |
+
weight_dict = {'loss_logits': 2, 'loss_line': 5}
|
43 |
+
use_warmup = False
|
44 |
+
|
45 |
+
# optimizer params
|
46 |
+
model_parameters = [
|
47 |
+
{
|
48 |
+
'params': '^(?=.*backbone)(?!.*norm|bn).*$',
|
49 |
+
'lr': 0.0004
|
50 |
+
},
|
51 |
+
{
|
52 |
+
'params': '^(?=.*backbone)(?=.*norm|bn).*$',
|
53 |
+
'lr': 0.0004,
|
54 |
+
'weight_decay': 0.
|
55 |
+
},
|
56 |
+
{
|
57 |
+
'params': '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn|bias)).*$',
|
58 |
+
'weight_decay': 0.
|
59 |
+
}
|
60 |
+
]
|
61 |
+
lr = 0.0008
|
62 |
+
betas = [0.9, 0.999]
|
63 |
+
weight_decay = 0.0001
|
linea/configs/linea/linea_hgnetv2_s.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'./include/dataset.py',
|
3 |
+
'./include/optimizer.py',
|
4 |
+
'./include/linea.py'
|
5 |
+
]
|
6 |
+
|
7 |
+
output_dir = 'output/line_hgnetv2_s'
|
8 |
+
|
9 |
+
# backbone
|
10 |
+
backbone = 'HGNetv2_B1'
|
11 |
+
use_lab = True
|
12 |
+
freeze_norm = False
|
13 |
+
freeze_stem_only = True
|
14 |
+
|
15 |
+
# transformer
|
16 |
+
feat_strides = [8, 16, 32]
|
17 |
+
hidden_dim = 256
|
18 |
+
dim_feedforward = 512
|
19 |
+
nheads = 8
|
20 |
+
use_lmap = False
|
21 |
+
|
22 |
+
## encoder
|
23 |
+
hybrid_encoder = 'hybrid_encoder_asymmetric_conv'
|
24 |
+
in_channels_encoder = [256, 512, 1024]
|
25 |
+
pe_temperatureH = 20
|
26 |
+
pe_temperatureW = 20
|
27 |
+
expansion = 0.34
|
28 |
+
depth_mult = 0.5
|
29 |
+
|
30 |
+
## decoder
|
31 |
+
feat_channels_decoder = [hidden_dim, hidden_dim, hidden_dim]
|
32 |
+
dec_layers = 3
|
33 |
+
num_queries = 1100
|
34 |
+
num_select = 300
|
35 |
+
reg_max = 16
|
36 |
+
reg_scale = 4
|
37 |
+
eval_idx = 2
|
38 |
+
|
39 |
+
# criterion
|
40 |
+
epochs = 36
|
41 |
+
lr_drop_list = [25]
|
42 |
+
weight_dict = {'loss_logits': 2, 'loss_line': 5}
|
43 |
+
use_warmup = True
|
44 |
+
warmup_iters = 625 * 5
|
45 |
+
|
46 |
+
# optimizer params
|
47 |
+
model_parameters = [
|
48 |
+
{
|
49 |
+
'params': '^(?=.*backbone)(?!.*norm|bn).*$',
|
50 |
+
'lr': 0.0001
|
51 |
+
},
|
52 |
+
{
|
53 |
+
'params': '^(?=.*backbone)(?=.*norm|bn).*$',
|
54 |
+
'lr': 0.0001,
|
55 |
+
'weight_decay': 0.
|
56 |
+
},
|
57 |
+
{
|
58 |
+
'params': '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn|bias)).*$',
|
59 |
+
'weight_decay': 0.
|
60 |
+
}
|
61 |
+
]
|
62 |
+
lr = 0.0002
|
63 |
+
betas = [0.9, 0.999]
|
64 |
+
weight_decay = 0.0001
|
linea/models/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# DINO
|
3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
7 |
+
from .linea import build_linea
|
8 |
+
|
linea/models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (210 Bytes). View file
|
|
linea/models/__pycache__/registry.cpython-311.pyc
ADDED
Binary file (3.14 kB). View file
|
|
linea/models/linea/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Conditional DETR
|
3 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------
|
9 |
+
|
10 |
+
from .linea import build_linea
|
11 |
+
from .criterion import build_criterion
|
linea/models/linea/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (281 Bytes). View file
|
|
linea/models/linea/__pycache__/attention_mechanism.cpython-311.pyc
ADDED
Binary file (17.4 kB). View file
|
|
linea/models/linea/__pycache__/criterion.cpython-311.pyc
ADDED
Binary file (40.7 kB). View file
|
|
linea/models/linea/__pycache__/decoder.cpython-311.pyc
ADDED
Binary file (31.3 kB). View file
|
|
linea/models/linea/__pycache__/dn_components.cpython-311.pyc
ADDED
Binary file (11.9 kB). View file
|
|
linea/models/linea/__pycache__/hgnetv2.cpython-311.pyc
ADDED
Binary file (25.1 kB). View file
|
|
linea/models/linea/__pycache__/hybrid_encoder_asymmetric_conv.cpython-311.pyc
ADDED
Binary file (35 kB). View file
|
|
linea/models/linea/__pycache__/linea.cpython-311.pyc
ADDED
Binary file (7.82 kB). View file
|
|
linea/models/linea/__pycache__/linea_utils.cpython-311.pyc
ADDED
Binary file (10.1 kB). View file
|
|
linea/models/linea/__pycache__/matcher.cpython-311.pyc
ADDED
Binary file (10.6 kB). View file
|
|
linea/models/linea/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (9.26 kB). View file
|
|
linea/models/linea/attention_mechanism.py
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.init import xavier_uniform_, constant_
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
def _is_power_of_2(n):
|
10 |
+
if (not isinstance(n, int)) or (n < 0):
|
11 |
+
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
|
12 |
+
return (n & (n-1) == 0) and n != 0
|
13 |
+
|
14 |
+
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights, total_num_points):
|
15 |
+
# for debug and test only,
|
16 |
+
# need to use cuda version instead
|
17 |
+
N_, S_, M_, D_ = value.shape
|
18 |
+
_, Lq_, M_, P_, _ = sampling_locations[0].shape
|
19 |
+
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
20 |
+
# sampling_grids = 2 * sampling_locations - 1
|
21 |
+
sampling_value_list = []
|
22 |
+
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
|
23 |
+
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
|
24 |
+
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
|
25 |
+
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
|
26 |
+
sampling_grid_l_ = (2 * sampling_locations[lid_] - 1).transpose(1, 2).flatten(0, 1)
|
27 |
+
# N_*M_, D_, Lq_, P_
|
28 |
+
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
|
29 |
+
mode='bilinear', padding_mode='zeros', align_corners=False)
|
30 |
+
sampling_value_list.append(sampling_value_l_)
|
31 |
+
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
|
32 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, total_num_points)
|
33 |
+
output = (torch.cat(sampling_value_list, dim=-1) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
|
34 |
+
return output.transpose(1, 2).contiguous()
|
35 |
+
|
36 |
+
def ms_deform_attn_core_pytorchv2(value, value_spatial_shapes, sampling_locations, attention_weights, num_points_list):
|
37 |
+
# for debug and test only,
|
38 |
+
# need to use cuda version instead
|
39 |
+
_, D_ , _= value[0].shape
|
40 |
+
N_, Lq_, M_, _, _ = sampling_locations.shape
|
41 |
+
|
42 |
+
sampling_grids = 2 * sampling_locations - 1
|
43 |
+
sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
44 |
+
sampling_locations_list = sampling_grids.split(num_points_list, dim=-2)
|
45 |
+
|
46 |
+
sampling_value_list = []
|
47 |
+
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
|
48 |
+
# N_* M_, D_, H_*W_ -> N_*M_, D_, H_, W_
|
49 |
+
value_l_ = value[lid_].unflatten(2, (H_, W_))
|
50 |
+
# N_*M_, Lq_, P_, 2
|
51 |
+
sampling_grid_l_ = sampling_locations_list[lid_]
|
52 |
+
# N_*M_, D_, Lq_, P_
|
53 |
+
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
|
54 |
+
mode='bilinear', padding_mode='zeros', align_corners=False)
|
55 |
+
sampling_value_list.append(sampling_value_l_)
|
56 |
+
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
|
57 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, sum(num_points_list))
|
58 |
+
output = (torch.cat(sampling_value_list, dim=-1) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
|
59 |
+
return output.transpose(1, 2).contiguous()
|
60 |
+
|
61 |
+
|
62 |
+
class MSDeformAttn(nn.Module):
|
63 |
+
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
|
64 |
+
"""
|
65 |
+
Multi-Scale Deformable Attention Module
|
66 |
+
:param d_model hidden dimension
|
67 |
+
:param n_levels number of feature levels
|
68 |
+
:param n_heads number of attention heads
|
69 |
+
:param n_points number of sampling points per attention head per feature level
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
if d_model % n_heads != 0:
|
73 |
+
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
|
74 |
+
_d_per_head = d_model // n_heads
|
75 |
+
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
76 |
+
if not _is_power_of_2(_d_per_head):
|
77 |
+
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
78 |
+
"which is more efficient in our CUDA implementation.")
|
79 |
+
|
80 |
+
self.d_model = d_model
|
81 |
+
self.n_levels = n_levels
|
82 |
+
self.n_heads = n_heads
|
83 |
+
self.n_points = n_points
|
84 |
+
|
85 |
+
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
86 |
+
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
87 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
88 |
+
self.output_proj = nn.Linear(d_model, d_model)
|
89 |
+
|
90 |
+
self._reset_parameters()
|
91 |
+
|
92 |
+
def _reset_parameters(self):
|
93 |
+
constant_(self.sampling_offsets.weight.data, 0.)
|
94 |
+
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
95 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
96 |
+
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
|
97 |
+
for i in range(self.n_points):
|
98 |
+
grid_init[:, :, i, :] *= i + 1
|
99 |
+
with torch.no_grad():
|
100 |
+
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
101 |
+
constant_(self.attention_weights.weight.data, 0.)
|
102 |
+
constant_(self.attention_weights.bias.data, 0.)
|
103 |
+
xavier_uniform_(self.value_proj.weight.data)
|
104 |
+
constant_(self.value_proj.bias.data, 0.)
|
105 |
+
xavier_uniform_(self.output_proj.weight.data)
|
106 |
+
constant_(self.output_proj.bias.data, 0.)
|
107 |
+
|
108 |
+
def forward(self, query, reference_points, input_flatten, input_spatial_shapes):
|
109 |
+
"""
|
110 |
+
:param query (N, Length_{query}, C)
|
111 |
+
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
|
112 |
+
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
113 |
+
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
114 |
+
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
115 |
+
|
116 |
+
:return output (N, Length_{query}, C)
|
117 |
+
"""
|
118 |
+
N, Len_q, _ = query.shape
|
119 |
+
N, Len_in, _ = input_flatten.shape
|
120 |
+
|
121 |
+
value = self.value_proj(input_flatten)
|
122 |
+
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
|
123 |
+
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
|
124 |
+
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
|
125 |
+
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
|
126 |
+
# N, Len_q, n_heads, n_levels, n_points, 2
|
127 |
+
if reference_points.shape[-1] == 2:
|
128 |
+
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
|
129 |
+
sampling_locations = reference_points[:, :, None, :, None, :] \
|
130 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
131 |
+
elif reference_points.shape[-1] == 4:
|
132 |
+
sampling_locations = reference_points[:, :, None, :, None, :2] \
|
133 |
+
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
134 |
+
else:
|
135 |
+
raise ValueError(
|
136 |
+
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
|
137 |
+
output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
|
138 |
+
output = self.output_proj(output)
|
139 |
+
return output
|
140 |
+
|
141 |
+
class MSDeformLineAttn(nn.Module):
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
d_model=256,
|
145 |
+
n_levels=4,
|
146 |
+
n_heads=8,
|
147 |
+
n_points=4
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
This version is inspired from DFine. We removed the following layers:
|
151 |
+
- value_proj
|
152 |
+
- output_proj
|
153 |
+
"""
|
154 |
+
super().__init__()
|
155 |
+
if d_model % n_heads != 0:
|
156 |
+
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
|
157 |
+
_d_per_head = d_model // n_heads
|
158 |
+
|
159 |
+
self.d_model = d_model
|
160 |
+
self.n_levels = n_levels
|
161 |
+
self.n_heads = n_heads
|
162 |
+
|
163 |
+
if isinstance(n_points, list):
|
164 |
+
assert len(n_points) == n_levels, ''
|
165 |
+
num_points_list = n_points
|
166 |
+
else:
|
167 |
+
num_points_list = [n_points for _ in range(n_levels)]
|
168 |
+
self.num_points_list = num_points_list
|
169 |
+
self.total_num_points = sum(num_points_list)
|
170 |
+
|
171 |
+
num_points_scale = [1/n for n in num_points_list for _ in range(n)]
|
172 |
+
self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32).reshape(-1, 1))
|
173 |
+
|
174 |
+
self.sampling_ratios = nn.Linear(d_model, n_heads * sum(num_points_list))
|
175 |
+
self.attention_weights = nn.Linear(d_model, n_heads * sum(num_points_list))
|
176 |
+
|
177 |
+
self._reset_parameters()
|
178 |
+
|
179 |
+
def _reset_parameters(self):
|
180 |
+
constant_(self.sampling_ratios.weight.data, 0.)
|
181 |
+
with torch.no_grad():
|
182 |
+
self.sampling_ratios.bias = nn.Parameter(torch.linspace(-1, 1, self.n_heads * self.total_num_points))
|
183 |
+
|
184 |
+
constant_(self.attention_weights.weight.data, 0.)
|
185 |
+
constant_(self.attention_weights.bias.data, 0.)
|
186 |
+
|
187 |
+
def forward(self, query, reference_points, value, value_spatial_shapes):
|
188 |
+
"""
|
189 |
+
:param query (N, Length_{query}, C)
|
190 |
+
:param reference_points (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
191 |
+
:param value (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
192 |
+
:param value_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
193 |
+
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
194 |
+
|
195 |
+
:return output (N, Length_{query}, C)
|
196 |
+
|
197 |
+
####################################################################
|
198 |
+
# Difference respect to MSDeformAttn
|
199 |
+
# The query already stores the line's junctions
|
200 |
+
# :param reference_points is not needed. We keep it to make both
|
201 |
+
MSDeformAttn and MSDeformLineAttn interchangebale
|
202 |
+
between different frameworks
|
203 |
+
# MSDeformLineAttn does not generate offsets. Instead, it samples
|
204 |
+
n_points equally-spaced points from the line segment
|
205 |
+
####################################################################
|
206 |
+
"""
|
207 |
+
N, Len_q, _ = query.shape
|
208 |
+
|
209 |
+
sampling_ratios = self.sampling_ratios(query).view(N, Len_q, self.n_heads, self.total_num_points, 1)
|
210 |
+
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.total_num_points)
|
211 |
+
attention_weights = F.softmax(attention_weights, -1)
|
212 |
+
|
213 |
+
num_points_scale = self.num_points_scale.to(dtype=query.dtype)
|
214 |
+
|
215 |
+
vector = reference_points[:, :, None, :, :2] - reference_points[:, :, None, :, 2:]
|
216 |
+
center = 0.5 * (reference_points[:, :, None, :, :2] + reference_points[:, :, None, :, 2:])
|
217 |
+
|
218 |
+
sampling_locations = center + sampling_ratios * num_points_scale * vector * 0.5
|
219 |
+
|
220 |
+
output = ms_deform_attn_core_pytorchv2(
|
221 |
+
value,
|
222 |
+
value_spatial_shapes,
|
223 |
+
sampling_locations,
|
224 |
+
attention_weights,
|
225 |
+
self.num_points_list
|
226 |
+
)
|
227 |
+
return output
|
228 |
+
|
229 |
+
|
230 |
+
#######################
|
231 |
+
## Previous versions ##
|
232 |
+
#######################
|
233 |
+
|
234 |
+
# class MSDeformLineAttn(nn.Module):
|
235 |
+
# def __init__(
|
236 |
+
# self,
|
237 |
+
# d_model=256,
|
238 |
+
# n_levels=4,
|
239 |
+
# n_heads=8,
|
240 |
+
# n_points=4
|
241 |
+
# ):
|
242 |
+
# """
|
243 |
+
# Multi-Scale Deformable Attention Module
|
244 |
+
# :param d_model hidden dimension
|
245 |
+
# :param n_levels number of feature levels
|
246 |
+
# :param n_heads number of attention heads
|
247 |
+
# :param n_points number of sampling points per attention head per feature level
|
248 |
+
# """
|
249 |
+
# super().__init__()
|
250 |
+
# if d_model % n_heads != 0:
|
251 |
+
# raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
|
252 |
+
# _d_per_head = d_model // n_heads
|
253 |
+
# # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
254 |
+
# if not _is_power_of_2(_d_per_head):
|
255 |
+
# warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
256 |
+
# "which is more efficient in our CUDA implementation.")
|
257 |
+
|
258 |
+
# self.d_model = d_model
|
259 |
+
# self.n_levels = n_levels
|
260 |
+
# self.n_heads = n_heads
|
261 |
+
|
262 |
+
# if isinstance(n_points, list):
|
263 |
+
# assert len(n_points) == n_levels, ''
|
264 |
+
# num_points_list = n_points
|
265 |
+
# else:
|
266 |
+
# num_points_list = [n_points for _ in range(n_levels)]
|
267 |
+
# self.num_points_list = num_points_list
|
268 |
+
# self.total_num_points = sum(num_points_list)
|
269 |
+
|
270 |
+
# num_points_scale = [1/n for n in num_points_list]
|
271 |
+
# self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32).reshape(-1, 1, 1))
|
272 |
+
|
273 |
+
# self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * 4)
|
274 |
+
|
275 |
+
# self.attention_weights = nn.Linear(d_model, n_heads * sum(num_points_list))
|
276 |
+
# self.value_proj = nn.Linear(d_model, d_model)
|
277 |
+
# self.output_proj = nn.Linear(d_model, d_model)
|
278 |
+
|
279 |
+
# for i in range(len(num_points_list)):
|
280 |
+
# if num_points_list[i] == 1:
|
281 |
+
# lambda_ = torch.linspace(0.5, 0.5, num_points_list[i])[:, None]
|
282 |
+
# else:
|
283 |
+
# lambda_ = torch.linspace(0, 1, num_points_list[i])[:, None]
|
284 |
+
# self.register_buffer(f"lambda_{i}", lambda_)
|
285 |
+
|
286 |
+
# self._reset_parameters()
|
287 |
+
|
288 |
+
# def _reset_parameters(self):
|
289 |
+
# constant_(self.sampling_offsets.weight.data, 0.)
|
290 |
+
# thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
291 |
+
# grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
292 |
+
# grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, 2, 1)
|
293 |
+
# for i in range(1):
|
294 |
+
# grid_init[:, :, 2*i, :] *= i + 1
|
295 |
+
# grid_init[:, :, 2*i+1, :] *= i + 1
|
296 |
+
# with torch.no_grad():
|
297 |
+
# self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
298 |
+
# constant_(self.attention_weights.weight.data, 0.)
|
299 |
+
# constant_(self.attention_weights.bias.data, 0.)
|
300 |
+
# xavier_uniform_(self.value_proj.weight.data)
|
301 |
+
# constant_(self.value_proj.bias.data, 0.)
|
302 |
+
# xavier_uniform_(self.output_proj.weight.data)
|
303 |
+
# constant_(self.output_proj.bias.data, 0.)
|
304 |
+
|
305 |
+
# def forward(self, query, reference_points, input_flatten, input_spatial_shapes):
|
306 |
+
# """
|
307 |
+
# :param query (N, Length_{query}, C)
|
308 |
+
# :param reference_points (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
309 |
+
# :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
310 |
+
# :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
311 |
+
# :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
312 |
+
|
313 |
+
# :return output (N, Length_{query}, C)
|
314 |
+
|
315 |
+
# ####################################################################
|
316 |
+
# # Difference respect to MSDeformAttn
|
317 |
+
# # The query already stores the line's junctions
|
318 |
+
# # :param reference_points is not needed. We keep it to make both
|
319 |
+
# MSDeformAttn and MSDeformLineAttn interchangebale
|
320 |
+
# between different frameworks
|
321 |
+
# # MSDeformLineAttn does not generate offsets. Instead, it samples
|
322 |
+
# n_points equally-spaced points from the line segment
|
323 |
+
# ####################################################################
|
324 |
+
# """
|
325 |
+
# N, Len_q, _ = query.shape
|
326 |
+
# N, Len_in, _ = input_flatten.shape
|
327 |
+
|
328 |
+
# value = self.value_proj(input_flatten)
|
329 |
+
|
330 |
+
# value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
|
331 |
+
# sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, 1, 4)
|
332 |
+
# attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.total_num_points)
|
333 |
+
# attention_weights = F.softmax(attention_weights, -1)
|
334 |
+
|
335 |
+
# num_points_scale = self.num_points_scale.to(dtype=query.dtype)
|
336 |
+
|
337 |
+
# wh = reference_points[:, :, None, :, None, :2] - reference_points[:, :, None, :, None, 2:]
|
338 |
+
# center = 0.5 * (reference_points[:, :, None, :, None, :2] + reference_points[:, :, None, :, None, 2:])
|
339 |
+
|
340 |
+
# sampling_junctions = torch.cat((center, center), dim=-1) \
|
341 |
+
# + sampling_offsets * num_points_scale * torch.cat([wh, wh], -1) * 0.5
|
342 |
+
|
343 |
+
# sampling_locations = []
|
344 |
+
|
345 |
+
# # sampling_junctions_level = torch.split(sampling_junctions, self.num_points_list, dim=-2)
|
346 |
+
# for i in range(len(self.num_points_list)):
|
347 |
+
# lambda_ = getattr(self, f'lambda_{i}')
|
348 |
+
# junctions = sampling_junctions[:, :, :, i]
|
349 |
+
# locations = junctions[..., :2] * lambda_ + junctions[..., 2:] * (1 - lambda_)
|
350 |
+
# sampling_locations.append(locations)
|
351 |
+
|
352 |
+
# output = ms_deform_attn_core_pytorch(
|
353 |
+
# value,
|
354 |
+
# input_spatial_shapes,
|
355 |
+
# sampling_locations,
|
356 |
+
# attention_weights,
|
357 |
+
# self.total_num_points
|
358 |
+
# )
|
359 |
+
# output = self.output_proj(output)
|
360 |
+
# return output
|
361 |
+
|
362 |
+
|
363 |
+
# class MSDeformLineAttnV2(nn.Module):
|
364 |
+
# def __init__(
|
365 |
+
# self,
|
366 |
+
# d_model=256,
|
367 |
+
# n_levels=4,
|
368 |
+
# n_heads=8,
|
369 |
+
# n_points=4
|
370 |
+
# ):
|
371 |
+
# """
|
372 |
+
# Multi-Scale Deformable Attention Module
|
373 |
+
# :param d_model hidden dimension
|
374 |
+
# :param n_levels number of feature levels
|
375 |
+
# :param n_heads number of attention heads
|
376 |
+
# :param n_points number of sampling points per attention head per feature level
|
377 |
+
# """
|
378 |
+
# super().__init__()
|
379 |
+
# if d_model % n_heads != 0:
|
380 |
+
# raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
|
381 |
+
# _d_per_head = d_model // n_heads
|
382 |
+
# # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
383 |
+
# if not _is_power_of_2(_d_per_head):
|
384 |
+
# warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
385 |
+
# "which is more efficient in our CUDA implementation.")
|
386 |
+
|
387 |
+
# self.d_model = d_model
|
388 |
+
# self.n_levels = n_levels
|
389 |
+
# self.n_heads = n_heads
|
390 |
+
|
391 |
+
# if isinstance(n_points, list):
|
392 |
+
# assert len(n_points) == n_levels, ''
|
393 |
+
# num_points_list = n_points
|
394 |
+
# else:
|
395 |
+
# num_points_list = [n_points for _ in range(n_levels)]
|
396 |
+
# self.num_points_list = num_points_list
|
397 |
+
# self.total_num_points = sum(num_points_list)
|
398 |
+
|
399 |
+
# num_points_scale = [1/n for n in num_points_list]
|
400 |
+
# self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32).reshape(-1, 1, 1))
|
401 |
+
|
402 |
+
# self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * 4)
|
403 |
+
# self.sampling_ratios = nn.Linear(d_model, n_heads * sum(num_points_list))
|
404 |
+
|
405 |
+
# self.attention_weights = nn.Linear(d_model, n_heads * sum(num_points_list))
|
406 |
+
# self.value_proj = nn.Linear(d_model, d_model)
|
407 |
+
# self.output_proj = nn.Linear(d_model, d_model)
|
408 |
+
|
409 |
+
# self._reset_parameters()
|
410 |
+
|
411 |
+
# def _reset_parameters(self):
|
412 |
+
# constant_(self.sampling_offsets.weight.data, 0.)
|
413 |
+
# thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
414 |
+
# grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
415 |
+
# grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, 2, 1)
|
416 |
+
# for i in range(1):
|
417 |
+
# grid_init[:, :, 2*i, :] *= i + 1
|
418 |
+
# grid_init[:, :, 2*i+1, :] *= i + 1
|
419 |
+
# with torch.no_grad():
|
420 |
+
# self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
421 |
+
# constant_(self.attention_weights.weight.data, 0.)
|
422 |
+
# constant_(self.attention_weights.bias.data, 0.)
|
423 |
+
# xavier_uniform_(self.value_proj.weight.data)
|
424 |
+
# constant_(self.value_proj.bias.data, 0.)
|
425 |
+
# xavier_uniform_(self.output_proj.weight.data)
|
426 |
+
# constant_(self.output_proj.bias.data, 0.)
|
427 |
+
|
428 |
+
# def forward(self, query, reference_points, input_flatten, input_spatial_shapes):
|
429 |
+
# """
|
430 |
+
# :param query (N, Length_{query}, C)
|
431 |
+
# :param reference_points (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
432 |
+
# :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
433 |
+
# :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
434 |
+
# :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
435 |
+
|
436 |
+
# :return output (N, Length_{query}, C)
|
437 |
+
|
438 |
+
# ####################################################################
|
439 |
+
# # Difference respect to MSDeformAttn
|
440 |
+
# # The query already stores the line's junctions
|
441 |
+
# # :param reference_points is not needed. We keep it to make both
|
442 |
+
# MSDeformAttn and MSDeformLineAttn interchangebale
|
443 |
+
# between different frameworks
|
444 |
+
# # MSDeformLineAttn does not generate offsets. Instead, it samples
|
445 |
+
# n_points equally-spaced points from the line segment
|
446 |
+
# ####################################################################
|
447 |
+
# """
|
448 |
+
# N, Len_q, _ = query.shape
|
449 |
+
# N, Len_in, _ = input_flatten.shape
|
450 |
+
|
451 |
+
# value = self.value_proj(input_flatten)
|
452 |
+
|
453 |
+
# value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
|
454 |
+
# sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, 1, 4)
|
455 |
+
# sampling_ratios = self.sampling_ratios(query).view(N, Len_q, self.n_heads, self.total_num_points).sigmoid()
|
456 |
+
# attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.total_num_points)
|
457 |
+
# attention_weights = F.softmax(attention_weights, -1)
|
458 |
+
|
459 |
+
# num_points_scale = self.num_points_scale.to(dtype=query.dtype)
|
460 |
+
|
461 |
+
# wh = reference_points[:, :, None, :, None, :2] - reference_points[:, :, None, :, None, 2:]
|
462 |
+
# center = 0.5 * (reference_points[:, :, None, :, None, :2] + reference_points[:, :, None, :, None, 2:])
|
463 |
+
|
464 |
+
# sampling_junctions = torch.cat((center, center), dim=-1) \
|
465 |
+
# + sampling_offsets * num_points_scale * torch.cat([wh, wh], -1) * 0.5
|
466 |
+
|
467 |
+
# sampling_locations = []
|
468 |
+
|
469 |
+
# for i, lambda_ in enumerate(torch.split(sampling_ratios, self.num_points_list, dim=-1)):
|
470 |
+
# lambda_ = lambda_[..., None]
|
471 |
+
# junctions = sampling_junctions[:, :, :, i]
|
472 |
+
# locations = junctions[..., :2] * lambda_ + junctions[..., 2:] * (1 - lambda_)
|
473 |
+
# sampling_locations.append(locations)
|
474 |
+
|
475 |
+
# output = ms_deform_attn_core_pytorch(
|
476 |
+
# value,
|
477 |
+
# input_spatial_shapes,
|
478 |
+
# sampling_locations,
|
479 |
+
# attention_weights,
|
480 |
+
# self.total_num_points
|
481 |
+
# )
|
482 |
+
# output = self.output_proj(output)
|
483 |
+
# return output
|
484 |
+
|
485 |
+
|
486 |
+
# class MSDeformLineAttnV3(nn.Module):
|
487 |
+
# def __init__(
|
488 |
+
# self,
|
489 |
+
# d_model=256,
|
490 |
+
# n_levels=4,
|
491 |
+
# n_heads=8,
|
492 |
+
# n_points=4
|
493 |
+
# ):
|
494 |
+
# """
|
495 |
+
# This version is inspired from DFine. We removed the following layers:
|
496 |
+
# - value_proj
|
497 |
+
# - output_proj
|
498 |
+
# """
|
499 |
+
# super().__init__()
|
500 |
+
# if d_model % n_heads != 0:
|
501 |
+
# raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
|
502 |
+
# _d_per_head = d_model // n_heads
|
503 |
+
# # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
504 |
+
# if not _is_power_of_2(_d_per_head):
|
505 |
+
# warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
506 |
+
# "which is more efficient in our CUDA implementation.")
|
507 |
+
|
508 |
+
# self.d_model = d_model
|
509 |
+
# self.n_levels = n_levels
|
510 |
+
# self.n_heads = n_heads
|
511 |
+
|
512 |
+
# if isinstance(n_points, list):
|
513 |
+
# assert len(n_points) == n_levels, ''
|
514 |
+
# num_points_list = n_points
|
515 |
+
# else:
|
516 |
+
# num_points_list = [n_points for _ in range(n_levels)]
|
517 |
+
# self.num_points_list = num_points_list
|
518 |
+
# self.total_num_points = sum(num_points_list)
|
519 |
+
|
520 |
+
# num_points_scale = [1/n for n in num_points_list]
|
521 |
+
# self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32).reshape(-1, 1, 1))
|
522 |
+
|
523 |
+
# self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * 4)
|
524 |
+
# self.sampling_ratios = nn.Linear(d_model, n_heads * sum(num_points_list))
|
525 |
+
|
526 |
+
# self.attention_weights = nn.Linear(d_model, n_heads * sum(num_points_list))
|
527 |
+
|
528 |
+
# self._reset_parameters()
|
529 |
+
|
530 |
+
# def _reset_parameters(self):
|
531 |
+
# constant_(self.sampling_offsets.weight.data, 0.)
|
532 |
+
# thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
533 |
+
# grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
534 |
+
# grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, 2, 1)
|
535 |
+
# for i in range(1):
|
536 |
+
# grid_init[:, :, 2*i, :] *= i + 1
|
537 |
+
# grid_init[:, :, 2*i+1, :] *= i + 1
|
538 |
+
# with torch.no_grad():
|
539 |
+
# self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
540 |
+
# constant_(self.attention_weights.weight.data, 0.)
|
541 |
+
# constant_(self.attention_weights.bias.data, 0.)
|
542 |
+
|
543 |
+
# def forward(self, query, reference_points, value, value_spatial_shapes):
|
544 |
+
# """
|
545 |
+
# :param query (N, Length_{query}, C)
|
546 |
+
# :param reference_points (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
547 |
+
# :param value (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
548 |
+
# :param value_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
549 |
+
# :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
550 |
+
|
551 |
+
# :return output (N, Length_{query}, C)
|
552 |
+
|
553 |
+
# ####################################################################
|
554 |
+
# # Difference respect to MSDeformAttn
|
555 |
+
# # The query already stores the line's junctions
|
556 |
+
# # :param reference_points is not needed. We keep it to make both
|
557 |
+
# MSDeformAttn and MSDeformLineAttn interchangebale
|
558 |
+
# between different frameworks
|
559 |
+
# # MSDeformLineAttn does not generate offsets. Instead, it samples
|
560 |
+
# n_points equally-spaced points from the line segment
|
561 |
+
# ####################################################################
|
562 |
+
# """
|
563 |
+
# N, Len_q, _ = query.shape
|
564 |
+
|
565 |
+
# sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, 1, 4)
|
566 |
+
# sampling_ratios = self.sampling_ratios(query).view(N, Len_q, self.n_heads, self.total_num_points).sigmoid()
|
567 |
+
# attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.total_num_points)
|
568 |
+
# attention_weights = F.softmax(attention_weights, -1)
|
569 |
+
|
570 |
+
# num_points_scale = self.num_points_scale.to(dtype=query.dtype)
|
571 |
+
|
572 |
+
# wh = reference_points[:, :, None, :, None, :2] - reference_points[:, :, None, :, None, 2:]
|
573 |
+
# center = 0.5 * (reference_points[:, :, None, :, None, :2] + reference_points[:, :, None, :, None, 2:])
|
574 |
+
|
575 |
+
# sampling_junctions = torch.cat((center, center), dim=-1) \
|
576 |
+
# + sampling_offsets * num_points_scale * torch.cat([wh, wh], -1) * 0.5
|
577 |
+
|
578 |
+
# sampling_locations = []
|
579 |
+
|
580 |
+
# for i, lambda_ in enumerate(torch.split(sampling_ratios, self.num_points_list, dim=-1)):
|
581 |
+
# lambda_ = lambda_[..., None]
|
582 |
+
# junctions = sampling_junctions[:, :, :, i]
|
583 |
+
# locations = junctions[..., :2] * lambda_ + junctions[..., 2:] * (1 - lambda_)
|
584 |
+
# sampling_locations.append(locations)
|
585 |
+
|
586 |
+
# output = ms_deform_attn_core_pytorchv2(
|
587 |
+
# value,
|
588 |
+
# value_spatial_shapes,
|
589 |
+
# sampling_locations,
|
590 |
+
# attention_weights,
|
591 |
+
# self.total_num_points
|
592 |
+
# )
|
593 |
+
# return output
|
linea/models/linea/criterion.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from torchvision.transforms.functional import resize
|
5 |
+
|
6 |
+
from .utils import sigmoid_focal_loss
|
7 |
+
|
8 |
+
from .matcher import build_matcher
|
9 |
+
|
10 |
+
from .linea_utils import weighting_function, bbox2distance
|
11 |
+
|
12 |
+
from ..registry import MODULE_BUILD_FUNCS
|
13 |
+
|
14 |
+
# TODO. Quick solution to make the model run on GoogleColab
|
15 |
+
import os, sys
|
16 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
|
17 |
+
from util.misc import get_world_size, is_dist_avail_and_initialized
|
18 |
+
|
19 |
+
|
20 |
+
class LINEACriterion(nn.Module):
|
21 |
+
""" This class computes the loss for Conditional DETR.
|
22 |
+
The process happens in two steps:
|
23 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
24 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
25 |
+
"""
|
26 |
+
def __init__(self, num_classes, matcher, weight_dict, focal_alpha, losses):
|
27 |
+
""" Create the criterion.
|
28 |
+
Parameters:
|
29 |
+
num_classes: number of object categories, omitting the special no-object category
|
30 |
+
matcher: module able to compute a matching between targets and proposals
|
31 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
32 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
33 |
+
focal_alpha: alpha in Focal Loss
|
34 |
+
"""
|
35 |
+
super().__init__()
|
36 |
+
self.num_classes = num_classes
|
37 |
+
self.matcher = matcher
|
38 |
+
self.weight_dict = weight_dict
|
39 |
+
self.losses = losses
|
40 |
+
self.focal_alpha = focal_alpha
|
41 |
+
|
42 |
+
def loss_labels(self, outputs, targets, indices, num_boxes):
|
43 |
+
"""Classification loss (Binary focal loss)
|
44 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
45 |
+
"""
|
46 |
+
assert 'pred_logits' in outputs
|
47 |
+
src_logits = outputs['pred_logits']
|
48 |
+
idx = self._get_src_permutation_idx(indices)
|
49 |
+
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
50 |
+
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
|
51 |
+
dtype=torch.int64, device=src_logits.device)
|
52 |
+
target_classes[idx] = target_classes_o
|
53 |
+
|
54 |
+
target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1],
|
55 |
+
dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
|
56 |
+
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
57 |
+
target_classes_onehot = target_classes_onehot[:,:,:-1]
|
58 |
+
|
59 |
+
loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
|
60 |
+
losses = {'loss_logits': loss_ce}
|
61 |
+
|
62 |
+
return losses
|
63 |
+
|
64 |
+
def loss_lines(self, outputs, targets, indices, num_boxes):
|
65 |
+
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
66 |
+
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
67 |
+
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
68 |
+
"""
|
69 |
+
assert 'pred_lines' in outputs
|
70 |
+
idx = self._get_src_permutation_idx(indices)
|
71 |
+
src_lines = outputs['pred_lines'][idx]
|
72 |
+
target_lines = torch.cat([t['lines'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
73 |
+
|
74 |
+
loss_line = F.l1_loss(src_lines, target_lines, reduction='none')
|
75 |
+
|
76 |
+
losses = {}
|
77 |
+
losses['loss_line'] = loss_line.sum() / num_boxes
|
78 |
+
|
79 |
+
return losses
|
80 |
+
|
81 |
+
def loss_lmap(self, outputs, targets, indices, num_boxes):
|
82 |
+
losses = {}
|
83 |
+
if 'aux_lmap' in outputs:
|
84 |
+
src_lmap = outputs['aux_lmap']
|
85 |
+
size = src_lmap[0].size(2)
|
86 |
+
target_lmap = []
|
87 |
+
for t in targets:
|
88 |
+
lmaps_flatten = []
|
89 |
+
for lmap, downsampling in zip(t['lmap'], [1, 2, 4]):
|
90 |
+
lmap_ = resize(lmap, (size//downsampling, size//downsampling))
|
91 |
+
lmaps_flatten.append(lmap_.flatten(1))
|
92 |
+
target_lmap.append(torch.cat(lmaps_flatten, dim=1))
|
93 |
+
target_lmap = torch.cat(target_lmap, dim=0)
|
94 |
+
|
95 |
+
src_lmap = torch.cat([lmap_.flatten(1) for lmap_ in src_lmap], dim=1)
|
96 |
+
|
97 |
+
loss_lmap = F.binary_cross_entropy_with_logits(src_lmap, target_lmap, reduction='mean')
|
98 |
+
|
99 |
+
losses['loss_lmap'] = loss_lmap
|
100 |
+
|
101 |
+
return losses
|
102 |
+
|
103 |
+
def _get_src_permutation_idx(self, indices):
|
104 |
+
# permute predictions following indices
|
105 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
106 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
107 |
+
return batch_idx, src_idx
|
108 |
+
|
109 |
+
def _get_tgt_permutation_idx(self, indices):
|
110 |
+
# permute targets following indices
|
111 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
112 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
113 |
+
return batch_idx, tgt_idx
|
114 |
+
|
115 |
+
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
116 |
+
loss_map = {
|
117 |
+
'labels': self.loss_labels,
|
118 |
+
'lines': self.loss_lines,
|
119 |
+
'lmap': self.loss_lmap,
|
120 |
+
}
|
121 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
122 |
+
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
123 |
+
|
124 |
+
def forward(self, outputs, targets, return_indices=False):
|
125 |
+
""" This performs the loss computation.
|
126 |
+
Parameters:
|
127 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
128 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
129 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
130 |
+
|
131 |
+
return_indices: used for vis. if True, the layer0-5 indices will be returned as well.
|
132 |
+
|
133 |
+
"""
|
134 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
|
135 |
+
device = next(iter(outputs.values())).device
|
136 |
+
indices = self.matcher(outputs_without_aux, targets)
|
137 |
+
if return_indices:
|
138 |
+
return indices
|
139 |
+
|
140 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
141 |
+
num_boxes = sum(len(t["labels"]) for t in targets)
|
142 |
+
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device)
|
143 |
+
if is_dist_avail_and_initialized():
|
144 |
+
torch.distributed.all_reduce(num_boxes)
|
145 |
+
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
146 |
+
|
147 |
+
# Compute all the requested losses
|
148 |
+
losses = {}
|
149 |
+
|
150 |
+
for loss in self.losses:
|
151 |
+
indices_in = indices
|
152 |
+
num_boxes_in = num_boxes
|
153 |
+
l_dict = self.get_loss(loss, outputs, targets, indices_in, num_boxes_in)
|
154 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
155 |
+
losses.update(l_dict)
|
156 |
+
|
157 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
158 |
+
if 'aux_outputs' in outputs:
|
159 |
+
for idx, aux_outputs in enumerate(outputs['aux_outputs']):
|
160 |
+
indices = self.matcher(aux_outputs, targets)
|
161 |
+
for loss in self.losses:
|
162 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
|
163 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
164 |
+
l_dict = {k + f'_{idx}': v for k, v in l_dict.items()}
|
165 |
+
losses.update(l_dict)
|
166 |
+
|
167 |
+
# interm_outputs loss
|
168 |
+
if 'aux_interm_outputs' in outputs:
|
169 |
+
interm_outputs = outputs['aux_interm_outputs']
|
170 |
+
indices = self.matcher(interm_outputs, targets)
|
171 |
+
for loss in self.losses:
|
172 |
+
l_dict = self.get_loss(loss, interm_outputs, targets, indices, num_boxes)
|
173 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
174 |
+
l_dict = {k + f'_interm': v for k, v in l_dict.items()}
|
175 |
+
losses.update(l_dict)
|
176 |
+
|
177 |
+
# pre output loss
|
178 |
+
if 'aux_pre_outputs' in outputs:
|
179 |
+
pre_outputs = outputs['aux_pre_outputs']
|
180 |
+
indices = self.matcher(pre_outputs, targets)
|
181 |
+
for loss in self.losses:
|
182 |
+
l_dict = self.get_loss(loss, pre_outputs, targets, indices, num_boxes)
|
183 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
184 |
+
l_dict = {k + f'_pre': v for k, v in l_dict.items()}
|
185 |
+
losses.update(l_dict)
|
186 |
+
|
187 |
+
# prepare for dn loss
|
188 |
+
dn_meta = outputs['dn_meta']
|
189 |
+
|
190 |
+
if self.training and dn_meta and 'aux_denoise' in outputs:
|
191 |
+
single_pad, scalar = self.prep_for_dn(dn_meta)
|
192 |
+
dn_pos_idx = []
|
193 |
+
dn_neg_idx = []
|
194 |
+
for i in range(len(targets)):
|
195 |
+
if len(targets[i]['labels']) > 0:
|
196 |
+
t = torch.arange(len(targets[i]['labels'])).long().cuda()
|
197 |
+
t = t.unsqueeze(0).repeat(scalar, 1)
|
198 |
+
tgt_idx = t.flatten()
|
199 |
+
output_idx = (torch.tensor(range(scalar)) * single_pad).long().cuda().unsqueeze(1) + t
|
200 |
+
output_idx = output_idx.flatten()
|
201 |
+
else:
|
202 |
+
output_idx = tgt_idx = torch.tensor([]).long().cuda()
|
203 |
+
|
204 |
+
dn_pos_idx.append((output_idx, tgt_idx))
|
205 |
+
dn_neg_idx.append((output_idx + single_pad // 2, tgt_idx))
|
206 |
+
|
207 |
+
dn_outputs = outputs['aux_denoise']
|
208 |
+
|
209 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
210 |
+
if 'aux_outputs' in dn_outputs:
|
211 |
+
for idx, aux_outputs in enumerate(dn_outputs['aux_outputs']):
|
212 |
+
for loss in self.losses:
|
213 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, dn_pos_idx, num_boxes*scalar)
|
214 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
215 |
+
l_dict = {k + f'_dn_{idx}': v for k, v in l_dict.items()}
|
216 |
+
losses.update(l_dict)
|
217 |
+
|
218 |
+
if 'aux_pre_outputs' in dn_outputs:
|
219 |
+
aux_outputs_known = dn_outputs['aux_pre_outputs']
|
220 |
+
l_dict={}
|
221 |
+
for loss in self.losses:
|
222 |
+
l_dict.update(self.get_loss(loss, aux_outputs_known, targets, dn_pos_idx, num_boxes*scalar))
|
223 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
224 |
+
l_dict = {k + f'_pre_dn': v for k, v in l_dict.items()}
|
225 |
+
losses.update(l_dict)
|
226 |
+
|
227 |
+
losses = {k: v for k, v in sorted(losses.items(), key=lambda item: item[0])}
|
228 |
+
|
229 |
+
return losses
|
230 |
+
|
231 |
+
def prep_for_dn(self,dn_meta):
|
232 |
+
# output_known_lbs_lines = dn_meta['output_known_lbs_lines']
|
233 |
+
num_dn_groups, pad_size=dn_meta['num_dn_group'],dn_meta['pad_size']
|
234 |
+
assert pad_size % num_dn_groups==0
|
235 |
+
single_pad=pad_size//num_dn_groups
|
236 |
+
|
237 |
+
return single_pad,num_dn_groups
|
238 |
+
|
239 |
+
class DFINESetCriterion(LINEACriterion):
|
240 |
+
def __init__(self, num_classes, matcher, weight_dict, focal_alpha, reg_max, losses):
|
241 |
+
super().__init__(num_classes, matcher, weight_dict, focal_alpha, losses)
|
242 |
+
self.reg_max = reg_max
|
243 |
+
|
244 |
+
def loss_local(self, outputs, targets, indices, num_boxes, T=5):
|
245 |
+
losses = {}
|
246 |
+
if 'pred_corners' in outputs:
|
247 |
+
idx = self._get_src_permutation_idx(indices)
|
248 |
+
target_lines = torch.cat([t['lines'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
249 |
+
|
250 |
+
pred_corners = outputs['pred_corners'][idx].reshape(-1, (self.reg_max+1))
|
251 |
+
ref_points = outputs['ref_points'][idx].detach()
|
252 |
+
|
253 |
+
with torch.no_grad():
|
254 |
+
if self.fgl_targets_dn is None and 'is_dn' in outputs:
|
255 |
+
self.fgl_targets_dn= bbox2distance(ref_points, target_lines,
|
256 |
+
self.reg_max, outputs['reg_scale'],
|
257 |
+
outputs['up'])
|
258 |
+
if self.fgl_targets is None and 'is_dn' not in outputs:
|
259 |
+
self.fgl_targets = bbox2distance(ref_points, target_lines,
|
260 |
+
self.reg_max, outputs['reg_scale'],
|
261 |
+
outputs['up'])
|
262 |
+
|
263 |
+
target_corners, weight_right, weight_left = self.fgl_targets_dn if 'is_dn' in outputs else self.fgl_targets
|
264 |
+
|
265 |
+
losses['loss_fgl'] = self.unimodal_distribution_focal_loss(
|
266 |
+
pred_corners, target_corners, weight_right, weight_left, None, avg_factor=num_boxes)
|
267 |
+
|
268 |
+
if 'teacher_corners' in outputs:
|
269 |
+
pred_corners = outputs['pred_corners'].reshape(-1, (self.reg_max+1))
|
270 |
+
target_corners = outputs['teacher_corners'].reshape(-1, (self.reg_max+1))
|
271 |
+
if torch.equal(pred_corners, target_corners):
|
272 |
+
losses['loss_ddf'] = pred_corners.sum() * 0
|
273 |
+
else:
|
274 |
+
weight_targets_local = outputs['teacher_logits'].sigmoid().max(dim=-1)[0]
|
275 |
+
|
276 |
+
mask = torch.zeros_like(weight_targets_local, dtype=torch.bool)
|
277 |
+
mask[idx] = True
|
278 |
+
mask = mask.unsqueeze(-1).repeat(1, 1, 4).reshape(-1)
|
279 |
+
|
280 |
+
weight_targets_local = weight_targets_local.unsqueeze(-1).repeat(1, 1, 4).reshape(-1).detach()
|
281 |
+
|
282 |
+
loss_match_local = weight_targets_local * (T ** 2) * (nn.KLDivLoss(reduction='none')
|
283 |
+
(F.log_softmax(pred_corners / T, dim=1), F.softmax(target_corners.detach() / T, dim=1))).sum(-1)
|
284 |
+
if 'is_dn' not in outputs:
|
285 |
+
batch_scale = 8 / outputs['pred_lines'].shape[0] # Avoid the influence of batch size per GPU
|
286 |
+
self.num_pos, self.num_neg = (mask.sum() * batch_scale) ** 0.5, ((~mask).sum() * batch_scale) ** 0.5
|
287 |
+
loss_match_local1 = loss_match_local[mask].mean() if mask.any() else 0
|
288 |
+
loss_match_local2 = loss_match_local[~mask].mean() if (~mask).any() else 0
|
289 |
+
losses['loss_ddf'] = (loss_match_local1 * self.num_pos + loss_match_local2 * self.num_neg) / (self.num_pos + self.num_neg)
|
290 |
+
|
291 |
+
return losses
|
292 |
+
|
293 |
+
def _clear_cache(self):
|
294 |
+
self.fgl_targets, self.fgl_targets_dn = None, None
|
295 |
+
self.own_targets, self.own_targets_dn = None, None
|
296 |
+
self.num_pos, self.num_neg = None, None
|
297 |
+
|
298 |
+
def unimodal_distribution_focal_loss(self, pred, label, weight_right, weight_left, weight=None, reduction='sum', avg_factor=None):
|
299 |
+
dis_left = label.long()
|
300 |
+
dis_right = dis_left + 1
|
301 |
+
|
302 |
+
loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left.reshape(-1) \
|
303 |
+
+ F.cross_entropy(pred, dis_right, reduction='none') * weight_right.reshape(-1)
|
304 |
+
|
305 |
+
if weight is not None:
|
306 |
+
weight = weight.float()
|
307 |
+
loss = loss * weight
|
308 |
+
|
309 |
+
if avg_factor is not None:
|
310 |
+
loss = loss.sum() / avg_factor
|
311 |
+
elif reduction == 'mean':
|
312 |
+
loss = loss.mean()
|
313 |
+
elif reduction == 'sum':
|
314 |
+
loss = loss.sum()
|
315 |
+
|
316 |
+
return loss
|
317 |
+
|
318 |
+
def _get_go_indices(self, indices, indices_aux_list):
|
319 |
+
"""Get a matching union set across all decoder layers. """
|
320 |
+
results = []
|
321 |
+
for indices_aux in indices_aux_list:
|
322 |
+
indices = [(torch.cat([idx1[0], idx2[0]]), torch.cat([idx1[1], idx2[1]]))
|
323 |
+
for idx1, idx2 in zip(indices.copy(), indices_aux.copy())]
|
324 |
+
|
325 |
+
for ind in [torch.cat([idx[0][:, None], idx[1][:, None]], 1) for idx in indices]:
|
326 |
+
unique, counts = torch.unique(ind, return_counts=True, dim=0)
|
327 |
+
count_sort_indices = torch.argsort(counts, descending=True)
|
328 |
+
unique_sorted = unique[count_sort_indices]
|
329 |
+
column_to_row = {}
|
330 |
+
for idx in unique_sorted:
|
331 |
+
row_idx, col_idx = idx[0].item(), idx[1].item()
|
332 |
+
if row_idx not in column_to_row:
|
333 |
+
column_to_row[row_idx] = col_idx
|
334 |
+
final_rows = torch.tensor(list(column_to_row.keys()), device=ind.device)
|
335 |
+
final_cols = torch.tensor(list(column_to_row.values()), device=ind.device)
|
336 |
+
results.append((final_rows.long(), final_cols.long()))
|
337 |
+
return results
|
338 |
+
|
339 |
+
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
340 |
+
loss_map = {
|
341 |
+
'labels': self.loss_labels,
|
342 |
+
'lines': self.loss_lines,
|
343 |
+
'lmap': self.loss_lmap,
|
344 |
+
'local': self.loss_local,
|
345 |
+
}
|
346 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
347 |
+
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
348 |
+
|
349 |
+
def forward(self, outputs, targets):
|
350 |
+
""" This performs the loss computation.
|
351 |
+
Parameters:
|
352 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
353 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
354 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
355 |
+
|
356 |
+
return_indices: used for vis. if True, the layer0-5 indices will be returned as well.
|
357 |
+
|
358 |
+
"""
|
359 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
|
360 |
+
device = next(iter(outputs.values())).device
|
361 |
+
indices = self.matcher(outputs_without_aux, targets)
|
362 |
+
|
363 |
+
self._clear_cache()
|
364 |
+
|
365 |
+
# Get the matching union set across all decoder layers.
|
366 |
+
if 'aux_outputs' in outputs:
|
367 |
+
indices_aux_list, cached_indices, cached_indices_enc = [], [], []
|
368 |
+
for i, aux_outputs in enumerate(outputs['aux_outputs'] + [outputs['aux_pre_outputs']]):
|
369 |
+
indices_aux = self.matcher(aux_outputs, targets)
|
370 |
+
cached_indices.append(indices_aux)
|
371 |
+
indices_aux_list.append(indices_aux)
|
372 |
+
for i, aux_outputs in enumerate([outputs['aux_interm_outputs']]):
|
373 |
+
indices_enc = self.matcher(aux_outputs, targets)
|
374 |
+
cached_indices_enc.append(indices_enc)
|
375 |
+
indices_aux_list.append(indices_enc)
|
376 |
+
indices_go = self._get_go_indices(indices, indices_aux_list)
|
377 |
+
|
378 |
+
num_boxes_go = sum(len(x[0]) for x in indices_go)
|
379 |
+
num_boxes_go = torch.as_tensor([num_boxes_go], dtype=torch.float, device=next(iter(outputs.values())).device)
|
380 |
+
if is_dist_avail_and_initialized():
|
381 |
+
torch.distributed.all_reduce(num_boxes_go)
|
382 |
+
num_boxes_go = torch.clamp(num_boxes_go / get_world_size(), min=1).item()
|
383 |
+
else:
|
384 |
+
# assert 'aux_outputs' in outputs, ''
|
385 |
+
indices_go = indices
|
386 |
+
|
387 |
+
num_boxes_go = sum(len(x[0]) for x in indices_go)
|
388 |
+
num_boxes_go = torch.as_tensor([num_boxes_go], dtype=torch.float, device=next(iter(outputs.values())).device)
|
389 |
+
if is_dist_avail_and_initialized():
|
390 |
+
torch.distributed.all_reduce(num_boxes_go)
|
391 |
+
num_boxes_go = torch.clamp(num_boxes_go / get_world_size(), min=1).item()
|
392 |
+
|
393 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
394 |
+
num_boxes = sum(len(t["labels"]) for t in targets)
|
395 |
+
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device)
|
396 |
+
if is_dist_avail_and_initialized():
|
397 |
+
torch.distributed.all_reduce(num_boxes)
|
398 |
+
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
399 |
+
|
400 |
+
# Compute all the requested losses
|
401 |
+
losses = {}
|
402 |
+
|
403 |
+
for loss in self.losses:
|
404 |
+
indices_in = indices_go if loss in ['lines', 'local'] else indices
|
405 |
+
num_boxes_in = num_boxes_go if loss in ['lines', 'local'] else num_boxes
|
406 |
+
l_dict = self.get_loss(loss, outputs, targets, indices_in, num_boxes_in)
|
407 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
408 |
+
losses.update(l_dict)
|
409 |
+
|
410 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
411 |
+
if 'aux_outputs' in outputs:
|
412 |
+
for idx, aux_outputs in enumerate(outputs['aux_outputs']):
|
413 |
+
aux_outputs['up'], aux_outputs['reg_scale'] = outputs['up'], outputs['reg_scale']
|
414 |
+
# indices = self.matcher(aux_outputs, targets)
|
415 |
+
for loss in self.losses:
|
416 |
+
indices_in = indices_go if loss in ['lines', 'local'] else cached_indices[idx]
|
417 |
+
num_boxes_in = num_boxes_go if loss in ['lines', 'local'] else num_boxes
|
418 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices_in, num_boxes_in)
|
419 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
420 |
+
l_dict = {k + f'_{idx}': v for k, v in l_dict.items()}
|
421 |
+
losses.update(l_dict)
|
422 |
+
|
423 |
+
# interm_outputs loss
|
424 |
+
if 'aux_interm_outputs' in outputs:
|
425 |
+
interm_outputs = outputs['aux_interm_outputs']
|
426 |
+
# indices = self.matcher(interm_outputs, targets)
|
427 |
+
for loss in self.losses:
|
428 |
+
indices_in = indices_go if loss in ['lines', 'local'] else cached_indices_enc[0]
|
429 |
+
num_boxes_in = num_boxes_go if loss in ['lines', 'local'] else num_boxes
|
430 |
+
l_dict = self.get_loss(loss, interm_outputs, targets, indices_in, num_boxes_in)
|
431 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
432 |
+
l_dict = {k + f'_interm': v for k, v in l_dict.items()}
|
433 |
+
losses.update(l_dict)
|
434 |
+
|
435 |
+
# pre output loss
|
436 |
+
if 'aux_pre_outputs' in outputs:
|
437 |
+
pre_outputs = outputs['aux_pre_outputs']
|
438 |
+
# indices = self.matcher(pre_outputs, targets)
|
439 |
+
for loss in self.losses:
|
440 |
+
indices_in = indices_go if loss in ['lines', 'local'] else cached_indices[-1]
|
441 |
+
num_boxes_in = num_boxes_go if loss in ['lines', 'local'] else num_boxes
|
442 |
+
l_dict = self.get_loss(loss, pre_outputs, targets, indices_in, num_boxes_in)
|
443 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
444 |
+
l_dict = {k + f'_pre': v for k, v in l_dict.items()}
|
445 |
+
losses.update(l_dict)
|
446 |
+
|
447 |
+
|
448 |
+
# prepare for dn loss
|
449 |
+
dn_meta = outputs['dn_meta']
|
450 |
+
|
451 |
+
if self.training and dn_meta and 'aux_denoise' in outputs:
|
452 |
+
single_pad, scalar = self.prep_for_dn(dn_meta)
|
453 |
+
dn_pos_idx = []
|
454 |
+
dn_neg_idx = []
|
455 |
+
for i in range(len(targets)):
|
456 |
+
if len(targets[i]['labels']) > 0:
|
457 |
+
t = torch.arange(len(targets[i]['labels'])).long().cuda()
|
458 |
+
t = t.unsqueeze(0).repeat(scalar, 1)
|
459 |
+
tgt_idx = t.flatten()
|
460 |
+
output_idx = (torch.tensor(range(scalar)) * single_pad).long().cuda().unsqueeze(1) + t
|
461 |
+
output_idx = output_idx.flatten()
|
462 |
+
else:
|
463 |
+
output_idx = tgt_idx = torch.tensor([]).long().cuda()
|
464 |
+
|
465 |
+
dn_pos_idx.append((output_idx, tgt_idx))
|
466 |
+
dn_neg_idx.append((output_idx + single_pad // 2, tgt_idx))
|
467 |
+
|
468 |
+
dn_outputs = outputs['aux_denoise']
|
469 |
+
|
470 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
471 |
+
if 'aux_outputs' in dn_outputs:
|
472 |
+
for idx, aux_outputs in enumerate(dn_outputs['aux_outputs']):
|
473 |
+
aux_outputs['is_dn'] = True
|
474 |
+
aux_outputs['reg_scale'] = outputs['reg_scale']
|
475 |
+
aux_outputs['up'] = outputs['up']
|
476 |
+
# indices = self.matcher(aux_outputs, targets)
|
477 |
+
for loss in self.losses:
|
478 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, dn_pos_idx, num_boxes*scalar)
|
479 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
480 |
+
l_dict = {k + f'_dn_{idx}': v for k, v in l_dict.items()}
|
481 |
+
losses.update(l_dict)
|
482 |
+
|
483 |
+
if 'aux_pre_outputs' in dn_outputs:
|
484 |
+
aux_outputs_known = dn_outputs['aux_pre_outputs']
|
485 |
+
l_dict={}
|
486 |
+
for loss in self.losses:
|
487 |
+
l_dict.update(self.get_loss(loss, aux_outputs_known, targets, dn_pos_idx, num_boxes*scalar))
|
488 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
489 |
+
l_dict = {k + f'_pre_dn': v for k, v in l_dict.items()}
|
490 |
+
losses.update(l_dict)
|
491 |
+
|
492 |
+
if 'aux_lmap' in outputs:
|
493 |
+
l_dict = self.get_loss('lmap', outputs, targets, indices, num_boxes, **kwargs)
|
494 |
+
l_dict = {k: v for k, v in l_dict.items()}
|
495 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
496 |
+
losses.update(l_dict)
|
497 |
+
|
498 |
+
losses = {k: v for k, v in sorted(losses.items(), key=lambda item: item[0])}
|
499 |
+
|
500 |
+
return losses
|
501 |
+
|
502 |
+
@MODULE_BUILD_FUNCS.registe_with_name(module_name='LINEACRITERION')
|
503 |
+
def build_criterion(args):
|
504 |
+
num_classes = args.num_classes
|
505 |
+
|
506 |
+
matcher = build_matcher(args)
|
507 |
+
|
508 |
+
if args.criterion_type == 'default':
|
509 |
+
criterion = LINEACriterion(num_classes, matcher=matcher, weight_dict=args.weight_dict,
|
510 |
+
focal_alpha=args.focal_alpha, losses=args.losses)
|
511 |
+
elif args.criterion_type == 'dfine':
|
512 |
+
criterion = DFINESetCriterion(num_classes, matcher=matcher, weight_dict=args.weight_dict,
|
513 |
+
focal_alpha=args.focal_alpha, reg_max=args.reg_max, losses=args.losses)
|
514 |
+
else:
|
515 |
+
raise Exception(f"Criterion type: {args.criterion_type}.We only support two classes: 'default' and 'dfine'. ")
|
516 |
+
|
517 |
+
return criterion
|
linea/models/linea/decoder.py
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified from D-FINE (https://github.com/Peterande/D-FINE)
|
3 |
+
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
|
4 |
+
---------------------------------------------------------------------------------
|
5 |
+
Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
6 |
+
Copyright (c) 2023 lyuwenyu. All Rights Reserved.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import copy
|
10 |
+
import math
|
11 |
+
from typing import Optional
|
12 |
+
from collections import OrderedDict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn, Tensor
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch.nn.init as init
|
18 |
+
|
19 |
+
from .utils import gen_encoder_output_proposals, MLP, _get_activation_fn, gen_sineembed_for_position
|
20 |
+
from .attention_mechanism import MSDeformAttn
|
21 |
+
from .attention_mechanism import MSDeformLineAttn
|
22 |
+
|
23 |
+
from .dn_components import prepare_for_cdn, dn_post_process
|
24 |
+
from .linea_utils import weighting_function, distance2bbox, inverse_sigmoid
|
25 |
+
|
26 |
+
|
27 |
+
def _get_clones(module, N, layer_share=False):
|
28 |
+
if layer_share:
|
29 |
+
return nn.ModuleList([module for i in range(N)])
|
30 |
+
else:
|
31 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
32 |
+
|
33 |
+
|
34 |
+
class DeformableTransformerDecoderLayer(nn.Module):
|
35 |
+
def __init__(self, d_model=256, d_ffn=1024,
|
36 |
+
dropout=0.1, activation="relu",
|
37 |
+
n_levels=4, n_heads=8, n_points=4,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
# cross attention
|
41 |
+
self.cross_attn = MSDeformLineAttn(d_model, n_levels, n_heads, n_points)
|
42 |
+
self.dropout1 = nn.Dropout(dropout)
|
43 |
+
self.norm1 = nn.LayerNorm(d_model)
|
44 |
+
|
45 |
+
# self attention
|
46 |
+
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
47 |
+
self.dropout2 = nn.Dropout(dropout)
|
48 |
+
self.norm2 = nn.LayerNorm(d_model)
|
49 |
+
|
50 |
+
# ffn
|
51 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
52 |
+
self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
|
53 |
+
self.dropout3 = nn.Dropout(dropout)
|
54 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
55 |
+
self.dropout4 = nn.Dropout(dropout)
|
56 |
+
self.norm3 = nn.LayerNorm(d_model)
|
57 |
+
|
58 |
+
def rm_self_attn_modules(self):
|
59 |
+
self.self_attn = None
|
60 |
+
self.dropout2 = None
|
61 |
+
self.norm2 = None
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def with_pos_embed(tensor, pos):
|
65 |
+
return tensor if pos is None else tensor + pos
|
66 |
+
|
67 |
+
def forward(self,
|
68 |
+
# for tgt
|
69 |
+
tgt: Optional[Tensor], # nq, bs, d_model
|
70 |
+
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
|
71 |
+
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
|
72 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
73 |
+
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
|
74 |
+
|
75 |
+
# for memory
|
76 |
+
memory: Optional[Tensor] = None, # hw, bs, d_modelmemory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
|
77 |
+
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
|
78 |
+
memory_pos: Optional[Tensor] = None, # pos for memory
|
79 |
+
|
80 |
+
# sa
|
81 |
+
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
|
82 |
+
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
|
83 |
+
):
|
84 |
+
# self attention
|
85 |
+
q = k = self.with_pos_embed(tgt, tgt_query_pos)
|
86 |
+
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
|
87 |
+
tgt = tgt + self.dropout2(tgt2)
|
88 |
+
tgt = self.norm2(tgt)
|
89 |
+
|
90 |
+
# cross attention
|
91 |
+
tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
|
92 |
+
tgt_reference_points.transpose(0, 1).contiguous(),
|
93 |
+
memory, #.transpose(0, 1),
|
94 |
+
memory_spatial_shapes,
|
95 |
+
).transpose(0, 1)
|
96 |
+
tgt = tgt + self.dropout1(tgt2)
|
97 |
+
tgt = self.norm1(tgt)
|
98 |
+
|
99 |
+
# feed forward network
|
100 |
+
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
101 |
+
tgt = tgt + self.dropout4(tgt2)
|
102 |
+
tgt = self.norm3(tgt)
|
103 |
+
|
104 |
+
return tgt
|
105 |
+
|
106 |
+
|
107 |
+
class Integral(nn.Module):
|
108 |
+
"""
|
109 |
+
A static layer that calculates integral results from a distribution.
|
110 |
+
|
111 |
+
This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
|
112 |
+
where Pr(n) is the softmax probability vector representing the discrete
|
113 |
+
distribution, and W(n) is the non-uniform Weighting Function.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
reg_max (int): Max number of the discrete bins. Default is 32.
|
117 |
+
It can be adjusted based on the dataset or task requirements.
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __init__(self, reg_max=32):
|
121 |
+
super(Integral, self).__init__()
|
122 |
+
self.reg_max = reg_max
|
123 |
+
|
124 |
+
def forward(self, x, project):
|
125 |
+
shape = x.shape
|
126 |
+
x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1)
|
127 |
+
x = F.linear(x, project.to(x.device)).reshape(-1, 4)
|
128 |
+
return x.reshape(list(shape[:-1]) + [-1])
|
129 |
+
|
130 |
+
|
131 |
+
class LQE(nn.Module):
|
132 |
+
def __init__(self, k, hidden_dim, num_layers, reg_max):
|
133 |
+
super(LQE, self).__init__()
|
134 |
+
self.k = k
|
135 |
+
self.reg_max = reg_max
|
136 |
+
self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers)
|
137 |
+
init.constant_(self.reg_conf.layers[-1].bias, 0)
|
138 |
+
init.constant_(self.reg_conf.layers[-1].weight, 0)
|
139 |
+
|
140 |
+
def forward(self, scores, pred_corners):
|
141 |
+
B, L, _ = pred_corners.size()
|
142 |
+
prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max+1), dim=-1)
|
143 |
+
prob_topk, _ = prob.topk(self.k, dim=-1)
|
144 |
+
stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
|
145 |
+
quality_score = self.reg_conf(stat.reshape(B, L, -1))
|
146 |
+
return scores + quality_score
|
147 |
+
|
148 |
+
|
149 |
+
class TransformerDecoder(nn.Module):
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
decoder_layer,
|
153 |
+
num_layers,
|
154 |
+
norm=None,
|
155 |
+
d_model=256,
|
156 |
+
query_dim=4,
|
157 |
+
num_feature_levels=1,
|
158 |
+
aux_loss=False,
|
159 |
+
eval_idx=5,
|
160 |
+
# from D-FINE
|
161 |
+
reg_max=32,
|
162 |
+
reg_scale=4,
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
if num_layers > 0:
|
166 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
167 |
+
else:
|
168 |
+
self.layers = []
|
169 |
+
self.num_layers = num_layers
|
170 |
+
# self.norm = norm
|
171 |
+
self.query_dim = query_dim
|
172 |
+
self.num_feature_levels = num_feature_levels
|
173 |
+
|
174 |
+
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
|
175 |
+
|
176 |
+
self.reg_max = reg_max
|
177 |
+
self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False)
|
178 |
+
self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False)
|
179 |
+
self.d_model = d_model
|
180 |
+
|
181 |
+
# prediction layers
|
182 |
+
_class_embed = nn.Linear(d_model, 2)
|
183 |
+
_enc_bbox_embed = MLP(d_model, d_model, 4, 3)
|
184 |
+
# init the two embed layers
|
185 |
+
prior_prob = 0.01
|
186 |
+
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
187 |
+
_class_embed.bias.data = torch.ones(2) * bias_value
|
188 |
+
nn.init.constant_(_enc_bbox_embed.layers[-1].weight.data, 0)
|
189 |
+
nn.init.constant_(_enc_bbox_embed.layers[-1].bias.data, 0)
|
190 |
+
|
191 |
+
_bbox_embed = MLP(d_model, d_model, 4 * (self.reg_max + 1), 3)
|
192 |
+
|
193 |
+
self.bbox_embed = nn.ModuleList([copy.deepcopy(_bbox_embed) for i in range(num_layers)])
|
194 |
+
self.class_embed = nn.ModuleList([copy.deepcopy(_class_embed) for i in range(num_layers)])
|
195 |
+
self.lqe_layers = nn.ModuleList([copy.deepcopy(LQE(4, 64, 2, reg_max)) for _ in range(num_layers)])
|
196 |
+
self.integral = Integral(self.reg_max)
|
197 |
+
|
198 |
+
# two stage
|
199 |
+
self.enc_out_bbox_embed = copy.deepcopy(_enc_bbox_embed)
|
200 |
+
# self.enc_out_class_embed = copy.deepcopy(_class_embed)
|
201 |
+
self.aux_loss = aux_loss
|
202 |
+
|
203 |
+
# inference
|
204 |
+
self.eval_idx = eval_idx
|
205 |
+
|
206 |
+
def forward(self,
|
207 |
+
tgt,
|
208 |
+
memory,
|
209 |
+
tgt_mask: Optional[Tensor] = None,
|
210 |
+
memory_mask: Optional[Tensor] = None,
|
211 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
212 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
213 |
+
pos: Optional[Tensor] = None,
|
214 |
+
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
|
215 |
+
# for memory
|
216 |
+
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
|
217 |
+
):
|
218 |
+
"""
|
219 |
+
Input:
|
220 |
+
- tgt: nq, bs, d_model
|
221 |
+
- memory: hw, bs, d_model
|
222 |
+
- pos: hw, bs, d_model
|
223 |
+
- refpoints_unsigmoid: nq, bs, 2/4
|
224 |
+
"""
|
225 |
+
output = tgt
|
226 |
+
output_detach = pred_corners_undetach = 0
|
227 |
+
|
228 |
+
intermediate = []
|
229 |
+
ref_points_detach = refpoints_unsigmoid.sigmoid()
|
230 |
+
|
231 |
+
dec_out_bboxes = []
|
232 |
+
dec_out_logits = []
|
233 |
+
dec_out_corners = []
|
234 |
+
dec_out_refs = []
|
235 |
+
|
236 |
+
if not hasattr(self, 'project'):
|
237 |
+
project = weighting_function(self.reg_max, self.up, self.reg_scale)
|
238 |
+
else:
|
239 |
+
project = self.project
|
240 |
+
|
241 |
+
for layer_id, layer in enumerate(self.layers):
|
242 |
+
ref_points_input = ref_points_detach[:, :, None] # nq, bs, nlevel, 4
|
243 |
+
|
244 |
+
query_sine_embed = gen_sineembed_for_position(ref_points_input[:, :, 0, :], self.d_model) # nq, bs, 256*2
|
245 |
+
|
246 |
+
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
|
247 |
+
|
248 |
+
output = layer(
|
249 |
+
tgt = output,
|
250 |
+
tgt_query_pos = query_pos,
|
251 |
+
tgt_query_sine_embed = query_sine_embed,
|
252 |
+
tgt_key_padding_mask = tgt_key_padding_mask,
|
253 |
+
tgt_reference_points = ref_points_input,
|
254 |
+
|
255 |
+
memory = memory,
|
256 |
+
memory_spatial_shapes = spatial_shapes,
|
257 |
+
memory_pos = pos,
|
258 |
+
|
259 |
+
self_attn_mask = tgt_mask,
|
260 |
+
cross_attn_mask = memory_mask
|
261 |
+
)
|
262 |
+
|
263 |
+
if layer_id == 0:
|
264 |
+
pre_bboxes = torch.sigmoid(self.enc_out_bbox_embed(output) + inverse_sigmoid(ref_points_detach))
|
265 |
+
pre_scores = self.class_embed[0](output)
|
266 |
+
ref_points_initial = pre_bboxes.detach()
|
267 |
+
|
268 |
+
pred_corners = self.bbox_embed[layer_id](output + output_detach) + pred_corners_undetach
|
269 |
+
inter_ref_bbox = distance2bbox(ref_points_initial, self.integral(pred_corners, project), self.reg_scale)
|
270 |
+
|
271 |
+
if self.training or layer_id == self.eval_idx:
|
272 |
+
scores = self.class_embed[layer_id](output)
|
273 |
+
scores = self.lqe_layers[layer_id](scores, pred_corners)
|
274 |
+
dec_out_logits.append(scores)
|
275 |
+
dec_out_bboxes.append(inter_ref_bbox)
|
276 |
+
dec_out_corners.append(pred_corners)
|
277 |
+
dec_out_refs.append(ref_points_initial)
|
278 |
+
|
279 |
+
pred_corners_undetach = pred_corners
|
280 |
+
if self.training:
|
281 |
+
ref_points_detach = inter_ref_bbox.detach()
|
282 |
+
output_detach = output.detach()
|
283 |
+
else:
|
284 |
+
ref_points_detach = inter_ref_bbox
|
285 |
+
output_detach = output
|
286 |
+
|
287 |
+
return torch.stack(dec_out_bboxes).permute(0, 2, 1, 3), torch.stack(dec_out_logits).permute(0, 2, 1, 3), \
|
288 |
+
pre_bboxes, pre_scores
|
289 |
+
|
290 |
+
|
291 |
+
class LINEATransformer(nn.Module):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
feat_channels=[256, 256, 256],
|
295 |
+
feat_strides=[8, 16, 32],
|
296 |
+
d_model=256,
|
297 |
+
num_classes=2,
|
298 |
+
nhead=8,
|
299 |
+
num_queries=300,
|
300 |
+
num_decoder_layers=6,
|
301 |
+
dim_feedforward=2048,
|
302 |
+
dropout=0.0,
|
303 |
+
activation="relu",
|
304 |
+
normalize_before=False,
|
305 |
+
query_dim=4,
|
306 |
+
aux_loss=False,
|
307 |
+
# for deformable encoder
|
308 |
+
num_feature_levels=1,
|
309 |
+
dec_n_points=4,
|
310 |
+
# from D-FINE
|
311 |
+
reg_max=32,
|
312 |
+
reg_scale=4,
|
313 |
+
# denoising
|
314 |
+
dn_number=100,
|
315 |
+
dn_label_noise_ratio=0.5,
|
316 |
+
dn_line_noise_scale=0.5,
|
317 |
+
# for inference
|
318 |
+
eval_spatial_size=None,
|
319 |
+
eval_idx=5
|
320 |
+
):
|
321 |
+
super().__init__()
|
322 |
+
|
323 |
+
# init learnable queries
|
324 |
+
self.tgt_embed = nn.Embedding(num_queries, d_model)
|
325 |
+
nn.init.normal_(self.tgt_embed.weight.data)
|
326 |
+
|
327 |
+
# line segment detection parameters
|
328 |
+
self.num_classes = num_classes
|
329 |
+
self.num_queries = num_queries
|
330 |
+
|
331 |
+
# anchor selection at the output of encoder
|
332 |
+
self.enc_output = nn.Linear(d_model, d_model)
|
333 |
+
self.enc_output_norm = nn.LayerNorm(d_model)
|
334 |
+
self._reset_parameters()
|
335 |
+
|
336 |
+
# prediction layers
|
337 |
+
_class_embed = nn.Linear(d_model, num_classes)
|
338 |
+
_bbox_embed = MLP(d_model, d_model, 4, 3)
|
339 |
+
|
340 |
+
# init the two embed layers
|
341 |
+
prior_prob = 0.01
|
342 |
+
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
343 |
+
_class_embed.bias.data = torch.ones(2) * bias_value
|
344 |
+
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
|
345 |
+
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
|
346 |
+
self.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
|
347 |
+
self.enc_out_class_embed = copy.deepcopy(_class_embed)
|
348 |
+
|
349 |
+
# decoder parameters
|
350 |
+
self.d_model = d_model
|
351 |
+
self.n_heads = nhead
|
352 |
+
decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
|
353 |
+
dropout, activation,
|
354 |
+
num_feature_levels, nhead, dec_n_points)
|
355 |
+
decoder_norm = nn.LayerNorm(d_model)
|
356 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
357 |
+
d_model=d_model, query_dim=query_dim,
|
358 |
+
num_feature_levels=num_feature_levels,
|
359 |
+
eval_idx=eval_idx, aux_loss=aux_loss,
|
360 |
+
reg_max=reg_max, reg_scale=reg_scale)
|
361 |
+
|
362 |
+
# for inference mode
|
363 |
+
self.eval_spatial_size = eval_spatial_size
|
364 |
+
if eval_spatial_size is not None:
|
365 |
+
spatial_shapes = [[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)]
|
366 |
+
for s in feat_strides
|
367 |
+
]
|
368 |
+
output_proposals, output_proposals_valid = self.generate_anchors(spatial_shapes)
|
369 |
+
self.register_buffer('output_proposals', output_proposals)
|
370 |
+
self.register_buffer('output_proposals_mask', ~output_proposals_valid)
|
371 |
+
|
372 |
+
# denoising parameters
|
373 |
+
self.dn_number = dn_number
|
374 |
+
self.dn_label_noise_ratio = dn_label_noise_ratio
|
375 |
+
self.dn_line_noise_scale = dn_line_noise_scale
|
376 |
+
self.label_enc = nn.Embedding(90 + 1, d_model)
|
377 |
+
|
378 |
+
def _reset_parameters(self):
|
379 |
+
for p in self.parameters():
|
380 |
+
if p.dim() > 1:
|
381 |
+
nn.init.xavier_uniform_(p)
|
382 |
+
for m in self.modules():
|
383 |
+
if isinstance(m, MSDeformAttn): # or isinstance(m, MSDeformLineAttn):
|
384 |
+
m._reset_parameters()
|
385 |
+
|
386 |
+
def generate_anchors(self, spatial_shapes):
|
387 |
+
proposals = []
|
388 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
389 |
+
|
390 |
+
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32),
|
391 |
+
torch.linspace(0, W_ - 1, W_, dtype=torch.float32), indexing='ij')
|
392 |
+
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
|
393 |
+
|
394 |
+
scale = torch.tensor([W_, H_], dtype=torch.float32,).view(1, 1, 1, 2)
|
395 |
+
grid = (grid.unsqueeze(0) + 0.5) / scale
|
396 |
+
|
397 |
+
proposal = torch.cat((grid, grid), -1).view(1, -1, 4)
|
398 |
+
proposals.append(proposal)
|
399 |
+
output_proposals = torch.cat(proposals, 1)
|
400 |
+
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
|
401 |
+
output_proposals = torch.log(output_proposals / (1 - output_proposals))
|
402 |
+
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
|
403 |
+
|
404 |
+
return output_proposals, output_proposals_valid
|
405 |
+
|
406 |
+
|
407 |
+
def forward(self, feats, targets):
|
408 |
+
# flatten feature maps
|
409 |
+
memory = []
|
410 |
+
spatial_shapes = []
|
411 |
+
split_sizes = []
|
412 |
+
for feat in feats:
|
413 |
+
bs, c, h, w = feat.shape
|
414 |
+
memory.append(feat.flatten(2).permute(0, 2, 1))
|
415 |
+
spatial_shape = (h, w)
|
416 |
+
spatial_shapes.append(spatial_shape)
|
417 |
+
split_sizes.append(h*w)
|
418 |
+
|
419 |
+
# spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feats[0].device)
|
420 |
+
memory = torch.cat(memory, 1) # bs, \sum{hxw}, c
|
421 |
+
|
422 |
+
# two-stage
|
423 |
+
if self.training:
|
424 |
+
output_memory, output_proposals = gen_encoder_output_proposals(memory, spatial_shapes)
|
425 |
+
else:
|
426 |
+
output_proposals = self.output_proposals.repeat(bs, 1, 1)
|
427 |
+
output_memory = memory.masked_fill(self.output_proposals_mask, float(0))
|
428 |
+
|
429 |
+
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
430 |
+
|
431 |
+
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
|
432 |
+
enc_outputs_coord_unselected = self.enc_out_bbox_embed(output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid
|
433 |
+
topk = self.num_queries
|
434 |
+
topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] # bs, nq
|
435 |
+
|
436 |
+
# gather boxes
|
437 |
+
refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) ) # unsigmoid
|
438 |
+
refpoint_embed = refpoint_embed_undetach.detach()
|
439 |
+
init_box_proposal = torch.gather(output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)).sigmoid() # sigmoid
|
440 |
+
|
441 |
+
# gather tgt
|
442 |
+
tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
|
443 |
+
tgt = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
|
444 |
+
|
445 |
+
# denoise (only for training)
|
446 |
+
if self.training and targets is not None:
|
447 |
+
dn_tgt, dn_refpoint_embed, dn_attn_mask, dn_meta =\
|
448 |
+
prepare_for_cdn(dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_line_noise_scale),
|
449 |
+
training=self.training,num_queries=self.num_queries, num_classes=self.num_classes,
|
450 |
+
hidden_dim=self.d_model, label_enc=self.label_enc)
|
451 |
+
tgt = torch.cat([dn_tgt, tgt], dim=1)
|
452 |
+
refpoint_embed = torch.cat([dn_refpoint_embed, refpoint_embed], dim=1)
|
453 |
+
else:
|
454 |
+
dn_attn_mask = dn_meta = None
|
455 |
+
|
456 |
+
# preprocess memory for MSDeformableLineAttention
|
457 |
+
value = memory.unflatten(2, (self.n_heads, -1)) # (bs, \sum{hxw}, n_heads, d_model//n_heads)
|
458 |
+
value = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_sizes, dim=-1)
|
459 |
+
out_coords, out_class , pre_coords, pre_class = self.decoder(
|
460 |
+
tgt=tgt.transpose(0, 1),
|
461 |
+
memory=value, #memory.transpose(0, 1),
|
462 |
+
pos=None,
|
463 |
+
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
|
464 |
+
spatial_shapes=spatial_shapes,
|
465 |
+
tgt_mask=dn_attn_mask)
|
466 |
+
|
467 |
+
# output
|
468 |
+
if self.training:
|
469 |
+
pre_coords = pre_coords.permute(1, 0, 2)
|
470 |
+
pre_class = pre_class.permute(1, 0, 2)
|
471 |
+
if dn_meta is not None:
|
472 |
+
dn_out_coords, out_coords = torch.split(out_coords, [dn_meta['pad_size'], self.num_queries], dim=2)
|
473 |
+
dn_out_class, out_class = torch.split(out_class, [dn_meta['pad_size'], self.num_queries], dim=2)
|
474 |
+
|
475 |
+
dn_pre_coords, pre_coords = torch.split(pre_coords, [dn_meta['pad_size'], self.num_queries], dim=1)
|
476 |
+
dn_pre_class, pre_class = torch.split(pre_class, [dn_meta['pad_size'], self.num_queries], dim=1)
|
477 |
+
|
478 |
+
out = {'pred_logits': out_class[-1], 'pred_lines': out_coords[-1]}
|
479 |
+
|
480 |
+
out['aux_pre_outputs'] = {'pred_logits': pre_class, 'pred_lines': pre_coords}
|
481 |
+
|
482 |
+
if self.decoder.aux_loss:
|
483 |
+
out['aux_outputs'] = self._set_aux_loss(out_class[:-1], out_coords[:-1])
|
484 |
+
|
485 |
+
# for encoder output
|
486 |
+
out_coords_enc = refpoint_embed_undetach.sigmoid()
|
487 |
+
out_class_enc = self.enc_out_class_embed(tgt_undetach)
|
488 |
+
out['aux_interm_outputs'] = {'pred_logits': out_class_enc, 'pred_lines': out_coords_enc}
|
489 |
+
|
490 |
+
if dn_meta is not None:
|
491 |
+
dn_out = {}
|
492 |
+
dn_out['aux_outputs'] = self._set_aux_loss(dn_out_class, dn_out_coords)
|
493 |
+
dn_out['aux_pre_outputs'] = {'pred_logits': dn_pre_class, 'pred_lines': dn_pre_coords}
|
494 |
+
out['aux_denoise'] = dn_out
|
495 |
+
else:
|
496 |
+
out = {'pred_logits': out_class[0], 'pred_lines': out_coords[0]}
|
497 |
+
|
498 |
+
out['dn_meta'] = dn_meta
|
499 |
+
|
500 |
+
return out
|
501 |
+
|
502 |
+
@torch.jit.unused
|
503 |
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
504 |
+
# this is a workaround to make torchscript happy, as torchscript
|
505 |
+
# doesn't support dictionary with non-homogeneous values, such
|
506 |
+
# as a dict having both a Tensor and a list.
|
507 |
+
return [{'pred_logits': a, 'pred_lines': b}
|
508 |
+
for a, b in zip(outputs_class, outputs_coord)]
|
509 |
+
|
510 |
+
@torch.jit.unused
|
511 |
+
def _set_aux_loss2(self, outputs_class, outputs_coord, outputs_corners, outputs_ref,
|
512 |
+
teacher_corners=None, teacher_class=None):
|
513 |
+
# this is a workaround to make torchscript happy, as torchscript
|
514 |
+
# doesn't support dictionary with non-homogeneous values, such
|
515 |
+
# as a dict having both a Tensor and a list.
|
516 |
+
return [{'pred_logits': a, 'pred_lines': b, 'pred_corners': c, 'ref_points': d,
|
517 |
+
'teacher_corners': teacher_corners, 'teacher_logits': teacher_class}
|
518 |
+
for a, b, c, d in zip(outputs_class, outputs_coord, outputs_corners, outputs_ref)]
|
519 |
+
|
520 |
+
|
521 |
+
def build_decoder(args):
|
522 |
+
return LINEATransformer(
|
523 |
+
feat_channels = args.feat_channels_decoder,
|
524 |
+
feat_strides=args.feat_strides,
|
525 |
+
num_classes=args.num_classes,
|
526 |
+
d_model=args.hidden_dim,
|
527 |
+
nhead=args.nheads,
|
528 |
+
num_queries=args.num_queries,
|
529 |
+
num_decoder_layers=args.dec_layers,
|
530 |
+
dim_feedforward=args.dim_feedforward,
|
531 |
+
dropout=args.dropout,
|
532 |
+
activation=args.transformer_activation,
|
533 |
+
normalize_before=args.pre_norm,
|
534 |
+
query_dim=args.query_dim,
|
535 |
+
aux_loss=True,
|
536 |
+
# for deformable encoder
|
537 |
+
num_feature_levels=args.num_feature_levels,
|
538 |
+
dec_n_points=args.dec_n_points,
|
539 |
+
# for D-FINE layers
|
540 |
+
reg_max=args.reg_max,
|
541 |
+
reg_scale=args.reg_scale,
|
542 |
+
# for inference
|
543 |
+
eval_spatial_size=args.eval_spatial_size,
|
544 |
+
eval_idx=args.eval_idx,
|
545 |
+
# for denoising
|
546 |
+
dn_number=args.dn_number,
|
547 |
+
dn_label_noise_ratio=args.dn_label_noise_ratio,
|
548 |
+
dn_line_noise_scale=args.dn_line_noise_scale,
|
549 |
+
)
|
550 |
+
|
551 |
+
|
linea/models/linea/dn_components.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# DINO
|
3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# DN-DETR
|
7 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
8 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from .linea_utils import inverse_sigmoid
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
def prepare_for_cdn(dn_args, training, num_queries, num_classes, hidden_dim, label_enc):
|
16 |
+
"""
|
17 |
+
A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
|
18 |
+
forward function and use learnable tgt embedding, so we change this function a little bit.
|
19 |
+
:param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
|
20 |
+
:param training: if it is training or inference
|
21 |
+
:param num_queries: number of queires
|
22 |
+
:param num_classes: number of classes
|
23 |
+
:param hidden_dim: transformer hidden dim
|
24 |
+
:param label_enc: encode labels in dn
|
25 |
+
:return:
|
26 |
+
"""
|
27 |
+
if training:
|
28 |
+
targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
|
29 |
+
# positive and negative dn queries
|
30 |
+
dn_number = dn_number * 2
|
31 |
+
known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
|
32 |
+
batch_size = len(known)
|
33 |
+
known_num = [sum(k) for k in known]
|
34 |
+
|
35 |
+
if int(max(known_num)) == 0:
|
36 |
+
dn_number = 1
|
37 |
+
else:
|
38 |
+
if dn_number >= 100:
|
39 |
+
dn_number = dn_number // (int(max(known_num) * 2))
|
40 |
+
elif dn_number < 1:
|
41 |
+
dn_number = 1
|
42 |
+
if dn_number == 0:
|
43 |
+
dn_number = 1
|
44 |
+
|
45 |
+
unmask_bbox = unmask_label = torch.cat(known)
|
46 |
+
labels = torch.cat([t['labels'] for t in targets])
|
47 |
+
lines = torch.cat([t['lines'] for t in targets])
|
48 |
+
batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
|
49 |
+
|
50 |
+
known_indice = torch.nonzero(unmask_label + unmask_bbox)
|
51 |
+
known_indice = known_indice.view(-1)
|
52 |
+
|
53 |
+
known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
|
54 |
+
known_labels = labels.repeat(2 * dn_number, 1).view(-1)
|
55 |
+
known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
|
56 |
+
known_lines = lines.repeat(2 * dn_number, 1)
|
57 |
+
|
58 |
+
known_labels_expaned = known_labels.clone()
|
59 |
+
known_lines_expand = known_lines.clone()
|
60 |
+
|
61 |
+
if label_noise_ratio > 0:
|
62 |
+
p = torch.rand_like(known_labels_expaned.float())
|
63 |
+
chosen_indice = torch.nonzero(p < (label_noise_ratio * 0.5)).view(-1) # half of bbox prob
|
64 |
+
new_label = torch.randint_like(chosen_indice, 0, num_classes) # randomly put a new one here
|
65 |
+
known_labels_expaned.scatter_(0, chosen_indice, new_label)
|
66 |
+
|
67 |
+
single_pad = int(max(known_num))
|
68 |
+
|
69 |
+
pad_size = int(single_pad * 2 * dn_number)
|
70 |
+
positive_idx = torch.tensor(range(len(lines))).long().cuda().unsqueeze(0).repeat(dn_number, 1)
|
71 |
+
positive_idx += (torch.tensor(range(dn_number)) * len(lines) * 2).long().cuda().unsqueeze(1)
|
72 |
+
positive_idx = positive_idx.flatten()
|
73 |
+
negative_idx = positive_idx + len(lines)
|
74 |
+
|
75 |
+
|
76 |
+
known_lines_ = known_lines.clone()
|
77 |
+
known_lines_[:, :2] = (known_lines[:, :2] - known_lines[:, 2:]) / 2
|
78 |
+
known_lines_[:, 2:] = (known_lines[:, :2] + known_lines[:, 2:]) / 2
|
79 |
+
|
80 |
+
centers = torch.zeros_like(known_lines)
|
81 |
+
centers[:, :2] = (known_lines_[:, :2] + known_lines_[:, 2:]) / 2
|
82 |
+
centers[:, 2:] = (known_lines_[:, :2] + known_lines_[:, 2:]) / 2
|
83 |
+
|
84 |
+
# Noisy length
|
85 |
+
diff = torch.zeros_like(known_lines)
|
86 |
+
diff[:, :2] = (known_lines[:, 2:] - known_lines[:, :2]) / 2
|
87 |
+
diff[:, 2:] = (known_lines[:, 2:] - known_lines[:, :2]) / 2
|
88 |
+
|
89 |
+
rand_sign = torch.randint(low=0, high=2, size=(known_lines.shape[0], 2), dtype=torch.float32, device=known_lines.device) * 2.0 - 1.0
|
90 |
+
rand_part = torch.rand(size=(known_lines.shape[0], 2), device=known_lines.device)
|
91 |
+
rand_part[negative_idx] += 1.2
|
92 |
+
rand_part *= rand_sign
|
93 |
+
|
94 |
+
known_lines_ = centers + torch.mul(rand_part.repeat_interleave(2, 1),
|
95 |
+
diff).cuda() * box_noise_scale
|
96 |
+
|
97 |
+
known_lines_expand = known_lines_.clamp(min=0.0, max=1.0)
|
98 |
+
|
99 |
+
# order: top point > bottom point
|
100 |
+
# if same y coordinate, right point > left point
|
101 |
+
|
102 |
+
idx = torch.logical_or(known_lines_expand[..., 0] > known_lines_expand[..., 2],
|
103 |
+
torch.logical_or(
|
104 |
+
known_lines_expand[..., 0] == known_lines_expand[..., 2],
|
105 |
+
known_lines_expand[..., 1] < known_lines_expand[..., 3]
|
106 |
+
)
|
107 |
+
)
|
108 |
+
|
109 |
+
known_lines_expand[idx] = known_lines_expand[idx][:, [2, 3, 0, 1]]
|
110 |
+
|
111 |
+
m = known_labels_expaned.long().to('cuda')
|
112 |
+
input_label_embed = label_enc(m)
|
113 |
+
input_lines_embed = inverse_sigmoid(known_lines_expand)
|
114 |
+
|
115 |
+
padding_label = torch.zeros(pad_size, hidden_dim).cuda()
|
116 |
+
padding_lines = torch.zeros(pad_size, 4).cuda()
|
117 |
+
|
118 |
+
input_query_label = padding_label.repeat(batch_size, 1, 1)
|
119 |
+
input_query_lines = padding_lines.repeat(batch_size, 1, 1)
|
120 |
+
|
121 |
+
map_known_indice = torch.tensor([]).to('cuda')
|
122 |
+
if len(known_num):
|
123 |
+
map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
|
124 |
+
map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
|
125 |
+
|
126 |
+
if len(known_bid):
|
127 |
+
input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed
|
128 |
+
input_query_lines[(known_bid.long(), map_known_indice)] = input_lines_embed
|
129 |
+
|
130 |
+
tgt_size = pad_size + num_queries
|
131 |
+
attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
|
132 |
+
# match query cannot see the reconstruct
|
133 |
+
attn_mask[pad_size:, :pad_size] = True
|
134 |
+
# reconstruct cannot see each other
|
135 |
+
for i in range(dn_number):
|
136 |
+
if i == 0:
|
137 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
|
138 |
+
if i == dn_number - 1:
|
139 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True
|
140 |
+
else:
|
141 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
|
142 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True
|
143 |
+
|
144 |
+
dn_meta = {
|
145 |
+
'pad_size': pad_size,
|
146 |
+
'num_dn_group': dn_number,
|
147 |
+
}
|
148 |
+
else:
|
149 |
+
|
150 |
+
input_query_label = None
|
151 |
+
input_query_lines = None
|
152 |
+
attn_mask = None
|
153 |
+
dn_meta = None
|
154 |
+
|
155 |
+
return input_query_label, input_query_lines, attn_mask, dn_meta
|
156 |
+
|
157 |
+
|
158 |
+
def dn_post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
|
159 |
+
"""
|
160 |
+
post process of dn after output from the transformer
|
161 |
+
put the dn part in the dn_meta
|
162 |
+
"""
|
163 |
+
if dn_meta and dn_meta['pad_size'] > 0:
|
164 |
+
output_known_class = outputs_class[:, :, :dn_meta['pad_size'], :]
|
165 |
+
output_known_coord = outputs_coord[:, :, :dn_meta['pad_size'], :]
|
166 |
+
outputs_class = outputs_class[:, :, dn_meta['pad_size']:, :]
|
167 |
+
outputs_coord = outputs_coord[:, :, dn_meta['pad_size']:, :]
|
168 |
+
# print(output_known_class.shape, outputs_class.shape)
|
169 |
+
# quit()
|
170 |
+
out = {'pred_logits': output_known_class[-1], 'pred_lines': output_known_coord[-1]}
|
171 |
+
if aux_loss:
|
172 |
+
out['aux_outputs'] = _set_aux_loss(output_known_class[1:], output_known_coord[1:])
|
173 |
+
out['pre_outputs'] = {'pred_logits':output_known_class[0], 'pred_lines': output_known_coord[0]}
|
174 |
+
dn_meta['output_known_lbs_lines'] = out
|
175 |
+
return outputs_class, outputs_coord
|
176 |
+
|
177 |
+
|
178 |
+
|
linea/models/linea/hgnetv2.py
ADDED
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
|
7 |
+
# Constants for initialization
|
8 |
+
kaiming_normal_ = nn.init.kaiming_normal_
|
9 |
+
zeros_ = nn.init.zeros_
|
10 |
+
ones_ = nn.init.ones_
|
11 |
+
|
12 |
+
|
13 |
+
class FrozenBatchNorm2d(nn.Module):
|
14 |
+
"""copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py
|
15 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
16 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
17 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
18 |
+
produce nans.
|
19 |
+
"""
|
20 |
+
def __init__(self, num_features, eps=1e-5):
|
21 |
+
super(FrozenBatchNorm2d, self).__init__()
|
22 |
+
n = num_features
|
23 |
+
self.register_buffer("weight", torch.ones(n))
|
24 |
+
self.register_buffer("bias", torch.zeros(n))
|
25 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
26 |
+
self.register_buffer("running_var", torch.ones(n))
|
27 |
+
self.eps = eps
|
28 |
+
self.num_features = n
|
29 |
+
|
30 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
31 |
+
missing_keys, unexpected_keys, error_msgs):
|
32 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
33 |
+
if num_batches_tracked_key in state_dict:
|
34 |
+
del state_dict[num_batches_tracked_key]
|
35 |
+
|
36 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
37 |
+
state_dict, prefix, local_metadata, strict,
|
38 |
+
missing_keys, unexpected_keys, error_msgs)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
# move reshapes to the beginning
|
42 |
+
# to make it fuser-friendly
|
43 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
44 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
45 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
46 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
47 |
+
scale = w * (rv + self.eps).rsqrt()
|
48 |
+
bias = b - rm * scale
|
49 |
+
return x * scale + bias
|
50 |
+
|
51 |
+
def extra_repr(self):
|
52 |
+
return (
|
53 |
+
"{num_features}, eps={eps}".format(**self.__dict__)
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
class LearnableAffineBlock(nn.Module):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
scale_value=1.0,
|
61 |
+
bias_value=0.0
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True)
|
65 |
+
self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
return self.scale * x + self.bias
|
69 |
+
|
70 |
+
|
71 |
+
class ConvBNAct(nn.Module):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
in_chs,
|
75 |
+
out_chs,
|
76 |
+
kernel_size,
|
77 |
+
stride=1,
|
78 |
+
groups=1,
|
79 |
+
padding='',
|
80 |
+
use_act=True,
|
81 |
+
use_lab=False
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
self.use_act = use_act
|
85 |
+
self.use_lab = use_lab
|
86 |
+
if padding == 'same':
|
87 |
+
self.conv = nn.Sequential(
|
88 |
+
nn.ZeroPad2d([0, 1, 0, 1]),
|
89 |
+
nn.Conv2d(
|
90 |
+
in_chs,
|
91 |
+
out_chs,
|
92 |
+
kernel_size,
|
93 |
+
stride,
|
94 |
+
groups=groups,
|
95 |
+
bias=False
|
96 |
+
)
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
self.conv = nn.Conv2d(
|
100 |
+
in_chs,
|
101 |
+
out_chs,
|
102 |
+
kernel_size,
|
103 |
+
stride,
|
104 |
+
padding=(kernel_size - 1) // 2,
|
105 |
+
groups=groups,
|
106 |
+
bias=False
|
107 |
+
)
|
108 |
+
self.bn = nn.BatchNorm2d(out_chs)
|
109 |
+
if self.use_act:
|
110 |
+
self.act = nn.ReLU()
|
111 |
+
else:
|
112 |
+
self.act = nn.Identity()
|
113 |
+
if self.use_act and self.use_lab:
|
114 |
+
self.lab = LearnableAffineBlock()
|
115 |
+
else:
|
116 |
+
self.lab = nn.Identity()
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
x = self.conv(x)
|
120 |
+
x = self.bn(x)
|
121 |
+
x = self.act(x)
|
122 |
+
x = self.lab(x)
|
123 |
+
return x
|
124 |
+
|
125 |
+
|
126 |
+
class LightConvBNAct(nn.Module):
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
in_chs,
|
130 |
+
out_chs,
|
131 |
+
kernel_size,
|
132 |
+
groups=1,
|
133 |
+
use_lab=False,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
self.conv1 = ConvBNAct(
|
137 |
+
in_chs,
|
138 |
+
out_chs,
|
139 |
+
kernel_size=1,
|
140 |
+
use_act=False,
|
141 |
+
use_lab=use_lab,
|
142 |
+
)
|
143 |
+
self.conv2 = ConvBNAct(
|
144 |
+
out_chs,
|
145 |
+
out_chs,
|
146 |
+
kernel_size=kernel_size,
|
147 |
+
groups=out_chs,
|
148 |
+
use_act=True,
|
149 |
+
use_lab=use_lab,
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
x = self.conv1(x)
|
154 |
+
x = self.conv2(x)
|
155 |
+
return x
|
156 |
+
|
157 |
+
|
158 |
+
class StemBlock(nn.Module):
|
159 |
+
# for HGNetv2
|
160 |
+
def __init__(self, in_chs, mid_chs, out_chs, use_lab=False):
|
161 |
+
super().__init__()
|
162 |
+
self.stem1 = ConvBNAct(
|
163 |
+
in_chs,
|
164 |
+
mid_chs,
|
165 |
+
kernel_size=3,
|
166 |
+
stride=2,
|
167 |
+
use_lab=use_lab,
|
168 |
+
)
|
169 |
+
self.stem2a = ConvBNAct(
|
170 |
+
mid_chs,
|
171 |
+
mid_chs // 2,
|
172 |
+
kernel_size=2,
|
173 |
+
stride=1,
|
174 |
+
use_lab=use_lab,
|
175 |
+
)
|
176 |
+
self.stem2b = ConvBNAct(
|
177 |
+
mid_chs // 2,
|
178 |
+
mid_chs,
|
179 |
+
kernel_size=2,
|
180 |
+
stride=1,
|
181 |
+
use_lab=use_lab,
|
182 |
+
)
|
183 |
+
self.stem3 = ConvBNAct(
|
184 |
+
mid_chs * 2,
|
185 |
+
mid_chs,
|
186 |
+
kernel_size=3,
|
187 |
+
stride=2,
|
188 |
+
use_lab=use_lab,
|
189 |
+
)
|
190 |
+
self.stem4 = ConvBNAct(
|
191 |
+
mid_chs,
|
192 |
+
out_chs,
|
193 |
+
kernel_size=1,
|
194 |
+
stride=1,
|
195 |
+
use_lab=use_lab,
|
196 |
+
)
|
197 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True)
|
198 |
+
|
199 |
+
def forward(self, x):
|
200 |
+
x = self.stem1(x)
|
201 |
+
x = F.pad(x, (0, 1, 0, 1))
|
202 |
+
x2 = self.stem2a(x)
|
203 |
+
x2 = F.pad(x2, (0, 1, 0, 1))
|
204 |
+
x2 = self.stem2b(x2)
|
205 |
+
x1 = self.pool(x)
|
206 |
+
x = torch.cat([x1, x2], dim=1)
|
207 |
+
x = self.stem3(x)
|
208 |
+
x = self.stem4(x)
|
209 |
+
return x
|
210 |
+
|
211 |
+
|
212 |
+
class EseModule(nn.Module):
|
213 |
+
def __init__(self, chs):
|
214 |
+
super().__init__()
|
215 |
+
self.conv = nn.Conv2d(
|
216 |
+
chs,
|
217 |
+
chs,
|
218 |
+
kernel_size=1,
|
219 |
+
stride=1,
|
220 |
+
padding=0,
|
221 |
+
)
|
222 |
+
self.sigmoid = nn.Sigmoid()
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
identity = x
|
226 |
+
x = x.mean((2, 3), keepdim=True)
|
227 |
+
x = self.conv(x)
|
228 |
+
x = self.sigmoid(x)
|
229 |
+
return torch.mul(identity, x)
|
230 |
+
|
231 |
+
|
232 |
+
class HG_Block(nn.Module):
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
in_chs,
|
236 |
+
mid_chs,
|
237 |
+
out_chs,
|
238 |
+
layer_num,
|
239 |
+
kernel_size=3,
|
240 |
+
residual=False,
|
241 |
+
light_block=False,
|
242 |
+
use_lab=False,
|
243 |
+
agg='ese',
|
244 |
+
drop_path=0.,
|
245 |
+
):
|
246 |
+
super().__init__()
|
247 |
+
self.residual = residual
|
248 |
+
|
249 |
+
self.layers = nn.ModuleList()
|
250 |
+
for i in range(layer_num):
|
251 |
+
if light_block:
|
252 |
+
self.layers.append(
|
253 |
+
LightConvBNAct(
|
254 |
+
in_chs if i == 0 else mid_chs,
|
255 |
+
mid_chs,
|
256 |
+
kernel_size=kernel_size,
|
257 |
+
use_lab=use_lab,
|
258 |
+
)
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
self.layers.append(
|
262 |
+
ConvBNAct(
|
263 |
+
in_chs if i == 0 else mid_chs,
|
264 |
+
mid_chs,
|
265 |
+
kernel_size=kernel_size,
|
266 |
+
stride=1,
|
267 |
+
use_lab=use_lab,
|
268 |
+
)
|
269 |
+
)
|
270 |
+
|
271 |
+
# feature aggregation
|
272 |
+
total_chs = in_chs + layer_num * mid_chs
|
273 |
+
if agg == 'se':
|
274 |
+
aggregation_squeeze_conv = ConvBNAct(
|
275 |
+
total_chs,
|
276 |
+
out_chs // 2,
|
277 |
+
kernel_size=1,
|
278 |
+
stride=1,
|
279 |
+
use_lab=use_lab,
|
280 |
+
)
|
281 |
+
aggregation_excitation_conv = ConvBNAct(
|
282 |
+
out_chs // 2,
|
283 |
+
out_chs,
|
284 |
+
kernel_size=1,
|
285 |
+
stride=1,
|
286 |
+
use_lab=use_lab,
|
287 |
+
)
|
288 |
+
self.aggregation = nn.Sequential(
|
289 |
+
aggregation_squeeze_conv,
|
290 |
+
aggregation_excitation_conv,
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
aggregation_conv = ConvBNAct(
|
294 |
+
total_chs,
|
295 |
+
out_chs,
|
296 |
+
kernel_size=1,
|
297 |
+
stride=1,
|
298 |
+
use_lab=use_lab,
|
299 |
+
)
|
300 |
+
att = EseModule(out_chs)
|
301 |
+
self.aggregation = nn.Sequential(
|
302 |
+
aggregation_conv,
|
303 |
+
att,
|
304 |
+
)
|
305 |
+
|
306 |
+
self.drop_path = nn.Dropout(drop_path) if drop_path else nn.Identity()
|
307 |
+
|
308 |
+
def forward(self, x):
|
309 |
+
identity = x
|
310 |
+
output = [x]
|
311 |
+
for layer in self.layers:
|
312 |
+
x = layer(x)
|
313 |
+
output.append(x)
|
314 |
+
x = torch.cat(output, dim=1)
|
315 |
+
x = self.aggregation(x)
|
316 |
+
if self.residual:
|
317 |
+
x = self.drop_path(x) + identity
|
318 |
+
return x
|
319 |
+
|
320 |
+
|
321 |
+
class HG_Stage(nn.Module):
|
322 |
+
def __init__(
|
323 |
+
self,
|
324 |
+
in_chs,
|
325 |
+
mid_chs,
|
326 |
+
out_chs,
|
327 |
+
block_num,
|
328 |
+
layer_num,
|
329 |
+
downsample=True,
|
330 |
+
light_block=False,
|
331 |
+
kernel_size=3,
|
332 |
+
use_lab=False,
|
333 |
+
agg='se',
|
334 |
+
drop_path=0.,
|
335 |
+
):
|
336 |
+
super().__init__()
|
337 |
+
self.downsample = downsample
|
338 |
+
if downsample:
|
339 |
+
self.downsample = ConvBNAct(
|
340 |
+
in_chs,
|
341 |
+
in_chs,
|
342 |
+
kernel_size=3,
|
343 |
+
stride=2,
|
344 |
+
groups=in_chs,
|
345 |
+
use_act=False,
|
346 |
+
use_lab=use_lab,
|
347 |
+
)
|
348 |
+
else:
|
349 |
+
self.downsample = nn.Identity()
|
350 |
+
|
351 |
+
blocks_list = []
|
352 |
+
for i in range(block_num):
|
353 |
+
blocks_list.append(
|
354 |
+
HG_Block(
|
355 |
+
in_chs if i == 0 else out_chs,
|
356 |
+
mid_chs,
|
357 |
+
out_chs,
|
358 |
+
layer_num,
|
359 |
+
residual=False if i == 0 else True,
|
360 |
+
kernel_size=kernel_size,
|
361 |
+
light_block=light_block,
|
362 |
+
use_lab=use_lab,
|
363 |
+
agg=agg,
|
364 |
+
drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
|
365 |
+
)
|
366 |
+
)
|
367 |
+
self.blocks = nn.Sequential(*blocks_list)
|
368 |
+
|
369 |
+
def forward(self, x):
|
370 |
+
x = self.downsample(x)
|
371 |
+
x = self.blocks(x)
|
372 |
+
return x
|
373 |
+
|
374 |
+
|
375 |
+
class HGNetv2(nn.Module):
|
376 |
+
"""
|
377 |
+
HGNetV2
|
378 |
+
Args:
|
379 |
+
stem_channels: list. Number of channels for the stem block.
|
380 |
+
stage_type: str. The stage configuration of HGNet. such as the number of channels, stride, etc.
|
381 |
+
use_lab: boolean. Whether to use LearnableAffineBlock in network.
|
382 |
+
lr_mult_list: list. Control the learning rate of different stages.
|
383 |
+
Returns:
|
384 |
+
model: nn.Layer. Specific HGNetV2 model depends on args.
|
385 |
+
"""
|
386 |
+
|
387 |
+
arch_configs = {
|
388 |
+
'B0': {
|
389 |
+
'stem_channels': [3, 16, 16],
|
390 |
+
'stage_config': {
|
391 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
392 |
+
"stage1": [16, 16, 64, 1, False, False, 3, 3],
|
393 |
+
"stage2": [64, 32, 256, 1, True, False, 3, 3],
|
394 |
+
"stage3": [256, 64, 512, 2, True, True, 5, 3],
|
395 |
+
"stage4": [512, 128, 1024, 1, True, True, 5, 3],
|
396 |
+
},
|
397 |
+
'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B0_stage1.pth'
|
398 |
+
},
|
399 |
+
'B1': {
|
400 |
+
'stem_channels': [3, 24, 32],
|
401 |
+
'stage_config': {
|
402 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
403 |
+
"stage1": [32, 32, 64, 1, False, False, 3, 3],
|
404 |
+
"stage2": [64, 48, 256, 1, True, False, 3, 3],
|
405 |
+
"stage3": [256, 96, 512, 2, True, True, 5, 3],
|
406 |
+
"stage4": [512, 192, 1024, 1, True, True, 5, 3],
|
407 |
+
},
|
408 |
+
'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B1_stage1.pth'
|
409 |
+
},
|
410 |
+
'B2': {
|
411 |
+
'stem_channels': [3, 24, 32],
|
412 |
+
'stage_config': {
|
413 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
414 |
+
"stage1": [32, 32, 96, 1, False, False, 3, 4],
|
415 |
+
"stage2": [96, 64, 384, 1, True, False, 3, 4],
|
416 |
+
"stage3": [384, 128, 768, 3, True, True, 5, 4],
|
417 |
+
"stage4": [768, 256, 1536, 1, True, True, 5, 4],
|
418 |
+
},
|
419 |
+
'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B2_stage1.pth'
|
420 |
+
},
|
421 |
+
'B3': {
|
422 |
+
'stem_channels': [3, 24, 32],
|
423 |
+
'stage_config': {
|
424 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
425 |
+
"stage1": [32, 32, 128, 1, False, False, 3, 5],
|
426 |
+
"stage2": [128, 64, 512, 1, True, False, 3, 5],
|
427 |
+
"stage3": [512, 128, 1024, 3, True, True, 5, 5],
|
428 |
+
"stage4": [1024, 256, 2048, 1, True, True, 5, 5],
|
429 |
+
},
|
430 |
+
'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B3_stage1.pth'
|
431 |
+
},
|
432 |
+
'B4': {
|
433 |
+
'stem_channels': [3, 32, 48],
|
434 |
+
'stage_config': {
|
435 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
436 |
+
"stage1": [48, 48, 128, 1, False, False, 3, 6],
|
437 |
+
"stage2": [128, 96, 512, 1, True, False, 3, 6],
|
438 |
+
"stage3": [512, 192, 1024, 3, True, True, 5, 6],
|
439 |
+
"stage4": [1024, 384, 2048, 1, True, True, 5, 6],
|
440 |
+
},
|
441 |
+
'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B4_stage1.pth'
|
442 |
+
},
|
443 |
+
'B5': {
|
444 |
+
'stem_channels': [3, 32, 64],
|
445 |
+
'stage_config': {
|
446 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
447 |
+
"stage1": [64, 64, 128, 1, False, False, 3, 6],
|
448 |
+
"stage2": [128, 128, 512, 2, True, False, 3, 6],
|
449 |
+
"stage3": [512, 256, 1024, 5, True, True, 5, 6],
|
450 |
+
"stage4": [1024, 512, 2048, 2, True, True, 5, 6],
|
451 |
+
},
|
452 |
+
'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B5_stage1.pth'
|
453 |
+
},
|
454 |
+
'B6': {
|
455 |
+
'stem_channels': [3, 48, 96],
|
456 |
+
'stage_config': {
|
457 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
458 |
+
"stage1": [96, 96, 192, 2, False, False, 3, 6],
|
459 |
+
"stage2": [192, 192, 512, 3, True, False, 3, 6],
|
460 |
+
"stage3": [512, 384, 1024, 6, True, True, 5, 6],
|
461 |
+
"stage4": [1024, 768, 2048, 3, True, True, 5, 6],
|
462 |
+
},
|
463 |
+
'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B6_stage1.pth'
|
464 |
+
},
|
465 |
+
}
|
466 |
+
|
467 |
+
def __init__(self,
|
468 |
+
name,
|
469 |
+
use_lab=False,
|
470 |
+
return_idx=[1, 2, 3],
|
471 |
+
freeze_stem_only=True,
|
472 |
+
freeze_at=0,
|
473 |
+
freeze_norm=True,
|
474 |
+
pretrained=True,
|
475 |
+
local_model_dir='weight/hgnetv2/',
|
476 |
+
for_pgi=False):
|
477 |
+
super().__init__()
|
478 |
+
self.use_lab = use_lab
|
479 |
+
self.return_idx = return_idx
|
480 |
+
|
481 |
+
stem_channels = self.arch_configs[name]['stem_channels']
|
482 |
+
stage_config = self.arch_configs[name]['stage_config']
|
483 |
+
download_url = self.arch_configs[name]['url']
|
484 |
+
|
485 |
+
self._out_strides = [4, 8, 16, 32]
|
486 |
+
self._out_channels = [stage_config[k][2] for k in stage_config]
|
487 |
+
|
488 |
+
self.num_channels = self._out_channels[4 - len(return_idx):]
|
489 |
+
|
490 |
+
# stem
|
491 |
+
self.stem = StemBlock(
|
492 |
+
in_chs=stem_channels[0],
|
493 |
+
mid_chs=stem_channels[1],
|
494 |
+
out_chs=stem_channels[2],
|
495 |
+
use_lab=use_lab)
|
496 |
+
|
497 |
+
# stages
|
498 |
+
self.stages = nn.ModuleList()
|
499 |
+
for i, k in enumerate(stage_config):
|
500 |
+
in_channels, mid_channels, out_channels, block_num, downsample, light_block, kernel_size, layer_num = stage_config[k]
|
501 |
+
self.stages.append(
|
502 |
+
HG_Stage(
|
503 |
+
in_channels,
|
504 |
+
mid_channels,
|
505 |
+
out_channels,
|
506 |
+
block_num,
|
507 |
+
layer_num,
|
508 |
+
downsample,
|
509 |
+
light_block,
|
510 |
+
kernel_size,
|
511 |
+
use_lab))
|
512 |
+
|
513 |
+
if freeze_at >= 0:
|
514 |
+
self._freeze_parameters(self.stem)
|
515 |
+
if not freeze_stem_only:
|
516 |
+
for i in range(min(freeze_at + 1, len(self.stages))):
|
517 |
+
self._freeze_parameters(self.stages[i])
|
518 |
+
|
519 |
+
if freeze_norm:
|
520 |
+
self._freeze_norm(self)
|
521 |
+
|
522 |
+
if pretrained:
|
523 |
+
RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m"
|
524 |
+
try:
|
525 |
+
model_path = local_model_dir + 'PPHGNetV2_' + name + '_stage1.pth'
|
526 |
+
if os.path.exists(model_path):
|
527 |
+
print("Loading stage1")
|
528 |
+
state = torch.load(model_path, map_location='cpu')
|
529 |
+
print(f"Loaded stage1 {name} HGNetV2 from local file.")
|
530 |
+
else:
|
531 |
+
# If the file doesn't exist locally, download from the URL
|
532 |
+
if torch.distributed.get_rank() == 0:
|
533 |
+
print(GREEN + "If the pretrained HGNetV2 can't be downloaded automatically. Please check your network connection." + RESET)
|
534 |
+
print(GREEN + "Please check your network connection. Or download the model manually from " + RESET + f"{download_url}" + GREEN + " to " + RESET + f"{local_model_dir}." + RESET)
|
535 |
+
state = torch.hub.load_state_dict_from_url(download_url, map_location='cpu', model_dir=local_model_dir)
|
536 |
+
torch.distributed.barrier()
|
537 |
+
else:
|
538 |
+
torch.distributed.barrier()
|
539 |
+
state = torch.load(local_model_dir)
|
540 |
+
|
541 |
+
print(f"Loaded stage1 {name} HGNetV2 from URL.")
|
542 |
+
|
543 |
+
self.load_state_dict(state)
|
544 |
+
|
545 |
+
except (Exception, KeyboardInterrupt) as e:
|
546 |
+
if torch.distributed.get_rank() == 0:
|
547 |
+
print(f"{str(e)}")
|
548 |
+
logging.error(RED + "CRITICAL WARNING: Failed to load pretrained HGNetV2 model" + RESET)
|
549 |
+
logging.error(GREEN + "Please check your network connection. Or download the model manually from " \
|
550 |
+
+ RESET + f"{download_url}" + GREEN + " to " + RESET + f"{local_model_dir}." + RESET)
|
551 |
+
exit()
|
552 |
+
|
553 |
+
|
554 |
+
def _freeze_norm(self, m: nn.Module):
|
555 |
+
if isinstance(m, nn.BatchNorm2d):
|
556 |
+
m = FrozenBatchNorm2d(m.num_features)
|
557 |
+
else:
|
558 |
+
for name, child in m.named_children():
|
559 |
+
_child = self._freeze_norm(child)
|
560 |
+
if _child is not child:
|
561 |
+
setattr(m, name, _child)
|
562 |
+
return m
|
563 |
+
|
564 |
+
def _freeze_parameters(self, m: nn.Module):
|
565 |
+
for p in m.parameters():
|
566 |
+
p.requires_grad = False
|
567 |
+
|
568 |
+
def forward(self, x):
|
569 |
+
x = self.stem(x)
|
570 |
+
outs = []
|
571 |
+
for idx, stage in enumerate(self.stages):
|
572 |
+
x = stage(x)
|
573 |
+
if idx in self.return_idx:
|
574 |
+
outs.append(x)
|
575 |
+
return outs
|
576 |
+
|
577 |
+
def build_hgnetv2(args):
|
578 |
+
name = {
|
579 |
+
'HGNetv2_B0': 'B0',
|
580 |
+
'HGNetv2_B1': 'B1',
|
581 |
+
'HGNetv2_B2': 'B2',
|
582 |
+
'HGNetv2_B3': 'B3',
|
583 |
+
'HGNetv2_B4': 'B4',
|
584 |
+
'HGNetv2_B5': 'B5',
|
585 |
+
'HGNetv2_B6': 'B6'
|
586 |
+
}
|
587 |
+
return HGNetv2(
|
588 |
+
name[args.backbone],
|
589 |
+
return_idx=args.return_interm_indices,
|
590 |
+
freeze_at=-1,
|
591 |
+
freeze_norm=args.freeze_norm,
|
592 |
+
freeze_stem_only= args.freeze_stem_only,
|
593 |
+
use_lab = args.use_lab,
|
594 |
+
pretrained = args.pretrained,
|
595 |
+
)
|
linea/models/linea/hybrid_encoder.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified from D-FINE (https://github.com/Peterande/D-FINE)
|
3 |
+
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
|
4 |
+
---------------------------------------------------------------------------------
|
5 |
+
Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
6 |
+
Copyright (c) 2023 lyuwenyu. All Rights Reserved.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import copy
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn, Tensor
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
|
17 |
+
def get_activation(act: str, inpace: bool=True):
|
18 |
+
"""get activation
|
19 |
+
"""
|
20 |
+
if act is None:
|
21 |
+
return nn.Identity()
|
22 |
+
|
23 |
+
elif isinstance(act, nn.Module):
|
24 |
+
return act
|
25 |
+
|
26 |
+
act = act.lower()
|
27 |
+
|
28 |
+
if act == 'silu' or act == 'swish':
|
29 |
+
m = nn.SiLU()
|
30 |
+
|
31 |
+
elif act == 'relu':
|
32 |
+
m = nn.ReLU()
|
33 |
+
|
34 |
+
elif act == 'leaky_relu':
|
35 |
+
m = nn.LeakyReLU()
|
36 |
+
|
37 |
+
elif act == 'silu':
|
38 |
+
m = nn.SiLU()
|
39 |
+
|
40 |
+
elif act == 'gelu':
|
41 |
+
m = nn.GELU()
|
42 |
+
|
43 |
+
elif act == 'hardsigmoid':
|
44 |
+
m = nn.Hardsigmoid()
|
45 |
+
|
46 |
+
else:
|
47 |
+
raise RuntimeError('')
|
48 |
+
|
49 |
+
if hasattr(m, 'inplace'):
|
50 |
+
m.inplace = inpace
|
51 |
+
|
52 |
+
return m
|
53 |
+
|
54 |
+
class ConvNormLayer_fuse(nn.Module):
|
55 |
+
def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None):
|
56 |
+
super().__init__()
|
57 |
+
padding = (kernel_size-1)//2 if padding is None else padding
|
58 |
+
self.conv = nn.Conv2d(
|
59 |
+
ch_in,
|
60 |
+
ch_out,
|
61 |
+
kernel_size,
|
62 |
+
stride,
|
63 |
+
groups=g,
|
64 |
+
padding=padding,
|
65 |
+
bias=bias)
|
66 |
+
self.norm = nn.BatchNorm2d(ch_out)
|
67 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
68 |
+
self.ch_in, self.ch_out, self.kernel_size, self.stride, self.g, self.padding, self.bias = \
|
69 |
+
ch_in, ch_out, kernel_size, stride, g, padding, bias
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if hasattr(self, 'conv_bn_fused'):
|
73 |
+
y = self.conv_bn_fused(x)
|
74 |
+
else:
|
75 |
+
y = self.norm(self.conv(x))
|
76 |
+
return self.act(y)
|
77 |
+
|
78 |
+
def convert_to_deploy(self):
|
79 |
+
if not hasattr(self, 'conv_bn_fused'):
|
80 |
+
self.conv_bn_fused = nn.Conv2d(
|
81 |
+
self.ch_in,
|
82 |
+
self.ch_out,
|
83 |
+
self.kernel_size,
|
84 |
+
self.stride,
|
85 |
+
groups=self.g,
|
86 |
+
padding=self.padding,
|
87 |
+
bias=True)
|
88 |
+
|
89 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
90 |
+
self.conv_bn_fused.weight.data = kernel
|
91 |
+
self.conv_bn_fused.bias.data = bias
|
92 |
+
self.__delattr__('conv')
|
93 |
+
self.__delattr__('norm')
|
94 |
+
|
95 |
+
def get_equivalent_kernel_bias(self):
|
96 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor()
|
97 |
+
|
98 |
+
return kernel3x3, bias3x3
|
99 |
+
|
100 |
+
def _fuse_bn_tensor(self):
|
101 |
+
kernel = self.conv.weight
|
102 |
+
running_mean = self.norm.running_mean
|
103 |
+
running_var = self.norm.running_var
|
104 |
+
gamma = self.norm.weight
|
105 |
+
beta = self.norm.bias
|
106 |
+
eps = self.norm.eps
|
107 |
+
std = (running_var + eps).sqrt()
|
108 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
109 |
+
return kernel * t, beta - running_mean * gamma / std
|
110 |
+
|
111 |
+
|
112 |
+
class ConvNormLayer(nn.Module):
|
113 |
+
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
|
114 |
+
super().__init__()
|
115 |
+
self.conv = nn.Conv2d(
|
116 |
+
ch_in,
|
117 |
+
ch_out,
|
118 |
+
kernel_size,
|
119 |
+
stride,
|
120 |
+
padding=(kernel_size-1)//2 if padding is None else padding,
|
121 |
+
bias=bias)
|
122 |
+
self.norm = nn.BatchNorm2d(ch_out)
|
123 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
return self.act(self.norm(self.conv(x)))
|
127 |
+
|
128 |
+
class SCDown(nn.Module):
|
129 |
+
def __init__(self, c1, c2, k, s):
|
130 |
+
super().__init__()
|
131 |
+
self.cv1 = ConvNormLayer_fuse(c1, c2, 1, 1)
|
132 |
+
self.cv2 = ConvNormLayer_fuse(c2, c2, k, s, c2)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
return self.cv2(self.cv1(x))
|
136 |
+
|
137 |
+
class VGGBlock(nn.Module):
|
138 |
+
def __init__(self, ch_in, ch_out, act='relu'):
|
139 |
+
super().__init__()
|
140 |
+
self.ch_in = ch_in
|
141 |
+
self.ch_out = ch_out
|
142 |
+
assert ch_out % 2 == 0
|
143 |
+
self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
|
144 |
+
self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
|
145 |
+
# self.conv1H = ConvNormLayer(ch_in, ch_out//2 , (3, 1), 1, padding=(1, 0), act=None)
|
146 |
+
# self.conv1W = ConvNormLayer(ch_in, ch_out//2, (1, 3), 1, padding=(0, 1), act=None)
|
147 |
+
# self.conv2H = ConvNormLayer(ch_in, ch_out//2, 1, 1, padding=0, act=None)
|
148 |
+
# self.conv2W = ConvNormLayer(ch_in, ch_out//2, 1, 1, padding=0, act=None)
|
149 |
+
# self.conv3 = ConvNormLayer(ch_out, ch_out, 1, 1, padding=0, act=None)
|
150 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
if hasattr(self, 'conv'):
|
154 |
+
y = self.conv(x)
|
155 |
+
else:
|
156 |
+
y = self.conv1(x) + self.conv2(x)
|
157 |
+
|
158 |
+
return self.act(y)
|
159 |
+
|
160 |
+
def convert_to_deploy(self):
|
161 |
+
if not hasattr(self, 'conv'):
|
162 |
+
self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
|
163 |
+
|
164 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
165 |
+
self.conv.weight.data = kernel
|
166 |
+
self.conv.bias.data = bias
|
167 |
+
# self.__delattr__('conv1')
|
168 |
+
# self.__delattr__('conv2')
|
169 |
+
|
170 |
+
def get_equivalent_kernel_bias(self):
|
171 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
|
172 |
+
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
|
173 |
+
|
174 |
+
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
|
175 |
+
|
176 |
+
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
177 |
+
if kernel1x1 is None:
|
178 |
+
return 0
|
179 |
+
else:
|
180 |
+
return F.pad(kernel1x1, [1, 1, 1, 1])
|
181 |
+
|
182 |
+
def _fuse_bn_tensor(self, branch: ConvNormLayer):
|
183 |
+
if branch is None:
|
184 |
+
return 0, 0
|
185 |
+
kernel = branch.conv.weight
|
186 |
+
running_mean = branch.norm.running_mean
|
187 |
+
running_var = branch.norm.running_var
|
188 |
+
gamma = branch.norm.weight
|
189 |
+
beta = branch.norm.bias
|
190 |
+
eps = branch.norm.eps
|
191 |
+
std = (running_var + eps).sqrt()
|
192 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
193 |
+
return kernel * t, beta - running_mean * gamma / std
|
194 |
+
|
195 |
+
|
196 |
+
class RepNCSPELAN4(nn.Module):
|
197 |
+
# csp-elan
|
198 |
+
def __init__(self, c1, c2, c3, c4, n=3,
|
199 |
+
bias=False,
|
200 |
+
act="silu"):
|
201 |
+
super().__init__()
|
202 |
+
self.c = c3//2
|
203 |
+
self.cv1 = ConvNormLayer_fuse(c1, c3, 1, 1, bias=bias, act=act)
|
204 |
+
self.cv2 = nn.Sequential(CSPLayer(c3//2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act))
|
205 |
+
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act))
|
206 |
+
self.cv4 = ConvNormLayer_fuse(c3+(2*c4), c2, 1, 1, bias=bias, act=act)
|
207 |
+
|
208 |
+
def forward_chunk(self, x):
|
209 |
+
y = list(self.cv1(x).chunk(2, 1))
|
210 |
+
y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
|
211 |
+
return self.cv4(torch.cat(y, 1))
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
y = list(self.cv1(x).split((self.c, self.c), 1))
|
215 |
+
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
|
216 |
+
return self.cv4(torch.cat(y, 1))
|
217 |
+
|
218 |
+
|
219 |
+
class CSPLayer(nn.Module):
|
220 |
+
def __init__(self,
|
221 |
+
in_channels,
|
222 |
+
out_channels,
|
223 |
+
num_blocks=3,
|
224 |
+
expansion=1.0,
|
225 |
+
bias=None,
|
226 |
+
act="silu",
|
227 |
+
bottletype=VGGBlock):
|
228 |
+
super(CSPLayer, self).__init__()
|
229 |
+
hidden_channels = int(out_channels * expansion)
|
230 |
+
self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
|
231 |
+
self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
|
232 |
+
self.bottlenecks = nn.Sequential(*[
|
233 |
+
bottletype(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
|
234 |
+
])
|
235 |
+
if hidden_channels != out_channels:
|
236 |
+
self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
|
237 |
+
else:
|
238 |
+
self.conv3 = nn.Identity()
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
x_1 = self.conv1(x)
|
242 |
+
x_1 = self.bottlenecks(x_1)
|
243 |
+
x_2 = self.conv2(x)
|
244 |
+
return self.conv3(x_1 + x_2)
|
245 |
+
|
246 |
+
|
247 |
+
# transformer
|
248 |
+
class TransformerEncoderLayer(nn.Module):
|
249 |
+
def __init__(self,
|
250 |
+
d_model,
|
251 |
+
nhead,
|
252 |
+
dim_feedforward=2048,
|
253 |
+
dropout=0.1,
|
254 |
+
activation="relu",
|
255 |
+
normalize_before=False):
|
256 |
+
super().__init__()
|
257 |
+
self.normalize_before = normalize_before
|
258 |
+
|
259 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)
|
260 |
+
|
261 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
262 |
+
self.dropout = nn.Dropout(dropout)
|
263 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
264 |
+
|
265 |
+
self.norm1 = nn.LayerNorm(d_model)
|
266 |
+
self.norm2 = nn.LayerNorm(d_model)
|
267 |
+
self.dropout1 = nn.Dropout(dropout)
|
268 |
+
self.dropout2 = nn.Dropout(dropout)
|
269 |
+
self.activation = get_activation(activation)
|
270 |
+
|
271 |
+
@staticmethod
|
272 |
+
def with_pos_embed(tensor, pos_embed):
|
273 |
+
return tensor if pos_embed is None else tensor + pos_embed
|
274 |
+
|
275 |
+
def forward(self,
|
276 |
+
src,
|
277 |
+
src_mask=None,
|
278 |
+
src_key_padding_mask=None,
|
279 |
+
pos_embed=None) -> torch.Tensor:
|
280 |
+
residual = src
|
281 |
+
if self.normalize_before:
|
282 |
+
src = self.norm1(src)
|
283 |
+
q = k = self.with_pos_embed(src, pos_embed)
|
284 |
+
src, _ = self.self_attn(q, k,
|
285 |
+
value=src,
|
286 |
+
attn_mask=src_mask,
|
287 |
+
key_padding_mask=src_key_padding_mask)
|
288 |
+
|
289 |
+
src = residual + self.dropout1(src)
|
290 |
+
if not self.normalize_before:
|
291 |
+
src = self.norm1(src)
|
292 |
+
|
293 |
+
residual = src
|
294 |
+
if self.normalize_before:
|
295 |
+
src = self.norm2(src)
|
296 |
+
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
297 |
+
src = residual + self.dropout2(src)
|
298 |
+
if not self.normalize_before:
|
299 |
+
src = self.norm2(src)
|
300 |
+
return src
|
301 |
+
|
302 |
+
|
303 |
+
class TransformerEncoder(nn.Module):
|
304 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
305 |
+
super(TransformerEncoder, self).__init__()
|
306 |
+
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
|
307 |
+
self.num_layers = num_layers
|
308 |
+
self.norm = norm
|
309 |
+
|
310 |
+
def forward(self,
|
311 |
+
src,
|
312 |
+
src_mask=None,
|
313 |
+
src_key_padding_mask=None,
|
314 |
+
pos_embed=None) -> torch.Tensor:
|
315 |
+
output = src
|
316 |
+
for layer in self.layers:
|
317 |
+
output = layer(output,
|
318 |
+
src_mask=src_mask,
|
319 |
+
src_key_padding_mask=src_key_padding_mask,
|
320 |
+
pos_embed=pos_embed)
|
321 |
+
|
322 |
+
if self.norm is not None:
|
323 |
+
output = self.norm(output)
|
324 |
+
|
325 |
+
return output
|
326 |
+
|
327 |
+
|
328 |
+
class HybridEncoder(nn.Module):
|
329 |
+
def __init__(self,
|
330 |
+
n_levels=3,
|
331 |
+
hidden_dim=256,
|
332 |
+
nhead=8,
|
333 |
+
dim_feedforward = 1024,
|
334 |
+
dropout=0.0,
|
335 |
+
enc_act='gelu',
|
336 |
+
use_encoder_idx=[2],
|
337 |
+
num_encoder_layers=1,
|
338 |
+
expansion=1.0,
|
339 |
+
depth_mult=1.0,
|
340 |
+
act='silu',
|
341 |
+
eval_spatial_size=None
|
342 |
+
):
|
343 |
+
super().__init__()
|
344 |
+
self.n_levels = n_levels
|
345 |
+
self.hidden_dim = hidden_dim
|
346 |
+
self.use_encoder_idx = use_encoder_idx
|
347 |
+
self.num_encoder_layers = num_encoder_layers
|
348 |
+
self.eval_spatial_size = eval_spatial_size
|
349 |
+
|
350 |
+
# encoder transformer
|
351 |
+
encoder_layer = TransformerEncoderLayer(
|
352 |
+
hidden_dim,
|
353 |
+
nhead=nhead,
|
354 |
+
dim_feedforward=dim_feedforward,
|
355 |
+
dropout=dropout,
|
356 |
+
activation=enc_act)
|
357 |
+
|
358 |
+
self.encoder = TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers)
|
359 |
+
|
360 |
+
# top-down fpn
|
361 |
+
self.lateral_convs = nn.ModuleList()
|
362 |
+
self.fpn_blocks = nn.ModuleList()
|
363 |
+
for _ in range(n_levels - 1, 0, -1):
|
364 |
+
self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
|
365 |
+
self.fpn_blocks.append(
|
366 |
+
RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult))
|
367 |
+
# CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
|
368 |
+
)
|
369 |
+
|
370 |
+
# bottom-up pan
|
371 |
+
self.downsample_convs = nn.ModuleList()
|
372 |
+
self.pan_blocks = nn.ModuleList()
|
373 |
+
for _ in range(n_levels - 1):
|
374 |
+
self.downsample_convs.append(nn.Sequential(
|
375 |
+
SCDown(hidden_dim, hidden_dim, 3, 2),
|
376 |
+
)
|
377 |
+
)
|
378 |
+
self.pan_blocks.append(
|
379 |
+
RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult))
|
380 |
+
# CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
|
381 |
+
)
|
382 |
+
|
383 |
+
# self._reset_parameters()
|
384 |
+
|
385 |
+
# def _reset_parameters(self):
|
386 |
+
# if self.eval_spatial_size:
|
387 |
+
# for idx in self.use_encoder_idx:
|
388 |
+
# stride = self.feat_strides[idx]
|
389 |
+
# pos_embed = self.build_2d_sincos_position_embedding(
|
390 |
+
# self.eval_spatial_size[1] // stride, self.eval_spatial_size[0] // stride,
|
391 |
+
# self.hidden_dim, self.pe_temperature)
|
392 |
+
# setattr(self, f'pos_embed{idx}', pos_embed)
|
393 |
+
# # self.register_buffer(f'pos_embed{idx}', pos_embed)
|
394 |
+
|
395 |
+
def forward(self,
|
396 |
+
src: Tensor,
|
397 |
+
pos: Tensor,
|
398 |
+
spatial_shapes: Tensor,
|
399 |
+
level_start_index: Tensor,
|
400 |
+
valid_ratios: Tensor,
|
401 |
+
key_padding_mask: Tensor,
|
402 |
+
ref_token_index: Optional[Tensor]=None,
|
403 |
+
ref_token_coord: Optional[Tensor]=None
|
404 |
+
):
|
405 |
+
"""
|
406 |
+
Input:
|
407 |
+
- src: [bs, sum(hi*wi), 256]
|
408 |
+
- pos: pos embed for src. [bs, sum(hi*wi), 256]
|
409 |
+
- spatial_shapes: h,w of each level [num_level, 2]
|
410 |
+
- level_start_index: [num_level] start point of level in sum(hi*wi).
|
411 |
+
- valid_ratios: [bs, num_level, 2]
|
412 |
+
- key_padding_mask: [bs, sum(hi*wi)]
|
413 |
+
|
414 |
+
- ref_token_index: bs, nq
|
415 |
+
- ref_token_coord: bs, nq, 4
|
416 |
+
Intermedia:
|
417 |
+
- reference_points: [bs, sum(hi*wi), num_level, 2]
|
418 |
+
Outpus:
|
419 |
+
- output: [bs, sum(hi*wi), 256]
|
420 |
+
"""
|
421 |
+
src_list = src.split([H_ * W_ for H_, W_ in spatial_shapes], dim=1)
|
422 |
+
pos_ = pos[:, level_start_index[-1]:]
|
423 |
+
key_padding_mask_ = key_padding_mask[:, level_start_index[-1]:]
|
424 |
+
|
425 |
+
memory = self.encoder(src_list[-1], pos_embed=pos_, src_key_padding_mask=key_padding_mask_)
|
426 |
+
|
427 |
+
c = src.size(2)
|
428 |
+
proj_feats = []
|
429 |
+
for i, (H, W) in enumerate(spatial_shapes):
|
430 |
+
if i == len(spatial_shapes) - 1:
|
431 |
+
proj_feats.append(memory.reshape(-1, H, W, c).permute(0, 3, 1, 2))
|
432 |
+
continue
|
433 |
+
proj_feats.append(src_list[i].reshape(-1, H, W, c).permute(0, 3, 1, 2))
|
434 |
+
|
435 |
+
# broadcasting and fusion
|
436 |
+
inner_outs = [proj_feats[-1]]
|
437 |
+
for idx in range(self.n_levels - 1, 0, -1):
|
438 |
+
feat_high = inner_outs[0]
|
439 |
+
feat_low = proj_feats[idx - 1]
|
440 |
+
feat_high = self.lateral_convs[self.n_levels - 1 - idx](feat_high)
|
441 |
+
inner_outs[0] = feat_high
|
442 |
+
upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest')
|
443 |
+
inner_out = self.fpn_blocks[self.n_levels-1-idx](torch.concat([upsample_feat, feat_low], dim=1))
|
444 |
+
inner_outs.insert(0, inner_out)
|
445 |
+
|
446 |
+
outs = [inner_outs[0]]
|
447 |
+
for idx in range(self.n_levels - 1):
|
448 |
+
feat_low = outs[-1]
|
449 |
+
feat_high = inner_outs[idx + 1]
|
450 |
+
downsample_feat = self.downsample_convs[idx](feat_low)
|
451 |
+
out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_high], dim=1))
|
452 |
+
outs.append(out)
|
453 |
+
|
454 |
+
for i in range(len(outs)):
|
455 |
+
outs[i] = outs[i].flatten(2).permute(0, 2, 1)
|
456 |
+
|
457 |
+
return torch.cat(outs, dim=1), None, None
|
458 |
+
|
459 |
+
def build_hybrid_encoder(args):
|
460 |
+
return HybridEncoder(
|
461 |
+
n_levels=args.num_feature_levels,
|
462 |
+
hidden_dim=args.hidden_dim,
|
463 |
+
nhead=args.nheads,
|
464 |
+
dim_feedforward = args.dim_feedforward,
|
465 |
+
dropout=args.dropout,
|
466 |
+
enc_act='gelu',
|
467 |
+
# pe_temperature=10000,
|
468 |
+
expansion=args.expansion,
|
469 |
+
depth_mult=args.depth_mult,
|
470 |
+
act='silu',
|
471 |
+
)
|
linea/models/linea/hybrid_encoder_asymmetric_conv.py
ADDED
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified from D-FINE (https://github.com/Peterande/D-FINE)
|
3 |
+
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
|
4 |
+
---------------------------------------------------------------------------------
|
5 |
+
Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
6 |
+
Copyright (c) 2023 lyuwenyu. All Rights Reserved.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import copy
|
10 |
+
import math
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn, Tensor
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
|
18 |
+
def get_activation(act: str, inpace: bool=True):
|
19 |
+
"""get activation
|
20 |
+
"""
|
21 |
+
if act is None:
|
22 |
+
return nn.Identity()
|
23 |
+
|
24 |
+
elif isinstance(act, nn.Module):
|
25 |
+
return act
|
26 |
+
|
27 |
+
act = act.lower()
|
28 |
+
|
29 |
+
if act == 'silu' or act == 'swish':
|
30 |
+
m = nn.SiLU()
|
31 |
+
|
32 |
+
elif act == 'relu':
|
33 |
+
m = nn.ReLU()
|
34 |
+
|
35 |
+
elif act == 'leaky_relu':
|
36 |
+
m = nn.LeakyReLU()
|
37 |
+
|
38 |
+
elif act == 'silu':
|
39 |
+
m = nn.SiLU()
|
40 |
+
|
41 |
+
elif act == 'gelu':
|
42 |
+
m = nn.GELU()
|
43 |
+
|
44 |
+
elif act == 'hardsigmoid':
|
45 |
+
m = nn.Hardsigmoid()
|
46 |
+
|
47 |
+
else:
|
48 |
+
raise RuntimeError('')
|
49 |
+
|
50 |
+
if hasattr(m, 'inplace'):
|
51 |
+
m.inplace = inpace
|
52 |
+
|
53 |
+
return m
|
54 |
+
|
55 |
+
class CBLinear(nn.Module):
|
56 |
+
def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): # ch_in, ch_outs, kernel, stride, padding, groups
|
57 |
+
super(CBLinear, self).__init__()
|
58 |
+
self.c2s = c2s
|
59 |
+
self.conv = nn.Conv2d(c1, sum(c2s), k, s, 0, groups=g, bias=True)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
outs = self.conv(x).split(self.c2s, dim=1)
|
63 |
+
return outs
|
64 |
+
|
65 |
+
class CBFuse(nn.Module):
|
66 |
+
def __init__(self, idx):
|
67 |
+
super(CBFuse, self).__init__()
|
68 |
+
self.idx = idx
|
69 |
+
|
70 |
+
def forward(self, xs):
|
71 |
+
target_size = xs[-1].shape[2:]
|
72 |
+
res = [F.interpolate(x[self.idx[i]], size=target_size, mode='nearest') for i, x in enumerate(xs[:-1])]
|
73 |
+
out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
|
74 |
+
return out
|
75 |
+
|
76 |
+
class ConvNormLayer_fuse(nn.Module):
|
77 |
+
def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None):
|
78 |
+
super().__init__()
|
79 |
+
padding = (kernel_size-1)//2 if padding is None else padding
|
80 |
+
self.conv = nn.Conv2d(
|
81 |
+
ch_in,
|
82 |
+
ch_out,
|
83 |
+
kernel_size,
|
84 |
+
stride,
|
85 |
+
groups=g,
|
86 |
+
padding=padding,
|
87 |
+
bias=bias)
|
88 |
+
self.norm = nn.BatchNorm2d(ch_out)
|
89 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
90 |
+
self.ch_in, self.ch_out, self.kernel_size, self.stride, self.g, self.padding, self.bias = \
|
91 |
+
ch_in, ch_out, kernel_size, stride, g, padding, bias
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
if hasattr(self, 'conv_bn_fused'):
|
95 |
+
y = self.conv_bn_fused(x)
|
96 |
+
else:
|
97 |
+
y = self.norm(self.conv(x))
|
98 |
+
return self.act(y)
|
99 |
+
|
100 |
+
def convert_to_deploy(self):
|
101 |
+
if not hasattr(self, 'conv_bn_fused'):
|
102 |
+
self.conv_bn_fused = nn.Conv2d(
|
103 |
+
self.ch_in,
|
104 |
+
self.ch_out,
|
105 |
+
self.kernel_size,
|
106 |
+
self.stride,
|
107 |
+
groups=self.g,
|
108 |
+
padding=self.padding,
|
109 |
+
bias=True)
|
110 |
+
|
111 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
112 |
+
self.conv_bn_fused.weight.data = kernel
|
113 |
+
self.conv_bn_fused.bias.data = bias
|
114 |
+
self.__delattr__('conv')
|
115 |
+
self.__delattr__('norm')
|
116 |
+
|
117 |
+
def get_equivalent_kernel_bias(self):
|
118 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor()
|
119 |
+
|
120 |
+
return kernel3x3, bias3x3
|
121 |
+
|
122 |
+
def _fuse_bn_tensor(self):
|
123 |
+
kernel = self.conv.weight
|
124 |
+
running_mean = self.norm.running_mean
|
125 |
+
running_var = self.norm.running_var
|
126 |
+
gamma = self.norm.weight
|
127 |
+
beta = self.norm.bias
|
128 |
+
eps = self.norm.eps
|
129 |
+
std = (running_var + eps).sqrt()
|
130 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
131 |
+
return kernel * t, beta - running_mean * gamma / std
|
132 |
+
|
133 |
+
|
134 |
+
class ConvNormLayer(nn.Module):
|
135 |
+
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
|
136 |
+
super().__init__()
|
137 |
+
self.conv = nn.Conv2d(
|
138 |
+
ch_in,
|
139 |
+
ch_out,
|
140 |
+
kernel_size,
|
141 |
+
stride,
|
142 |
+
padding=(kernel_size-1)//2 if padding is None else padding,
|
143 |
+
bias=bias)
|
144 |
+
self.norm = nn.BatchNorm2d(ch_out)
|
145 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
return self.act(self.norm(self.conv(x)))
|
149 |
+
|
150 |
+
class SCDown(nn.Module):
|
151 |
+
def __init__(self, c1, c2, k, s):
|
152 |
+
super().__init__()
|
153 |
+
self.cv1 = ConvNormLayer_fuse(c1, c2, 1, 1)
|
154 |
+
self.cv2 = ConvNormLayer_fuse(c2, c2, k, s, c2)
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
return self.cv2(self.cv1(x))
|
158 |
+
|
159 |
+
class VGGBlock(nn.Module):
|
160 |
+
def __init__(self, ch_in, ch_out, act='relu'):
|
161 |
+
super().__init__()
|
162 |
+
self.ch_in = ch_in
|
163 |
+
self.ch_out = ch_out
|
164 |
+
self.convH = ConvNormLayer(ch_in, ch_out, (3, 1), 1, padding=(1, 0), act=None)
|
165 |
+
self.convW = ConvNormLayer(ch_in, ch_out, (1, 3), 1, padding=(0, 1), act=None)
|
166 |
+
self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
|
167 |
+
self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
|
168 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
if hasattr(self, 'conv'):
|
172 |
+
y = self.conv(x)
|
173 |
+
else:
|
174 |
+
y_vertical = self.convH(x)
|
175 |
+
y_horizontal = self.convW(x)
|
176 |
+
y = self.conv1(x) + self.conv2(x) + y_horizontal + y_vertical
|
177 |
+
|
178 |
+
return self.act(y)
|
179 |
+
|
180 |
+
def convert_to_deploy(self):
|
181 |
+
if not hasattr(self, 'conv'):
|
182 |
+
self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
|
183 |
+
|
184 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
185 |
+
self.conv.weight.data = kernel
|
186 |
+
self.conv.bias.data = bias
|
187 |
+
self.__delattr__('conv1')
|
188 |
+
self.__delattr__('conv2')
|
189 |
+
self.__delattr__('convH')
|
190 |
+
self.__delattr__('convW')
|
191 |
+
|
192 |
+
def get_equivalent_kernel_bias(self):
|
193 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
|
194 |
+
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
|
195 |
+
kernel3x1, bias3x1 = self._fuse_bn_tensor(self.convH)
|
196 |
+
kernel1x3, bias1x3 = self._fuse_bn_tensor(self.convW)
|
197 |
+
|
198 |
+
kernel = kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + self._pad_1x1_to_3x3_tensor(kernel3x1, 'Vertical') + self._pad_1x1_to_3x3_tensor(kernel1x3, 'Horizontal')
|
199 |
+
bias = bias3x3 + bias1x1 + bias3x1 + bias1x3
|
200 |
+
return kernel, bias
|
201 |
+
|
202 |
+
def _pad_1x1_to_3x3_tensor(self, kernel1x1, assymetric='None'):
|
203 |
+
if kernel1x1 is None:
|
204 |
+
return 0
|
205 |
+
else:
|
206 |
+
if assymetric == 'None':
|
207 |
+
return F.pad(kernel1x1, [1, 1, 1, 1])
|
208 |
+
elif assymetric == 'Vertical':
|
209 |
+
return F.pad(kernel1x1, [1, 1, 0, 0])
|
210 |
+
elif assymetric == 'Horizontal':
|
211 |
+
return F.pad(kernel1x1, [0, 0, 1, 1])
|
212 |
+
|
213 |
+
def _fuse_bn_tensor(self, branch: ConvNormLayer):
|
214 |
+
if branch is None:
|
215 |
+
return 0, 0
|
216 |
+
kernel = branch.conv.weight
|
217 |
+
running_mean = branch.norm.running_mean
|
218 |
+
running_var = branch.norm.running_var
|
219 |
+
gamma = branch.norm.weight
|
220 |
+
beta = branch.norm.bias
|
221 |
+
eps = branch.norm.eps
|
222 |
+
std = (running_var + eps).sqrt()
|
223 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
224 |
+
return kernel * t, beta - running_mean * gamma / std
|
225 |
+
|
226 |
+
|
227 |
+
class RepNCSPELAN4(nn.Module):
|
228 |
+
# csp-elan
|
229 |
+
def __init__(self, c1, c2, c3, c4, n=3,
|
230 |
+
bias=False,
|
231 |
+
act="silu"):
|
232 |
+
super().__init__()
|
233 |
+
self.c = c3//2
|
234 |
+
self.cv1 = ConvNormLayer_fuse(c1, c3, 1, 1, bias=bias, act=act)
|
235 |
+
self.cv2 = nn.Sequential(CSPLayer(c3//2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act))
|
236 |
+
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act))
|
237 |
+
self.cv4 = ConvNormLayer_fuse(c3+(2*c4), c2, 1, 1, bias=bias, act=act)
|
238 |
+
|
239 |
+
def forward_chunk(self, x):
|
240 |
+
y = list(self.cv1(x).chunk(2, 1))
|
241 |
+
y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
|
242 |
+
return self.cv4(torch.cat(y, 1))
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
y = list(self.cv1(x).split((self.c, self.c), 1))
|
246 |
+
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
|
247 |
+
return self.cv4(torch.cat(y, 1))
|
248 |
+
|
249 |
+
|
250 |
+
class CSPLayer(nn.Module):
|
251 |
+
def __init__(self,
|
252 |
+
in_channels,
|
253 |
+
out_channels,
|
254 |
+
num_blocks=3,
|
255 |
+
expansion=1.0,
|
256 |
+
bias=None,
|
257 |
+
act="silu",
|
258 |
+
bottletype=VGGBlock):
|
259 |
+
super(CSPLayer, self).__init__()
|
260 |
+
hidden_channels = int(out_channels * expansion)
|
261 |
+
self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
|
262 |
+
self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
|
263 |
+
self.bottlenecks = nn.Sequential(*[
|
264 |
+
bottletype(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
|
265 |
+
])
|
266 |
+
if hidden_channels != out_channels:
|
267 |
+
self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
|
268 |
+
else:
|
269 |
+
self.conv3 = nn.Identity()
|
270 |
+
|
271 |
+
def forward(self, x):
|
272 |
+
x_1 = self.conv1(x)
|
273 |
+
x_1 = self.bottlenecks(x_1)
|
274 |
+
x_2 = self.conv2(x)
|
275 |
+
return self.conv3(x_1 + x_2)
|
276 |
+
|
277 |
+
|
278 |
+
# transformer
|
279 |
+
class TransformerEncoderLayer(nn.Module):
|
280 |
+
def __init__(self,
|
281 |
+
d_model,
|
282 |
+
nhead,
|
283 |
+
dim_feedforward=2048,
|
284 |
+
dropout=0.1,
|
285 |
+
activation="relu",
|
286 |
+
normalize_before=False):
|
287 |
+
super().__init__()
|
288 |
+
self.normalize_before = normalize_before
|
289 |
+
|
290 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)
|
291 |
+
|
292 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
293 |
+
self.dropout = nn.Dropout(dropout)
|
294 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
295 |
+
|
296 |
+
self.norm1 = nn.LayerNorm(d_model)
|
297 |
+
self.norm2 = nn.LayerNorm(d_model)
|
298 |
+
self.dropout1 = nn.Dropout(dropout)
|
299 |
+
self.dropout2 = nn.Dropout(dropout)
|
300 |
+
self.activation = get_activation(activation)
|
301 |
+
|
302 |
+
@staticmethod
|
303 |
+
def with_pos_embed(tensor, pos_embed):
|
304 |
+
return tensor if pos_embed is None else tensor + pos_embed
|
305 |
+
|
306 |
+
def forward(self,
|
307 |
+
src,
|
308 |
+
src_mask=None,
|
309 |
+
src_key_padding_mask=None,
|
310 |
+
pos_embed=None) -> torch.Tensor:
|
311 |
+
residual = src
|
312 |
+
if self.normalize_before:
|
313 |
+
src = self.norm1(src)
|
314 |
+
q = k = self.with_pos_embed(src, pos_embed)
|
315 |
+
src, _ = self.self_attn(q, k,
|
316 |
+
value=src,
|
317 |
+
attn_mask=src_mask,
|
318 |
+
key_padding_mask=src_key_padding_mask)
|
319 |
+
|
320 |
+
src = residual + self.dropout1(src)
|
321 |
+
if not self.normalize_before:
|
322 |
+
src = self.norm1(src)
|
323 |
+
|
324 |
+
residual = src
|
325 |
+
if self.normalize_before:
|
326 |
+
src = self.norm2(src)
|
327 |
+
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
328 |
+
src = residual + self.dropout2(src)
|
329 |
+
if not self.normalize_before:
|
330 |
+
src = self.norm2(src)
|
331 |
+
return src
|
332 |
+
|
333 |
+
|
334 |
+
class TransformerEncoder(nn.Module):
|
335 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
336 |
+
super(TransformerEncoder, self).__init__()
|
337 |
+
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
|
338 |
+
self.num_layers = num_layers
|
339 |
+
self.norm = norm
|
340 |
+
|
341 |
+
def forward(self,
|
342 |
+
src,
|
343 |
+
src_mask=None,
|
344 |
+
src_key_padding_mask=None,
|
345 |
+
pos_embed=None) -> torch.Tensor:
|
346 |
+
output = src
|
347 |
+
for layer in self.layers:
|
348 |
+
output = layer(output,
|
349 |
+
src_mask=src_mask,
|
350 |
+
src_key_padding_mask=src_key_padding_mask,
|
351 |
+
pos_embed=pos_embed)
|
352 |
+
|
353 |
+
if self.norm is not None:
|
354 |
+
output = self.norm(output)
|
355 |
+
|
356 |
+
return output
|
357 |
+
|
358 |
+
|
359 |
+
class HybridEncoderAsymConv(nn.Module):
|
360 |
+
def __init__(
|
361 |
+
self,
|
362 |
+
in_channels=[512, 1024, 2048],
|
363 |
+
feat_strides=[8, 16, 32],
|
364 |
+
n_levels=3,
|
365 |
+
hidden_dim=256,
|
366 |
+
nhead=8,
|
367 |
+
dim_feedforward = 1024,
|
368 |
+
dropout=0.0,
|
369 |
+
enc_act='gelu',
|
370 |
+
use_encoder_idx=[2],
|
371 |
+
num_encoder_layers=1,
|
372 |
+
expansion=1.0,
|
373 |
+
depth_mult=1.0,
|
374 |
+
act='silu',
|
375 |
+
eval_spatial_size=None,
|
376 |
+
# position embedding
|
377 |
+
temperatureH=20,
|
378 |
+
temperatureW=20,
|
379 |
+
):
|
380 |
+
super().__init__()
|
381 |
+
self.in_channels = in_channels
|
382 |
+
self.feat_strides = feat_strides
|
383 |
+
self.n_levels = n_levels
|
384 |
+
self.hidden_dim = hidden_dim
|
385 |
+
self.use_encoder_idx = use_encoder_idx
|
386 |
+
self.num_encoder_layers = num_encoder_layers
|
387 |
+
self.eval_spatial_size = eval_spatial_size
|
388 |
+
|
389 |
+
self.temperatureW = temperatureW
|
390 |
+
self.temperatureH = temperatureH
|
391 |
+
|
392 |
+
# channel projection
|
393 |
+
input_proj_list = []
|
394 |
+
for in_channel in in_channels:
|
395 |
+
input_proj_list.append(
|
396 |
+
nn.Sequential(
|
397 |
+
nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False),
|
398 |
+
nn.GroupNorm(32, hidden_dim)
|
399 |
+
)
|
400 |
+
)
|
401 |
+
self.input_proj = nn.ModuleList(input_proj_list)
|
402 |
+
|
403 |
+
# encoder transformer
|
404 |
+
encoder_layer = TransformerEncoderLayer(
|
405 |
+
hidden_dim,
|
406 |
+
nhead=nhead,
|
407 |
+
dim_feedforward=dim_feedforward,
|
408 |
+
dropout=dropout,
|
409 |
+
activation=enc_act)
|
410 |
+
|
411 |
+
self.encoder = nn.ModuleList([
|
412 |
+
TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))
|
413 |
+
])
|
414 |
+
|
415 |
+
# top-down fpn
|
416 |
+
self.lateral_convs = nn.ModuleList()
|
417 |
+
self.fpn_blocks = nn.ModuleList()
|
418 |
+
for _ in range(n_levels - 1, 0, -1):
|
419 |
+
self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
|
420 |
+
self.fpn_blocks.append(
|
421 |
+
RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult))
|
422 |
+
# CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
|
423 |
+
)
|
424 |
+
|
425 |
+
# bottom-up pan
|
426 |
+
self.downsample_convs = nn.ModuleList()
|
427 |
+
self.pan_blocks = nn.ModuleList()
|
428 |
+
for _ in range(n_levels - 1):
|
429 |
+
self.downsample_convs.append(nn.Sequential(
|
430 |
+
SCDown(hidden_dim, hidden_dim, 3, 2),
|
431 |
+
)
|
432 |
+
)
|
433 |
+
self.pan_blocks.append(
|
434 |
+
RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult))
|
435 |
+
# CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
|
436 |
+
)
|
437 |
+
|
438 |
+
self._reset_parameters()
|
439 |
+
|
440 |
+
def _reset_parameters(self):
|
441 |
+
# init input_proj
|
442 |
+
for proj in self.input_proj:
|
443 |
+
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
444 |
+
# nn.init.constant_(proj[0].bias, 0)
|
445 |
+
|
446 |
+
if self.eval_spatial_size is not None:
|
447 |
+
for idx in self.use_encoder_idx:
|
448 |
+
stride = self.feat_strides[idx]
|
449 |
+
pos_embed = self.create_sinehw_position_embedding(
|
450 |
+
self.eval_spatial_size[1] // stride,
|
451 |
+
self.eval_spatial_size[0] // stride,
|
452 |
+
self.hidden_dim // 2
|
453 |
+
)
|
454 |
+
setattr(self, f'pos_embed{idx}', pos_embed)
|
455 |
+
|
456 |
+
def create_sinehw_position_embedding(self, w, h, hidden_dim, scale=None, device='cpu'):
|
457 |
+
"""
|
458 |
+
"""
|
459 |
+
grid_w = torch.arange(1, int(w)+1, dtype=torch.float32, device=device)
|
460 |
+
grid_h = torch.arange(1, int(h)+1, dtype=torch.float32, device=device)
|
461 |
+
|
462 |
+
grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing='ij')
|
463 |
+
|
464 |
+
if scale is None:
|
465 |
+
scale = 2 * math.pi
|
466 |
+
|
467 |
+
eps = 1e-6
|
468 |
+
grid_w = grid_w / (int(w) + eps) * scale
|
469 |
+
grid_h = grid_h / (int(h) + eps) * scale
|
470 |
+
|
471 |
+
dim_tx = torch.arange(hidden_dim, dtype=torch.float32, device=device)
|
472 |
+
dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / hidden_dim)
|
473 |
+
pos_x = grid_w[..., None] / dim_tx
|
474 |
+
|
475 |
+
dim_ty = torch.arange(hidden_dim, dtype=torch.float32, device=device)
|
476 |
+
dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / hidden_dim)
|
477 |
+
pos_y = grid_h[..., None] / dim_ty
|
478 |
+
|
479 |
+
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
480 |
+
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
481 |
+
|
482 |
+
pos = torch.cat((pos_y, pos_x), dim=2).permute(2, 0, 1)
|
483 |
+
pos = pos[None].flatten(2).permute(0, 2, 1).contiguous()
|
484 |
+
|
485 |
+
return pos
|
486 |
+
|
487 |
+
def forward(self, feats):
|
488 |
+
"""
|
489 |
+
Input:
|
490 |
+
- feats: List of features from the backbone
|
491 |
+
Outpus:
|
492 |
+
- output: List of enhanced features
|
493 |
+
"""
|
494 |
+
assert len(feats) == len(self.in_channels)
|
495 |
+
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
|
496 |
+
|
497 |
+
# encoder
|
498 |
+
for i, enc_idx in enumerate(self.use_encoder_idx):
|
499 |
+
N_, C_, H_, W_ = proj_feats[enc_idx].shape
|
500 |
+
# flatten [B, C, H, W] to [B, HxW, C]
|
501 |
+
src_flatten = proj_feats[enc_idx].flatten(2).permute(0, 2, 1)
|
502 |
+
if self.training or self.eval_spatial_size is None:
|
503 |
+
pos_embed = self.create_sinehw_position_embedding(
|
504 |
+
H_, W_, self.hidden_dim//2, device=src_flatten.device)
|
505 |
+
else:
|
506 |
+
pos_embed = getattr(self, f'pos_embed{enc_idx}', None).to(src_flatten.device)
|
507 |
+
|
508 |
+
proj_feats[enc_idx] = self.encoder[i](
|
509 |
+
src_flatten,
|
510 |
+
pos_embed=pos_embed
|
511 |
+
).permute(0, 2, 1).reshape(N_, C_, H_, W_).contiguous()
|
512 |
+
|
513 |
+
# broadcasting and fusion
|
514 |
+
inner_outs = [proj_feats[-1]]
|
515 |
+
for idx in range(self.n_levels - 1, 0, -1):
|
516 |
+
feat_high = inner_outs[0]
|
517 |
+
feat_low = proj_feats[idx - 1]
|
518 |
+
feat_high = self.lateral_convs[self.n_levels - 1 - idx](feat_high)
|
519 |
+
inner_outs[0] = feat_high
|
520 |
+
upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest')
|
521 |
+
inner_out = self.fpn_blocks[self.n_levels-1-idx](torch.concat([upsample_feat, feat_low], dim=1))
|
522 |
+
inner_outs.insert(0, inner_out)
|
523 |
+
|
524 |
+
outs = [inner_outs[0]]
|
525 |
+
for idx in range(self.n_levels - 1):
|
526 |
+
feat_low = outs[-1]
|
527 |
+
feat_high = inner_outs[idx + 1]
|
528 |
+
downsample_feat = self.downsample_convs[idx](feat_low)
|
529 |
+
out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_high], dim=1))
|
530 |
+
outs.append(out)
|
531 |
+
return outs
|
532 |
+
|
533 |
+
def build_hybrid_encoder_with_asymmetric_conv(args):
|
534 |
+
return HybridEncoderAsymConv(
|
535 |
+
in_channels=args.in_channels_encoder,
|
536 |
+
feat_strides=args.feat_strides,
|
537 |
+
n_levels=args.num_feature_levels,
|
538 |
+
hidden_dim=args.hidden_dim,
|
539 |
+
nhead=args.nheads,
|
540 |
+
dim_feedforward = args.dim_feedforward,
|
541 |
+
dropout=args.dropout,
|
542 |
+
enc_act='gelu',
|
543 |
+
expansion=args.expansion,
|
544 |
+
depth_mult=args.depth_mult,
|
545 |
+
act='silu',
|
546 |
+
temperatureH=args.pe_temperatureH,
|
547 |
+
temperatureW=args.pe_temperatureW,
|
548 |
+
eval_spatial_size= args.eval_spatial_size,
|
549 |
+
)
|
linea/models/linea/linea.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# DINO
|
3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Conditional DETR model and criterion classes.
|
7 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
8 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
9 |
+
# ------------------------------------------------------------------------
|
10 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
11 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
12 |
+
# ------------------------------------------------------------------------
|
13 |
+
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
|
14 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
import copy
|
17 |
+
import math
|
18 |
+
from typing import List
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
from torchvision.transforms.functional import resize
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
from .utils import sigmoid_focal_loss, MLP
|
26 |
+
|
27 |
+
from ..registry import MODULE_BUILD_FUNCS
|
28 |
+
|
29 |
+
from .hgnetv2 import build_hgnetv2
|
30 |
+
from .hybrid_encoder_asymmetric_conv import build_hybrid_encoder_with_asymmetric_conv
|
31 |
+
from .decoder import build_decoder
|
32 |
+
|
33 |
+
from .linea_utils import *
|
34 |
+
|
35 |
+
class LINEA(nn.Module):
|
36 |
+
""" This is the Cross-Attention Detector module that performs object detection """
|
37 |
+
def __init__(self,
|
38 |
+
backbone,
|
39 |
+
encoder,
|
40 |
+
decoder,
|
41 |
+
# multiscale = None,
|
42 |
+
use_lmap = False
|
43 |
+
):
|
44 |
+
""" Initializes the model.
|
45 |
+
Parameters:
|
46 |
+
backbone: torch module of the backbone to be used. See backbone.py
|
47 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
48 |
+
num_classes: number of object classes
|
49 |
+
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
50 |
+
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
|
51 |
+
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
52 |
+
"""
|
53 |
+
super().__init__()
|
54 |
+
self.backbone = backbone
|
55 |
+
self.encoder = encoder
|
56 |
+
self.decoder = decoder
|
57 |
+
|
58 |
+
# for auxiliary branch
|
59 |
+
if use_lmap:
|
60 |
+
self.aux_branch = nn.ModuleList()
|
61 |
+
hidden_dim = encoder.hidden_dim
|
62 |
+
for i in range(3):
|
63 |
+
n = 2 ** i
|
64 |
+
self.aux_branch.append(nn.Conv2d(hidden_dim, 1, 1))
|
65 |
+
|
66 |
+
def forward(self, samples, targets:List=None):
|
67 |
+
""" The forward expects a NestedTensor, which consists of:
|
68 |
+
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
69 |
+
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
|
70 |
+
|
71 |
+
It returns a dict with the following elements:
|
72 |
+
- "pred_logits": the classification logits (including no-object) for all queries.
|
73 |
+
Shape= [batch_size x num_queries x num_classes]
|
74 |
+
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
|
75 |
+
(center_x, center_y, width, height). These values are normalized in [0, 1],
|
76 |
+
relative to the size of each individual image (disregarding possible padding).
|
77 |
+
See PostProcess for information on how to retrieve the unnormalized bounding box.
|
78 |
+
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
|
79 |
+
dictionnaries containing the two above keys for each decoder layer.
|
80 |
+
"""
|
81 |
+
features = self.backbone(samples)
|
82 |
+
|
83 |
+
features = self.encoder(features)
|
84 |
+
|
85 |
+
out = self.decoder(features, targets)
|
86 |
+
|
87 |
+
if self.training and hasattr(self, 'aux_branch'):
|
88 |
+
lmaps = []
|
89 |
+
for feat, convs in zip(features, self.aux_branch):
|
90 |
+
lmap = convs(feat)
|
91 |
+
lmaps.append(lmap)
|
92 |
+
# lmaps = torch.cat(lmaps, dim=1)
|
93 |
+
out['aux_lmap'] = lmaps
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
def deploy(self, ):
|
98 |
+
self.eval()
|
99 |
+
for m in self.modules():
|
100 |
+
if hasattr(m, 'convert_to_deploy'):
|
101 |
+
m.convert_to_deploy()
|
102 |
+
return self
|
103 |
+
|
104 |
+
|
105 |
+
class PostProcess(nn.Module):
|
106 |
+
""" This module converts the model's output into the format expected by the coco api"""
|
107 |
+
def __init__(self) -> None:
|
108 |
+
super().__init__()
|
109 |
+
self.deploy_mode = False
|
110 |
+
|
111 |
+
@torch.no_grad()
|
112 |
+
def forward(self, outputs, target_sizes):
|
113 |
+
""" Perform the computation
|
114 |
+
Parameters:
|
115 |
+
outputs: raw outputs of the model
|
116 |
+
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
117 |
+
For evaluation, this must be the original image size (before any data augmentation)
|
118 |
+
For visualization, this should be the image size after data augment, but before padding
|
119 |
+
"""
|
120 |
+
out_logits, out_line = outputs['pred_logits'], outputs['pred_lines']
|
121 |
+
|
122 |
+
scores = out_logits[..., 0].sigmoid()
|
123 |
+
|
124 |
+
# convert to [x0, y0, x1, y1] format
|
125 |
+
lines = out_line * target_sizes.repeat(1, 2).unsqueeze(1)
|
126 |
+
|
127 |
+
if self.deploy_mode:
|
128 |
+
return lines, scores
|
129 |
+
|
130 |
+
results = [{'lines': l, 'scores': s} for s, l in zip(scores, lines)]
|
131 |
+
|
132 |
+
return results
|
133 |
+
|
134 |
+
def deploy(self, ):
|
135 |
+
self.eval()
|
136 |
+
self.deploy_mode = True
|
137 |
+
return self
|
138 |
+
|
139 |
+
@MODULE_BUILD_FUNCS.registe_with_name(module_name='LINEA')
|
140 |
+
def build_linea(args):
|
141 |
+
num_classes = args.num_classes
|
142 |
+
|
143 |
+
backbone = build_hgnetv2(args)
|
144 |
+
encoder = build_hybrid_encoder_with_asymmetric_conv(args)
|
145 |
+
decoder = build_decoder(args)
|
146 |
+
|
147 |
+
model = LINEA(
|
148 |
+
backbone,
|
149 |
+
encoder,
|
150 |
+
decoder,
|
151 |
+
use_lmap=args.use_lmap
|
152 |
+
)
|
153 |
+
|
154 |
+
postprocessors = PostProcess()
|
155 |
+
|
156 |
+
return model, postprocessors
|
linea/models/linea/linea_utils.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def weighting_function(reg_max, up, reg_scale, deploy=False):
|
5 |
+
"""
|
6 |
+
Generates the non-uniform Weighting Function W(n) for bounding box regression.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
reg_max (int): Max number of the discrete bins.
|
10 |
+
up (Tensor): Controls upper bounds of the sequence,
|
11 |
+
where maximum offset is ±up * H / W.
|
12 |
+
reg_scale (float): Controls the curvature of the Weighting Function.
|
13 |
+
Larger values result in flatter weights near the central axis W(reg_max/2)=0
|
14 |
+
and steeper weights at both ends.
|
15 |
+
deploy (bool): If True, uses deployment mode settings.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Tensor: Sequence of Weighting Function.
|
19 |
+
"""
|
20 |
+
if deploy:
|
21 |
+
upper_bound1 = (abs(up[0]) * abs(reg_scale)).item()
|
22 |
+
upper_bound2 = (abs(up[0]) * abs(reg_scale) * 2).item()
|
23 |
+
step = (upper_bound1 + 1) ** (2 / (reg_max - 2))
|
24 |
+
left_values = [-(step) ** i + 1 for i in range(reg_max // 2 - 1, 0, -1)]
|
25 |
+
right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)]
|
26 |
+
values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2]
|
27 |
+
return torch.tensor(values, dtype=up.dtype, device=up.device)
|
28 |
+
else:
|
29 |
+
upper_bound1 = abs(up[0]) * abs(reg_scale)
|
30 |
+
upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
|
31 |
+
step = (upper_bound1 + 1) ** (2 / (reg_max - 2))
|
32 |
+
left_values = [-(step) ** i + 1 for i in range(reg_max // 2 - 1, 0, -1)]
|
33 |
+
right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)]
|
34 |
+
values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2]
|
35 |
+
return torch.cat(values, 0)
|
36 |
+
|
37 |
+
|
38 |
+
def translate_gt(gt, reg_max, reg_scale, up):
|
39 |
+
"""
|
40 |
+
Decodes bounding box ground truth (GT) values into distribution-based GT representations.
|
41 |
+
|
42 |
+
This function maps continuous GT values into discrete distribution bins, which can be used
|
43 |
+
for regression tasks in object detection models. It calculates the indices of the closest
|
44 |
+
bins to each GT value and assigns interpolation weights to these bins based on their proximity
|
45 |
+
to the GT value.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
gt (Tensor): Ground truth bounding box values, shape (N, ).
|
49 |
+
reg_max (int): Maximum number of discrete bins for the distribution.
|
50 |
+
reg_scale (float): Controls the curvature of the Weighting Function.
|
51 |
+
up (Tensor): Controls the upper bounds of the Weighting Function.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Tuple[Tensor, Tensor, Tensor]:
|
55 |
+
- indices (Tensor): Index of the left bin closest to each GT value, shape (N, ).
|
56 |
+
- weight_right (Tensor): Weight assigned to the right bin, shape (N, ).
|
57 |
+
- weight_left (Tensor): Weight assigned to the left bin, shape (N, ).
|
58 |
+
"""
|
59 |
+
gt = gt.reshape(-1)
|
60 |
+
function_values = weighting_function(reg_max, up, reg_scale)
|
61 |
+
|
62 |
+
# Find the closest left-side indices for each value
|
63 |
+
diffs = function_values.unsqueeze(0) - gt.unsqueeze(1)
|
64 |
+
mask = diffs <= 0
|
65 |
+
closest_left_indices = torch.sum(mask, dim=1) - 1
|
66 |
+
|
67 |
+
# Calculate the weights for the interpolation
|
68 |
+
indices = closest_left_indices.float()
|
69 |
+
|
70 |
+
weight_right = torch.zeros_like(indices)
|
71 |
+
weight_left = torch.zeros_like(indices)
|
72 |
+
|
73 |
+
valid_idx_mask = (indices >= 0) & (indices < reg_max)
|
74 |
+
valid_indices = indices[valid_idx_mask].long()
|
75 |
+
|
76 |
+
# Obtain distances
|
77 |
+
left_values = function_values[valid_indices]
|
78 |
+
right_values = function_values[valid_indices + 1]
|
79 |
+
|
80 |
+
left_diffs = torch.abs(gt[valid_idx_mask] - left_values)
|
81 |
+
right_diffs = torch.abs(right_values - gt[valid_idx_mask])
|
82 |
+
|
83 |
+
# Valid weights
|
84 |
+
weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs)
|
85 |
+
weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask]
|
86 |
+
|
87 |
+
# Invalid weights (out of range)
|
88 |
+
invalid_idx_mask_neg = (indices < 0)
|
89 |
+
weight_right[invalid_idx_mask_neg] = 0.0
|
90 |
+
weight_left[invalid_idx_mask_neg] = 1.0
|
91 |
+
indices[invalid_idx_mask_neg] = 0.0
|
92 |
+
|
93 |
+
invalid_idx_mask_pos = (indices >= reg_max)
|
94 |
+
weight_right[invalid_idx_mask_pos] = 1.0
|
95 |
+
weight_left[invalid_idx_mask_pos] = 0.0
|
96 |
+
indices[invalid_idx_mask_pos] = reg_max - 0.1
|
97 |
+
|
98 |
+
return indices, weight_right, weight_left
|
99 |
+
|
100 |
+
|
101 |
+
def bbox2distance(points, bbox, reg_max, reg_scale, up, eps=0.1):
|
102 |
+
"""
|
103 |
+
Converts bounding box coordinates to distances from a reference point.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center.
|
107 |
+
bbox (Tensor): (n, 4) bounding boxes in "xyxy" format.
|
108 |
+
reg_max (float): Maximum bin value.
|
109 |
+
reg_scale (float): Controling curvarture of W(n).
|
110 |
+
up (Tensor): Controling upper bounds of W(n).
|
111 |
+
eps (float): Small value to ensure target < reg_max.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Tensor: Decoded distances.
|
115 |
+
"""
|
116 |
+
reg_scale = abs(reg_scale)
|
117 |
+
|
118 |
+
Dx = torch.abs(points[..., 0] - points[..., 2])
|
119 |
+
Dy = torch.abs(points[..., 1] - points[..., 3])
|
120 |
+
|
121 |
+
left = (points[:, 0] - bbox[:, 0]) / (Dx / reg_scale + 1e-16) - 0.5 * reg_scale
|
122 |
+
top = (points[:, 1] - bbox[:, 1]) / (Dy / reg_scale + 1e-16) - 0.5 * reg_scale
|
123 |
+
right = (points[:, 2] - bbox[:, 2]) / (Dx / reg_scale + 1e-16) - 0.5 * reg_scale
|
124 |
+
bottom = (points[:, 3] - bbox[:, 3]) / (Dy / reg_scale + 1e-16) - 0.5 * reg_scale
|
125 |
+
four_lens = torch.stack([left, top, right, bottom], -1)
|
126 |
+
four_lens, weight_right, weight_left = translate_gt(four_lens, reg_max, reg_scale, up)
|
127 |
+
if reg_max is not None:
|
128 |
+
four_lens = four_lens.clamp(min=0, max=reg_max-eps)
|
129 |
+
return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach()
|
130 |
+
|
131 |
+
|
132 |
+
def distance2bbox(points, distance, reg_scale):
|
133 |
+
"""
|
134 |
+
Decodes edge-distances into bounding box coordinates.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
points (Tensor): (B, N, 4) or (N, 4) format, representing [x, y, w, h],
|
138 |
+
where (x, y) is the center and (w, h) are width and height.
|
139 |
+
distance (Tensor): (B, N, 4) or (N, 4), representing distances from the
|
140 |
+
point to the left, top, right, and bottom boundaries.
|
141 |
+
|
142 |
+
reg_scale (float): Controls the curvature of the Weighting Function.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
Tensor: Bounding boxes in (N, 4) or (B, N, 4) format [cx, cy, w, h].
|
146 |
+
"""
|
147 |
+
reg_scale = abs(reg_scale)
|
148 |
+
|
149 |
+
Dx = torch.abs(points[..., 0] - points[..., 2])
|
150 |
+
Dy = torch.abs(points[..., 1] - points[..., 3])
|
151 |
+
|
152 |
+
x1 = points[..., 0] + (0.5 * reg_scale + distance[..., 0]) * (Dx / reg_scale)
|
153 |
+
y1 = points[..., 1] + (0.5 * reg_scale + distance[..., 1]) * (Dy / reg_scale)
|
154 |
+
x2 = points[..., 2] + (0.5 * reg_scale + distance[..., 2]) * (Dx / reg_scale)
|
155 |
+
y2 = points[..., 3] + (0.5 * reg_scale + distance[..., 3]) * (Dy / reg_scale)
|
156 |
+
|
157 |
+
bboxes = torch.stack([x1, y1, x2, y2], -1)
|
158 |
+
|
159 |
+
return bboxes
|
160 |
+
|
161 |
+
def inverse_sigmoid(x, eps=1e-3):
|
162 |
+
x = x.clamp(min=0, max=1)
|
163 |
+
x1 = x.clamp(min=eps)
|
164 |
+
x2 = (1 - x).clamp(min=eps)
|
165 |
+
return torch.log(x1/x2)
|
linea/models/linea/matcher.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# DINO
|
3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Modules to compute the matching cost and solve the corresponding LSAP.
|
7 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
8 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
9 |
+
# ------------------------------------------------------------------------
|
10 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
11 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
12 |
+
# ------------------------------------------------------------------------
|
13 |
+
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
|
14 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
|
17 |
+
|
18 |
+
import torch, os
|
19 |
+
from torch import nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from scipy.optimize import linear_sum_assignment
|
22 |
+
|
23 |
+
|
24 |
+
class HungarianMatcher(nn.Module):
|
25 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
26 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
27 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
28 |
+
while the others are un-matched (and thus treated as non-objects).
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, focal_alpha = 0.25):
|
32 |
+
"""Creates the matcher
|
33 |
+
Params:
|
34 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
35 |
+
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
36 |
+
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
37 |
+
"""
|
38 |
+
super().__init__()
|
39 |
+
self.cost_class = cost_class
|
40 |
+
self.cost_line = cost_bbox
|
41 |
+
assert cost_class != 0 or cost_bbox != 0, "all costs cant be 0"
|
42 |
+
|
43 |
+
self.focal_alpha = focal_alpha
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
def forward(self, outputs, targets):
|
47 |
+
""" Performs the matching
|
48 |
+
Params:
|
49 |
+
outputs: This is a dict that contains at least these entries:
|
50 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
51 |
+
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
52 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
53 |
+
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
54 |
+
objects in the target) containing the class labels
|
55 |
+
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
56 |
+
Returns:
|
57 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
58 |
+
- index_i is the indices of the selected predictions (in order)
|
59 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
60 |
+
For each batch element, it holds:
|
61 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
62 |
+
"""
|
63 |
+
|
64 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
65 |
+
|
66 |
+
# We flatten to compute the cost matrices in a batch
|
67 |
+
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
68 |
+
out_line = outputs["pred_lines"].flatten(0, 1) # [batch_size * num_queries, 4]
|
69 |
+
|
70 |
+
# Also concat the target labels and lines
|
71 |
+
tgt_ids = torch.cat([v["labels"] for v in targets])
|
72 |
+
tgt_line = torch.cat([v["lines"] for v in targets])
|
73 |
+
|
74 |
+
# Compute the classification cost.
|
75 |
+
alpha = self.focal_alpha
|
76 |
+
gamma = 2.0
|
77 |
+
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
|
78 |
+
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
79 |
+
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
|
80 |
+
|
81 |
+
# Compute the L1 cost between boxes
|
82 |
+
cost_line = torch.cdist(out_line, tgt_line, p=1)
|
83 |
+
|
84 |
+
# Final cost matrix
|
85 |
+
C = self.cost_line * cost_line + self.cost_class * cost_class
|
86 |
+
C = C.view(bs, num_queries, -1).cpu()
|
87 |
+
|
88 |
+
sizes = [len(v["lines"]) for v in targets]
|
89 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
90 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
91 |
+
|
92 |
+
|
93 |
+
class SimpleMinsumMatcher(nn.Module):
|
94 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
95 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
96 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
97 |
+
while the others are un-matched (and thus treated as non-objects).
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, focal_alpha = 0.25):
|
101 |
+
"""Creates the matcher
|
102 |
+
Params:
|
103 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
104 |
+
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
105 |
+
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
106 |
+
"""
|
107 |
+
super().__init__()
|
108 |
+
self.cost_class = cost_class
|
109 |
+
self.cost_line = cost_bbox
|
110 |
+
assert cost_class != 0 or cost_bbox != 0, "all costs cant be 0"
|
111 |
+
|
112 |
+
self.focal_alpha = focal_alpha
|
113 |
+
|
114 |
+
@torch.no_grad()
|
115 |
+
def forward(self, outputs, targets):
|
116 |
+
""" Performs the matching
|
117 |
+
Params:
|
118 |
+
outputs: This is a dict that contains at least these entries:
|
119 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
120 |
+
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
121 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
122 |
+
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
123 |
+
objects in the target) containing the class labels
|
124 |
+
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
125 |
+
Returns:
|
126 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
127 |
+
- index_i is the indices of the selected predictions (in order)
|
128 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
129 |
+
For each batch element, it holds:
|
130 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
131 |
+
"""
|
132 |
+
|
133 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
134 |
+
|
135 |
+
# We flatten to compute the cost matrices in a batch
|
136 |
+
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
137 |
+
out_line = outputs["pred_lines"].flatten(0, 1) # [batch_size * num_queries, 4]
|
138 |
+
|
139 |
+
# Also concat the target labels and boxes
|
140 |
+
tgt_ids = torch.cat([v["labels"] for v in targets])
|
141 |
+
tgt_line = torch.cat([v["lines"] for v in targets])
|
142 |
+
|
143 |
+
# Compute the classification cost.
|
144 |
+
alpha = self.focal_alpha
|
145 |
+
gamma = 2.0
|
146 |
+
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
|
147 |
+
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
148 |
+
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
|
149 |
+
|
150 |
+
# Compute the L1 cost between boxes
|
151 |
+
cost_line = torch.cdist(out_line, tgt_line, p=1)
|
152 |
+
|
153 |
+
# Final cost matrix
|
154 |
+
C = self.cost_line * cost_line + self.cost_class * cost_class
|
155 |
+
C = C.view(bs, num_queries, -1)
|
156 |
+
|
157 |
+
sizes = [len(v["lines"]) for v in targets]
|
158 |
+
indices = []
|
159 |
+
device = C.device
|
160 |
+
for i, (c, _size) in enumerate(zip(C.split(sizes, -1), sizes)):
|
161 |
+
weight_mat = c[i]
|
162 |
+
idx_i = weight_mat.min(0)[1]
|
163 |
+
idx_j = torch.arange(_size).to(device)
|
164 |
+
indices.append((idx_i, idx_j))
|
165 |
+
|
166 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
167 |
+
|
168 |
+
|
169 |
+
def build_matcher(args):
|
170 |
+
assert args.matcher_type in ['HungarianMatcher', 'SimpleMinsumMatcher'], "Unknown args.matcher_type: {}".format(args.matcher_type)
|
171 |
+
if args.matcher_type == 'HungarianMatcher':
|
172 |
+
return HungarianMatcher(
|
173 |
+
cost_class=args.set_cost_class, cost_bbox=args.set_cost_lines, focal_alpha=args.focal_alpha
|
174 |
+
)
|
175 |
+
elif args.matcher_type == 'SimpleMinsumMatcher':
|
176 |
+
return SimpleMinsumMatcher(
|
177 |
+
cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, focal_alpha=args.focal_alpha
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
raise NotImplementedError("Unknown args.matcher_type: {}".format(args.matcher_type))
|
linea/models/linea/new_dn_components.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# DINO
|
3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# DN-DETR
|
7 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
8 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from .linea_utils import inverse_sigmoid
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
def prepare_for_cdn(dn_args, training, num_queries, num_classes, hidden_dim, label_enc):
|
16 |
+
"""
|
17 |
+
A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
|
18 |
+
forward function and use learnable tgt embedding, so we change this function a little bit.
|
19 |
+
:param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
|
20 |
+
:param training: if it is training or inference
|
21 |
+
:param num_queries: number of queires
|
22 |
+
:param num_classes: number of classes
|
23 |
+
:param hidden_dim: transformer hidden dim
|
24 |
+
:param label_enc: encode labels in dn
|
25 |
+
:return:
|
26 |
+
"""
|
27 |
+
if training:
|
28 |
+
targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
|
29 |
+
# positive and negative dn queries
|
30 |
+
dn_number = dn_number * 2
|
31 |
+
known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
|
32 |
+
batch_size = len(known)
|
33 |
+
known_num = [sum(k) for k in known]
|
34 |
+
|
35 |
+
if int(max(known_num)) == 0:
|
36 |
+
dn_number = 1
|
37 |
+
else:
|
38 |
+
if dn_number >= 100:
|
39 |
+
dn_number = dn_number // (int(max(known_num) * 2))
|
40 |
+
elif dn_number < 1:
|
41 |
+
dn_number = 1
|
42 |
+
if dn_number == 0:
|
43 |
+
dn_number = 1
|
44 |
+
|
45 |
+
unmask_bbox = unmask_label = torch.cat(known)
|
46 |
+
labels = torch.cat([t['labels'] for t in targets])
|
47 |
+
lines = torch.cat([t['lines'] for t in targets])
|
48 |
+
batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
|
49 |
+
|
50 |
+
known_indice = torch.nonzero(unmask_label + unmask_bbox)
|
51 |
+
known_indice = known_indice.view(-1)
|
52 |
+
|
53 |
+
known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
|
54 |
+
known_labels = labels.repeat(2 * dn_number, 1).view(-1)
|
55 |
+
known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
|
56 |
+
known_lines = lines.repeat(2 * dn_number, 1)
|
57 |
+
|
58 |
+
known_labels_expaned = known_labels.clone()
|
59 |
+
known_lines_expand = known_lines.clone()
|
60 |
+
|
61 |
+
if label_noise_ratio > 0:
|
62 |
+
p = torch.rand_like(known_labels_expaned.float())
|
63 |
+
chosen_indice = torch.nonzero(p < (label_noise_ratio * 0.5)).view(-1) # half of bbox prob
|
64 |
+
new_label = torch.randint_like(chosen_indice, 0, num_classes) # randomly put a new one here
|
65 |
+
known_labels_expaned.scatter_(0, chosen_indice, new_label)
|
66 |
+
|
67 |
+
single_pad = int(max(known_num))
|
68 |
+
|
69 |
+
pad_size = int(single_pad * 2 * dn_number)
|
70 |
+
positive_idx = torch.tensor(range(len(lines))).long().cuda().unsqueeze(0).repeat(dn_number, 1)
|
71 |
+
positive_idx += (torch.tensor(range(dn_number)) * len(lines) * 2).long().cuda().unsqueeze(1)
|
72 |
+
positive_idx = positive_idx.flatten()
|
73 |
+
negative_idx = positive_idx + len(lines)
|
74 |
+
|
75 |
+
known_lines_ = known_lines.clone().unflatten(-1, (2, 2))
|
76 |
+
offsets = F.normalize(2 * torch.rand_like(known_lines_) - 1, dim=-1)
|
77 |
+
rand_part = torch.rand(size=(known_lines.shape[0], 2, 1), device=known_lines.device, dtype=offsets.dtype)
|
78 |
+
|
79 |
+
rand_part[positive_idx] *= 0.005
|
80 |
+
rand_part[negative_idx] *= 0.0645
|
81 |
+
rand_part[negative_idx] += 0.0055
|
82 |
+
|
83 |
+
known_lines_ = known_lines_ + offsets * rand_part
|
84 |
+
known_lines_ = known_lines_.flatten(-2)
|
85 |
+
|
86 |
+
known_lines_expand = known_lines_.clamp(min=0.0, max=1.0)
|
87 |
+
|
88 |
+
# # order: top point > bottom point
|
89 |
+
# # if same y coordinate, right point > left point
|
90 |
+
|
91 |
+
# idx = torch.logical_or(known_lines_expand[..., 0] > known_lines_expand[..., 2],
|
92 |
+
# torch.logical_or(
|
93 |
+
# known_lines_expand[..., 0] == known_lines_expand[..., 2],
|
94 |
+
# known_lines_expand[..., 1] < known_lines_expand[..., 3]
|
95 |
+
# )
|
96 |
+
# )
|
97 |
+
|
98 |
+
# known_lines_expand[idx] = known_lines_expand[idx][:, [2, 3, 0, 1]]
|
99 |
+
|
100 |
+
m = known_labels_expaned.long().to('cuda')
|
101 |
+
input_label_embed = label_enc(m)
|
102 |
+
input_lines_embed = inverse_sigmoid(known_lines_expand)
|
103 |
+
|
104 |
+
padding_label = torch.zeros(pad_size, hidden_dim).cuda()
|
105 |
+
padding_lines = torch.zeros(pad_size, 4).cuda()
|
106 |
+
|
107 |
+
input_query_label = padding_label.repeat(batch_size, 1, 1)
|
108 |
+
input_query_lines = padding_lines.repeat(batch_size, 1, 1)
|
109 |
+
|
110 |
+
map_known_indice = torch.tensor([]).to('cuda')
|
111 |
+
if len(known_num):
|
112 |
+
map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
|
113 |
+
map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
|
114 |
+
|
115 |
+
if len(known_bid):
|
116 |
+
input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed
|
117 |
+
input_query_lines[(known_bid.long(), map_known_indice)] = input_lines_embed
|
118 |
+
|
119 |
+
tgt_size = pad_size + num_queries
|
120 |
+
attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
|
121 |
+
# match query cannot see the reconstruct
|
122 |
+
attn_mask[pad_size:, :pad_size] = True
|
123 |
+
# reconstruct cannot see each other
|
124 |
+
for i in range(dn_number):
|
125 |
+
if i == 0:
|
126 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
|
127 |
+
if i == dn_number - 1:
|
128 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True
|
129 |
+
else:
|
130 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
|
131 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True
|
132 |
+
|
133 |
+
dn_meta = {
|
134 |
+
'pad_size': pad_size,
|
135 |
+
'num_dn_group': dn_number,
|
136 |
+
}
|
137 |
+
else:
|
138 |
+
|
139 |
+
input_query_label = None
|
140 |
+
input_query_lines = None
|
141 |
+
attn_mask = None
|
142 |
+
dn_meta = None
|
143 |
+
|
144 |
+
return input_query_label, input_query_lines, attn_mask, dn_meta
|
145 |
+
|
146 |
+
|
147 |
+
def dn_post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
|
148 |
+
"""
|
149 |
+
post process of dn after output from the transformer
|
150 |
+
put the dn part in the dn_meta
|
151 |
+
"""
|
152 |
+
if dn_meta and dn_meta['pad_size'] > 0:
|
153 |
+
output_known_class = outputs_class[:, :, :dn_meta['pad_size'], :]
|
154 |
+
output_known_coord = outputs_coord[:, :, :dn_meta['pad_size'], :]
|
155 |
+
outputs_class = outputs_class[:, :, dn_meta['pad_size']:, :]
|
156 |
+
outputs_coord = outputs_coord[:, :, dn_meta['pad_size']:, :]
|
157 |
+
out = {'pred_logits': output_known_class[-1], 'pred_lines': output_known_coord[-1]}
|
158 |
+
if aux_loss:
|
159 |
+
out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord)
|
160 |
+
dn_meta['output_known_lbs_lines'] = out
|
161 |
+
return outputs_class, outputs_coord
|
162 |
+
|
163 |
+
|
linea/models/linea/position_encoding.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# DINO
|
3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Conditional DETR
|
7 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
8 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
9 |
+
# ------------------------------------------------------------------------
|
10 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
11 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
12 |
+
# ------------------------------------------------------------------------
|
13 |
+
|
14 |
+
"""
|
15 |
+
Various positional encodings for the transformer.
|
16 |
+
"""
|
17 |
+
import math
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from util.misc import NestedTensor
|
22 |
+
|
23 |
+
|
24 |
+
class PositionEmbeddingSine(nn.Module):
|
25 |
+
"""
|
26 |
+
This is a more standard version of the position embedding, very similar to the one
|
27 |
+
used by the Attention is all you need paper, generalized to work on images.
|
28 |
+
"""
|
29 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
30 |
+
super().__init__()
|
31 |
+
self.num_pos_feats = num_pos_feats
|
32 |
+
self.temperature = temperature
|
33 |
+
self.normalize = normalize
|
34 |
+
if scale is not None and normalize is False:
|
35 |
+
raise ValueError("normalize should be True if scale is passed")
|
36 |
+
if scale is None:
|
37 |
+
scale = 2 * math.pi
|
38 |
+
self.scale = scale
|
39 |
+
|
40 |
+
def forward(self, tensor_list: NestedTensor):
|
41 |
+
x = tensor_list.tensors
|
42 |
+
mask = tensor_list.mask
|
43 |
+
assert mask is not None
|
44 |
+
not_mask = ~mask
|
45 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
46 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
47 |
+
if self.normalize:
|
48 |
+
eps = 1e-6
|
49 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
50 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
51 |
+
|
52 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
53 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
54 |
+
|
55 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
56 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
57 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
58 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
59 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
60 |
+
return pos
|
61 |
+
|
62 |
+
class PositionEmbeddingSineHW(nn.Module):
|
63 |
+
"""
|
64 |
+
This is a more standard version of the position embedding, very similar to the one
|
65 |
+
used by the Attention is all you need paper, generalized to work on images.
|
66 |
+
"""
|
67 |
+
def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
|
68 |
+
super().__init__()
|
69 |
+
self.num_pos_feats = num_pos_feats
|
70 |
+
self.temperatureH = temperatureH
|
71 |
+
self.temperatureW = temperatureW
|
72 |
+
self.normalize = normalize
|
73 |
+
if scale is not None and normalize is False:
|
74 |
+
raise ValueError("normalize should be True if scale is passed")
|
75 |
+
if scale is None:
|
76 |
+
scale = 2 * math.pi
|
77 |
+
self.scale = scale
|
78 |
+
|
79 |
+
def forward(self, tensor_list: NestedTensor):
|
80 |
+
x = tensor_list.tensors
|
81 |
+
mask = tensor_list.mask
|
82 |
+
assert mask is not None
|
83 |
+
not_mask = ~mask
|
84 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
85 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
if self.normalize:
|
90 |
+
eps = 1e-6
|
91 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
92 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
93 |
+
|
94 |
+
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
95 |
+
dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
|
96 |
+
pos_x = x_embed[:, :, :, None] / dim_tx
|
97 |
+
|
98 |
+
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
99 |
+
dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
|
100 |
+
pos_y = y_embed[:, :, :, None] / dim_ty
|
101 |
+
|
102 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
103 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
104 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
105 |
+
return pos
|
106 |
+
|
107 |
+
class PositionEmbeddingLearned(nn.Module):
|
108 |
+
"""
|
109 |
+
Absolute pos embedding, learned.
|
110 |
+
"""
|
111 |
+
def __init__(self, num_pos_feats=256):
|
112 |
+
super().__init__()
|
113 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
114 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
115 |
+
self.reset_parameters()
|
116 |
+
|
117 |
+
def reset_parameters(self):
|
118 |
+
nn.init.uniform_(self.row_embed.weight)
|
119 |
+
nn.init.uniform_(self.col_embed.weight)
|
120 |
+
|
121 |
+
def forward(self, tensor_list: NestedTensor):
|
122 |
+
x = tensor_list.tensors
|
123 |
+
h, w = x.shape[-2:]
|
124 |
+
i = torch.arange(w, device=x.device)
|
125 |
+
j = torch.arange(h, device=x.device)
|
126 |
+
x_emb = self.col_embed(i)
|
127 |
+
y_emb = self.row_embed(j)
|
128 |
+
pos = torch.cat([
|
129 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
130 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
131 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
132 |
+
return pos
|
133 |
+
|
134 |
+
|
135 |
+
def build_position_encoding(args):
|
136 |
+
N_steps = args.hidden_dim // 2
|
137 |
+
if args.position_embedding in ('v2', 'sine'):
|
138 |
+
# TODO find a better way of exposing other arguments
|
139 |
+
position_embedding = PositionEmbeddingSineHW(
|
140 |
+
N_steps,
|
141 |
+
temperatureH=args.pe_temperatureH,
|
142 |
+
temperatureW=args.pe_temperatureW,
|
143 |
+
normalize=True
|
144 |
+
)
|
145 |
+
elif args.position_embedding in ('v3', 'learned'):
|
146 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
147 |
+
else:
|
148 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
149 |
+
|
150 |
+
return position_embedding
|
linea/models/linea/utils.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# DINO
|
3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn, Tensor
|
9 |
+
|
10 |
+
import math
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def gen_encoder_output_proposals(memory:Tensor, spatial_shapes:Tensor):
|
16 |
+
"""
|
17 |
+
Input:
|
18 |
+
- memory: bs, \sum{hw}, d_model
|
19 |
+
- memory_padding_mask: bs, \sum{hw}
|
20 |
+
- spatial_shapes: nlevel, 2
|
21 |
+
- learnedwh: 2
|
22 |
+
Output:
|
23 |
+
- output_memory: bs, \sum{hw}, d_model
|
24 |
+
- output_proposals: bs, \sum{hw}, 4
|
25 |
+
"""
|
26 |
+
N_, S_, C_ = memory.shape
|
27 |
+
base_scale = 4.0
|
28 |
+
proposals = []
|
29 |
+
|
30 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
31 |
+
|
32 |
+
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
33 |
+
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
|
34 |
+
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
|
35 |
+
|
36 |
+
scale = torch.tensor([W_, H_], dtype=torch.float32, device=memory.device).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
|
37 |
+
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
|
38 |
+
|
39 |
+
# wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
|
40 |
+
|
41 |
+
# proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
|
42 |
+
proposal = torch.cat((grid, grid), -1).view(N_, -1, 4)
|
43 |
+
proposals.append(proposal)
|
44 |
+
|
45 |
+
output_proposals = torch.cat(proposals, 1)
|
46 |
+
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
|
47 |
+
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
|
48 |
+
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
|
49 |
+
|
50 |
+
output_memory = memory
|
51 |
+
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
|
52 |
+
|
53 |
+
return output_memory, output_proposals
|
54 |
+
|
55 |
+
|
56 |
+
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
57 |
+
"""
|
58 |
+
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
59 |
+
Args:
|
60 |
+
inputs: A float tensor of arbitrary shape.
|
61 |
+
The predictions for each example.
|
62 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
63 |
+
classification label for each element in inputs
|
64 |
+
(0 for the negative class and 1 for the positive class).
|
65 |
+
alpha: (optional) Weighting factor in range (0,1) to balance
|
66 |
+
positive vs negative examples. Default = -1 (no weighting).
|
67 |
+
gamma: Exponent of the modulating factor (1 - p_t) to
|
68 |
+
balance easy vs hard examples.
|
69 |
+
Returns:
|
70 |
+
Loss tensor
|
71 |
+
"""
|
72 |
+
prob = inputs.sigmoid()
|
73 |
+
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
74 |
+
p_t = prob * targets + (1 - prob) * (1 - targets)
|
75 |
+
loss = ce_loss * ((1 - p_t) ** gamma)
|
76 |
+
|
77 |
+
if alpha >= 0:
|
78 |
+
alpha_t = alpha * targets + (1-alpha) * (1 - targets)
|
79 |
+
loss = alpha_t * loss
|
80 |
+
|
81 |
+
return loss.mean(1).sum() / num_boxes
|
82 |
+
|
83 |
+
|
84 |
+
class MLP(nn.Module):
|
85 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
86 |
+
|
87 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
88 |
+
super().__init__()
|
89 |
+
self.num_layers = num_layers
|
90 |
+
h = [hidden_dim] * (num_layers - 1)
|
91 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
for i, layer in enumerate(self.layers):
|
95 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
96 |
+
return x
|
97 |
+
|
98 |
+
|
99 |
+
def _get_activation_fn(activation, d_model=256, batch_dim=0):
|
100 |
+
"""Return an activation function given a string"""
|
101 |
+
if activation == "relu":
|
102 |
+
return F.relu
|
103 |
+
if activation == "gelu":
|
104 |
+
return F.gelu
|
105 |
+
if activation == "glu":
|
106 |
+
return F.glu
|
107 |
+
if activation == "prelu":
|
108 |
+
return nn.PReLU()
|
109 |
+
if activation == "selu":
|
110 |
+
return F.selu
|
111 |
+
|
112 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
113 |
+
|
114 |
+
|
115 |
+
def gen_sineembed_for_position(pos_tensor, hidden_dim):
|
116 |
+
# n_query, bs, _ = pos_tensor.size()
|
117 |
+
# sineembed_tensor = torch.zeros(n_query, bs, 256)
|
118 |
+
hidden_dim_ = hidden_dim // 2
|
119 |
+
scale = 2 * math.pi
|
120 |
+
dim_t = torch.arange(hidden_dim_, dtype=torch.float32, device=pos_tensor.device)
|
121 |
+
dim_t = 10000 ** (2 * (dim_t // 2) / hidden_dim_)
|
122 |
+
x_embed = pos_tensor[:, :, 0] * scale
|
123 |
+
y_embed = pos_tensor[:, :, 1] * scale
|
124 |
+
pos_x = x_embed[:, :, None] / dim_t
|
125 |
+
pos_y = y_embed[:, :, None] / dim_t
|
126 |
+
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
127 |
+
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
128 |
+
|
129 |
+
w_embed = pos_tensor[:, :, 2] * scale
|
130 |
+
pos_w = w_embed[:, :, None] / dim_t
|
131 |
+
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
|
132 |
+
|
133 |
+
h_embed = pos_tensor[:, :, 3] * scale
|
134 |
+
pos_h = h_embed[:, :, None] / dim_t
|
135 |
+
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
|
136 |
+
|
137 |
+
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
|
138 |
+
return pos
|
139 |
+
|
linea/models/registry.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author: Yihao Chen
|
3 |
+
# @Date: 2021-08-16 16:03:17
|
4 |
+
# @Last Modified by: Shilong Liu
|
5 |
+
# @Last Modified time: 2022-01-23 15:26
|
6 |
+
# modified from mmcv
|
7 |
+
|
8 |
+
import inspect
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
|
12 |
+
class Registry(object):
|
13 |
+
|
14 |
+
def __init__(self, name):
|
15 |
+
self._name = name
|
16 |
+
self._module_dict = dict()
|
17 |
+
|
18 |
+
def __repr__(self):
|
19 |
+
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
|
20 |
+
self._name, list(self._module_dict.keys()))
|
21 |
+
return format_str
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return len(self._module_dict)
|
25 |
+
|
26 |
+
@property
|
27 |
+
def name(self):
|
28 |
+
return self._name
|
29 |
+
|
30 |
+
@property
|
31 |
+
def module_dict(self):
|
32 |
+
return self._module_dict
|
33 |
+
|
34 |
+
def get(self, key):
|
35 |
+
return self._module_dict.get(key, None)
|
36 |
+
|
37 |
+
def registe_with_name(self, module_name=None, force=False):
|
38 |
+
return partial(self.register, module_name=module_name, force=force)
|
39 |
+
|
40 |
+
def register(self, module_build_function, module_name=None, force=False):
|
41 |
+
"""Register a module build function.
|
42 |
+
Args:
|
43 |
+
module (:obj:`nn.Module`): Module to be registered.
|
44 |
+
"""
|
45 |
+
if not inspect.isfunction(module_build_function):
|
46 |
+
raise TypeError('module_build_function must be a function, but got {}'.format(
|
47 |
+
type(module_build_function)))
|
48 |
+
if module_name is None:
|
49 |
+
module_name = module_build_function.__name__
|
50 |
+
if not force and module_name in self._module_dict:
|
51 |
+
raise KeyError('{} is already registered in {}'.format(
|
52 |
+
module_name, self.name))
|
53 |
+
self._module_dict[module_name] = module_build_function
|
54 |
+
|
55 |
+
return module_build_function
|
56 |
+
|
57 |
+
MODULE_BUILD_FUNCS = Registry('model build functions')
|
58 |
+
|
linea/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.1
|
2 |
+
torchvision>=0.15.2
|
3 |
+
scipy
|
4 |
+
calflops
|
5 |
+
transformers
|
6 |
+
tensorboardx
|
7 |
+
addict
|
8 |
+
yapf
|
9 |
+
pycocotools
|
linea/util/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
linea/util/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (151 Bytes). View file
|
|
linea/util/__pycache__/misc.cpython-311.pyc
ADDED
Binary file (15.2 kB). View file
|
|
linea/util/__pycache__/slconfig.cpython-311.pyc
ADDED
Binary file (24.6 kB). View file
|
|
linea/util/get_param_dicts.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
import re
|
6 |
+
|
7 |
+
|
8 |
+
def get_optim_params(cfg: list, model: nn.Module):
|
9 |
+
"""
|
10 |
+
E.g.:
|
11 |
+
^(?=.*a)(?=.*b).*$ means including a and b
|
12 |
+
^(?=.*(?:a|b)).*$ means including a or b
|
13 |
+
^(?=.*a)(?!.*b).*$ means including a, but not b
|
14 |
+
"""
|
15 |
+
|
16 |
+
param_groups = []
|
17 |
+
visited = []
|
18 |
+
for pg in cfg:
|
19 |
+
pattern = pg['params']
|
20 |
+
params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0}
|
21 |
+
pg['params'] = params.values()
|
22 |
+
param_groups.append(pg)
|
23 |
+
visited.extend(list(params.keys()))
|
24 |
+
|
25 |
+
names = [k for k, v in model.named_parameters() if v.requires_grad]
|
26 |
+
|
27 |
+
if len(visited) < len(names):
|
28 |
+
unseen = set(names) - set(visited)
|
29 |
+
params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
|
30 |
+
param_groups.append({'params': params.values()})
|
31 |
+
visited.extend(list(params.keys()))
|
32 |
+
|
33 |
+
assert len(visited) == len(names), ''
|
34 |
+
|
35 |
+
return param_groups
|
linea/util/misc.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Misc functions, including distributed helpers.
|
4 |
+
|
5 |
+
Mostly copy-paste from torchvision references.
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
from collections import defaultdict, deque
|
10 |
+
import datetime
|
11 |
+
from typing import Optional, List
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
from torch import Tensor
|
16 |
+
|
17 |
+
|
18 |
+
class SmoothedValue(object):
|
19 |
+
"""Track a series of values and provide access to smoothed values over a
|
20 |
+
window or the global series average.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, window_size=20, fmt=None):
|
24 |
+
if fmt is None:
|
25 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
26 |
+
self.deque = deque(maxlen=window_size)
|
27 |
+
self.total = 0.0
|
28 |
+
self.count = 0
|
29 |
+
self.fmt = fmt
|
30 |
+
|
31 |
+
def update(self, value, n=1):
|
32 |
+
self.deque.append(value)
|
33 |
+
self.count += n
|
34 |
+
self.total += value * n
|
35 |
+
|
36 |
+
def synchronize_between_processes(self):
|
37 |
+
"""
|
38 |
+
Warning: does not synchronize the deque!
|
39 |
+
"""
|
40 |
+
if not is_dist_avail_and_initialized():
|
41 |
+
return
|
42 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
43 |
+
dist.barrier()
|
44 |
+
dist.all_reduce(t)
|
45 |
+
t = t.tolist()
|
46 |
+
self.count = int(t[0])
|
47 |
+
self.total = t[1]
|
48 |
+
|
49 |
+
@property
|
50 |
+
def median(self):
|
51 |
+
d = torch.tensor(list(self.deque))
|
52 |
+
if d.shape[0] == 0:
|
53 |
+
return 0
|
54 |
+
return d.median().item()
|
55 |
+
|
56 |
+
@property
|
57 |
+
def avg(self):
|
58 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
59 |
+
return d.mean().item()
|
60 |
+
|
61 |
+
@property
|
62 |
+
def global_avg(self):
|
63 |
+
return self.total / self.count
|
64 |
+
|
65 |
+
@property
|
66 |
+
def max(self):
|
67 |
+
return max(self.deque)
|
68 |
+
|
69 |
+
@property
|
70 |
+
def value(self):
|
71 |
+
return self.deque[-1]
|
72 |
+
|
73 |
+
def __str__(self):
|
74 |
+
return self.fmt.format(
|
75 |
+
median=self.median,
|
76 |
+
avg=self.avg,
|
77 |
+
global_avg=self.global_avg,
|
78 |
+
max=self.max,
|
79 |
+
value=self.value)
|
80 |
+
|
81 |
+
|
82 |
+
def reduce_dict(input_dict, average=True):
|
83 |
+
"""
|
84 |
+
Args:
|
85 |
+
input_dict (dict): all the values will be reduced
|
86 |
+
average (bool): whether to do average or sum
|
87 |
+
Reduce the values in the dictionary from all processes so that all processes
|
88 |
+
have the averaged results. Returns a dict with the same fields as
|
89 |
+
input_dict, after reduction.
|
90 |
+
"""
|
91 |
+
world_size = get_world_size()
|
92 |
+
if world_size < 2:
|
93 |
+
return input_dict
|
94 |
+
with torch.no_grad():
|
95 |
+
names = []
|
96 |
+
values = []
|
97 |
+
# sort the keys so that they are consistent across processes
|
98 |
+
for k in sorted(input_dict.keys()):
|
99 |
+
names.append(k)
|
100 |
+
values.append(input_dict[k])
|
101 |
+
values = torch.stack(values, dim=0)
|
102 |
+
dist.all_reduce(values)
|
103 |
+
if average:
|
104 |
+
values /= world_size
|
105 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
106 |
+
return reduced_dict
|
107 |
+
|
108 |
+
|
109 |
+
class MetricLogger(object):
|
110 |
+
def __init__(self, delimiter="\t"):
|
111 |
+
self.meters = defaultdict(SmoothedValue)
|
112 |
+
self.delimiter = delimiter
|
113 |
+
|
114 |
+
def update(self, **kwargs):
|
115 |
+
for k, v in kwargs.items():
|
116 |
+
if isinstance(v, torch.Tensor):
|
117 |
+
v = v.item()
|
118 |
+
assert isinstance(v, (float, int))
|
119 |
+
self.meters[k].update(v)
|
120 |
+
|
121 |
+
def __getattr__(self, attr):
|
122 |
+
if attr in self.meters:
|
123 |
+
return self.meters[attr]
|
124 |
+
if attr in self.__dict__:
|
125 |
+
return self.__dict__[attr]
|
126 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
127 |
+
type(self).__name__, attr))
|
128 |
+
|
129 |
+
def __str__(self):
|
130 |
+
loss_str = []
|
131 |
+
for name, meter in self.meters.items():
|
132 |
+
# print(name, str(meter))
|
133 |
+
# import ipdb;ipdb.set_trace()
|
134 |
+
if meter.count > 0:
|
135 |
+
loss_str.append(
|
136 |
+
"{}: {}".format(name, str(meter))
|
137 |
+
)
|
138 |
+
return self.delimiter.join(loss_str)
|
139 |
+
|
140 |
+
def synchronize_between_processes(self):
|
141 |
+
for meter in self.meters.values():
|
142 |
+
meter.synchronize_between_processes()
|
143 |
+
|
144 |
+
def add_meter(self, name, meter):
|
145 |
+
self.meters[name] = meter
|
146 |
+
|
147 |
+
def log_every(self, iterable, print_freq, header=None, logger=None):
|
148 |
+
if logger is None:
|
149 |
+
print_func = print
|
150 |
+
else:
|
151 |
+
print_func = logger.info
|
152 |
+
|
153 |
+
i = 0
|
154 |
+
if not header:
|
155 |
+
header = ''
|
156 |
+
start_time = time.time()
|
157 |
+
end = time.time()
|
158 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
159 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
160 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
161 |
+
if torch.cuda.is_available():
|
162 |
+
log_msg = self.delimiter.join([
|
163 |
+
header,
|
164 |
+
'[{0' + space_fmt + '}/{1}]',
|
165 |
+
'eta: {eta}',
|
166 |
+
'{meters}',
|
167 |
+
'time: {time}',
|
168 |
+
'data: {data}',
|
169 |
+
'max mem: {memory:.0f}'
|
170 |
+
])
|
171 |
+
else:
|
172 |
+
log_msg = self.delimiter.join([
|
173 |
+
header,
|
174 |
+
'[{0' + space_fmt + '}/{1}]',
|
175 |
+
'eta: {eta}',
|
176 |
+
'{meters}',
|
177 |
+
'time: {time}',
|
178 |
+
'data: {data}'
|
179 |
+
])
|
180 |
+
MB = 1024.0 * 1024.0
|
181 |
+
for obj in iterable:
|
182 |
+
data_time.update(time.time() - end)
|
183 |
+
yield obj
|
184 |
+
|
185 |
+
iter_time.update(time.time() - end)
|
186 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
187 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
188 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
189 |
+
if torch.cuda.is_available():
|
190 |
+
print_func(log_msg.format(
|
191 |
+
i, len(iterable), eta=eta_string,
|
192 |
+
meters=str(self),
|
193 |
+
time=str(iter_time), data=str(data_time),
|
194 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
195 |
+
else:
|
196 |
+
print_func(log_msg.format(
|
197 |
+
i, len(iterable), eta=eta_string,
|
198 |
+
meters=str(self),
|
199 |
+
time=str(iter_time), data=str(data_time)))
|
200 |
+
i += 1
|
201 |
+
end = time.time()
|
202 |
+
total_time = time.time() - start_time
|
203 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
204 |
+
print_func('{} Total time: {} ({:.4f} s / it)'.format(
|
205 |
+
header, total_time_str, total_time / len(iterable)))
|
206 |
+
|
207 |
+
|
208 |
+
def setup_for_distributed(is_master):
|
209 |
+
"""
|
210 |
+
This function disables printing when not in master process
|
211 |
+
"""
|
212 |
+
import builtins as __builtin__
|
213 |
+
builtin_print = __builtin__.print
|
214 |
+
|
215 |
+
def print(*args, **kwargs):
|
216 |
+
force = kwargs.pop('force', False)
|
217 |
+
if is_master or force:
|
218 |
+
builtin_print(*args, **kwargs)
|
219 |
+
|
220 |
+
__builtin__.print = print
|
221 |
+
|
222 |
+
|
223 |
+
def is_dist_avail_and_initialized():
|
224 |
+
if not dist.is_available():
|
225 |
+
return False
|
226 |
+
if not dist.is_initialized():
|
227 |
+
return False
|
228 |
+
return True
|
229 |
+
|
230 |
+
|
231 |
+
def get_world_size():
|
232 |
+
if not is_dist_avail_and_initialized():
|
233 |
+
return 1
|
234 |
+
return dist.get_world_size()
|
235 |
+
|
236 |
+
|
237 |
+
def get_rank():
|
238 |
+
if not is_dist_avail_and_initialized():
|
239 |
+
return 0
|
240 |
+
return dist.get_rank()
|
241 |
+
|
242 |
+
|
243 |
+
def is_main_process():
|
244 |
+
return get_rank() == 0
|
245 |
+
|
246 |
+
|
247 |
+
def save_on_master(*args, **kwargs):
|
248 |
+
if is_main_process():
|
249 |
+
torch.save(*args, **kwargs)
|
250 |
+
|
251 |
+
|
252 |
+
def init_distributed_mode(args):
|
253 |
+
try:
|
254 |
+
# https://pytorch.org/docs/stable/elastic/run.html
|
255 |
+
RANK = int(os.getenv('RANK', -1))
|
256 |
+
args.gpu = LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))
|
257 |
+
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
258 |
+
|
259 |
+
torch.distributed.init_process_group(init_method='env://')
|
260 |
+
torch.distributed.barrier()
|
261 |
+
|
262 |
+
rank = torch.distributed.get_rank()
|
263 |
+
torch.cuda.set_device(rank)
|
264 |
+
torch.cuda.empty_cache()
|
265 |
+
args.distributed = True
|
266 |
+
setup_for_distributed(get_rank() == 0)
|
267 |
+
print('Initialized distributed mode...')
|
268 |
+
except:
|
269 |
+
print('Not using distributed mode')
|
270 |
+
args.distributed = False
|
271 |
+
args.world_size = 1
|
272 |
+
args.rank = 0
|
273 |
+
args.local_rank = 0
|
274 |
+
return
|
275 |
+
|
linea/util/profiler.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from calflops import calculate_flops
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
def stats(
|
6 |
+
model, args,
|
7 |
+
input_shape: Tuple=(1, 3, 640, 640), ) -> Tuple[int, dict]:
|
8 |
+
|
9 |
+
base_size = args.eval_spatial_size[0]
|
10 |
+
input_shape = (1, 3, base_size, base_size)
|
11 |
+
|
12 |
+
model_for_info = copy.deepcopy(model).deploy()
|
13 |
+
|
14 |
+
flops, macs, _ = calculate_flops(model=model_for_info,
|
15 |
+
input_shape=input_shape,
|
16 |
+
output_as_string=True,
|
17 |
+
output_precision=4,
|
18 |
+
print_detailed=False)
|
19 |
+
params = sum(p.numel() for p in model_for_info.parameters())
|
20 |
+
del model_for_info
|
21 |
+
return {'flops': flops, 'macs': macs, 'params': params}
|
linea/util/slconfig.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ==========================================================
|
2 |
+
# Modified from mmcv
|
3 |
+
# ==========================================================
|
4 |
+
import os, sys
|
5 |
+
import os.path as osp
|
6 |
+
import ast
|
7 |
+
import tempfile
|
8 |
+
import shutil
|
9 |
+
from importlib import import_module
|
10 |
+
|
11 |
+
from argparse import Action
|
12 |
+
|
13 |
+
from addict import Dict
|
14 |
+
from yapf.yapflib.yapf_api import FormatCode
|
15 |
+
|
16 |
+
import platform
|
17 |
+
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans
|
18 |
+
|
19 |
+
BASE_KEY = '_base_'
|
20 |
+
DELETE_KEY = '_delete_'
|
21 |
+
RESERVED_KEYS = ['filename', 'text', 'pretty_text', 'get', 'dump', 'merge_from_dict']
|
22 |
+
|
23 |
+
|
24 |
+
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
|
25 |
+
if not osp.isfile(filename):
|
26 |
+
raise FileNotFoundError(msg_tmpl.format(filename))
|
27 |
+
|
28 |
+
class ConfigDict(Dict):
|
29 |
+
|
30 |
+
def __missing__(self, name):
|
31 |
+
raise KeyError(name)
|
32 |
+
|
33 |
+
def __getattr__(self, name):
|
34 |
+
try:
|
35 |
+
value = super(ConfigDict, self).__getattr__(name)
|
36 |
+
except KeyError:
|
37 |
+
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
|
38 |
+
f"attribute '{name}'")
|
39 |
+
except Exception as e:
|
40 |
+
ex = e
|
41 |
+
else:
|
42 |
+
return value
|
43 |
+
raise ex
|
44 |
+
|
45 |
+
|
46 |
+
class SLConfig(object):
|
47 |
+
"""
|
48 |
+
config files.
|
49 |
+
only support .py file as config now.
|
50 |
+
|
51 |
+
ref: mmcv.utils.config
|
52 |
+
|
53 |
+
Example:
|
54 |
+
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
55 |
+
>>> cfg.a
|
56 |
+
1
|
57 |
+
>>> cfg.b
|
58 |
+
{'b1': [0, 1]}
|
59 |
+
>>> cfg.b.b1
|
60 |
+
[0, 1]
|
61 |
+
>>> cfg = Config.fromfile('tests/data/config/a.py')
|
62 |
+
>>> cfg.filename
|
63 |
+
"/home/kchen/projects/mmcv/tests/data/config/a.py"
|
64 |
+
>>> cfg.item4
|
65 |
+
'test'
|
66 |
+
>>> cfg
|
67 |
+
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
|
68 |
+
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
|
69 |
+
"""
|
70 |
+
@staticmethod
|
71 |
+
def _validate_py_syntax(filename):
|
72 |
+
with open(filename) as f:
|
73 |
+
content = f.read()
|
74 |
+
try:
|
75 |
+
ast.parse(content)
|
76 |
+
except SyntaxError:
|
77 |
+
raise SyntaxError('There are syntax errors in config '
|
78 |
+
f'file {filename}')
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def _file2dict(filename):
|
82 |
+
filename = osp.abspath(osp.expanduser(filename))
|
83 |
+
check_file_exist(filename)
|
84 |
+
if filename.lower().endswith('.py'):
|
85 |
+
with tempfile.TemporaryDirectory() as temp_config_dir:
|
86 |
+
temp_config_file = tempfile.NamedTemporaryFile(
|
87 |
+
dir=temp_config_dir, suffix='.py')
|
88 |
+
temp_config_name = osp.basename(temp_config_file.name)
|
89 |
+
if WINDOWS:
|
90 |
+
temp_config_file.close()
|
91 |
+
shutil.copyfile(filename,
|
92 |
+
osp.join(temp_config_dir, temp_config_name))
|
93 |
+
temp_module_name = osp.splitext(temp_config_name)[0]
|
94 |
+
sys.path.insert(0, temp_config_dir)
|
95 |
+
SLConfig._validate_py_syntax(filename)
|
96 |
+
mod = import_module(temp_module_name)
|
97 |
+
sys.path.pop(0)
|
98 |
+
cfg_dict = {
|
99 |
+
name: value
|
100 |
+
for name, value in mod.__dict__.items()
|
101 |
+
if not name.startswith('__')
|
102 |
+
}
|
103 |
+
# delete imported module
|
104 |
+
del sys.modules[temp_module_name]
|
105 |
+
# close temp file
|
106 |
+
temp_config_file.close()
|
107 |
+
elif filename.lower().endswith(('.yml', '.yaml', '.json')):
|
108 |
+
from .slio import slload
|
109 |
+
cfg_dict = slload(filename)
|
110 |
+
else:
|
111 |
+
raise IOError('Only py/yml/yaml/json type are supported now!')
|
112 |
+
|
113 |
+
cfg_text = filename + '\n'
|
114 |
+
with open(filename, 'r') as f:
|
115 |
+
cfg_text += f.read()
|
116 |
+
|
117 |
+
# parse the base file
|
118 |
+
if BASE_KEY in cfg_dict:
|
119 |
+
cfg_dir = osp.dirname(filename)
|
120 |
+
base_filename = cfg_dict.pop(BASE_KEY)
|
121 |
+
base_filename = base_filename if isinstance(
|
122 |
+
base_filename, list) else [base_filename]
|
123 |
+
|
124 |
+
cfg_dict_list = list()
|
125 |
+
cfg_text_list = list()
|
126 |
+
for f in base_filename:
|
127 |
+
_cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f))
|
128 |
+
cfg_dict_list.append(_cfg_dict)
|
129 |
+
cfg_text_list.append(_cfg_text)
|
130 |
+
|
131 |
+
base_cfg_dict = dict()
|
132 |
+
for c in cfg_dict_list:
|
133 |
+
if len(base_cfg_dict.keys() & c.keys()) > 0:
|
134 |
+
raise KeyError('Duplicate key is not allowed among bases')
|
135 |
+
# TODO Allow the duplicate key while warnning user
|
136 |
+
base_cfg_dict.update(c)
|
137 |
+
|
138 |
+
base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict)
|
139 |
+
cfg_dict = base_cfg_dict
|
140 |
+
|
141 |
+
# merge cfg_text
|
142 |
+
cfg_text_list.append(cfg_text)
|
143 |
+
cfg_text = '\n'.join(cfg_text_list)
|
144 |
+
|
145 |
+
return cfg_dict, cfg_text
|
146 |
+
|
147 |
+
@staticmethod
|
148 |
+
def _merge_a_into_b(a, b):
|
149 |
+
"""merge dict `a` into dict `b` (non-inplace).
|
150 |
+
values in `a` will overwrite `b`.
|
151 |
+
copy first to avoid inplace modification
|
152 |
+
|
153 |
+
Args:
|
154 |
+
a ([type]): [description]
|
155 |
+
b ([type]): [description]
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
[dict]: [description]
|
159 |
+
"""
|
160 |
+
|
161 |
+
if not isinstance(a, dict):
|
162 |
+
return a
|
163 |
+
|
164 |
+
b = b.copy()
|
165 |
+
for k, v in a.items():
|
166 |
+
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
|
167 |
+
|
168 |
+
if not isinstance(b[k], dict) and not isinstance(b[k], list):
|
169 |
+
# if :
|
170 |
+
|
171 |
+
raise TypeError(
|
172 |
+
f'{k}={v} in child config cannot inherit from base '
|
173 |
+
f'because {k} is a dict in the child config but is of '
|
174 |
+
f'type {type(b[k])} in base config. You may set '
|
175 |
+
f'`{DELETE_KEY}=True` to ignore the base config')
|
176 |
+
b[k] = SLConfig._merge_a_into_b(v, b[k])
|
177 |
+
elif isinstance(b, list):
|
178 |
+
try:
|
179 |
+
_ = int(k)
|
180 |
+
except:
|
181 |
+
raise TypeError(
|
182 |
+
f'b is a list, '
|
183 |
+
f'index {k} should be an int when input but {type(k)}'
|
184 |
+
)
|
185 |
+
b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)])
|
186 |
+
else:
|
187 |
+
b[k] = v
|
188 |
+
|
189 |
+
return b
|
190 |
+
|
191 |
+
@staticmethod
|
192 |
+
def fromfile(filename):
|
193 |
+
cfg_dict, cfg_text = SLConfig._file2dict(filename)
|
194 |
+
return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename)
|
195 |
+
|
196 |
+
|
197 |
+
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
|
198 |
+
if cfg_dict is None:
|
199 |
+
cfg_dict = dict()
|
200 |
+
elif not isinstance(cfg_dict, dict):
|
201 |
+
raise TypeError('cfg_dict must be a dict, but '
|
202 |
+
f'got {type(cfg_dict)}')
|
203 |
+
for key in cfg_dict:
|
204 |
+
if key in RESERVED_KEYS:
|
205 |
+
raise KeyError(f'{key} is reserved for config file')
|
206 |
+
|
207 |
+
super(SLConfig, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
208 |
+
super(SLConfig, self).__setattr__('_filename', filename)
|
209 |
+
if cfg_text:
|
210 |
+
text = cfg_text
|
211 |
+
elif filename:
|
212 |
+
with open(filename, 'r') as f:
|
213 |
+
text = f.read()
|
214 |
+
else:
|
215 |
+
text = ''
|
216 |
+
super(SLConfig, self).__setattr__('_text', text)
|
217 |
+
|
218 |
+
|
219 |
+
@property
|
220 |
+
def filename(self):
|
221 |
+
return self._filename
|
222 |
+
|
223 |
+
@property
|
224 |
+
def text(self):
|
225 |
+
return self._text
|
226 |
+
|
227 |
+
@property
|
228 |
+
def pretty_text(self):
|
229 |
+
|
230 |
+
indent = 4
|
231 |
+
|
232 |
+
def _indent(s_, num_spaces):
|
233 |
+
s = s_.split('\n')
|
234 |
+
if len(s) == 1:
|
235 |
+
return s_
|
236 |
+
first = s.pop(0)
|
237 |
+
s = [(num_spaces * ' ') + line for line in s]
|
238 |
+
s = '\n'.join(s)
|
239 |
+
s = first + '\n' + s
|
240 |
+
return s
|
241 |
+
|
242 |
+
def _format_basic_types(k, v, use_mapping=False):
|
243 |
+
if isinstance(v, str):
|
244 |
+
v_str = f"'{v}'"
|
245 |
+
else:
|
246 |
+
v_str = str(v)
|
247 |
+
|
248 |
+
if use_mapping:
|
249 |
+
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
250 |
+
attr_str = f'{k_str}: {v_str}'
|
251 |
+
else:
|
252 |
+
attr_str = f'{str(k)}={v_str}'
|
253 |
+
attr_str = _indent(attr_str, indent)
|
254 |
+
|
255 |
+
return attr_str
|
256 |
+
|
257 |
+
def _format_list(k, v, use_mapping=False):
|
258 |
+
# check if all items in the list are dict
|
259 |
+
if all(isinstance(_, dict) for _ in v):
|
260 |
+
v_str = '[\n'
|
261 |
+
v_str += '\n'.join(
|
262 |
+
f'dict({_indent(_format_dict(v_), indent)}),'
|
263 |
+
for v_ in v).rstrip(',')
|
264 |
+
if use_mapping:
|
265 |
+
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
266 |
+
attr_str = f'{k_str}: {v_str}'
|
267 |
+
else:
|
268 |
+
attr_str = f'{str(k)}={v_str}'
|
269 |
+
attr_str = _indent(attr_str, indent) + ']'
|
270 |
+
else:
|
271 |
+
attr_str = _format_basic_types(k, v, use_mapping)
|
272 |
+
return attr_str
|
273 |
+
|
274 |
+
def _contain_invalid_identifier(dict_str):
|
275 |
+
contain_invalid_identifier = False
|
276 |
+
for key_name in dict_str:
|
277 |
+
contain_invalid_identifier |= \
|
278 |
+
(not str(key_name).isidentifier())
|
279 |
+
return contain_invalid_identifier
|
280 |
+
|
281 |
+
def _format_dict(input_dict, outest_level=False):
|
282 |
+
r = ''
|
283 |
+
s = []
|
284 |
+
|
285 |
+
use_mapping = _contain_invalid_identifier(input_dict)
|
286 |
+
if use_mapping:
|
287 |
+
r += '{'
|
288 |
+
for idx, (k, v) in enumerate(input_dict.items()):
|
289 |
+
is_last = idx >= len(input_dict) - 1
|
290 |
+
end = '' if outest_level or is_last else ','
|
291 |
+
if isinstance(v, dict):
|
292 |
+
v_str = '\n' + _format_dict(v)
|
293 |
+
if use_mapping:
|
294 |
+
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
295 |
+
attr_str = f'{k_str}: dict({v_str}'
|
296 |
+
else:
|
297 |
+
attr_str = f'{str(k)}=dict({v_str}'
|
298 |
+
attr_str = _indent(attr_str, indent) + ')' + end
|
299 |
+
elif isinstance(v, list):
|
300 |
+
attr_str = _format_list(k, v, use_mapping) + end
|
301 |
+
else:
|
302 |
+
attr_str = _format_basic_types(k, v, use_mapping) + end
|
303 |
+
|
304 |
+
s.append(attr_str)
|
305 |
+
r += '\n'.join(s)
|
306 |
+
if use_mapping:
|
307 |
+
r += '}'
|
308 |
+
return r
|
309 |
+
|
310 |
+
cfg_dict = self._cfg_dict.to_dict()
|
311 |
+
text = _format_dict(cfg_dict, outest_level=True)
|
312 |
+
# copied from setup.cfg
|
313 |
+
yapf_style = dict(
|
314 |
+
based_on_style='pep8',
|
315 |
+
blank_line_before_nested_class_or_def=True,
|
316 |
+
split_before_expression_after_opening_paren=True)
|
317 |
+
text, _ = FormatCode(text, style_config=yapf_style)#, verify=True)
|
318 |
+
|
319 |
+
return text
|
320 |
+
|
321 |
+
|
322 |
+
def __repr__(self):
|
323 |
+
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
|
324 |
+
|
325 |
+
def __len__(self):
|
326 |
+
return len(self._cfg_dict)
|
327 |
+
|
328 |
+
def __getattr__(self, name):
|
329 |
+
# # debug
|
330 |
+
# print('+'*15)
|
331 |
+
# print('name=%s' % name)
|
332 |
+
# print("addr:", id(self))
|
333 |
+
# # print('type(self):', type(self))
|
334 |
+
# print(self.__dict__)
|
335 |
+
# print('+'*15)
|
336 |
+
# if self.__dict__ == {}:
|
337 |
+
# raise ValueError
|
338 |
+
|
339 |
+
return getattr(self._cfg_dict, name)
|
340 |
+
|
341 |
+
def __getitem__(self, name):
|
342 |
+
return self._cfg_dict.__getitem__(name)
|
343 |
+
|
344 |
+
def __setattr__(self, name, value):
|
345 |
+
if isinstance(value, dict):
|
346 |
+
value = ConfigDict(value)
|
347 |
+
self._cfg_dict.__setattr__(name, value)
|
348 |
+
|
349 |
+
def __setitem__(self, name, value):
|
350 |
+
if isinstance(value, dict):
|
351 |
+
value = ConfigDict(value)
|
352 |
+
self._cfg_dict.__setitem__(name, value)
|
353 |
+
|
354 |
+
def __iter__(self):
|
355 |
+
return iter(self._cfg_dict)
|
356 |
+
|
357 |
+
def dump(self, file=None):
|
358 |
+
|
359 |
+
if file is None:
|
360 |
+
return self.pretty_text
|
361 |
+
else:
|
362 |
+
with open(file, 'w') as f:
|
363 |
+
f.write(self.pretty_text)
|
364 |
+
|
365 |
+
def merge_from_dict(self, options):
|
366 |
+
"""Merge list into cfg_dict
|
367 |
+
|
368 |
+
Merge the dict parsed by MultipleKVAction into this cfg.
|
369 |
+
|
370 |
+
Examples:
|
371 |
+
>>> options = {'model.backbone.depth': 50,
|
372 |
+
... 'model.backbone.with_cp':True}
|
373 |
+
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
|
374 |
+
>>> cfg.merge_from_dict(options)
|
375 |
+
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
376 |
+
>>> assert cfg_dict == dict(
|
377 |
+
... model=dict(backbone=dict(depth=50, with_cp=True)))
|
378 |
+
|
379 |
+
Args:
|
380 |
+
options (dict): dict of configs to merge from.
|
381 |
+
"""
|
382 |
+
option_cfg_dict = {}
|
383 |
+
for full_key, v in options.items():
|
384 |
+
d = option_cfg_dict
|
385 |
+
key_list = full_key.split('.')
|
386 |
+
for subkey in key_list[:-1]:
|
387 |
+
d.setdefault(subkey, ConfigDict())
|
388 |
+
d = d[subkey]
|
389 |
+
subkey = key_list[-1]
|
390 |
+
d[subkey] = v
|
391 |
+
|
392 |
+
cfg_dict = super(SLConfig, self).__getattribute__('_cfg_dict')
|
393 |
+
super(SLConfig, self).__setattr__(
|
394 |
+
'_cfg_dict', SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict))
|
395 |
+
|
396 |
+
# for multiprocess
|
397 |
+
def __setstate__(self, state):
|
398 |
+
self.__init__(state)
|
399 |
+
|
400 |
+
|
401 |
+
def copy(self):
|
402 |
+
return SLConfig(self._cfg_dict.copy())
|
403 |
+
|
404 |
+
def deepcopy(self):
|
405 |
+
return SLConfig(self._cfg_dict.deepcopy())
|
406 |
+
|
407 |
+
|
408 |
+
class DictAction(Action):
|
409 |
+
"""
|
410 |
+
argparse action to split an argument into KEY=VALUE form
|
411 |
+
on the first = and append to a dictionary. List options should
|
412 |
+
be passed as comma separated values, i.e KEY=V1,V2,V3
|
413 |
+
"""
|
414 |
+
|
415 |
+
@staticmethod
|
416 |
+
def _parse_int_float_bool(val):
|
417 |
+
try:
|
418 |
+
return int(val)
|
419 |
+
except ValueError:
|
420 |
+
pass
|
421 |
+
try:
|
422 |
+
return float(val)
|
423 |
+
except ValueError:
|
424 |
+
pass
|
425 |
+
if val.lower() in ['true', 'false']:
|
426 |
+
return True if val.lower() == 'true' else False
|
427 |
+
if val.lower() in ['none', 'null']:
|
428 |
+
return None
|
429 |
+
return val
|
430 |
+
|
431 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
432 |
+
options = {}
|
433 |
+
for kv in values:
|
434 |
+
key, val = kv.split('=', maxsplit=1)
|
435 |
+
val = [self._parse_int_float_bool(v) for v in val.split(',')]
|
436 |
+
if len(val) == 1:
|
437 |
+
val = val[0]
|
438 |
+
options[key] = val
|
439 |
+
setattr(namespace, self.dest, options)
|
440 |
+
|