SebasJanampa commited on
Commit
51b2bf9
·
verified ·
1 Parent(s): 67b2aee

Upload 45 files

Browse files
Files changed (45) hide show
  1. linea/configs/linea/include/dataset.py +10 -0
  2. linea/configs/linea/include/linea.py +62 -0
  3. linea/configs/linea/include/optimizer.py +9 -0
  4. linea/configs/linea/linea_hgnetv2_l.py +56 -0
  5. linea/configs/linea/linea_hgnetv2_m.py +63 -0
  6. linea/configs/linea/linea_hgnetv2_n.py +63 -0
  7. linea/configs/linea/linea_hgnetv2_s.py +64 -0
  8. linea/models/__init__.py +8 -0
  9. linea/models/__pycache__/__init__.cpython-311.pyc +0 -0
  10. linea/models/__pycache__/registry.cpython-311.pyc +0 -0
  11. linea/models/linea/__init__.py +11 -0
  12. linea/models/linea/__pycache__/__init__.cpython-311.pyc +0 -0
  13. linea/models/linea/__pycache__/attention_mechanism.cpython-311.pyc +0 -0
  14. linea/models/linea/__pycache__/criterion.cpython-311.pyc +0 -0
  15. linea/models/linea/__pycache__/decoder.cpython-311.pyc +0 -0
  16. linea/models/linea/__pycache__/dn_components.cpython-311.pyc +0 -0
  17. linea/models/linea/__pycache__/hgnetv2.cpython-311.pyc +0 -0
  18. linea/models/linea/__pycache__/hybrid_encoder_asymmetric_conv.cpython-311.pyc +0 -0
  19. linea/models/linea/__pycache__/linea.cpython-311.pyc +0 -0
  20. linea/models/linea/__pycache__/linea_utils.cpython-311.pyc +0 -0
  21. linea/models/linea/__pycache__/matcher.cpython-311.pyc +0 -0
  22. linea/models/linea/__pycache__/utils.cpython-311.pyc +0 -0
  23. linea/models/linea/attention_mechanism.py +593 -0
  24. linea/models/linea/criterion.py +517 -0
  25. linea/models/linea/decoder.py +551 -0
  26. linea/models/linea/dn_components.py +178 -0
  27. linea/models/linea/hgnetv2.py +595 -0
  28. linea/models/linea/hybrid_encoder.py +471 -0
  29. linea/models/linea/hybrid_encoder_asymmetric_conv.py +549 -0
  30. linea/models/linea/linea.py +156 -0
  31. linea/models/linea/linea_utils.py +165 -0
  32. linea/models/linea/matcher.py +180 -0
  33. linea/models/linea/new_dn_components.py +163 -0
  34. linea/models/linea/position_encoding.py +150 -0
  35. linea/models/linea/utils.py +139 -0
  36. linea/models/registry.py +58 -0
  37. linea/requirements.txt +9 -0
  38. linea/util/__init__.py +1 -0
  39. linea/util/__pycache__/__init__.cpython-311.pyc +0 -0
  40. linea/util/__pycache__/misc.cpython-311.pyc +0 -0
  41. linea/util/__pycache__/slconfig.cpython-311.pyc +0 -0
  42. linea/util/get_param_dicts.py +35 -0
  43. linea/util/misc.py +275 -0
  44. linea/util/profiler.py +21 -0
  45. linea/util/slconfig.py +440 -0
linea/configs/linea/include/dataset.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ data_aug_scales = [(640, 640)]
2
+ data_aug_max_size = 1333
3
+ data_aug_scales2_resize = [400, 500, 600]
4
+ data_aug_scales2_crop = [384, 600]
5
+
6
+
7
+ data_aug_scale_overlap = None
8
+ batch_size_train = 8
9
+ batch_size_val = 64
10
+
linea/configs/linea/include/linea.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model
2
+ modelname = 'LINEA'
3
+ eval_spatial_size = (640, 640)
4
+ eval_idx = 5 # 6 decoder layers
5
+ num_classes = 2
6
+
7
+ ## backbone
8
+ pretrained = True
9
+ use_checkpoint = False
10
+ return_interm_indices = [1, 2, 3]
11
+ freeze_norm = True
12
+ freeze_stem_only = True
13
+
14
+ ## encoder
15
+ hybrid_encoder = 'hybrid_encoder_asymmetric_conv'
16
+ in_channels_encoder = [512, 1024, 2048]
17
+ pe_temperatureH = 20
18
+ pe_temperatureW = 20
19
+
20
+ ## encoder
21
+ transformer_activation = 'relu'
22
+ batch_norm_type = 'FrozenBatchNorm2d'
23
+ masks = False
24
+ aux_loss = True
25
+
26
+ ## decoder
27
+ num_queries = 1100
28
+ query_dim = 4
29
+ num_feature_levels = 3
30
+ dec_n_points = [4, 1, 1]
31
+ dropout = 0.0
32
+ pre_norm = False
33
+
34
+ # denoise
35
+ use_dn = True
36
+ dn_number = 300
37
+ dn_line_noise_scale = 1.0
38
+ dn_label_noise_ratio = 0.5
39
+ embed_init_tgt = True
40
+ dn_labelbook_size = 2
41
+ match_unstable_error = True
42
+
43
+ # matcher
44
+ set_cost_class = 2.0
45
+ set_cost_lines = 5.0
46
+
47
+ # criterion
48
+ criterionname = 'LINEACRITERION'
49
+ criterion_type = 'default'
50
+ weight_dict = {'loss_logits': 1, 'loss_line': 5}
51
+ losses = ['labels', 'lines']
52
+ focal_alpha = 0.1
53
+
54
+ matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher
55
+ nms_iou_threshold = -1
56
+
57
+ # for ema
58
+ use_ema = False
59
+ ema_decay = 0.9997
60
+ ema_epoch = 0
61
+
62
+
linea/configs/linea/include/optimizer.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ lr = 0.00025
2
+ weight_decay = 0.000125
3
+ betas = [0.9, 0.999]
4
+
5
+ epochs = 12
6
+ lr_drop_list = [11]
7
+ clip_max_norm = 0.1
8
+
9
+ save_checkpoint_interval = 1
linea/configs/linea/linea_hgnetv2_l.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ './include/dataset.py',
3
+ './include/optimizer.py',
4
+ './include/linea.py'
5
+ ]
6
+
7
+ output_dir = 'output/line_hgnetv2_l'
8
+
9
+ # backbone
10
+ backbone = 'HGNetv2_B4'
11
+ param_dict_type = backbone.lower()
12
+ use_lab = False
13
+
14
+
15
+ # transformer
16
+ feat_strides = [8, 16, 32]
17
+ hidden_dim = 256
18
+ dim_feedforward = 1024
19
+ nheads = 8
20
+ use_lmap = False
21
+
22
+ ## encoder
23
+ hybrid_encoder = 'hybrid_encoder_asymmetric_conv'
24
+ in_channels_encoder = [512, 1024, 2048]
25
+ pe_temperatureH = 20
26
+ pe_temperatureW = 20
27
+ expansion = 0.5
28
+ depth_mult = 1.0
29
+
30
+ ## decoder
31
+ feat_channels_decoder = [256, 256, 256]
32
+ dec_layers = 6
33
+ num_queries = 1100
34
+ num_select = 300
35
+ reg_max = 16
36
+ reg_scale = 4
37
+
38
+ # criterion
39
+ weight_dict = {'loss_logits': 4, 'loss_line': 5}
40
+ use_warmup = False
41
+
42
+ # optimizer params
43
+ model_parameters = [
44
+ {
45
+ 'params': '^(?=.*backbone)(?!.*norm|bn).*$',
46
+ 'lr': 0.0000125
47
+ },
48
+
49
+ {
50
+ 'params': '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn)).*$',
51
+ 'weight_decay': 0.
52
+ }
53
+ ]
54
+ lr = 0.00025
55
+ betas = [0.9, 0.999]
56
+ weight_decay = 0.000125
linea/configs/linea/linea_hgnetv2_m.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ './include/dataset.py',
3
+ './include/optimizer.py',
4
+ './include/linea.py'
5
+ ]
6
+
7
+ output_dir = 'output/line_hgnetv2_m'
8
+
9
+ # backbone
10
+ backbone = 'HGNetv2_B2'
11
+ use_lab = True
12
+ freeze_norm = False
13
+ freeze_stem_only = True
14
+
15
+ # transformer
16
+ feat_strides = [8, 16, 32]
17
+ hidden_dim = 256
18
+ dim_feedforward = 512
19
+ nheads = 8
20
+ use_lmap = False
21
+
22
+ ## encoder
23
+ hybrid_encoder = 'hybrid_encoder_asymmetric_conv'
24
+ in_channels_encoder = [384, 768, 1536]
25
+ pe_temperatureH = 20
26
+ pe_temperatureW = 20
27
+ expansion = 0.34
28
+ depth_mult = 1.0
29
+
30
+ ## decoder
31
+ feat_channels_decoder = [hidden_dim, hidden_dim, hidden_dim]
32
+ dec_layers = 4
33
+ num_queries = 1100
34
+ num_select = 300
35
+ reg_max = 16
36
+ reg_scale = 4
37
+ eval_idx = 3
38
+
39
+ # criterion
40
+ epochs = 24
41
+ lr_drop_list = [20]
42
+ weight_dict = {'loss_logits': 2, 'loss_line': 5}
43
+ use_warmup = False
44
+
45
+ # optimizer params
46
+ model_parameters = [
47
+ {
48
+ 'params': '^(?=.*backbone)(?!.*norm|bn).*$',
49
+ 'lr': 0.00002
50
+ },
51
+ {
52
+ 'params': '^(?=.*backbone)(?=.*norm|bn).*$',
53
+ 'lr': 0.00002,
54
+ 'weight_decay': 0.
55
+ },
56
+ {
57
+ 'params': '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn|bias)).*$',
58
+ 'weight_decay': 0.
59
+ }
60
+ ]
61
+ lr = 0.0002
62
+ betas = [0.9, 0.999]
63
+ weight_decay = 0.0001
linea/configs/linea/linea_hgnetv2_n.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ './include/dataset.py',
3
+ './include/optimizer.py',
4
+ './include/linea.py'
5
+ ]
6
+
7
+ output_dir = 'output/line_hgnetv2_n'
8
+
9
+ # backbone
10
+ backbone = 'HGNetv2_B0'
11
+ use_lab = True
12
+ freeze_norm = False
13
+ freeze_stem_only = True
14
+
15
+ # transformer
16
+ feat_strides = [8, 16, 32]
17
+ hidden_dim = 128
18
+ dim_feedforward = 512
19
+ nheads = 8
20
+ use_lmap = False
21
+
22
+ ## encoder
23
+ hybrid_encoder = 'hybrid_encoder_asymmetric_conv'
24
+ in_channels_encoder = [256, 512, 1024]
25
+ pe_temperatureH = 20
26
+ pe_temperatureW = 20
27
+ expansion = 0.34
28
+ depth_mult = 0.5
29
+
30
+ ## decoder
31
+ feat_channels_decoder = [hidden_dim, hidden_dim, hidden_dim]
32
+ dec_layers = 3
33
+ num_queries = 1100
34
+ num_select = 300
35
+ reg_max = 16
36
+ reg_scale = 4
37
+ eval_idx = 2
38
+
39
+ # criterion
40
+ epochs = 72
41
+ lr_drop_list = [60]
42
+ weight_dict = {'loss_logits': 2, 'loss_line': 5}
43
+ use_warmup = False
44
+
45
+ # optimizer params
46
+ model_parameters = [
47
+ {
48
+ 'params': '^(?=.*backbone)(?!.*norm|bn).*$',
49
+ 'lr': 0.0004
50
+ },
51
+ {
52
+ 'params': '^(?=.*backbone)(?=.*norm|bn).*$',
53
+ 'lr': 0.0004,
54
+ 'weight_decay': 0.
55
+ },
56
+ {
57
+ 'params': '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn|bias)).*$',
58
+ 'weight_decay': 0.
59
+ }
60
+ ]
61
+ lr = 0.0008
62
+ betas = [0.9, 0.999]
63
+ weight_decay = 0.0001
linea/configs/linea/linea_hgnetv2_s.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ './include/dataset.py',
3
+ './include/optimizer.py',
4
+ './include/linea.py'
5
+ ]
6
+
7
+ output_dir = 'output/line_hgnetv2_s'
8
+
9
+ # backbone
10
+ backbone = 'HGNetv2_B1'
11
+ use_lab = True
12
+ freeze_norm = False
13
+ freeze_stem_only = True
14
+
15
+ # transformer
16
+ feat_strides = [8, 16, 32]
17
+ hidden_dim = 256
18
+ dim_feedforward = 512
19
+ nheads = 8
20
+ use_lmap = False
21
+
22
+ ## encoder
23
+ hybrid_encoder = 'hybrid_encoder_asymmetric_conv'
24
+ in_channels_encoder = [256, 512, 1024]
25
+ pe_temperatureH = 20
26
+ pe_temperatureW = 20
27
+ expansion = 0.34
28
+ depth_mult = 0.5
29
+
30
+ ## decoder
31
+ feat_channels_decoder = [hidden_dim, hidden_dim, hidden_dim]
32
+ dec_layers = 3
33
+ num_queries = 1100
34
+ num_select = 300
35
+ reg_max = 16
36
+ reg_scale = 4
37
+ eval_idx = 2
38
+
39
+ # criterion
40
+ epochs = 36
41
+ lr_drop_list = [25]
42
+ weight_dict = {'loss_logits': 2, 'loss_line': 5}
43
+ use_warmup = True
44
+ warmup_iters = 625 * 5
45
+
46
+ # optimizer params
47
+ model_parameters = [
48
+ {
49
+ 'params': '^(?=.*backbone)(?!.*norm|bn).*$',
50
+ 'lr': 0.0001
51
+ },
52
+ {
53
+ 'params': '^(?=.*backbone)(?=.*norm|bn).*$',
54
+ 'lr': 0.0001,
55
+ 'weight_decay': 0.
56
+ },
57
+ {
58
+ 'params': '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn|bias)).*$',
59
+ 'weight_decay': 0.
60
+ }
61
+ ]
62
+ lr = 0.0002
63
+ betas = [0.9, 0.999]
64
+ weight_decay = 0.0001
linea/models/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # DINO
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7
+ from .linea import build_linea
8
+
linea/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (210 Bytes). View file
 
linea/models/__pycache__/registry.cpython-311.pyc ADDED
Binary file (3.14 kB). View file
 
linea/models/linea/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Conditional DETR
3
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Copied from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ from .linea import build_linea
11
+ from .criterion import build_criterion
linea/models/linea/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (281 Bytes). View file
 
linea/models/linea/__pycache__/attention_mechanism.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
linea/models/linea/__pycache__/criterion.cpython-311.pyc ADDED
Binary file (40.7 kB). View file
 
linea/models/linea/__pycache__/decoder.cpython-311.pyc ADDED
Binary file (31.3 kB). View file
 
linea/models/linea/__pycache__/dn_components.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
linea/models/linea/__pycache__/hgnetv2.cpython-311.pyc ADDED
Binary file (25.1 kB). View file
 
linea/models/linea/__pycache__/hybrid_encoder_asymmetric_conv.cpython-311.pyc ADDED
Binary file (35 kB). View file
 
linea/models/linea/__pycache__/linea.cpython-311.pyc ADDED
Binary file (7.82 kB). View file
 
linea/models/linea/__pycache__/linea_utils.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
linea/models/linea/__pycache__/matcher.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
linea/models/linea/__pycache__/utils.cpython-311.pyc ADDED
Binary file (9.26 kB). View file
 
linea/models/linea/attention_mechanism.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from torch import nn
5
+ from torch.nn.init import xavier_uniform_, constant_
6
+
7
+ import math
8
+
9
+ def _is_power_of_2(n):
10
+ if (not isinstance(n, int)) or (n < 0):
11
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
12
+ return (n & (n-1) == 0) and n != 0
13
+
14
+ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights, total_num_points):
15
+ # for debug and test only,
16
+ # need to use cuda version instead
17
+ N_, S_, M_, D_ = value.shape
18
+ _, Lq_, M_, P_, _ = sampling_locations[0].shape
19
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
20
+ # sampling_grids = 2 * sampling_locations - 1
21
+ sampling_value_list = []
22
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
23
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
24
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
25
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
26
+ sampling_grid_l_ = (2 * sampling_locations[lid_] - 1).transpose(1, 2).flatten(0, 1)
27
+ # N_*M_, D_, Lq_, P_
28
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
29
+ mode='bilinear', padding_mode='zeros', align_corners=False)
30
+ sampling_value_list.append(sampling_value_l_)
31
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
32
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, total_num_points)
33
+ output = (torch.cat(sampling_value_list, dim=-1) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
34
+ return output.transpose(1, 2).contiguous()
35
+
36
+ def ms_deform_attn_core_pytorchv2(value, value_spatial_shapes, sampling_locations, attention_weights, num_points_list):
37
+ # for debug and test only,
38
+ # need to use cuda version instead
39
+ _, D_ , _= value[0].shape
40
+ N_, Lq_, M_, _, _ = sampling_locations.shape
41
+
42
+ sampling_grids = 2 * sampling_locations - 1
43
+ sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
44
+ sampling_locations_list = sampling_grids.split(num_points_list, dim=-2)
45
+
46
+ sampling_value_list = []
47
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
48
+ # N_* M_, D_, H_*W_ -> N_*M_, D_, H_, W_
49
+ value_l_ = value[lid_].unflatten(2, (H_, W_))
50
+ # N_*M_, Lq_, P_, 2
51
+ sampling_grid_l_ = sampling_locations_list[lid_]
52
+ # N_*M_, D_, Lq_, P_
53
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
54
+ mode='bilinear', padding_mode='zeros', align_corners=False)
55
+ sampling_value_list.append(sampling_value_l_)
56
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
57
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, sum(num_points_list))
58
+ output = (torch.cat(sampling_value_list, dim=-1) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
59
+ return output.transpose(1, 2).contiguous()
60
+
61
+
62
+ class MSDeformAttn(nn.Module):
63
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
64
+ """
65
+ Multi-Scale Deformable Attention Module
66
+ :param d_model hidden dimension
67
+ :param n_levels number of feature levels
68
+ :param n_heads number of attention heads
69
+ :param n_points number of sampling points per attention head per feature level
70
+ """
71
+ super().__init__()
72
+ if d_model % n_heads != 0:
73
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
74
+ _d_per_head = d_model // n_heads
75
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
76
+ if not _is_power_of_2(_d_per_head):
77
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
78
+ "which is more efficient in our CUDA implementation.")
79
+
80
+ self.d_model = d_model
81
+ self.n_levels = n_levels
82
+ self.n_heads = n_heads
83
+ self.n_points = n_points
84
+
85
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
86
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
87
+ self.value_proj = nn.Linear(d_model, d_model)
88
+ self.output_proj = nn.Linear(d_model, d_model)
89
+
90
+ self._reset_parameters()
91
+
92
+ def _reset_parameters(self):
93
+ constant_(self.sampling_offsets.weight.data, 0.)
94
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
95
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
96
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
97
+ for i in range(self.n_points):
98
+ grid_init[:, :, i, :] *= i + 1
99
+ with torch.no_grad():
100
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
101
+ constant_(self.attention_weights.weight.data, 0.)
102
+ constant_(self.attention_weights.bias.data, 0.)
103
+ xavier_uniform_(self.value_proj.weight.data)
104
+ constant_(self.value_proj.bias.data, 0.)
105
+ xavier_uniform_(self.output_proj.weight.data)
106
+ constant_(self.output_proj.bias.data, 0.)
107
+
108
+ def forward(self, query, reference_points, input_flatten, input_spatial_shapes):
109
+ """
110
+ :param query (N, Length_{query}, C)
111
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
112
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
113
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
114
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
115
+
116
+ :return output (N, Length_{query}, C)
117
+ """
118
+ N, Len_q, _ = query.shape
119
+ N, Len_in, _ = input_flatten.shape
120
+
121
+ value = self.value_proj(input_flatten)
122
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
123
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
124
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
125
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
126
+ # N, Len_q, n_heads, n_levels, n_points, 2
127
+ if reference_points.shape[-1] == 2:
128
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
129
+ sampling_locations = reference_points[:, :, None, :, None, :] \
130
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
131
+ elif reference_points.shape[-1] == 4:
132
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
133
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
134
+ else:
135
+ raise ValueError(
136
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
137
+ output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
138
+ output = self.output_proj(output)
139
+ return output
140
+
141
+ class MSDeformLineAttn(nn.Module):
142
+ def __init__(
143
+ self,
144
+ d_model=256,
145
+ n_levels=4,
146
+ n_heads=8,
147
+ n_points=4
148
+ ):
149
+ """
150
+ This version is inspired from DFine. We removed the following layers:
151
+ - value_proj
152
+ - output_proj
153
+ """
154
+ super().__init__()
155
+ if d_model % n_heads != 0:
156
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
157
+ _d_per_head = d_model // n_heads
158
+
159
+ self.d_model = d_model
160
+ self.n_levels = n_levels
161
+ self.n_heads = n_heads
162
+
163
+ if isinstance(n_points, list):
164
+ assert len(n_points) == n_levels, ''
165
+ num_points_list = n_points
166
+ else:
167
+ num_points_list = [n_points for _ in range(n_levels)]
168
+ self.num_points_list = num_points_list
169
+ self.total_num_points = sum(num_points_list)
170
+
171
+ num_points_scale = [1/n for n in num_points_list for _ in range(n)]
172
+ self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32).reshape(-1, 1))
173
+
174
+ self.sampling_ratios = nn.Linear(d_model, n_heads * sum(num_points_list))
175
+ self.attention_weights = nn.Linear(d_model, n_heads * sum(num_points_list))
176
+
177
+ self._reset_parameters()
178
+
179
+ def _reset_parameters(self):
180
+ constant_(self.sampling_ratios.weight.data, 0.)
181
+ with torch.no_grad():
182
+ self.sampling_ratios.bias = nn.Parameter(torch.linspace(-1, 1, self.n_heads * self.total_num_points))
183
+
184
+ constant_(self.attention_weights.weight.data, 0.)
185
+ constant_(self.attention_weights.bias.data, 0.)
186
+
187
+ def forward(self, query, reference_points, value, value_spatial_shapes):
188
+ """
189
+ :param query (N, Length_{query}, C)
190
+ :param reference_points (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
191
+ :param value (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
192
+ :param value_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
193
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
194
+
195
+ :return output (N, Length_{query}, C)
196
+
197
+ ####################################################################
198
+ # Difference respect to MSDeformAttn
199
+ # The query already stores the line's junctions
200
+ # :param reference_points is not needed. We keep it to make both
201
+ MSDeformAttn and MSDeformLineAttn interchangebale
202
+ between different frameworks
203
+ # MSDeformLineAttn does not generate offsets. Instead, it samples
204
+ n_points equally-spaced points from the line segment
205
+ ####################################################################
206
+ """
207
+ N, Len_q, _ = query.shape
208
+
209
+ sampling_ratios = self.sampling_ratios(query).view(N, Len_q, self.n_heads, self.total_num_points, 1)
210
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.total_num_points)
211
+ attention_weights = F.softmax(attention_weights, -1)
212
+
213
+ num_points_scale = self.num_points_scale.to(dtype=query.dtype)
214
+
215
+ vector = reference_points[:, :, None, :, :2] - reference_points[:, :, None, :, 2:]
216
+ center = 0.5 * (reference_points[:, :, None, :, :2] + reference_points[:, :, None, :, 2:])
217
+
218
+ sampling_locations = center + sampling_ratios * num_points_scale * vector * 0.5
219
+
220
+ output = ms_deform_attn_core_pytorchv2(
221
+ value,
222
+ value_spatial_shapes,
223
+ sampling_locations,
224
+ attention_weights,
225
+ self.num_points_list
226
+ )
227
+ return output
228
+
229
+
230
+ #######################
231
+ ## Previous versions ##
232
+ #######################
233
+
234
+ # class MSDeformLineAttn(nn.Module):
235
+ # def __init__(
236
+ # self,
237
+ # d_model=256,
238
+ # n_levels=4,
239
+ # n_heads=8,
240
+ # n_points=4
241
+ # ):
242
+ # """
243
+ # Multi-Scale Deformable Attention Module
244
+ # :param d_model hidden dimension
245
+ # :param n_levels number of feature levels
246
+ # :param n_heads number of attention heads
247
+ # :param n_points number of sampling points per attention head per feature level
248
+ # """
249
+ # super().__init__()
250
+ # if d_model % n_heads != 0:
251
+ # raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
252
+ # _d_per_head = d_model // n_heads
253
+ # # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
254
+ # if not _is_power_of_2(_d_per_head):
255
+ # warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
256
+ # "which is more efficient in our CUDA implementation.")
257
+
258
+ # self.d_model = d_model
259
+ # self.n_levels = n_levels
260
+ # self.n_heads = n_heads
261
+
262
+ # if isinstance(n_points, list):
263
+ # assert len(n_points) == n_levels, ''
264
+ # num_points_list = n_points
265
+ # else:
266
+ # num_points_list = [n_points for _ in range(n_levels)]
267
+ # self.num_points_list = num_points_list
268
+ # self.total_num_points = sum(num_points_list)
269
+
270
+ # num_points_scale = [1/n for n in num_points_list]
271
+ # self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32).reshape(-1, 1, 1))
272
+
273
+ # self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * 4)
274
+
275
+ # self.attention_weights = nn.Linear(d_model, n_heads * sum(num_points_list))
276
+ # self.value_proj = nn.Linear(d_model, d_model)
277
+ # self.output_proj = nn.Linear(d_model, d_model)
278
+
279
+ # for i in range(len(num_points_list)):
280
+ # if num_points_list[i] == 1:
281
+ # lambda_ = torch.linspace(0.5, 0.5, num_points_list[i])[:, None]
282
+ # else:
283
+ # lambda_ = torch.linspace(0, 1, num_points_list[i])[:, None]
284
+ # self.register_buffer(f"lambda_{i}", lambda_)
285
+
286
+ # self._reset_parameters()
287
+
288
+ # def _reset_parameters(self):
289
+ # constant_(self.sampling_offsets.weight.data, 0.)
290
+ # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
291
+ # grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
292
+ # grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, 2, 1)
293
+ # for i in range(1):
294
+ # grid_init[:, :, 2*i, :] *= i + 1
295
+ # grid_init[:, :, 2*i+1, :] *= i + 1
296
+ # with torch.no_grad():
297
+ # self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
298
+ # constant_(self.attention_weights.weight.data, 0.)
299
+ # constant_(self.attention_weights.bias.data, 0.)
300
+ # xavier_uniform_(self.value_proj.weight.data)
301
+ # constant_(self.value_proj.bias.data, 0.)
302
+ # xavier_uniform_(self.output_proj.weight.data)
303
+ # constant_(self.output_proj.bias.data, 0.)
304
+
305
+ # def forward(self, query, reference_points, input_flatten, input_spatial_shapes):
306
+ # """
307
+ # :param query (N, Length_{query}, C)
308
+ # :param reference_points (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
309
+ # :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
310
+ # :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
311
+ # :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
312
+
313
+ # :return output (N, Length_{query}, C)
314
+
315
+ # ####################################################################
316
+ # # Difference respect to MSDeformAttn
317
+ # # The query already stores the line's junctions
318
+ # # :param reference_points is not needed. We keep it to make both
319
+ # MSDeformAttn and MSDeformLineAttn interchangebale
320
+ # between different frameworks
321
+ # # MSDeformLineAttn does not generate offsets. Instead, it samples
322
+ # n_points equally-spaced points from the line segment
323
+ # ####################################################################
324
+ # """
325
+ # N, Len_q, _ = query.shape
326
+ # N, Len_in, _ = input_flatten.shape
327
+
328
+ # value = self.value_proj(input_flatten)
329
+
330
+ # value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
331
+ # sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, 1, 4)
332
+ # attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.total_num_points)
333
+ # attention_weights = F.softmax(attention_weights, -1)
334
+
335
+ # num_points_scale = self.num_points_scale.to(dtype=query.dtype)
336
+
337
+ # wh = reference_points[:, :, None, :, None, :2] - reference_points[:, :, None, :, None, 2:]
338
+ # center = 0.5 * (reference_points[:, :, None, :, None, :2] + reference_points[:, :, None, :, None, 2:])
339
+
340
+ # sampling_junctions = torch.cat((center, center), dim=-1) \
341
+ # + sampling_offsets * num_points_scale * torch.cat([wh, wh], -1) * 0.5
342
+
343
+ # sampling_locations = []
344
+
345
+ # # sampling_junctions_level = torch.split(sampling_junctions, self.num_points_list, dim=-2)
346
+ # for i in range(len(self.num_points_list)):
347
+ # lambda_ = getattr(self, f'lambda_{i}')
348
+ # junctions = sampling_junctions[:, :, :, i]
349
+ # locations = junctions[..., :2] * lambda_ + junctions[..., 2:] * (1 - lambda_)
350
+ # sampling_locations.append(locations)
351
+
352
+ # output = ms_deform_attn_core_pytorch(
353
+ # value,
354
+ # input_spatial_shapes,
355
+ # sampling_locations,
356
+ # attention_weights,
357
+ # self.total_num_points
358
+ # )
359
+ # output = self.output_proj(output)
360
+ # return output
361
+
362
+
363
+ # class MSDeformLineAttnV2(nn.Module):
364
+ # def __init__(
365
+ # self,
366
+ # d_model=256,
367
+ # n_levels=4,
368
+ # n_heads=8,
369
+ # n_points=4
370
+ # ):
371
+ # """
372
+ # Multi-Scale Deformable Attention Module
373
+ # :param d_model hidden dimension
374
+ # :param n_levels number of feature levels
375
+ # :param n_heads number of attention heads
376
+ # :param n_points number of sampling points per attention head per feature level
377
+ # """
378
+ # super().__init__()
379
+ # if d_model % n_heads != 0:
380
+ # raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
381
+ # _d_per_head = d_model // n_heads
382
+ # # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
383
+ # if not _is_power_of_2(_d_per_head):
384
+ # warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
385
+ # "which is more efficient in our CUDA implementation.")
386
+
387
+ # self.d_model = d_model
388
+ # self.n_levels = n_levels
389
+ # self.n_heads = n_heads
390
+
391
+ # if isinstance(n_points, list):
392
+ # assert len(n_points) == n_levels, ''
393
+ # num_points_list = n_points
394
+ # else:
395
+ # num_points_list = [n_points for _ in range(n_levels)]
396
+ # self.num_points_list = num_points_list
397
+ # self.total_num_points = sum(num_points_list)
398
+
399
+ # num_points_scale = [1/n for n in num_points_list]
400
+ # self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32).reshape(-1, 1, 1))
401
+
402
+ # self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * 4)
403
+ # self.sampling_ratios = nn.Linear(d_model, n_heads * sum(num_points_list))
404
+
405
+ # self.attention_weights = nn.Linear(d_model, n_heads * sum(num_points_list))
406
+ # self.value_proj = nn.Linear(d_model, d_model)
407
+ # self.output_proj = nn.Linear(d_model, d_model)
408
+
409
+ # self._reset_parameters()
410
+
411
+ # def _reset_parameters(self):
412
+ # constant_(self.sampling_offsets.weight.data, 0.)
413
+ # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
414
+ # grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
415
+ # grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, 2, 1)
416
+ # for i in range(1):
417
+ # grid_init[:, :, 2*i, :] *= i + 1
418
+ # grid_init[:, :, 2*i+1, :] *= i + 1
419
+ # with torch.no_grad():
420
+ # self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
421
+ # constant_(self.attention_weights.weight.data, 0.)
422
+ # constant_(self.attention_weights.bias.data, 0.)
423
+ # xavier_uniform_(self.value_proj.weight.data)
424
+ # constant_(self.value_proj.bias.data, 0.)
425
+ # xavier_uniform_(self.output_proj.weight.data)
426
+ # constant_(self.output_proj.bias.data, 0.)
427
+
428
+ # def forward(self, query, reference_points, input_flatten, input_spatial_shapes):
429
+ # """
430
+ # :param query (N, Length_{query}, C)
431
+ # :param reference_points (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
432
+ # :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
433
+ # :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
434
+ # :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
435
+
436
+ # :return output (N, Length_{query}, C)
437
+
438
+ # ####################################################################
439
+ # # Difference respect to MSDeformAttn
440
+ # # The query already stores the line's junctions
441
+ # # :param reference_points is not needed. We keep it to make both
442
+ # MSDeformAttn and MSDeformLineAttn interchangebale
443
+ # between different frameworks
444
+ # # MSDeformLineAttn does not generate offsets. Instead, it samples
445
+ # n_points equally-spaced points from the line segment
446
+ # ####################################################################
447
+ # """
448
+ # N, Len_q, _ = query.shape
449
+ # N, Len_in, _ = input_flatten.shape
450
+
451
+ # value = self.value_proj(input_flatten)
452
+
453
+ # value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
454
+ # sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, 1, 4)
455
+ # sampling_ratios = self.sampling_ratios(query).view(N, Len_q, self.n_heads, self.total_num_points).sigmoid()
456
+ # attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.total_num_points)
457
+ # attention_weights = F.softmax(attention_weights, -1)
458
+
459
+ # num_points_scale = self.num_points_scale.to(dtype=query.dtype)
460
+
461
+ # wh = reference_points[:, :, None, :, None, :2] - reference_points[:, :, None, :, None, 2:]
462
+ # center = 0.5 * (reference_points[:, :, None, :, None, :2] + reference_points[:, :, None, :, None, 2:])
463
+
464
+ # sampling_junctions = torch.cat((center, center), dim=-1) \
465
+ # + sampling_offsets * num_points_scale * torch.cat([wh, wh], -1) * 0.5
466
+
467
+ # sampling_locations = []
468
+
469
+ # for i, lambda_ in enumerate(torch.split(sampling_ratios, self.num_points_list, dim=-1)):
470
+ # lambda_ = lambda_[..., None]
471
+ # junctions = sampling_junctions[:, :, :, i]
472
+ # locations = junctions[..., :2] * lambda_ + junctions[..., 2:] * (1 - lambda_)
473
+ # sampling_locations.append(locations)
474
+
475
+ # output = ms_deform_attn_core_pytorch(
476
+ # value,
477
+ # input_spatial_shapes,
478
+ # sampling_locations,
479
+ # attention_weights,
480
+ # self.total_num_points
481
+ # )
482
+ # output = self.output_proj(output)
483
+ # return output
484
+
485
+
486
+ # class MSDeformLineAttnV3(nn.Module):
487
+ # def __init__(
488
+ # self,
489
+ # d_model=256,
490
+ # n_levels=4,
491
+ # n_heads=8,
492
+ # n_points=4
493
+ # ):
494
+ # """
495
+ # This version is inspired from DFine. We removed the following layers:
496
+ # - value_proj
497
+ # - output_proj
498
+ # """
499
+ # super().__init__()
500
+ # if d_model % n_heads != 0:
501
+ # raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
502
+ # _d_per_head = d_model // n_heads
503
+ # # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
504
+ # if not _is_power_of_2(_d_per_head):
505
+ # warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
506
+ # "which is more efficient in our CUDA implementation.")
507
+
508
+ # self.d_model = d_model
509
+ # self.n_levels = n_levels
510
+ # self.n_heads = n_heads
511
+
512
+ # if isinstance(n_points, list):
513
+ # assert len(n_points) == n_levels, ''
514
+ # num_points_list = n_points
515
+ # else:
516
+ # num_points_list = [n_points for _ in range(n_levels)]
517
+ # self.num_points_list = num_points_list
518
+ # self.total_num_points = sum(num_points_list)
519
+
520
+ # num_points_scale = [1/n for n in num_points_list]
521
+ # self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32).reshape(-1, 1, 1))
522
+
523
+ # self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * 4)
524
+ # self.sampling_ratios = nn.Linear(d_model, n_heads * sum(num_points_list))
525
+
526
+ # self.attention_weights = nn.Linear(d_model, n_heads * sum(num_points_list))
527
+
528
+ # self._reset_parameters()
529
+
530
+ # def _reset_parameters(self):
531
+ # constant_(self.sampling_offsets.weight.data, 0.)
532
+ # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
533
+ # grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
534
+ # grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, 2, 1)
535
+ # for i in range(1):
536
+ # grid_init[:, :, 2*i, :] *= i + 1
537
+ # grid_init[:, :, 2*i+1, :] *= i + 1
538
+ # with torch.no_grad():
539
+ # self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
540
+ # constant_(self.attention_weights.weight.data, 0.)
541
+ # constant_(self.attention_weights.bias.data, 0.)
542
+
543
+ # def forward(self, query, reference_points, value, value_spatial_shapes):
544
+ # """
545
+ # :param query (N, Length_{query}, C)
546
+ # :param reference_points (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
547
+ # :param value (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
548
+ # :param value_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
549
+ # :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
550
+
551
+ # :return output (N, Length_{query}, C)
552
+
553
+ # ####################################################################
554
+ # # Difference respect to MSDeformAttn
555
+ # # The query already stores the line's junctions
556
+ # # :param reference_points is not needed. We keep it to make both
557
+ # MSDeformAttn and MSDeformLineAttn interchangebale
558
+ # between different frameworks
559
+ # # MSDeformLineAttn does not generate offsets. Instead, it samples
560
+ # n_points equally-spaced points from the line segment
561
+ # ####################################################################
562
+ # """
563
+ # N, Len_q, _ = query.shape
564
+
565
+ # sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, 1, 4)
566
+ # sampling_ratios = self.sampling_ratios(query).view(N, Len_q, self.n_heads, self.total_num_points).sigmoid()
567
+ # attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.total_num_points)
568
+ # attention_weights = F.softmax(attention_weights, -1)
569
+
570
+ # num_points_scale = self.num_points_scale.to(dtype=query.dtype)
571
+
572
+ # wh = reference_points[:, :, None, :, None, :2] - reference_points[:, :, None, :, None, 2:]
573
+ # center = 0.5 * (reference_points[:, :, None, :, None, :2] + reference_points[:, :, None, :, None, 2:])
574
+
575
+ # sampling_junctions = torch.cat((center, center), dim=-1) \
576
+ # + sampling_offsets * num_points_scale * torch.cat([wh, wh], -1) * 0.5
577
+
578
+ # sampling_locations = []
579
+
580
+ # for i, lambda_ in enumerate(torch.split(sampling_ratios, self.num_points_list, dim=-1)):
581
+ # lambda_ = lambda_[..., None]
582
+ # junctions = sampling_junctions[:, :, :, i]
583
+ # locations = junctions[..., :2] * lambda_ + junctions[..., 2:] * (1 - lambda_)
584
+ # sampling_locations.append(locations)
585
+
586
+ # output = ms_deform_attn_core_pytorchv2(
587
+ # value,
588
+ # value_spatial_shapes,
589
+ # sampling_locations,
590
+ # attention_weights,
591
+ # self.total_num_points
592
+ # )
593
+ # return output
linea/models/linea/criterion.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from torchvision.transforms.functional import resize
5
+
6
+ from .utils import sigmoid_focal_loss
7
+
8
+ from .matcher import build_matcher
9
+
10
+ from .linea_utils import weighting_function, bbox2distance
11
+
12
+ from ..registry import MODULE_BUILD_FUNCS
13
+
14
+ # TODO. Quick solution to make the model run on GoogleColab
15
+ import os, sys
16
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
17
+ from util.misc import get_world_size, is_dist_avail_and_initialized
18
+
19
+
20
+ class LINEACriterion(nn.Module):
21
+ """ This class computes the loss for Conditional DETR.
22
+ The process happens in two steps:
23
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
24
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
25
+ """
26
+ def __init__(self, num_classes, matcher, weight_dict, focal_alpha, losses):
27
+ """ Create the criterion.
28
+ Parameters:
29
+ num_classes: number of object categories, omitting the special no-object category
30
+ matcher: module able to compute a matching between targets and proposals
31
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
32
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
33
+ focal_alpha: alpha in Focal Loss
34
+ """
35
+ super().__init__()
36
+ self.num_classes = num_classes
37
+ self.matcher = matcher
38
+ self.weight_dict = weight_dict
39
+ self.losses = losses
40
+ self.focal_alpha = focal_alpha
41
+
42
+ def loss_labels(self, outputs, targets, indices, num_boxes):
43
+ """Classification loss (Binary focal loss)
44
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
45
+ """
46
+ assert 'pred_logits' in outputs
47
+ src_logits = outputs['pred_logits']
48
+ idx = self._get_src_permutation_idx(indices)
49
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
50
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes,
51
+ dtype=torch.int64, device=src_logits.device)
52
+ target_classes[idx] = target_classes_o
53
+
54
+ target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1],
55
+ dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
56
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
57
+ target_classes_onehot = target_classes_onehot[:,:,:-1]
58
+
59
+ loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
60
+ losses = {'loss_logits': loss_ce}
61
+
62
+ return losses
63
+
64
+ def loss_lines(self, outputs, targets, indices, num_boxes):
65
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
66
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
67
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
68
+ """
69
+ assert 'pred_lines' in outputs
70
+ idx = self._get_src_permutation_idx(indices)
71
+ src_lines = outputs['pred_lines'][idx]
72
+ target_lines = torch.cat([t['lines'][i] for t, (_, i) in zip(targets, indices)], dim=0)
73
+
74
+ loss_line = F.l1_loss(src_lines, target_lines, reduction='none')
75
+
76
+ losses = {}
77
+ losses['loss_line'] = loss_line.sum() / num_boxes
78
+
79
+ return losses
80
+
81
+ def loss_lmap(self, outputs, targets, indices, num_boxes):
82
+ losses = {}
83
+ if 'aux_lmap' in outputs:
84
+ src_lmap = outputs['aux_lmap']
85
+ size = src_lmap[0].size(2)
86
+ target_lmap = []
87
+ for t in targets:
88
+ lmaps_flatten = []
89
+ for lmap, downsampling in zip(t['lmap'], [1, 2, 4]):
90
+ lmap_ = resize(lmap, (size//downsampling, size//downsampling))
91
+ lmaps_flatten.append(lmap_.flatten(1))
92
+ target_lmap.append(torch.cat(lmaps_flatten, dim=1))
93
+ target_lmap = torch.cat(target_lmap, dim=0)
94
+
95
+ src_lmap = torch.cat([lmap_.flatten(1) for lmap_ in src_lmap], dim=1)
96
+
97
+ loss_lmap = F.binary_cross_entropy_with_logits(src_lmap, target_lmap, reduction='mean')
98
+
99
+ losses['loss_lmap'] = loss_lmap
100
+
101
+ return losses
102
+
103
+ def _get_src_permutation_idx(self, indices):
104
+ # permute predictions following indices
105
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
106
+ src_idx = torch.cat([src for (src, _) in indices])
107
+ return batch_idx, src_idx
108
+
109
+ def _get_tgt_permutation_idx(self, indices):
110
+ # permute targets following indices
111
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
112
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
113
+ return batch_idx, tgt_idx
114
+
115
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
116
+ loss_map = {
117
+ 'labels': self.loss_labels,
118
+ 'lines': self.loss_lines,
119
+ 'lmap': self.loss_lmap,
120
+ }
121
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
122
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
123
+
124
+ def forward(self, outputs, targets, return_indices=False):
125
+ """ This performs the loss computation.
126
+ Parameters:
127
+ outputs: dict of tensors, see the output specification of the model for the format
128
+ targets: list of dicts, such that len(targets) == batch_size.
129
+ The expected keys in each dict depends on the losses applied, see each loss' doc
130
+
131
+ return_indices: used for vis. if True, the layer0-5 indices will be returned as well.
132
+
133
+ """
134
+ outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
135
+ device = next(iter(outputs.values())).device
136
+ indices = self.matcher(outputs_without_aux, targets)
137
+ if return_indices:
138
+ return indices
139
+
140
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
141
+ num_boxes = sum(len(t["labels"]) for t in targets)
142
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device)
143
+ if is_dist_avail_and_initialized():
144
+ torch.distributed.all_reduce(num_boxes)
145
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
146
+
147
+ # Compute all the requested losses
148
+ losses = {}
149
+
150
+ for loss in self.losses:
151
+ indices_in = indices
152
+ num_boxes_in = num_boxes
153
+ l_dict = self.get_loss(loss, outputs, targets, indices_in, num_boxes_in)
154
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
155
+ losses.update(l_dict)
156
+
157
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
158
+ if 'aux_outputs' in outputs:
159
+ for idx, aux_outputs in enumerate(outputs['aux_outputs']):
160
+ indices = self.matcher(aux_outputs, targets)
161
+ for loss in self.losses:
162
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
163
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
164
+ l_dict = {k + f'_{idx}': v for k, v in l_dict.items()}
165
+ losses.update(l_dict)
166
+
167
+ # interm_outputs loss
168
+ if 'aux_interm_outputs' in outputs:
169
+ interm_outputs = outputs['aux_interm_outputs']
170
+ indices = self.matcher(interm_outputs, targets)
171
+ for loss in self.losses:
172
+ l_dict = self.get_loss(loss, interm_outputs, targets, indices, num_boxes)
173
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
174
+ l_dict = {k + f'_interm': v for k, v in l_dict.items()}
175
+ losses.update(l_dict)
176
+
177
+ # pre output loss
178
+ if 'aux_pre_outputs' in outputs:
179
+ pre_outputs = outputs['aux_pre_outputs']
180
+ indices = self.matcher(pre_outputs, targets)
181
+ for loss in self.losses:
182
+ l_dict = self.get_loss(loss, pre_outputs, targets, indices, num_boxes)
183
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
184
+ l_dict = {k + f'_pre': v for k, v in l_dict.items()}
185
+ losses.update(l_dict)
186
+
187
+ # prepare for dn loss
188
+ dn_meta = outputs['dn_meta']
189
+
190
+ if self.training and dn_meta and 'aux_denoise' in outputs:
191
+ single_pad, scalar = self.prep_for_dn(dn_meta)
192
+ dn_pos_idx = []
193
+ dn_neg_idx = []
194
+ for i in range(len(targets)):
195
+ if len(targets[i]['labels']) > 0:
196
+ t = torch.arange(len(targets[i]['labels'])).long().cuda()
197
+ t = t.unsqueeze(0).repeat(scalar, 1)
198
+ tgt_idx = t.flatten()
199
+ output_idx = (torch.tensor(range(scalar)) * single_pad).long().cuda().unsqueeze(1) + t
200
+ output_idx = output_idx.flatten()
201
+ else:
202
+ output_idx = tgt_idx = torch.tensor([]).long().cuda()
203
+
204
+ dn_pos_idx.append((output_idx, tgt_idx))
205
+ dn_neg_idx.append((output_idx + single_pad // 2, tgt_idx))
206
+
207
+ dn_outputs = outputs['aux_denoise']
208
+
209
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
210
+ if 'aux_outputs' in dn_outputs:
211
+ for idx, aux_outputs in enumerate(dn_outputs['aux_outputs']):
212
+ for loss in self.losses:
213
+ l_dict = self.get_loss(loss, aux_outputs, targets, dn_pos_idx, num_boxes*scalar)
214
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
215
+ l_dict = {k + f'_dn_{idx}': v for k, v in l_dict.items()}
216
+ losses.update(l_dict)
217
+
218
+ if 'aux_pre_outputs' in dn_outputs:
219
+ aux_outputs_known = dn_outputs['aux_pre_outputs']
220
+ l_dict={}
221
+ for loss in self.losses:
222
+ l_dict.update(self.get_loss(loss, aux_outputs_known, targets, dn_pos_idx, num_boxes*scalar))
223
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
224
+ l_dict = {k + f'_pre_dn': v for k, v in l_dict.items()}
225
+ losses.update(l_dict)
226
+
227
+ losses = {k: v for k, v in sorted(losses.items(), key=lambda item: item[0])}
228
+
229
+ return losses
230
+
231
+ def prep_for_dn(self,dn_meta):
232
+ # output_known_lbs_lines = dn_meta['output_known_lbs_lines']
233
+ num_dn_groups, pad_size=dn_meta['num_dn_group'],dn_meta['pad_size']
234
+ assert pad_size % num_dn_groups==0
235
+ single_pad=pad_size//num_dn_groups
236
+
237
+ return single_pad,num_dn_groups
238
+
239
+ class DFINESetCriterion(LINEACriterion):
240
+ def __init__(self, num_classes, matcher, weight_dict, focal_alpha, reg_max, losses):
241
+ super().__init__(num_classes, matcher, weight_dict, focal_alpha, losses)
242
+ self.reg_max = reg_max
243
+
244
+ def loss_local(self, outputs, targets, indices, num_boxes, T=5):
245
+ losses = {}
246
+ if 'pred_corners' in outputs:
247
+ idx = self._get_src_permutation_idx(indices)
248
+ target_lines = torch.cat([t['lines'][i] for t, (_, i) in zip(targets, indices)], dim=0)
249
+
250
+ pred_corners = outputs['pred_corners'][idx].reshape(-1, (self.reg_max+1))
251
+ ref_points = outputs['ref_points'][idx].detach()
252
+
253
+ with torch.no_grad():
254
+ if self.fgl_targets_dn is None and 'is_dn' in outputs:
255
+ self.fgl_targets_dn= bbox2distance(ref_points, target_lines,
256
+ self.reg_max, outputs['reg_scale'],
257
+ outputs['up'])
258
+ if self.fgl_targets is None and 'is_dn' not in outputs:
259
+ self.fgl_targets = bbox2distance(ref_points, target_lines,
260
+ self.reg_max, outputs['reg_scale'],
261
+ outputs['up'])
262
+
263
+ target_corners, weight_right, weight_left = self.fgl_targets_dn if 'is_dn' in outputs else self.fgl_targets
264
+
265
+ losses['loss_fgl'] = self.unimodal_distribution_focal_loss(
266
+ pred_corners, target_corners, weight_right, weight_left, None, avg_factor=num_boxes)
267
+
268
+ if 'teacher_corners' in outputs:
269
+ pred_corners = outputs['pred_corners'].reshape(-1, (self.reg_max+1))
270
+ target_corners = outputs['teacher_corners'].reshape(-1, (self.reg_max+1))
271
+ if torch.equal(pred_corners, target_corners):
272
+ losses['loss_ddf'] = pred_corners.sum() * 0
273
+ else:
274
+ weight_targets_local = outputs['teacher_logits'].sigmoid().max(dim=-1)[0]
275
+
276
+ mask = torch.zeros_like(weight_targets_local, dtype=torch.bool)
277
+ mask[idx] = True
278
+ mask = mask.unsqueeze(-1).repeat(1, 1, 4).reshape(-1)
279
+
280
+ weight_targets_local = weight_targets_local.unsqueeze(-1).repeat(1, 1, 4).reshape(-1).detach()
281
+
282
+ loss_match_local = weight_targets_local * (T ** 2) * (nn.KLDivLoss(reduction='none')
283
+ (F.log_softmax(pred_corners / T, dim=1), F.softmax(target_corners.detach() / T, dim=1))).sum(-1)
284
+ if 'is_dn' not in outputs:
285
+ batch_scale = 8 / outputs['pred_lines'].shape[0] # Avoid the influence of batch size per GPU
286
+ self.num_pos, self.num_neg = (mask.sum() * batch_scale) ** 0.5, ((~mask).sum() * batch_scale) ** 0.5
287
+ loss_match_local1 = loss_match_local[mask].mean() if mask.any() else 0
288
+ loss_match_local2 = loss_match_local[~mask].mean() if (~mask).any() else 0
289
+ losses['loss_ddf'] = (loss_match_local1 * self.num_pos + loss_match_local2 * self.num_neg) / (self.num_pos + self.num_neg)
290
+
291
+ return losses
292
+
293
+ def _clear_cache(self):
294
+ self.fgl_targets, self.fgl_targets_dn = None, None
295
+ self.own_targets, self.own_targets_dn = None, None
296
+ self.num_pos, self.num_neg = None, None
297
+
298
+ def unimodal_distribution_focal_loss(self, pred, label, weight_right, weight_left, weight=None, reduction='sum', avg_factor=None):
299
+ dis_left = label.long()
300
+ dis_right = dis_left + 1
301
+
302
+ loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left.reshape(-1) \
303
+ + F.cross_entropy(pred, dis_right, reduction='none') * weight_right.reshape(-1)
304
+
305
+ if weight is not None:
306
+ weight = weight.float()
307
+ loss = loss * weight
308
+
309
+ if avg_factor is not None:
310
+ loss = loss.sum() / avg_factor
311
+ elif reduction == 'mean':
312
+ loss = loss.mean()
313
+ elif reduction == 'sum':
314
+ loss = loss.sum()
315
+
316
+ return loss
317
+
318
+ def _get_go_indices(self, indices, indices_aux_list):
319
+ """Get a matching union set across all decoder layers. """
320
+ results = []
321
+ for indices_aux in indices_aux_list:
322
+ indices = [(torch.cat([idx1[0], idx2[0]]), torch.cat([idx1[1], idx2[1]]))
323
+ for idx1, idx2 in zip(indices.copy(), indices_aux.copy())]
324
+
325
+ for ind in [torch.cat([idx[0][:, None], idx[1][:, None]], 1) for idx in indices]:
326
+ unique, counts = torch.unique(ind, return_counts=True, dim=0)
327
+ count_sort_indices = torch.argsort(counts, descending=True)
328
+ unique_sorted = unique[count_sort_indices]
329
+ column_to_row = {}
330
+ for idx in unique_sorted:
331
+ row_idx, col_idx = idx[0].item(), idx[1].item()
332
+ if row_idx not in column_to_row:
333
+ column_to_row[row_idx] = col_idx
334
+ final_rows = torch.tensor(list(column_to_row.keys()), device=ind.device)
335
+ final_cols = torch.tensor(list(column_to_row.values()), device=ind.device)
336
+ results.append((final_rows.long(), final_cols.long()))
337
+ return results
338
+
339
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
340
+ loss_map = {
341
+ 'labels': self.loss_labels,
342
+ 'lines': self.loss_lines,
343
+ 'lmap': self.loss_lmap,
344
+ 'local': self.loss_local,
345
+ }
346
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
347
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
348
+
349
+ def forward(self, outputs, targets):
350
+ """ This performs the loss computation.
351
+ Parameters:
352
+ outputs: dict of tensors, see the output specification of the model for the format
353
+ targets: list of dicts, such that len(targets) == batch_size.
354
+ The expected keys in each dict depends on the losses applied, see each loss' doc
355
+
356
+ return_indices: used for vis. if True, the layer0-5 indices will be returned as well.
357
+
358
+ """
359
+ outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
360
+ device = next(iter(outputs.values())).device
361
+ indices = self.matcher(outputs_without_aux, targets)
362
+
363
+ self._clear_cache()
364
+
365
+ # Get the matching union set across all decoder layers.
366
+ if 'aux_outputs' in outputs:
367
+ indices_aux_list, cached_indices, cached_indices_enc = [], [], []
368
+ for i, aux_outputs in enumerate(outputs['aux_outputs'] + [outputs['aux_pre_outputs']]):
369
+ indices_aux = self.matcher(aux_outputs, targets)
370
+ cached_indices.append(indices_aux)
371
+ indices_aux_list.append(indices_aux)
372
+ for i, aux_outputs in enumerate([outputs['aux_interm_outputs']]):
373
+ indices_enc = self.matcher(aux_outputs, targets)
374
+ cached_indices_enc.append(indices_enc)
375
+ indices_aux_list.append(indices_enc)
376
+ indices_go = self._get_go_indices(indices, indices_aux_list)
377
+
378
+ num_boxes_go = sum(len(x[0]) for x in indices_go)
379
+ num_boxes_go = torch.as_tensor([num_boxes_go], dtype=torch.float, device=next(iter(outputs.values())).device)
380
+ if is_dist_avail_and_initialized():
381
+ torch.distributed.all_reduce(num_boxes_go)
382
+ num_boxes_go = torch.clamp(num_boxes_go / get_world_size(), min=1).item()
383
+ else:
384
+ # assert 'aux_outputs' in outputs, ''
385
+ indices_go = indices
386
+
387
+ num_boxes_go = sum(len(x[0]) for x in indices_go)
388
+ num_boxes_go = torch.as_tensor([num_boxes_go], dtype=torch.float, device=next(iter(outputs.values())).device)
389
+ if is_dist_avail_and_initialized():
390
+ torch.distributed.all_reduce(num_boxes_go)
391
+ num_boxes_go = torch.clamp(num_boxes_go / get_world_size(), min=1).item()
392
+
393
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
394
+ num_boxes = sum(len(t["labels"]) for t in targets)
395
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device)
396
+ if is_dist_avail_and_initialized():
397
+ torch.distributed.all_reduce(num_boxes)
398
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
399
+
400
+ # Compute all the requested losses
401
+ losses = {}
402
+
403
+ for loss in self.losses:
404
+ indices_in = indices_go if loss in ['lines', 'local'] else indices
405
+ num_boxes_in = num_boxes_go if loss in ['lines', 'local'] else num_boxes
406
+ l_dict = self.get_loss(loss, outputs, targets, indices_in, num_boxes_in)
407
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
408
+ losses.update(l_dict)
409
+
410
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
411
+ if 'aux_outputs' in outputs:
412
+ for idx, aux_outputs in enumerate(outputs['aux_outputs']):
413
+ aux_outputs['up'], aux_outputs['reg_scale'] = outputs['up'], outputs['reg_scale']
414
+ # indices = self.matcher(aux_outputs, targets)
415
+ for loss in self.losses:
416
+ indices_in = indices_go if loss in ['lines', 'local'] else cached_indices[idx]
417
+ num_boxes_in = num_boxes_go if loss in ['lines', 'local'] else num_boxes
418
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices_in, num_boxes_in)
419
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
420
+ l_dict = {k + f'_{idx}': v for k, v in l_dict.items()}
421
+ losses.update(l_dict)
422
+
423
+ # interm_outputs loss
424
+ if 'aux_interm_outputs' in outputs:
425
+ interm_outputs = outputs['aux_interm_outputs']
426
+ # indices = self.matcher(interm_outputs, targets)
427
+ for loss in self.losses:
428
+ indices_in = indices_go if loss in ['lines', 'local'] else cached_indices_enc[0]
429
+ num_boxes_in = num_boxes_go if loss in ['lines', 'local'] else num_boxes
430
+ l_dict = self.get_loss(loss, interm_outputs, targets, indices_in, num_boxes_in)
431
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
432
+ l_dict = {k + f'_interm': v for k, v in l_dict.items()}
433
+ losses.update(l_dict)
434
+
435
+ # pre output loss
436
+ if 'aux_pre_outputs' in outputs:
437
+ pre_outputs = outputs['aux_pre_outputs']
438
+ # indices = self.matcher(pre_outputs, targets)
439
+ for loss in self.losses:
440
+ indices_in = indices_go if loss in ['lines', 'local'] else cached_indices[-1]
441
+ num_boxes_in = num_boxes_go if loss in ['lines', 'local'] else num_boxes
442
+ l_dict = self.get_loss(loss, pre_outputs, targets, indices_in, num_boxes_in)
443
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
444
+ l_dict = {k + f'_pre': v for k, v in l_dict.items()}
445
+ losses.update(l_dict)
446
+
447
+
448
+ # prepare for dn loss
449
+ dn_meta = outputs['dn_meta']
450
+
451
+ if self.training and dn_meta and 'aux_denoise' in outputs:
452
+ single_pad, scalar = self.prep_for_dn(dn_meta)
453
+ dn_pos_idx = []
454
+ dn_neg_idx = []
455
+ for i in range(len(targets)):
456
+ if len(targets[i]['labels']) > 0:
457
+ t = torch.arange(len(targets[i]['labels'])).long().cuda()
458
+ t = t.unsqueeze(0).repeat(scalar, 1)
459
+ tgt_idx = t.flatten()
460
+ output_idx = (torch.tensor(range(scalar)) * single_pad).long().cuda().unsqueeze(1) + t
461
+ output_idx = output_idx.flatten()
462
+ else:
463
+ output_idx = tgt_idx = torch.tensor([]).long().cuda()
464
+
465
+ dn_pos_idx.append((output_idx, tgt_idx))
466
+ dn_neg_idx.append((output_idx + single_pad // 2, tgt_idx))
467
+
468
+ dn_outputs = outputs['aux_denoise']
469
+
470
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
471
+ if 'aux_outputs' in dn_outputs:
472
+ for idx, aux_outputs in enumerate(dn_outputs['aux_outputs']):
473
+ aux_outputs['is_dn'] = True
474
+ aux_outputs['reg_scale'] = outputs['reg_scale']
475
+ aux_outputs['up'] = outputs['up']
476
+ # indices = self.matcher(aux_outputs, targets)
477
+ for loss in self.losses:
478
+ l_dict = self.get_loss(loss, aux_outputs, targets, dn_pos_idx, num_boxes*scalar)
479
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
480
+ l_dict = {k + f'_dn_{idx}': v for k, v in l_dict.items()}
481
+ losses.update(l_dict)
482
+
483
+ if 'aux_pre_outputs' in dn_outputs:
484
+ aux_outputs_known = dn_outputs['aux_pre_outputs']
485
+ l_dict={}
486
+ for loss in self.losses:
487
+ l_dict.update(self.get_loss(loss, aux_outputs_known, targets, dn_pos_idx, num_boxes*scalar))
488
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
489
+ l_dict = {k + f'_pre_dn': v for k, v in l_dict.items()}
490
+ losses.update(l_dict)
491
+
492
+ if 'aux_lmap' in outputs:
493
+ l_dict = self.get_loss('lmap', outputs, targets, indices, num_boxes, **kwargs)
494
+ l_dict = {k: v for k, v in l_dict.items()}
495
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
496
+ losses.update(l_dict)
497
+
498
+ losses = {k: v for k, v in sorted(losses.items(), key=lambda item: item[0])}
499
+
500
+ return losses
501
+
502
+ @MODULE_BUILD_FUNCS.registe_with_name(module_name='LINEACRITERION')
503
+ def build_criterion(args):
504
+ num_classes = args.num_classes
505
+
506
+ matcher = build_matcher(args)
507
+
508
+ if args.criterion_type == 'default':
509
+ criterion = LINEACriterion(num_classes, matcher=matcher, weight_dict=args.weight_dict,
510
+ focal_alpha=args.focal_alpha, losses=args.losses)
511
+ elif args.criterion_type == 'dfine':
512
+ criterion = DFINESetCriterion(num_classes, matcher=matcher, weight_dict=args.weight_dict,
513
+ focal_alpha=args.focal_alpha, reg_max=args.reg_max, losses=args.losses)
514
+ else:
515
+ raise Exception(f"Criterion type: {args.criterion_type}.We only support two classes: 'default' and 'dfine'. ")
516
+
517
+ return criterion
linea/models/linea/decoder.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from D-FINE (https://github.com/Peterande/D-FINE)
3
+ Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
4
+ ---------------------------------------------------------------------------------
5
+ Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
6
+ Copyright (c) 2023 lyuwenyu. All Rights Reserved.
7
+ """
8
+
9
+ import copy
10
+ import math
11
+ from typing import Optional
12
+ from collections import OrderedDict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+ import torch.nn.functional as F
17
+ import torch.nn.init as init
18
+
19
+ from .utils import gen_encoder_output_proposals, MLP, _get_activation_fn, gen_sineembed_for_position
20
+ from .attention_mechanism import MSDeformAttn
21
+ from .attention_mechanism import MSDeformLineAttn
22
+
23
+ from .dn_components import prepare_for_cdn, dn_post_process
24
+ from .linea_utils import weighting_function, distance2bbox, inverse_sigmoid
25
+
26
+
27
+ def _get_clones(module, N, layer_share=False):
28
+ if layer_share:
29
+ return nn.ModuleList([module for i in range(N)])
30
+ else:
31
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
32
+
33
+
34
+ class DeformableTransformerDecoderLayer(nn.Module):
35
+ def __init__(self, d_model=256, d_ffn=1024,
36
+ dropout=0.1, activation="relu",
37
+ n_levels=4, n_heads=8, n_points=4,
38
+ ):
39
+ super().__init__()
40
+ # cross attention
41
+ self.cross_attn = MSDeformLineAttn(d_model, n_levels, n_heads, n_points)
42
+ self.dropout1 = nn.Dropout(dropout)
43
+ self.norm1 = nn.LayerNorm(d_model)
44
+
45
+ # self attention
46
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
47
+ self.dropout2 = nn.Dropout(dropout)
48
+ self.norm2 = nn.LayerNorm(d_model)
49
+
50
+ # ffn
51
+ self.linear1 = nn.Linear(d_model, d_ffn)
52
+ self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
53
+ self.dropout3 = nn.Dropout(dropout)
54
+ self.linear2 = nn.Linear(d_ffn, d_model)
55
+ self.dropout4 = nn.Dropout(dropout)
56
+ self.norm3 = nn.LayerNorm(d_model)
57
+
58
+ def rm_self_attn_modules(self):
59
+ self.self_attn = None
60
+ self.dropout2 = None
61
+ self.norm2 = None
62
+
63
+ @staticmethod
64
+ def with_pos_embed(tensor, pos):
65
+ return tensor if pos is None else tensor + pos
66
+
67
+ def forward(self,
68
+ # for tgt
69
+ tgt: Optional[Tensor], # nq, bs, d_model
70
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
71
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
72
+ tgt_key_padding_mask: Optional[Tensor] = None,
73
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
74
+
75
+ # for memory
76
+ memory: Optional[Tensor] = None, # hw, bs, d_modelmemory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
77
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
78
+ memory_pos: Optional[Tensor] = None, # pos for memory
79
+
80
+ # sa
81
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
82
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
83
+ ):
84
+ # self attention
85
+ q = k = self.with_pos_embed(tgt, tgt_query_pos)
86
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
87
+ tgt = tgt + self.dropout2(tgt2)
88
+ tgt = self.norm2(tgt)
89
+
90
+ # cross attention
91
+ tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
92
+ tgt_reference_points.transpose(0, 1).contiguous(),
93
+ memory, #.transpose(0, 1),
94
+ memory_spatial_shapes,
95
+ ).transpose(0, 1)
96
+ tgt = tgt + self.dropout1(tgt2)
97
+ tgt = self.norm1(tgt)
98
+
99
+ # feed forward network
100
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
101
+ tgt = tgt + self.dropout4(tgt2)
102
+ tgt = self.norm3(tgt)
103
+
104
+ return tgt
105
+
106
+
107
+ class Integral(nn.Module):
108
+ """
109
+ A static layer that calculates integral results from a distribution.
110
+
111
+ This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
112
+ where Pr(n) is the softmax probability vector representing the discrete
113
+ distribution, and W(n) is the non-uniform Weighting Function.
114
+
115
+ Args:
116
+ reg_max (int): Max number of the discrete bins. Default is 32.
117
+ It can be adjusted based on the dataset or task requirements.
118
+ """
119
+
120
+ def __init__(self, reg_max=32):
121
+ super(Integral, self).__init__()
122
+ self.reg_max = reg_max
123
+
124
+ def forward(self, x, project):
125
+ shape = x.shape
126
+ x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1)
127
+ x = F.linear(x, project.to(x.device)).reshape(-1, 4)
128
+ return x.reshape(list(shape[:-1]) + [-1])
129
+
130
+
131
+ class LQE(nn.Module):
132
+ def __init__(self, k, hidden_dim, num_layers, reg_max):
133
+ super(LQE, self).__init__()
134
+ self.k = k
135
+ self.reg_max = reg_max
136
+ self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers)
137
+ init.constant_(self.reg_conf.layers[-1].bias, 0)
138
+ init.constant_(self.reg_conf.layers[-1].weight, 0)
139
+
140
+ def forward(self, scores, pred_corners):
141
+ B, L, _ = pred_corners.size()
142
+ prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max+1), dim=-1)
143
+ prob_topk, _ = prob.topk(self.k, dim=-1)
144
+ stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
145
+ quality_score = self.reg_conf(stat.reshape(B, L, -1))
146
+ return scores + quality_score
147
+
148
+
149
+ class TransformerDecoder(nn.Module):
150
+ def __init__(
151
+ self,
152
+ decoder_layer,
153
+ num_layers,
154
+ norm=None,
155
+ d_model=256,
156
+ query_dim=4,
157
+ num_feature_levels=1,
158
+ aux_loss=False,
159
+ eval_idx=5,
160
+ # from D-FINE
161
+ reg_max=32,
162
+ reg_scale=4,
163
+ ):
164
+ super().__init__()
165
+ if num_layers > 0:
166
+ self.layers = _get_clones(decoder_layer, num_layers)
167
+ else:
168
+ self.layers = []
169
+ self.num_layers = num_layers
170
+ # self.norm = norm
171
+ self.query_dim = query_dim
172
+ self.num_feature_levels = num_feature_levels
173
+
174
+ self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
175
+
176
+ self.reg_max = reg_max
177
+ self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False)
178
+ self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False)
179
+ self.d_model = d_model
180
+
181
+ # prediction layers
182
+ _class_embed = nn.Linear(d_model, 2)
183
+ _enc_bbox_embed = MLP(d_model, d_model, 4, 3)
184
+ # init the two embed layers
185
+ prior_prob = 0.01
186
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
187
+ _class_embed.bias.data = torch.ones(2) * bias_value
188
+ nn.init.constant_(_enc_bbox_embed.layers[-1].weight.data, 0)
189
+ nn.init.constant_(_enc_bbox_embed.layers[-1].bias.data, 0)
190
+
191
+ _bbox_embed = MLP(d_model, d_model, 4 * (self.reg_max + 1), 3)
192
+
193
+ self.bbox_embed = nn.ModuleList([copy.deepcopy(_bbox_embed) for i in range(num_layers)])
194
+ self.class_embed = nn.ModuleList([copy.deepcopy(_class_embed) for i in range(num_layers)])
195
+ self.lqe_layers = nn.ModuleList([copy.deepcopy(LQE(4, 64, 2, reg_max)) for _ in range(num_layers)])
196
+ self.integral = Integral(self.reg_max)
197
+
198
+ # two stage
199
+ self.enc_out_bbox_embed = copy.deepcopy(_enc_bbox_embed)
200
+ # self.enc_out_class_embed = copy.deepcopy(_class_embed)
201
+ self.aux_loss = aux_loss
202
+
203
+ # inference
204
+ self.eval_idx = eval_idx
205
+
206
+ def forward(self,
207
+ tgt,
208
+ memory,
209
+ tgt_mask: Optional[Tensor] = None,
210
+ memory_mask: Optional[Tensor] = None,
211
+ tgt_key_padding_mask: Optional[Tensor] = None,
212
+ memory_key_padding_mask: Optional[Tensor] = None,
213
+ pos: Optional[Tensor] = None,
214
+ refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
215
+ # for memory
216
+ spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
217
+ ):
218
+ """
219
+ Input:
220
+ - tgt: nq, bs, d_model
221
+ - memory: hw, bs, d_model
222
+ - pos: hw, bs, d_model
223
+ - refpoints_unsigmoid: nq, bs, 2/4
224
+ """
225
+ output = tgt
226
+ output_detach = pred_corners_undetach = 0
227
+
228
+ intermediate = []
229
+ ref_points_detach = refpoints_unsigmoid.sigmoid()
230
+
231
+ dec_out_bboxes = []
232
+ dec_out_logits = []
233
+ dec_out_corners = []
234
+ dec_out_refs = []
235
+
236
+ if not hasattr(self, 'project'):
237
+ project = weighting_function(self.reg_max, self.up, self.reg_scale)
238
+ else:
239
+ project = self.project
240
+
241
+ for layer_id, layer in enumerate(self.layers):
242
+ ref_points_input = ref_points_detach[:, :, None] # nq, bs, nlevel, 4
243
+
244
+ query_sine_embed = gen_sineembed_for_position(ref_points_input[:, :, 0, :], self.d_model) # nq, bs, 256*2
245
+
246
+ query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
247
+
248
+ output = layer(
249
+ tgt = output,
250
+ tgt_query_pos = query_pos,
251
+ tgt_query_sine_embed = query_sine_embed,
252
+ tgt_key_padding_mask = tgt_key_padding_mask,
253
+ tgt_reference_points = ref_points_input,
254
+
255
+ memory = memory,
256
+ memory_spatial_shapes = spatial_shapes,
257
+ memory_pos = pos,
258
+
259
+ self_attn_mask = tgt_mask,
260
+ cross_attn_mask = memory_mask
261
+ )
262
+
263
+ if layer_id == 0:
264
+ pre_bboxes = torch.sigmoid(self.enc_out_bbox_embed(output) + inverse_sigmoid(ref_points_detach))
265
+ pre_scores = self.class_embed[0](output)
266
+ ref_points_initial = pre_bboxes.detach()
267
+
268
+ pred_corners = self.bbox_embed[layer_id](output + output_detach) + pred_corners_undetach
269
+ inter_ref_bbox = distance2bbox(ref_points_initial, self.integral(pred_corners, project), self.reg_scale)
270
+
271
+ if self.training or layer_id == self.eval_idx:
272
+ scores = self.class_embed[layer_id](output)
273
+ scores = self.lqe_layers[layer_id](scores, pred_corners)
274
+ dec_out_logits.append(scores)
275
+ dec_out_bboxes.append(inter_ref_bbox)
276
+ dec_out_corners.append(pred_corners)
277
+ dec_out_refs.append(ref_points_initial)
278
+
279
+ pred_corners_undetach = pred_corners
280
+ if self.training:
281
+ ref_points_detach = inter_ref_bbox.detach()
282
+ output_detach = output.detach()
283
+ else:
284
+ ref_points_detach = inter_ref_bbox
285
+ output_detach = output
286
+
287
+ return torch.stack(dec_out_bboxes).permute(0, 2, 1, 3), torch.stack(dec_out_logits).permute(0, 2, 1, 3), \
288
+ pre_bboxes, pre_scores
289
+
290
+
291
+ class LINEATransformer(nn.Module):
292
+ def __init__(
293
+ self,
294
+ feat_channels=[256, 256, 256],
295
+ feat_strides=[8, 16, 32],
296
+ d_model=256,
297
+ num_classes=2,
298
+ nhead=8,
299
+ num_queries=300,
300
+ num_decoder_layers=6,
301
+ dim_feedforward=2048,
302
+ dropout=0.0,
303
+ activation="relu",
304
+ normalize_before=False,
305
+ query_dim=4,
306
+ aux_loss=False,
307
+ # for deformable encoder
308
+ num_feature_levels=1,
309
+ dec_n_points=4,
310
+ # from D-FINE
311
+ reg_max=32,
312
+ reg_scale=4,
313
+ # denoising
314
+ dn_number=100,
315
+ dn_label_noise_ratio=0.5,
316
+ dn_line_noise_scale=0.5,
317
+ # for inference
318
+ eval_spatial_size=None,
319
+ eval_idx=5
320
+ ):
321
+ super().__init__()
322
+
323
+ # init learnable queries
324
+ self.tgt_embed = nn.Embedding(num_queries, d_model)
325
+ nn.init.normal_(self.tgt_embed.weight.data)
326
+
327
+ # line segment detection parameters
328
+ self.num_classes = num_classes
329
+ self.num_queries = num_queries
330
+
331
+ # anchor selection at the output of encoder
332
+ self.enc_output = nn.Linear(d_model, d_model)
333
+ self.enc_output_norm = nn.LayerNorm(d_model)
334
+ self._reset_parameters()
335
+
336
+ # prediction layers
337
+ _class_embed = nn.Linear(d_model, num_classes)
338
+ _bbox_embed = MLP(d_model, d_model, 4, 3)
339
+
340
+ # init the two embed layers
341
+ prior_prob = 0.01
342
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
343
+ _class_embed.bias.data = torch.ones(2) * bias_value
344
+ nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
345
+ nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
346
+ self.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
347
+ self.enc_out_class_embed = copy.deepcopy(_class_embed)
348
+
349
+ # decoder parameters
350
+ self.d_model = d_model
351
+ self.n_heads = nhead
352
+ decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
353
+ dropout, activation,
354
+ num_feature_levels, nhead, dec_n_points)
355
+ decoder_norm = nn.LayerNorm(d_model)
356
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
357
+ d_model=d_model, query_dim=query_dim,
358
+ num_feature_levels=num_feature_levels,
359
+ eval_idx=eval_idx, aux_loss=aux_loss,
360
+ reg_max=reg_max, reg_scale=reg_scale)
361
+
362
+ # for inference mode
363
+ self.eval_spatial_size = eval_spatial_size
364
+ if eval_spatial_size is not None:
365
+ spatial_shapes = [[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)]
366
+ for s in feat_strides
367
+ ]
368
+ output_proposals, output_proposals_valid = self.generate_anchors(spatial_shapes)
369
+ self.register_buffer('output_proposals', output_proposals)
370
+ self.register_buffer('output_proposals_mask', ~output_proposals_valid)
371
+
372
+ # denoising parameters
373
+ self.dn_number = dn_number
374
+ self.dn_label_noise_ratio = dn_label_noise_ratio
375
+ self.dn_line_noise_scale = dn_line_noise_scale
376
+ self.label_enc = nn.Embedding(90 + 1, d_model)
377
+
378
+ def _reset_parameters(self):
379
+ for p in self.parameters():
380
+ if p.dim() > 1:
381
+ nn.init.xavier_uniform_(p)
382
+ for m in self.modules():
383
+ if isinstance(m, MSDeformAttn): # or isinstance(m, MSDeformLineAttn):
384
+ m._reset_parameters()
385
+
386
+ def generate_anchors(self, spatial_shapes):
387
+ proposals = []
388
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
389
+
390
+ grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32),
391
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32), indexing='ij')
392
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
393
+
394
+ scale = torch.tensor([W_, H_], dtype=torch.float32,).view(1, 1, 1, 2)
395
+ grid = (grid.unsqueeze(0) + 0.5) / scale
396
+
397
+ proposal = torch.cat((grid, grid), -1).view(1, -1, 4)
398
+ proposals.append(proposal)
399
+ output_proposals = torch.cat(proposals, 1)
400
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
401
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
402
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
403
+
404
+ return output_proposals, output_proposals_valid
405
+
406
+
407
+ def forward(self, feats, targets):
408
+ # flatten feature maps
409
+ memory = []
410
+ spatial_shapes = []
411
+ split_sizes = []
412
+ for feat in feats:
413
+ bs, c, h, w = feat.shape
414
+ memory.append(feat.flatten(2).permute(0, 2, 1))
415
+ spatial_shape = (h, w)
416
+ spatial_shapes.append(spatial_shape)
417
+ split_sizes.append(h*w)
418
+
419
+ # spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feats[0].device)
420
+ memory = torch.cat(memory, 1) # bs, \sum{hxw}, c
421
+
422
+ # two-stage
423
+ if self.training:
424
+ output_memory, output_proposals = gen_encoder_output_proposals(memory, spatial_shapes)
425
+ else:
426
+ output_proposals = self.output_proposals.repeat(bs, 1, 1)
427
+ output_memory = memory.masked_fill(self.output_proposals_mask, float(0))
428
+
429
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
430
+
431
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
432
+ enc_outputs_coord_unselected = self.enc_out_bbox_embed(output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid
433
+ topk = self.num_queries
434
+ topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] # bs, nq
435
+
436
+ # gather boxes
437
+ refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) ) # unsigmoid
438
+ refpoint_embed = refpoint_embed_undetach.detach()
439
+ init_box_proposal = torch.gather(output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)).sigmoid() # sigmoid
440
+
441
+ # gather tgt
442
+ tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
443
+ tgt = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
444
+
445
+ # denoise (only for training)
446
+ if self.training and targets is not None:
447
+ dn_tgt, dn_refpoint_embed, dn_attn_mask, dn_meta =\
448
+ prepare_for_cdn(dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_line_noise_scale),
449
+ training=self.training,num_queries=self.num_queries, num_classes=self.num_classes,
450
+ hidden_dim=self.d_model, label_enc=self.label_enc)
451
+ tgt = torch.cat([dn_tgt, tgt], dim=1)
452
+ refpoint_embed = torch.cat([dn_refpoint_embed, refpoint_embed], dim=1)
453
+ else:
454
+ dn_attn_mask = dn_meta = None
455
+
456
+ # preprocess memory for MSDeformableLineAttention
457
+ value = memory.unflatten(2, (self.n_heads, -1)) # (bs, \sum{hxw}, n_heads, d_model//n_heads)
458
+ value = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_sizes, dim=-1)
459
+ out_coords, out_class , pre_coords, pre_class = self.decoder(
460
+ tgt=tgt.transpose(0, 1),
461
+ memory=value, #memory.transpose(0, 1),
462
+ pos=None,
463
+ refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
464
+ spatial_shapes=spatial_shapes,
465
+ tgt_mask=dn_attn_mask)
466
+
467
+ # output
468
+ if self.training:
469
+ pre_coords = pre_coords.permute(1, 0, 2)
470
+ pre_class = pre_class.permute(1, 0, 2)
471
+ if dn_meta is not None:
472
+ dn_out_coords, out_coords = torch.split(out_coords, [dn_meta['pad_size'], self.num_queries], dim=2)
473
+ dn_out_class, out_class = torch.split(out_class, [dn_meta['pad_size'], self.num_queries], dim=2)
474
+
475
+ dn_pre_coords, pre_coords = torch.split(pre_coords, [dn_meta['pad_size'], self.num_queries], dim=1)
476
+ dn_pre_class, pre_class = torch.split(pre_class, [dn_meta['pad_size'], self.num_queries], dim=1)
477
+
478
+ out = {'pred_logits': out_class[-1], 'pred_lines': out_coords[-1]}
479
+
480
+ out['aux_pre_outputs'] = {'pred_logits': pre_class, 'pred_lines': pre_coords}
481
+
482
+ if self.decoder.aux_loss:
483
+ out['aux_outputs'] = self._set_aux_loss(out_class[:-1], out_coords[:-1])
484
+
485
+ # for encoder output
486
+ out_coords_enc = refpoint_embed_undetach.sigmoid()
487
+ out_class_enc = self.enc_out_class_embed(tgt_undetach)
488
+ out['aux_interm_outputs'] = {'pred_logits': out_class_enc, 'pred_lines': out_coords_enc}
489
+
490
+ if dn_meta is not None:
491
+ dn_out = {}
492
+ dn_out['aux_outputs'] = self._set_aux_loss(dn_out_class, dn_out_coords)
493
+ dn_out['aux_pre_outputs'] = {'pred_logits': dn_pre_class, 'pred_lines': dn_pre_coords}
494
+ out['aux_denoise'] = dn_out
495
+ else:
496
+ out = {'pred_logits': out_class[0], 'pred_lines': out_coords[0]}
497
+
498
+ out['dn_meta'] = dn_meta
499
+
500
+ return out
501
+
502
+ @torch.jit.unused
503
+ def _set_aux_loss(self, outputs_class, outputs_coord):
504
+ # this is a workaround to make torchscript happy, as torchscript
505
+ # doesn't support dictionary with non-homogeneous values, such
506
+ # as a dict having both a Tensor and a list.
507
+ return [{'pred_logits': a, 'pred_lines': b}
508
+ for a, b in zip(outputs_class, outputs_coord)]
509
+
510
+ @torch.jit.unused
511
+ def _set_aux_loss2(self, outputs_class, outputs_coord, outputs_corners, outputs_ref,
512
+ teacher_corners=None, teacher_class=None):
513
+ # this is a workaround to make torchscript happy, as torchscript
514
+ # doesn't support dictionary with non-homogeneous values, such
515
+ # as a dict having both a Tensor and a list.
516
+ return [{'pred_logits': a, 'pred_lines': b, 'pred_corners': c, 'ref_points': d,
517
+ 'teacher_corners': teacher_corners, 'teacher_logits': teacher_class}
518
+ for a, b, c, d in zip(outputs_class, outputs_coord, outputs_corners, outputs_ref)]
519
+
520
+
521
+ def build_decoder(args):
522
+ return LINEATransformer(
523
+ feat_channels = args.feat_channels_decoder,
524
+ feat_strides=args.feat_strides,
525
+ num_classes=args.num_classes,
526
+ d_model=args.hidden_dim,
527
+ nhead=args.nheads,
528
+ num_queries=args.num_queries,
529
+ num_decoder_layers=args.dec_layers,
530
+ dim_feedforward=args.dim_feedforward,
531
+ dropout=args.dropout,
532
+ activation=args.transformer_activation,
533
+ normalize_before=args.pre_norm,
534
+ query_dim=args.query_dim,
535
+ aux_loss=True,
536
+ # for deformable encoder
537
+ num_feature_levels=args.num_feature_levels,
538
+ dec_n_points=args.dec_n_points,
539
+ # for D-FINE layers
540
+ reg_max=args.reg_max,
541
+ reg_scale=args.reg_scale,
542
+ # for inference
543
+ eval_spatial_size=args.eval_spatial_size,
544
+ eval_idx=args.eval_idx,
545
+ # for denoising
546
+ dn_number=args.dn_number,
547
+ dn_label_noise_ratio=args.dn_label_noise_ratio,
548
+ dn_line_noise_scale=args.dn_line_noise_scale,
549
+ )
550
+
551
+
linea/models/linea/dn_components.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # DINO
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # DN-DETR
7
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
8
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
9
+
10
+
11
+ import torch
12
+ from .linea_utils import inverse_sigmoid
13
+ import torch.nn.functional as F
14
+
15
+ def prepare_for_cdn(dn_args, training, num_queries, num_classes, hidden_dim, label_enc):
16
+ """
17
+ A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
18
+ forward function and use learnable tgt embedding, so we change this function a little bit.
19
+ :param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
20
+ :param training: if it is training or inference
21
+ :param num_queries: number of queires
22
+ :param num_classes: number of classes
23
+ :param hidden_dim: transformer hidden dim
24
+ :param label_enc: encode labels in dn
25
+ :return:
26
+ """
27
+ if training:
28
+ targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
29
+ # positive and negative dn queries
30
+ dn_number = dn_number * 2
31
+ known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
32
+ batch_size = len(known)
33
+ known_num = [sum(k) for k in known]
34
+
35
+ if int(max(known_num)) == 0:
36
+ dn_number = 1
37
+ else:
38
+ if dn_number >= 100:
39
+ dn_number = dn_number // (int(max(known_num) * 2))
40
+ elif dn_number < 1:
41
+ dn_number = 1
42
+ if dn_number == 0:
43
+ dn_number = 1
44
+
45
+ unmask_bbox = unmask_label = torch.cat(known)
46
+ labels = torch.cat([t['labels'] for t in targets])
47
+ lines = torch.cat([t['lines'] for t in targets])
48
+ batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
49
+
50
+ known_indice = torch.nonzero(unmask_label + unmask_bbox)
51
+ known_indice = known_indice.view(-1)
52
+
53
+ known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
54
+ known_labels = labels.repeat(2 * dn_number, 1).view(-1)
55
+ known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
56
+ known_lines = lines.repeat(2 * dn_number, 1)
57
+
58
+ known_labels_expaned = known_labels.clone()
59
+ known_lines_expand = known_lines.clone()
60
+
61
+ if label_noise_ratio > 0:
62
+ p = torch.rand_like(known_labels_expaned.float())
63
+ chosen_indice = torch.nonzero(p < (label_noise_ratio * 0.5)).view(-1) # half of bbox prob
64
+ new_label = torch.randint_like(chosen_indice, 0, num_classes) # randomly put a new one here
65
+ known_labels_expaned.scatter_(0, chosen_indice, new_label)
66
+
67
+ single_pad = int(max(known_num))
68
+
69
+ pad_size = int(single_pad * 2 * dn_number)
70
+ positive_idx = torch.tensor(range(len(lines))).long().cuda().unsqueeze(0).repeat(dn_number, 1)
71
+ positive_idx += (torch.tensor(range(dn_number)) * len(lines) * 2).long().cuda().unsqueeze(1)
72
+ positive_idx = positive_idx.flatten()
73
+ negative_idx = positive_idx + len(lines)
74
+
75
+
76
+ known_lines_ = known_lines.clone()
77
+ known_lines_[:, :2] = (known_lines[:, :2] - known_lines[:, 2:]) / 2
78
+ known_lines_[:, 2:] = (known_lines[:, :2] + known_lines[:, 2:]) / 2
79
+
80
+ centers = torch.zeros_like(known_lines)
81
+ centers[:, :2] = (known_lines_[:, :2] + known_lines_[:, 2:]) / 2
82
+ centers[:, 2:] = (known_lines_[:, :2] + known_lines_[:, 2:]) / 2
83
+
84
+ # Noisy length
85
+ diff = torch.zeros_like(known_lines)
86
+ diff[:, :2] = (known_lines[:, 2:] - known_lines[:, :2]) / 2
87
+ diff[:, 2:] = (known_lines[:, 2:] - known_lines[:, :2]) / 2
88
+
89
+ rand_sign = torch.randint(low=0, high=2, size=(known_lines.shape[0], 2), dtype=torch.float32, device=known_lines.device) * 2.0 - 1.0
90
+ rand_part = torch.rand(size=(known_lines.shape[0], 2), device=known_lines.device)
91
+ rand_part[negative_idx] += 1.2
92
+ rand_part *= rand_sign
93
+
94
+ known_lines_ = centers + torch.mul(rand_part.repeat_interleave(2, 1),
95
+ diff).cuda() * box_noise_scale
96
+
97
+ known_lines_expand = known_lines_.clamp(min=0.0, max=1.0)
98
+
99
+ # order: top point > bottom point
100
+ # if same y coordinate, right point > left point
101
+
102
+ idx = torch.logical_or(known_lines_expand[..., 0] > known_lines_expand[..., 2],
103
+ torch.logical_or(
104
+ known_lines_expand[..., 0] == known_lines_expand[..., 2],
105
+ known_lines_expand[..., 1] < known_lines_expand[..., 3]
106
+ )
107
+ )
108
+
109
+ known_lines_expand[idx] = known_lines_expand[idx][:, [2, 3, 0, 1]]
110
+
111
+ m = known_labels_expaned.long().to('cuda')
112
+ input_label_embed = label_enc(m)
113
+ input_lines_embed = inverse_sigmoid(known_lines_expand)
114
+
115
+ padding_label = torch.zeros(pad_size, hidden_dim).cuda()
116
+ padding_lines = torch.zeros(pad_size, 4).cuda()
117
+
118
+ input_query_label = padding_label.repeat(batch_size, 1, 1)
119
+ input_query_lines = padding_lines.repeat(batch_size, 1, 1)
120
+
121
+ map_known_indice = torch.tensor([]).to('cuda')
122
+ if len(known_num):
123
+ map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
124
+ map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
125
+
126
+ if len(known_bid):
127
+ input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed
128
+ input_query_lines[(known_bid.long(), map_known_indice)] = input_lines_embed
129
+
130
+ tgt_size = pad_size + num_queries
131
+ attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
132
+ # match query cannot see the reconstruct
133
+ attn_mask[pad_size:, :pad_size] = True
134
+ # reconstruct cannot see each other
135
+ for i in range(dn_number):
136
+ if i == 0:
137
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
138
+ if i == dn_number - 1:
139
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True
140
+ else:
141
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
142
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True
143
+
144
+ dn_meta = {
145
+ 'pad_size': pad_size,
146
+ 'num_dn_group': dn_number,
147
+ }
148
+ else:
149
+
150
+ input_query_label = None
151
+ input_query_lines = None
152
+ attn_mask = None
153
+ dn_meta = None
154
+
155
+ return input_query_label, input_query_lines, attn_mask, dn_meta
156
+
157
+
158
+ def dn_post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
159
+ """
160
+ post process of dn after output from the transformer
161
+ put the dn part in the dn_meta
162
+ """
163
+ if dn_meta and dn_meta['pad_size'] > 0:
164
+ output_known_class = outputs_class[:, :, :dn_meta['pad_size'], :]
165
+ output_known_coord = outputs_coord[:, :, :dn_meta['pad_size'], :]
166
+ outputs_class = outputs_class[:, :, dn_meta['pad_size']:, :]
167
+ outputs_coord = outputs_coord[:, :, dn_meta['pad_size']:, :]
168
+ # print(output_known_class.shape, outputs_class.shape)
169
+ # quit()
170
+ out = {'pred_logits': output_known_class[-1], 'pred_lines': output_known_coord[-1]}
171
+ if aux_loss:
172
+ out['aux_outputs'] = _set_aux_loss(output_known_class[1:], output_known_coord[1:])
173
+ out['pre_outputs'] = {'pred_logits':output_known_class[0], 'pred_lines': output_known_coord[0]}
174
+ dn_meta['output_known_lbs_lines'] = out
175
+ return outputs_class, outputs_coord
176
+
177
+
178
+
linea/models/linea/hgnetv2.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ import logging
6
+
7
+ # Constants for initialization
8
+ kaiming_normal_ = nn.init.kaiming_normal_
9
+ zeros_ = nn.init.zeros_
10
+ ones_ = nn.init.ones_
11
+
12
+
13
+ class FrozenBatchNorm2d(nn.Module):
14
+ """copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py
15
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
16
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
17
+ without which any other models than torchvision.models.resnet[18,34,50,101]
18
+ produce nans.
19
+ """
20
+ def __init__(self, num_features, eps=1e-5):
21
+ super(FrozenBatchNorm2d, self).__init__()
22
+ n = num_features
23
+ self.register_buffer("weight", torch.ones(n))
24
+ self.register_buffer("bias", torch.zeros(n))
25
+ self.register_buffer("running_mean", torch.zeros(n))
26
+ self.register_buffer("running_var", torch.ones(n))
27
+ self.eps = eps
28
+ self.num_features = n
29
+
30
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
31
+ missing_keys, unexpected_keys, error_msgs):
32
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
33
+ if num_batches_tracked_key in state_dict:
34
+ del state_dict[num_batches_tracked_key]
35
+
36
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
37
+ state_dict, prefix, local_metadata, strict,
38
+ missing_keys, unexpected_keys, error_msgs)
39
+
40
+ def forward(self, x):
41
+ # move reshapes to the beginning
42
+ # to make it fuser-friendly
43
+ w = self.weight.reshape(1, -1, 1, 1)
44
+ b = self.bias.reshape(1, -1, 1, 1)
45
+ rv = self.running_var.reshape(1, -1, 1, 1)
46
+ rm = self.running_mean.reshape(1, -1, 1, 1)
47
+ scale = w * (rv + self.eps).rsqrt()
48
+ bias = b - rm * scale
49
+ return x * scale + bias
50
+
51
+ def extra_repr(self):
52
+ return (
53
+ "{num_features}, eps={eps}".format(**self.__dict__)
54
+ )
55
+
56
+
57
+ class LearnableAffineBlock(nn.Module):
58
+ def __init__(
59
+ self,
60
+ scale_value=1.0,
61
+ bias_value=0.0
62
+ ):
63
+ super().__init__()
64
+ self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True)
65
+ self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True)
66
+
67
+ def forward(self, x):
68
+ return self.scale * x + self.bias
69
+
70
+
71
+ class ConvBNAct(nn.Module):
72
+ def __init__(
73
+ self,
74
+ in_chs,
75
+ out_chs,
76
+ kernel_size,
77
+ stride=1,
78
+ groups=1,
79
+ padding='',
80
+ use_act=True,
81
+ use_lab=False
82
+ ):
83
+ super().__init__()
84
+ self.use_act = use_act
85
+ self.use_lab = use_lab
86
+ if padding == 'same':
87
+ self.conv = nn.Sequential(
88
+ nn.ZeroPad2d([0, 1, 0, 1]),
89
+ nn.Conv2d(
90
+ in_chs,
91
+ out_chs,
92
+ kernel_size,
93
+ stride,
94
+ groups=groups,
95
+ bias=False
96
+ )
97
+ )
98
+ else:
99
+ self.conv = nn.Conv2d(
100
+ in_chs,
101
+ out_chs,
102
+ kernel_size,
103
+ stride,
104
+ padding=(kernel_size - 1) // 2,
105
+ groups=groups,
106
+ bias=False
107
+ )
108
+ self.bn = nn.BatchNorm2d(out_chs)
109
+ if self.use_act:
110
+ self.act = nn.ReLU()
111
+ else:
112
+ self.act = nn.Identity()
113
+ if self.use_act and self.use_lab:
114
+ self.lab = LearnableAffineBlock()
115
+ else:
116
+ self.lab = nn.Identity()
117
+
118
+ def forward(self, x):
119
+ x = self.conv(x)
120
+ x = self.bn(x)
121
+ x = self.act(x)
122
+ x = self.lab(x)
123
+ return x
124
+
125
+
126
+ class LightConvBNAct(nn.Module):
127
+ def __init__(
128
+ self,
129
+ in_chs,
130
+ out_chs,
131
+ kernel_size,
132
+ groups=1,
133
+ use_lab=False,
134
+ ):
135
+ super().__init__()
136
+ self.conv1 = ConvBNAct(
137
+ in_chs,
138
+ out_chs,
139
+ kernel_size=1,
140
+ use_act=False,
141
+ use_lab=use_lab,
142
+ )
143
+ self.conv2 = ConvBNAct(
144
+ out_chs,
145
+ out_chs,
146
+ kernel_size=kernel_size,
147
+ groups=out_chs,
148
+ use_act=True,
149
+ use_lab=use_lab,
150
+ )
151
+
152
+ def forward(self, x):
153
+ x = self.conv1(x)
154
+ x = self.conv2(x)
155
+ return x
156
+
157
+
158
+ class StemBlock(nn.Module):
159
+ # for HGNetv2
160
+ def __init__(self, in_chs, mid_chs, out_chs, use_lab=False):
161
+ super().__init__()
162
+ self.stem1 = ConvBNAct(
163
+ in_chs,
164
+ mid_chs,
165
+ kernel_size=3,
166
+ stride=2,
167
+ use_lab=use_lab,
168
+ )
169
+ self.stem2a = ConvBNAct(
170
+ mid_chs,
171
+ mid_chs // 2,
172
+ kernel_size=2,
173
+ stride=1,
174
+ use_lab=use_lab,
175
+ )
176
+ self.stem2b = ConvBNAct(
177
+ mid_chs // 2,
178
+ mid_chs,
179
+ kernel_size=2,
180
+ stride=1,
181
+ use_lab=use_lab,
182
+ )
183
+ self.stem3 = ConvBNAct(
184
+ mid_chs * 2,
185
+ mid_chs,
186
+ kernel_size=3,
187
+ stride=2,
188
+ use_lab=use_lab,
189
+ )
190
+ self.stem4 = ConvBNAct(
191
+ mid_chs,
192
+ out_chs,
193
+ kernel_size=1,
194
+ stride=1,
195
+ use_lab=use_lab,
196
+ )
197
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True)
198
+
199
+ def forward(self, x):
200
+ x = self.stem1(x)
201
+ x = F.pad(x, (0, 1, 0, 1))
202
+ x2 = self.stem2a(x)
203
+ x2 = F.pad(x2, (0, 1, 0, 1))
204
+ x2 = self.stem2b(x2)
205
+ x1 = self.pool(x)
206
+ x = torch.cat([x1, x2], dim=1)
207
+ x = self.stem3(x)
208
+ x = self.stem4(x)
209
+ return x
210
+
211
+
212
+ class EseModule(nn.Module):
213
+ def __init__(self, chs):
214
+ super().__init__()
215
+ self.conv = nn.Conv2d(
216
+ chs,
217
+ chs,
218
+ kernel_size=1,
219
+ stride=1,
220
+ padding=0,
221
+ )
222
+ self.sigmoid = nn.Sigmoid()
223
+
224
+ def forward(self, x):
225
+ identity = x
226
+ x = x.mean((2, 3), keepdim=True)
227
+ x = self.conv(x)
228
+ x = self.sigmoid(x)
229
+ return torch.mul(identity, x)
230
+
231
+
232
+ class HG_Block(nn.Module):
233
+ def __init__(
234
+ self,
235
+ in_chs,
236
+ mid_chs,
237
+ out_chs,
238
+ layer_num,
239
+ kernel_size=3,
240
+ residual=False,
241
+ light_block=False,
242
+ use_lab=False,
243
+ agg='ese',
244
+ drop_path=0.,
245
+ ):
246
+ super().__init__()
247
+ self.residual = residual
248
+
249
+ self.layers = nn.ModuleList()
250
+ for i in range(layer_num):
251
+ if light_block:
252
+ self.layers.append(
253
+ LightConvBNAct(
254
+ in_chs if i == 0 else mid_chs,
255
+ mid_chs,
256
+ kernel_size=kernel_size,
257
+ use_lab=use_lab,
258
+ )
259
+ )
260
+ else:
261
+ self.layers.append(
262
+ ConvBNAct(
263
+ in_chs if i == 0 else mid_chs,
264
+ mid_chs,
265
+ kernel_size=kernel_size,
266
+ stride=1,
267
+ use_lab=use_lab,
268
+ )
269
+ )
270
+
271
+ # feature aggregation
272
+ total_chs = in_chs + layer_num * mid_chs
273
+ if agg == 'se':
274
+ aggregation_squeeze_conv = ConvBNAct(
275
+ total_chs,
276
+ out_chs // 2,
277
+ kernel_size=1,
278
+ stride=1,
279
+ use_lab=use_lab,
280
+ )
281
+ aggregation_excitation_conv = ConvBNAct(
282
+ out_chs // 2,
283
+ out_chs,
284
+ kernel_size=1,
285
+ stride=1,
286
+ use_lab=use_lab,
287
+ )
288
+ self.aggregation = nn.Sequential(
289
+ aggregation_squeeze_conv,
290
+ aggregation_excitation_conv,
291
+ )
292
+ else:
293
+ aggregation_conv = ConvBNAct(
294
+ total_chs,
295
+ out_chs,
296
+ kernel_size=1,
297
+ stride=1,
298
+ use_lab=use_lab,
299
+ )
300
+ att = EseModule(out_chs)
301
+ self.aggregation = nn.Sequential(
302
+ aggregation_conv,
303
+ att,
304
+ )
305
+
306
+ self.drop_path = nn.Dropout(drop_path) if drop_path else nn.Identity()
307
+
308
+ def forward(self, x):
309
+ identity = x
310
+ output = [x]
311
+ for layer in self.layers:
312
+ x = layer(x)
313
+ output.append(x)
314
+ x = torch.cat(output, dim=1)
315
+ x = self.aggregation(x)
316
+ if self.residual:
317
+ x = self.drop_path(x) + identity
318
+ return x
319
+
320
+
321
+ class HG_Stage(nn.Module):
322
+ def __init__(
323
+ self,
324
+ in_chs,
325
+ mid_chs,
326
+ out_chs,
327
+ block_num,
328
+ layer_num,
329
+ downsample=True,
330
+ light_block=False,
331
+ kernel_size=3,
332
+ use_lab=False,
333
+ agg='se',
334
+ drop_path=0.,
335
+ ):
336
+ super().__init__()
337
+ self.downsample = downsample
338
+ if downsample:
339
+ self.downsample = ConvBNAct(
340
+ in_chs,
341
+ in_chs,
342
+ kernel_size=3,
343
+ stride=2,
344
+ groups=in_chs,
345
+ use_act=False,
346
+ use_lab=use_lab,
347
+ )
348
+ else:
349
+ self.downsample = nn.Identity()
350
+
351
+ blocks_list = []
352
+ for i in range(block_num):
353
+ blocks_list.append(
354
+ HG_Block(
355
+ in_chs if i == 0 else out_chs,
356
+ mid_chs,
357
+ out_chs,
358
+ layer_num,
359
+ residual=False if i == 0 else True,
360
+ kernel_size=kernel_size,
361
+ light_block=light_block,
362
+ use_lab=use_lab,
363
+ agg=agg,
364
+ drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
365
+ )
366
+ )
367
+ self.blocks = nn.Sequential(*blocks_list)
368
+
369
+ def forward(self, x):
370
+ x = self.downsample(x)
371
+ x = self.blocks(x)
372
+ return x
373
+
374
+
375
+ class HGNetv2(nn.Module):
376
+ """
377
+ HGNetV2
378
+ Args:
379
+ stem_channels: list. Number of channels for the stem block.
380
+ stage_type: str. The stage configuration of HGNet. such as the number of channels, stride, etc.
381
+ use_lab: boolean. Whether to use LearnableAffineBlock in network.
382
+ lr_mult_list: list. Control the learning rate of different stages.
383
+ Returns:
384
+ model: nn.Layer. Specific HGNetV2 model depends on args.
385
+ """
386
+
387
+ arch_configs = {
388
+ 'B0': {
389
+ 'stem_channels': [3, 16, 16],
390
+ 'stage_config': {
391
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
392
+ "stage1": [16, 16, 64, 1, False, False, 3, 3],
393
+ "stage2": [64, 32, 256, 1, True, False, 3, 3],
394
+ "stage3": [256, 64, 512, 2, True, True, 5, 3],
395
+ "stage4": [512, 128, 1024, 1, True, True, 5, 3],
396
+ },
397
+ 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B0_stage1.pth'
398
+ },
399
+ 'B1': {
400
+ 'stem_channels': [3, 24, 32],
401
+ 'stage_config': {
402
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
403
+ "stage1": [32, 32, 64, 1, False, False, 3, 3],
404
+ "stage2": [64, 48, 256, 1, True, False, 3, 3],
405
+ "stage3": [256, 96, 512, 2, True, True, 5, 3],
406
+ "stage4": [512, 192, 1024, 1, True, True, 5, 3],
407
+ },
408
+ 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B1_stage1.pth'
409
+ },
410
+ 'B2': {
411
+ 'stem_channels': [3, 24, 32],
412
+ 'stage_config': {
413
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
414
+ "stage1": [32, 32, 96, 1, False, False, 3, 4],
415
+ "stage2": [96, 64, 384, 1, True, False, 3, 4],
416
+ "stage3": [384, 128, 768, 3, True, True, 5, 4],
417
+ "stage4": [768, 256, 1536, 1, True, True, 5, 4],
418
+ },
419
+ 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B2_stage1.pth'
420
+ },
421
+ 'B3': {
422
+ 'stem_channels': [3, 24, 32],
423
+ 'stage_config': {
424
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
425
+ "stage1": [32, 32, 128, 1, False, False, 3, 5],
426
+ "stage2": [128, 64, 512, 1, True, False, 3, 5],
427
+ "stage3": [512, 128, 1024, 3, True, True, 5, 5],
428
+ "stage4": [1024, 256, 2048, 1, True, True, 5, 5],
429
+ },
430
+ 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B3_stage1.pth'
431
+ },
432
+ 'B4': {
433
+ 'stem_channels': [3, 32, 48],
434
+ 'stage_config': {
435
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
436
+ "stage1": [48, 48, 128, 1, False, False, 3, 6],
437
+ "stage2": [128, 96, 512, 1, True, False, 3, 6],
438
+ "stage3": [512, 192, 1024, 3, True, True, 5, 6],
439
+ "stage4": [1024, 384, 2048, 1, True, True, 5, 6],
440
+ },
441
+ 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B4_stage1.pth'
442
+ },
443
+ 'B5': {
444
+ 'stem_channels': [3, 32, 64],
445
+ 'stage_config': {
446
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
447
+ "stage1": [64, 64, 128, 1, False, False, 3, 6],
448
+ "stage2": [128, 128, 512, 2, True, False, 3, 6],
449
+ "stage3": [512, 256, 1024, 5, True, True, 5, 6],
450
+ "stage4": [1024, 512, 2048, 2, True, True, 5, 6],
451
+ },
452
+ 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B5_stage1.pth'
453
+ },
454
+ 'B6': {
455
+ 'stem_channels': [3, 48, 96],
456
+ 'stage_config': {
457
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
458
+ "stage1": [96, 96, 192, 2, False, False, 3, 6],
459
+ "stage2": [192, 192, 512, 3, True, False, 3, 6],
460
+ "stage3": [512, 384, 1024, 6, True, True, 5, 6],
461
+ "stage4": [1024, 768, 2048, 3, True, True, 5, 6],
462
+ },
463
+ 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B6_stage1.pth'
464
+ },
465
+ }
466
+
467
+ def __init__(self,
468
+ name,
469
+ use_lab=False,
470
+ return_idx=[1, 2, 3],
471
+ freeze_stem_only=True,
472
+ freeze_at=0,
473
+ freeze_norm=True,
474
+ pretrained=True,
475
+ local_model_dir='weight/hgnetv2/',
476
+ for_pgi=False):
477
+ super().__init__()
478
+ self.use_lab = use_lab
479
+ self.return_idx = return_idx
480
+
481
+ stem_channels = self.arch_configs[name]['stem_channels']
482
+ stage_config = self.arch_configs[name]['stage_config']
483
+ download_url = self.arch_configs[name]['url']
484
+
485
+ self._out_strides = [4, 8, 16, 32]
486
+ self._out_channels = [stage_config[k][2] for k in stage_config]
487
+
488
+ self.num_channels = self._out_channels[4 - len(return_idx):]
489
+
490
+ # stem
491
+ self.stem = StemBlock(
492
+ in_chs=stem_channels[0],
493
+ mid_chs=stem_channels[1],
494
+ out_chs=stem_channels[2],
495
+ use_lab=use_lab)
496
+
497
+ # stages
498
+ self.stages = nn.ModuleList()
499
+ for i, k in enumerate(stage_config):
500
+ in_channels, mid_channels, out_channels, block_num, downsample, light_block, kernel_size, layer_num = stage_config[k]
501
+ self.stages.append(
502
+ HG_Stage(
503
+ in_channels,
504
+ mid_channels,
505
+ out_channels,
506
+ block_num,
507
+ layer_num,
508
+ downsample,
509
+ light_block,
510
+ kernel_size,
511
+ use_lab))
512
+
513
+ if freeze_at >= 0:
514
+ self._freeze_parameters(self.stem)
515
+ if not freeze_stem_only:
516
+ for i in range(min(freeze_at + 1, len(self.stages))):
517
+ self._freeze_parameters(self.stages[i])
518
+
519
+ if freeze_norm:
520
+ self._freeze_norm(self)
521
+
522
+ if pretrained:
523
+ RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m"
524
+ try:
525
+ model_path = local_model_dir + 'PPHGNetV2_' + name + '_stage1.pth'
526
+ if os.path.exists(model_path):
527
+ print("Loading stage1")
528
+ state = torch.load(model_path, map_location='cpu')
529
+ print(f"Loaded stage1 {name} HGNetV2 from local file.")
530
+ else:
531
+ # If the file doesn't exist locally, download from the URL
532
+ if torch.distributed.get_rank() == 0:
533
+ print(GREEN + "If the pretrained HGNetV2 can't be downloaded automatically. Please check your network connection." + RESET)
534
+ print(GREEN + "Please check your network connection. Or download the model manually from " + RESET + f"{download_url}" + GREEN + " to " + RESET + f"{local_model_dir}." + RESET)
535
+ state = torch.hub.load_state_dict_from_url(download_url, map_location='cpu', model_dir=local_model_dir)
536
+ torch.distributed.barrier()
537
+ else:
538
+ torch.distributed.barrier()
539
+ state = torch.load(local_model_dir)
540
+
541
+ print(f"Loaded stage1 {name} HGNetV2 from URL.")
542
+
543
+ self.load_state_dict(state)
544
+
545
+ except (Exception, KeyboardInterrupt) as e:
546
+ if torch.distributed.get_rank() == 0:
547
+ print(f"{str(e)}")
548
+ logging.error(RED + "CRITICAL WARNING: Failed to load pretrained HGNetV2 model" + RESET)
549
+ logging.error(GREEN + "Please check your network connection. Or download the model manually from " \
550
+ + RESET + f"{download_url}" + GREEN + " to " + RESET + f"{local_model_dir}." + RESET)
551
+ exit()
552
+
553
+
554
+ def _freeze_norm(self, m: nn.Module):
555
+ if isinstance(m, nn.BatchNorm2d):
556
+ m = FrozenBatchNorm2d(m.num_features)
557
+ else:
558
+ for name, child in m.named_children():
559
+ _child = self._freeze_norm(child)
560
+ if _child is not child:
561
+ setattr(m, name, _child)
562
+ return m
563
+
564
+ def _freeze_parameters(self, m: nn.Module):
565
+ for p in m.parameters():
566
+ p.requires_grad = False
567
+
568
+ def forward(self, x):
569
+ x = self.stem(x)
570
+ outs = []
571
+ for idx, stage in enumerate(self.stages):
572
+ x = stage(x)
573
+ if idx in self.return_idx:
574
+ outs.append(x)
575
+ return outs
576
+
577
+ def build_hgnetv2(args):
578
+ name = {
579
+ 'HGNetv2_B0': 'B0',
580
+ 'HGNetv2_B1': 'B1',
581
+ 'HGNetv2_B2': 'B2',
582
+ 'HGNetv2_B3': 'B3',
583
+ 'HGNetv2_B4': 'B4',
584
+ 'HGNetv2_B5': 'B5',
585
+ 'HGNetv2_B6': 'B6'
586
+ }
587
+ return HGNetv2(
588
+ name[args.backbone],
589
+ return_idx=args.return_interm_indices,
590
+ freeze_at=-1,
591
+ freeze_norm=args.freeze_norm,
592
+ freeze_stem_only= args.freeze_stem_only,
593
+ use_lab = args.use_lab,
594
+ pretrained = args.pretrained,
595
+ )
linea/models/linea/hybrid_encoder.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from D-FINE (https://github.com/Peterande/D-FINE)
3
+ Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
4
+ ---------------------------------------------------------------------------------
5
+ Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
6
+ Copyright (c) 2023 lyuwenyu. All Rights Reserved.
7
+ """
8
+
9
+ import copy
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from torch import nn, Tensor
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def get_activation(act: str, inpace: bool=True):
18
+ """get activation
19
+ """
20
+ if act is None:
21
+ return nn.Identity()
22
+
23
+ elif isinstance(act, nn.Module):
24
+ return act
25
+
26
+ act = act.lower()
27
+
28
+ if act == 'silu' or act == 'swish':
29
+ m = nn.SiLU()
30
+
31
+ elif act == 'relu':
32
+ m = nn.ReLU()
33
+
34
+ elif act == 'leaky_relu':
35
+ m = nn.LeakyReLU()
36
+
37
+ elif act == 'silu':
38
+ m = nn.SiLU()
39
+
40
+ elif act == 'gelu':
41
+ m = nn.GELU()
42
+
43
+ elif act == 'hardsigmoid':
44
+ m = nn.Hardsigmoid()
45
+
46
+ else:
47
+ raise RuntimeError('')
48
+
49
+ if hasattr(m, 'inplace'):
50
+ m.inplace = inpace
51
+
52
+ return m
53
+
54
+ class ConvNormLayer_fuse(nn.Module):
55
+ def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None):
56
+ super().__init__()
57
+ padding = (kernel_size-1)//2 if padding is None else padding
58
+ self.conv = nn.Conv2d(
59
+ ch_in,
60
+ ch_out,
61
+ kernel_size,
62
+ stride,
63
+ groups=g,
64
+ padding=padding,
65
+ bias=bias)
66
+ self.norm = nn.BatchNorm2d(ch_out)
67
+ self.act = nn.Identity() if act is None else get_activation(act)
68
+ self.ch_in, self.ch_out, self.kernel_size, self.stride, self.g, self.padding, self.bias = \
69
+ ch_in, ch_out, kernel_size, stride, g, padding, bias
70
+
71
+ def forward(self, x):
72
+ if hasattr(self, 'conv_bn_fused'):
73
+ y = self.conv_bn_fused(x)
74
+ else:
75
+ y = self.norm(self.conv(x))
76
+ return self.act(y)
77
+
78
+ def convert_to_deploy(self):
79
+ if not hasattr(self, 'conv_bn_fused'):
80
+ self.conv_bn_fused = nn.Conv2d(
81
+ self.ch_in,
82
+ self.ch_out,
83
+ self.kernel_size,
84
+ self.stride,
85
+ groups=self.g,
86
+ padding=self.padding,
87
+ bias=True)
88
+
89
+ kernel, bias = self.get_equivalent_kernel_bias()
90
+ self.conv_bn_fused.weight.data = kernel
91
+ self.conv_bn_fused.bias.data = bias
92
+ self.__delattr__('conv')
93
+ self.__delattr__('norm')
94
+
95
+ def get_equivalent_kernel_bias(self):
96
+ kernel3x3, bias3x3 = self._fuse_bn_tensor()
97
+
98
+ return kernel3x3, bias3x3
99
+
100
+ def _fuse_bn_tensor(self):
101
+ kernel = self.conv.weight
102
+ running_mean = self.norm.running_mean
103
+ running_var = self.norm.running_var
104
+ gamma = self.norm.weight
105
+ beta = self.norm.bias
106
+ eps = self.norm.eps
107
+ std = (running_var + eps).sqrt()
108
+ t = (gamma / std).reshape(-1, 1, 1, 1)
109
+ return kernel * t, beta - running_mean * gamma / std
110
+
111
+
112
+ class ConvNormLayer(nn.Module):
113
+ def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
114
+ super().__init__()
115
+ self.conv = nn.Conv2d(
116
+ ch_in,
117
+ ch_out,
118
+ kernel_size,
119
+ stride,
120
+ padding=(kernel_size-1)//2 if padding is None else padding,
121
+ bias=bias)
122
+ self.norm = nn.BatchNorm2d(ch_out)
123
+ self.act = nn.Identity() if act is None else get_activation(act)
124
+
125
+ def forward(self, x):
126
+ return self.act(self.norm(self.conv(x)))
127
+
128
+ class SCDown(nn.Module):
129
+ def __init__(self, c1, c2, k, s):
130
+ super().__init__()
131
+ self.cv1 = ConvNormLayer_fuse(c1, c2, 1, 1)
132
+ self.cv2 = ConvNormLayer_fuse(c2, c2, k, s, c2)
133
+
134
+ def forward(self, x):
135
+ return self.cv2(self.cv1(x))
136
+
137
+ class VGGBlock(nn.Module):
138
+ def __init__(self, ch_in, ch_out, act='relu'):
139
+ super().__init__()
140
+ self.ch_in = ch_in
141
+ self.ch_out = ch_out
142
+ assert ch_out % 2 == 0
143
+ self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
144
+ self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
145
+ # self.conv1H = ConvNormLayer(ch_in, ch_out//2 , (3, 1), 1, padding=(1, 0), act=None)
146
+ # self.conv1W = ConvNormLayer(ch_in, ch_out//2, (1, 3), 1, padding=(0, 1), act=None)
147
+ # self.conv2H = ConvNormLayer(ch_in, ch_out//2, 1, 1, padding=0, act=None)
148
+ # self.conv2W = ConvNormLayer(ch_in, ch_out//2, 1, 1, padding=0, act=None)
149
+ # self.conv3 = ConvNormLayer(ch_out, ch_out, 1, 1, padding=0, act=None)
150
+ self.act = nn.Identity() if act is None else get_activation(act)
151
+
152
+ def forward(self, x):
153
+ if hasattr(self, 'conv'):
154
+ y = self.conv(x)
155
+ else:
156
+ y = self.conv1(x) + self.conv2(x)
157
+
158
+ return self.act(y)
159
+
160
+ def convert_to_deploy(self):
161
+ if not hasattr(self, 'conv'):
162
+ self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
163
+
164
+ kernel, bias = self.get_equivalent_kernel_bias()
165
+ self.conv.weight.data = kernel
166
+ self.conv.bias.data = bias
167
+ # self.__delattr__('conv1')
168
+ # self.__delattr__('conv2')
169
+
170
+ def get_equivalent_kernel_bias(self):
171
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
172
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
173
+
174
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
175
+
176
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
177
+ if kernel1x1 is None:
178
+ return 0
179
+ else:
180
+ return F.pad(kernel1x1, [1, 1, 1, 1])
181
+
182
+ def _fuse_bn_tensor(self, branch: ConvNormLayer):
183
+ if branch is None:
184
+ return 0, 0
185
+ kernel = branch.conv.weight
186
+ running_mean = branch.norm.running_mean
187
+ running_var = branch.norm.running_var
188
+ gamma = branch.norm.weight
189
+ beta = branch.norm.bias
190
+ eps = branch.norm.eps
191
+ std = (running_var + eps).sqrt()
192
+ t = (gamma / std).reshape(-1, 1, 1, 1)
193
+ return kernel * t, beta - running_mean * gamma / std
194
+
195
+
196
+ class RepNCSPELAN4(nn.Module):
197
+ # csp-elan
198
+ def __init__(self, c1, c2, c3, c4, n=3,
199
+ bias=False,
200
+ act="silu"):
201
+ super().__init__()
202
+ self.c = c3//2
203
+ self.cv1 = ConvNormLayer_fuse(c1, c3, 1, 1, bias=bias, act=act)
204
+ self.cv2 = nn.Sequential(CSPLayer(c3//2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act))
205
+ self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act))
206
+ self.cv4 = ConvNormLayer_fuse(c3+(2*c4), c2, 1, 1, bias=bias, act=act)
207
+
208
+ def forward_chunk(self, x):
209
+ y = list(self.cv1(x).chunk(2, 1))
210
+ y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
211
+ return self.cv4(torch.cat(y, 1))
212
+
213
+ def forward(self, x):
214
+ y = list(self.cv1(x).split((self.c, self.c), 1))
215
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
216
+ return self.cv4(torch.cat(y, 1))
217
+
218
+
219
+ class CSPLayer(nn.Module):
220
+ def __init__(self,
221
+ in_channels,
222
+ out_channels,
223
+ num_blocks=3,
224
+ expansion=1.0,
225
+ bias=None,
226
+ act="silu",
227
+ bottletype=VGGBlock):
228
+ super(CSPLayer, self).__init__()
229
+ hidden_channels = int(out_channels * expansion)
230
+ self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
231
+ self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
232
+ self.bottlenecks = nn.Sequential(*[
233
+ bottletype(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
234
+ ])
235
+ if hidden_channels != out_channels:
236
+ self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
237
+ else:
238
+ self.conv3 = nn.Identity()
239
+
240
+ def forward(self, x):
241
+ x_1 = self.conv1(x)
242
+ x_1 = self.bottlenecks(x_1)
243
+ x_2 = self.conv2(x)
244
+ return self.conv3(x_1 + x_2)
245
+
246
+
247
+ # transformer
248
+ class TransformerEncoderLayer(nn.Module):
249
+ def __init__(self,
250
+ d_model,
251
+ nhead,
252
+ dim_feedforward=2048,
253
+ dropout=0.1,
254
+ activation="relu",
255
+ normalize_before=False):
256
+ super().__init__()
257
+ self.normalize_before = normalize_before
258
+
259
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)
260
+
261
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
262
+ self.dropout = nn.Dropout(dropout)
263
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
264
+
265
+ self.norm1 = nn.LayerNorm(d_model)
266
+ self.norm2 = nn.LayerNorm(d_model)
267
+ self.dropout1 = nn.Dropout(dropout)
268
+ self.dropout2 = nn.Dropout(dropout)
269
+ self.activation = get_activation(activation)
270
+
271
+ @staticmethod
272
+ def with_pos_embed(tensor, pos_embed):
273
+ return tensor if pos_embed is None else tensor + pos_embed
274
+
275
+ def forward(self,
276
+ src,
277
+ src_mask=None,
278
+ src_key_padding_mask=None,
279
+ pos_embed=None) -> torch.Tensor:
280
+ residual = src
281
+ if self.normalize_before:
282
+ src = self.norm1(src)
283
+ q = k = self.with_pos_embed(src, pos_embed)
284
+ src, _ = self.self_attn(q, k,
285
+ value=src,
286
+ attn_mask=src_mask,
287
+ key_padding_mask=src_key_padding_mask)
288
+
289
+ src = residual + self.dropout1(src)
290
+ if not self.normalize_before:
291
+ src = self.norm1(src)
292
+
293
+ residual = src
294
+ if self.normalize_before:
295
+ src = self.norm2(src)
296
+ src = self.linear2(self.dropout(self.activation(self.linear1(src))))
297
+ src = residual + self.dropout2(src)
298
+ if not self.normalize_before:
299
+ src = self.norm2(src)
300
+ return src
301
+
302
+
303
+ class TransformerEncoder(nn.Module):
304
+ def __init__(self, encoder_layer, num_layers, norm=None):
305
+ super(TransformerEncoder, self).__init__()
306
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
307
+ self.num_layers = num_layers
308
+ self.norm = norm
309
+
310
+ def forward(self,
311
+ src,
312
+ src_mask=None,
313
+ src_key_padding_mask=None,
314
+ pos_embed=None) -> torch.Tensor:
315
+ output = src
316
+ for layer in self.layers:
317
+ output = layer(output,
318
+ src_mask=src_mask,
319
+ src_key_padding_mask=src_key_padding_mask,
320
+ pos_embed=pos_embed)
321
+
322
+ if self.norm is not None:
323
+ output = self.norm(output)
324
+
325
+ return output
326
+
327
+
328
+ class HybridEncoder(nn.Module):
329
+ def __init__(self,
330
+ n_levels=3,
331
+ hidden_dim=256,
332
+ nhead=8,
333
+ dim_feedforward = 1024,
334
+ dropout=0.0,
335
+ enc_act='gelu',
336
+ use_encoder_idx=[2],
337
+ num_encoder_layers=1,
338
+ expansion=1.0,
339
+ depth_mult=1.0,
340
+ act='silu',
341
+ eval_spatial_size=None
342
+ ):
343
+ super().__init__()
344
+ self.n_levels = n_levels
345
+ self.hidden_dim = hidden_dim
346
+ self.use_encoder_idx = use_encoder_idx
347
+ self.num_encoder_layers = num_encoder_layers
348
+ self.eval_spatial_size = eval_spatial_size
349
+
350
+ # encoder transformer
351
+ encoder_layer = TransformerEncoderLayer(
352
+ hidden_dim,
353
+ nhead=nhead,
354
+ dim_feedforward=dim_feedforward,
355
+ dropout=dropout,
356
+ activation=enc_act)
357
+
358
+ self.encoder = TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers)
359
+
360
+ # top-down fpn
361
+ self.lateral_convs = nn.ModuleList()
362
+ self.fpn_blocks = nn.ModuleList()
363
+ for _ in range(n_levels - 1, 0, -1):
364
+ self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
365
+ self.fpn_blocks.append(
366
+ RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult))
367
+ # CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
368
+ )
369
+
370
+ # bottom-up pan
371
+ self.downsample_convs = nn.ModuleList()
372
+ self.pan_blocks = nn.ModuleList()
373
+ for _ in range(n_levels - 1):
374
+ self.downsample_convs.append(nn.Sequential(
375
+ SCDown(hidden_dim, hidden_dim, 3, 2),
376
+ )
377
+ )
378
+ self.pan_blocks.append(
379
+ RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult))
380
+ # CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
381
+ )
382
+
383
+ # self._reset_parameters()
384
+
385
+ # def _reset_parameters(self):
386
+ # if self.eval_spatial_size:
387
+ # for idx in self.use_encoder_idx:
388
+ # stride = self.feat_strides[idx]
389
+ # pos_embed = self.build_2d_sincos_position_embedding(
390
+ # self.eval_spatial_size[1] // stride, self.eval_spatial_size[0] // stride,
391
+ # self.hidden_dim, self.pe_temperature)
392
+ # setattr(self, f'pos_embed{idx}', pos_embed)
393
+ # # self.register_buffer(f'pos_embed{idx}', pos_embed)
394
+
395
+ def forward(self,
396
+ src: Tensor,
397
+ pos: Tensor,
398
+ spatial_shapes: Tensor,
399
+ level_start_index: Tensor,
400
+ valid_ratios: Tensor,
401
+ key_padding_mask: Tensor,
402
+ ref_token_index: Optional[Tensor]=None,
403
+ ref_token_coord: Optional[Tensor]=None
404
+ ):
405
+ """
406
+ Input:
407
+ - src: [bs, sum(hi*wi), 256]
408
+ - pos: pos embed for src. [bs, sum(hi*wi), 256]
409
+ - spatial_shapes: h,w of each level [num_level, 2]
410
+ - level_start_index: [num_level] start point of level in sum(hi*wi).
411
+ - valid_ratios: [bs, num_level, 2]
412
+ - key_padding_mask: [bs, sum(hi*wi)]
413
+
414
+ - ref_token_index: bs, nq
415
+ - ref_token_coord: bs, nq, 4
416
+ Intermedia:
417
+ - reference_points: [bs, sum(hi*wi), num_level, 2]
418
+ Outpus:
419
+ - output: [bs, sum(hi*wi), 256]
420
+ """
421
+ src_list = src.split([H_ * W_ for H_, W_ in spatial_shapes], dim=1)
422
+ pos_ = pos[:, level_start_index[-1]:]
423
+ key_padding_mask_ = key_padding_mask[:, level_start_index[-1]:]
424
+
425
+ memory = self.encoder(src_list[-1], pos_embed=pos_, src_key_padding_mask=key_padding_mask_)
426
+
427
+ c = src.size(2)
428
+ proj_feats = []
429
+ for i, (H, W) in enumerate(spatial_shapes):
430
+ if i == len(spatial_shapes) - 1:
431
+ proj_feats.append(memory.reshape(-1, H, W, c).permute(0, 3, 1, 2))
432
+ continue
433
+ proj_feats.append(src_list[i].reshape(-1, H, W, c).permute(0, 3, 1, 2))
434
+
435
+ # broadcasting and fusion
436
+ inner_outs = [proj_feats[-1]]
437
+ for idx in range(self.n_levels - 1, 0, -1):
438
+ feat_high = inner_outs[0]
439
+ feat_low = proj_feats[idx - 1]
440
+ feat_high = self.lateral_convs[self.n_levels - 1 - idx](feat_high)
441
+ inner_outs[0] = feat_high
442
+ upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest')
443
+ inner_out = self.fpn_blocks[self.n_levels-1-idx](torch.concat([upsample_feat, feat_low], dim=1))
444
+ inner_outs.insert(0, inner_out)
445
+
446
+ outs = [inner_outs[0]]
447
+ for idx in range(self.n_levels - 1):
448
+ feat_low = outs[-1]
449
+ feat_high = inner_outs[idx + 1]
450
+ downsample_feat = self.downsample_convs[idx](feat_low)
451
+ out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_high], dim=1))
452
+ outs.append(out)
453
+
454
+ for i in range(len(outs)):
455
+ outs[i] = outs[i].flatten(2).permute(0, 2, 1)
456
+
457
+ return torch.cat(outs, dim=1), None, None
458
+
459
+ def build_hybrid_encoder(args):
460
+ return HybridEncoder(
461
+ n_levels=args.num_feature_levels,
462
+ hidden_dim=args.hidden_dim,
463
+ nhead=args.nheads,
464
+ dim_feedforward = args.dim_feedforward,
465
+ dropout=args.dropout,
466
+ enc_act='gelu',
467
+ # pe_temperature=10000,
468
+ expansion=args.expansion,
469
+ depth_mult=args.depth_mult,
470
+ act='silu',
471
+ )
linea/models/linea/hybrid_encoder_asymmetric_conv.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from D-FINE (https://github.com/Peterande/D-FINE)
3
+ Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
4
+ ---------------------------------------------------------------------------------
5
+ Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
6
+ Copyright (c) 2023 lyuwenyu. All Rights Reserved.
7
+ """
8
+
9
+ import copy
10
+ import math
11
+ from typing import Optional
12
+
13
+ import torch
14
+ from torch import nn, Tensor
15
+ import torch.nn.functional as F
16
+
17
+
18
+ def get_activation(act: str, inpace: bool=True):
19
+ """get activation
20
+ """
21
+ if act is None:
22
+ return nn.Identity()
23
+
24
+ elif isinstance(act, nn.Module):
25
+ return act
26
+
27
+ act = act.lower()
28
+
29
+ if act == 'silu' or act == 'swish':
30
+ m = nn.SiLU()
31
+
32
+ elif act == 'relu':
33
+ m = nn.ReLU()
34
+
35
+ elif act == 'leaky_relu':
36
+ m = nn.LeakyReLU()
37
+
38
+ elif act == 'silu':
39
+ m = nn.SiLU()
40
+
41
+ elif act == 'gelu':
42
+ m = nn.GELU()
43
+
44
+ elif act == 'hardsigmoid':
45
+ m = nn.Hardsigmoid()
46
+
47
+ else:
48
+ raise RuntimeError('')
49
+
50
+ if hasattr(m, 'inplace'):
51
+ m.inplace = inpace
52
+
53
+ return m
54
+
55
+ class CBLinear(nn.Module):
56
+ def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): # ch_in, ch_outs, kernel, stride, padding, groups
57
+ super(CBLinear, self).__init__()
58
+ self.c2s = c2s
59
+ self.conv = nn.Conv2d(c1, sum(c2s), k, s, 0, groups=g, bias=True)
60
+
61
+ def forward(self, x):
62
+ outs = self.conv(x).split(self.c2s, dim=1)
63
+ return outs
64
+
65
+ class CBFuse(nn.Module):
66
+ def __init__(self, idx):
67
+ super(CBFuse, self).__init__()
68
+ self.idx = idx
69
+
70
+ def forward(self, xs):
71
+ target_size = xs[-1].shape[2:]
72
+ res = [F.interpolate(x[self.idx[i]], size=target_size, mode='nearest') for i, x in enumerate(xs[:-1])]
73
+ out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
74
+ return out
75
+
76
+ class ConvNormLayer_fuse(nn.Module):
77
+ def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None):
78
+ super().__init__()
79
+ padding = (kernel_size-1)//2 if padding is None else padding
80
+ self.conv = nn.Conv2d(
81
+ ch_in,
82
+ ch_out,
83
+ kernel_size,
84
+ stride,
85
+ groups=g,
86
+ padding=padding,
87
+ bias=bias)
88
+ self.norm = nn.BatchNorm2d(ch_out)
89
+ self.act = nn.Identity() if act is None else get_activation(act)
90
+ self.ch_in, self.ch_out, self.kernel_size, self.stride, self.g, self.padding, self.bias = \
91
+ ch_in, ch_out, kernel_size, stride, g, padding, bias
92
+
93
+ def forward(self, x):
94
+ if hasattr(self, 'conv_bn_fused'):
95
+ y = self.conv_bn_fused(x)
96
+ else:
97
+ y = self.norm(self.conv(x))
98
+ return self.act(y)
99
+
100
+ def convert_to_deploy(self):
101
+ if not hasattr(self, 'conv_bn_fused'):
102
+ self.conv_bn_fused = nn.Conv2d(
103
+ self.ch_in,
104
+ self.ch_out,
105
+ self.kernel_size,
106
+ self.stride,
107
+ groups=self.g,
108
+ padding=self.padding,
109
+ bias=True)
110
+
111
+ kernel, bias = self.get_equivalent_kernel_bias()
112
+ self.conv_bn_fused.weight.data = kernel
113
+ self.conv_bn_fused.bias.data = bias
114
+ self.__delattr__('conv')
115
+ self.__delattr__('norm')
116
+
117
+ def get_equivalent_kernel_bias(self):
118
+ kernel3x3, bias3x3 = self._fuse_bn_tensor()
119
+
120
+ return kernel3x3, bias3x3
121
+
122
+ def _fuse_bn_tensor(self):
123
+ kernel = self.conv.weight
124
+ running_mean = self.norm.running_mean
125
+ running_var = self.norm.running_var
126
+ gamma = self.norm.weight
127
+ beta = self.norm.bias
128
+ eps = self.norm.eps
129
+ std = (running_var + eps).sqrt()
130
+ t = (gamma / std).reshape(-1, 1, 1, 1)
131
+ return kernel * t, beta - running_mean * gamma / std
132
+
133
+
134
+ class ConvNormLayer(nn.Module):
135
+ def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
136
+ super().__init__()
137
+ self.conv = nn.Conv2d(
138
+ ch_in,
139
+ ch_out,
140
+ kernel_size,
141
+ stride,
142
+ padding=(kernel_size-1)//2 if padding is None else padding,
143
+ bias=bias)
144
+ self.norm = nn.BatchNorm2d(ch_out)
145
+ self.act = nn.Identity() if act is None else get_activation(act)
146
+
147
+ def forward(self, x):
148
+ return self.act(self.norm(self.conv(x)))
149
+
150
+ class SCDown(nn.Module):
151
+ def __init__(self, c1, c2, k, s):
152
+ super().__init__()
153
+ self.cv1 = ConvNormLayer_fuse(c1, c2, 1, 1)
154
+ self.cv2 = ConvNormLayer_fuse(c2, c2, k, s, c2)
155
+
156
+ def forward(self, x):
157
+ return self.cv2(self.cv1(x))
158
+
159
+ class VGGBlock(nn.Module):
160
+ def __init__(self, ch_in, ch_out, act='relu'):
161
+ super().__init__()
162
+ self.ch_in = ch_in
163
+ self.ch_out = ch_out
164
+ self.convH = ConvNormLayer(ch_in, ch_out, (3, 1), 1, padding=(1, 0), act=None)
165
+ self.convW = ConvNormLayer(ch_in, ch_out, (1, 3), 1, padding=(0, 1), act=None)
166
+ self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
167
+ self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
168
+ self.act = nn.Identity() if act is None else get_activation(act)
169
+
170
+ def forward(self, x):
171
+ if hasattr(self, 'conv'):
172
+ y = self.conv(x)
173
+ else:
174
+ y_vertical = self.convH(x)
175
+ y_horizontal = self.convW(x)
176
+ y = self.conv1(x) + self.conv2(x) + y_horizontal + y_vertical
177
+
178
+ return self.act(y)
179
+
180
+ def convert_to_deploy(self):
181
+ if not hasattr(self, 'conv'):
182
+ self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
183
+
184
+ kernel, bias = self.get_equivalent_kernel_bias()
185
+ self.conv.weight.data = kernel
186
+ self.conv.bias.data = bias
187
+ self.__delattr__('conv1')
188
+ self.__delattr__('conv2')
189
+ self.__delattr__('convH')
190
+ self.__delattr__('convW')
191
+
192
+ def get_equivalent_kernel_bias(self):
193
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
194
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
195
+ kernel3x1, bias3x1 = self._fuse_bn_tensor(self.convH)
196
+ kernel1x3, bias1x3 = self._fuse_bn_tensor(self.convW)
197
+
198
+ kernel = kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + self._pad_1x1_to_3x3_tensor(kernel3x1, 'Vertical') + self._pad_1x1_to_3x3_tensor(kernel1x3, 'Horizontal')
199
+ bias = bias3x3 + bias1x1 + bias3x1 + bias1x3
200
+ return kernel, bias
201
+
202
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1, assymetric='None'):
203
+ if kernel1x1 is None:
204
+ return 0
205
+ else:
206
+ if assymetric == 'None':
207
+ return F.pad(kernel1x1, [1, 1, 1, 1])
208
+ elif assymetric == 'Vertical':
209
+ return F.pad(kernel1x1, [1, 1, 0, 0])
210
+ elif assymetric == 'Horizontal':
211
+ return F.pad(kernel1x1, [0, 0, 1, 1])
212
+
213
+ def _fuse_bn_tensor(self, branch: ConvNormLayer):
214
+ if branch is None:
215
+ return 0, 0
216
+ kernel = branch.conv.weight
217
+ running_mean = branch.norm.running_mean
218
+ running_var = branch.norm.running_var
219
+ gamma = branch.norm.weight
220
+ beta = branch.norm.bias
221
+ eps = branch.norm.eps
222
+ std = (running_var + eps).sqrt()
223
+ t = (gamma / std).reshape(-1, 1, 1, 1)
224
+ return kernel * t, beta - running_mean * gamma / std
225
+
226
+
227
+ class RepNCSPELAN4(nn.Module):
228
+ # csp-elan
229
+ def __init__(self, c1, c2, c3, c4, n=3,
230
+ bias=False,
231
+ act="silu"):
232
+ super().__init__()
233
+ self.c = c3//2
234
+ self.cv1 = ConvNormLayer_fuse(c1, c3, 1, 1, bias=bias, act=act)
235
+ self.cv2 = nn.Sequential(CSPLayer(c3//2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act))
236
+ self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act))
237
+ self.cv4 = ConvNormLayer_fuse(c3+(2*c4), c2, 1, 1, bias=bias, act=act)
238
+
239
+ def forward_chunk(self, x):
240
+ y = list(self.cv1(x).chunk(2, 1))
241
+ y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
242
+ return self.cv4(torch.cat(y, 1))
243
+
244
+ def forward(self, x):
245
+ y = list(self.cv1(x).split((self.c, self.c), 1))
246
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
247
+ return self.cv4(torch.cat(y, 1))
248
+
249
+
250
+ class CSPLayer(nn.Module):
251
+ def __init__(self,
252
+ in_channels,
253
+ out_channels,
254
+ num_blocks=3,
255
+ expansion=1.0,
256
+ bias=None,
257
+ act="silu",
258
+ bottletype=VGGBlock):
259
+ super(CSPLayer, self).__init__()
260
+ hidden_channels = int(out_channels * expansion)
261
+ self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
262
+ self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
263
+ self.bottlenecks = nn.Sequential(*[
264
+ bottletype(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
265
+ ])
266
+ if hidden_channels != out_channels:
267
+ self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
268
+ else:
269
+ self.conv3 = nn.Identity()
270
+
271
+ def forward(self, x):
272
+ x_1 = self.conv1(x)
273
+ x_1 = self.bottlenecks(x_1)
274
+ x_2 = self.conv2(x)
275
+ return self.conv3(x_1 + x_2)
276
+
277
+
278
+ # transformer
279
+ class TransformerEncoderLayer(nn.Module):
280
+ def __init__(self,
281
+ d_model,
282
+ nhead,
283
+ dim_feedforward=2048,
284
+ dropout=0.1,
285
+ activation="relu",
286
+ normalize_before=False):
287
+ super().__init__()
288
+ self.normalize_before = normalize_before
289
+
290
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)
291
+
292
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
293
+ self.dropout = nn.Dropout(dropout)
294
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
295
+
296
+ self.norm1 = nn.LayerNorm(d_model)
297
+ self.norm2 = nn.LayerNorm(d_model)
298
+ self.dropout1 = nn.Dropout(dropout)
299
+ self.dropout2 = nn.Dropout(dropout)
300
+ self.activation = get_activation(activation)
301
+
302
+ @staticmethod
303
+ def with_pos_embed(tensor, pos_embed):
304
+ return tensor if pos_embed is None else tensor + pos_embed
305
+
306
+ def forward(self,
307
+ src,
308
+ src_mask=None,
309
+ src_key_padding_mask=None,
310
+ pos_embed=None) -> torch.Tensor:
311
+ residual = src
312
+ if self.normalize_before:
313
+ src = self.norm1(src)
314
+ q = k = self.with_pos_embed(src, pos_embed)
315
+ src, _ = self.self_attn(q, k,
316
+ value=src,
317
+ attn_mask=src_mask,
318
+ key_padding_mask=src_key_padding_mask)
319
+
320
+ src = residual + self.dropout1(src)
321
+ if not self.normalize_before:
322
+ src = self.norm1(src)
323
+
324
+ residual = src
325
+ if self.normalize_before:
326
+ src = self.norm2(src)
327
+ src = self.linear2(self.dropout(self.activation(self.linear1(src))))
328
+ src = residual + self.dropout2(src)
329
+ if not self.normalize_before:
330
+ src = self.norm2(src)
331
+ return src
332
+
333
+
334
+ class TransformerEncoder(nn.Module):
335
+ def __init__(self, encoder_layer, num_layers, norm=None):
336
+ super(TransformerEncoder, self).__init__()
337
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
338
+ self.num_layers = num_layers
339
+ self.norm = norm
340
+
341
+ def forward(self,
342
+ src,
343
+ src_mask=None,
344
+ src_key_padding_mask=None,
345
+ pos_embed=None) -> torch.Tensor:
346
+ output = src
347
+ for layer in self.layers:
348
+ output = layer(output,
349
+ src_mask=src_mask,
350
+ src_key_padding_mask=src_key_padding_mask,
351
+ pos_embed=pos_embed)
352
+
353
+ if self.norm is not None:
354
+ output = self.norm(output)
355
+
356
+ return output
357
+
358
+
359
+ class HybridEncoderAsymConv(nn.Module):
360
+ def __init__(
361
+ self,
362
+ in_channels=[512, 1024, 2048],
363
+ feat_strides=[8, 16, 32],
364
+ n_levels=3,
365
+ hidden_dim=256,
366
+ nhead=8,
367
+ dim_feedforward = 1024,
368
+ dropout=0.0,
369
+ enc_act='gelu',
370
+ use_encoder_idx=[2],
371
+ num_encoder_layers=1,
372
+ expansion=1.0,
373
+ depth_mult=1.0,
374
+ act='silu',
375
+ eval_spatial_size=None,
376
+ # position embedding
377
+ temperatureH=20,
378
+ temperatureW=20,
379
+ ):
380
+ super().__init__()
381
+ self.in_channels = in_channels
382
+ self.feat_strides = feat_strides
383
+ self.n_levels = n_levels
384
+ self.hidden_dim = hidden_dim
385
+ self.use_encoder_idx = use_encoder_idx
386
+ self.num_encoder_layers = num_encoder_layers
387
+ self.eval_spatial_size = eval_spatial_size
388
+
389
+ self.temperatureW = temperatureW
390
+ self.temperatureH = temperatureH
391
+
392
+ # channel projection
393
+ input_proj_list = []
394
+ for in_channel in in_channels:
395
+ input_proj_list.append(
396
+ nn.Sequential(
397
+ nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False),
398
+ nn.GroupNorm(32, hidden_dim)
399
+ )
400
+ )
401
+ self.input_proj = nn.ModuleList(input_proj_list)
402
+
403
+ # encoder transformer
404
+ encoder_layer = TransformerEncoderLayer(
405
+ hidden_dim,
406
+ nhead=nhead,
407
+ dim_feedforward=dim_feedforward,
408
+ dropout=dropout,
409
+ activation=enc_act)
410
+
411
+ self.encoder = nn.ModuleList([
412
+ TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))
413
+ ])
414
+
415
+ # top-down fpn
416
+ self.lateral_convs = nn.ModuleList()
417
+ self.fpn_blocks = nn.ModuleList()
418
+ for _ in range(n_levels - 1, 0, -1):
419
+ self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
420
+ self.fpn_blocks.append(
421
+ RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult))
422
+ # CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
423
+ )
424
+
425
+ # bottom-up pan
426
+ self.downsample_convs = nn.ModuleList()
427
+ self.pan_blocks = nn.ModuleList()
428
+ for _ in range(n_levels - 1):
429
+ self.downsample_convs.append(nn.Sequential(
430
+ SCDown(hidden_dim, hidden_dim, 3, 2),
431
+ )
432
+ )
433
+ self.pan_blocks.append(
434
+ RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult))
435
+ # CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
436
+ )
437
+
438
+ self._reset_parameters()
439
+
440
+ def _reset_parameters(self):
441
+ # init input_proj
442
+ for proj in self.input_proj:
443
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
444
+ # nn.init.constant_(proj[0].bias, 0)
445
+
446
+ if self.eval_spatial_size is not None:
447
+ for idx in self.use_encoder_idx:
448
+ stride = self.feat_strides[idx]
449
+ pos_embed = self.create_sinehw_position_embedding(
450
+ self.eval_spatial_size[1] // stride,
451
+ self.eval_spatial_size[0] // stride,
452
+ self.hidden_dim // 2
453
+ )
454
+ setattr(self, f'pos_embed{idx}', pos_embed)
455
+
456
+ def create_sinehw_position_embedding(self, w, h, hidden_dim, scale=None, device='cpu'):
457
+ """
458
+ """
459
+ grid_w = torch.arange(1, int(w)+1, dtype=torch.float32, device=device)
460
+ grid_h = torch.arange(1, int(h)+1, dtype=torch.float32, device=device)
461
+
462
+ grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing='ij')
463
+
464
+ if scale is None:
465
+ scale = 2 * math.pi
466
+
467
+ eps = 1e-6
468
+ grid_w = grid_w / (int(w) + eps) * scale
469
+ grid_h = grid_h / (int(h) + eps) * scale
470
+
471
+ dim_tx = torch.arange(hidden_dim, dtype=torch.float32, device=device)
472
+ dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / hidden_dim)
473
+ pos_x = grid_w[..., None] / dim_tx
474
+
475
+ dim_ty = torch.arange(hidden_dim, dtype=torch.float32, device=device)
476
+ dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / hidden_dim)
477
+ pos_y = grid_h[..., None] / dim_ty
478
+
479
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
480
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
481
+
482
+ pos = torch.cat((pos_y, pos_x), dim=2).permute(2, 0, 1)
483
+ pos = pos[None].flatten(2).permute(0, 2, 1).contiguous()
484
+
485
+ return pos
486
+
487
+ def forward(self, feats):
488
+ """
489
+ Input:
490
+ - feats: List of features from the backbone
491
+ Outpus:
492
+ - output: List of enhanced features
493
+ """
494
+ assert len(feats) == len(self.in_channels)
495
+ proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
496
+
497
+ # encoder
498
+ for i, enc_idx in enumerate(self.use_encoder_idx):
499
+ N_, C_, H_, W_ = proj_feats[enc_idx].shape
500
+ # flatten [B, C, H, W] to [B, HxW, C]
501
+ src_flatten = proj_feats[enc_idx].flatten(2).permute(0, 2, 1)
502
+ if self.training or self.eval_spatial_size is None:
503
+ pos_embed = self.create_sinehw_position_embedding(
504
+ H_, W_, self.hidden_dim//2, device=src_flatten.device)
505
+ else:
506
+ pos_embed = getattr(self, f'pos_embed{enc_idx}', None).to(src_flatten.device)
507
+
508
+ proj_feats[enc_idx] = self.encoder[i](
509
+ src_flatten,
510
+ pos_embed=pos_embed
511
+ ).permute(0, 2, 1).reshape(N_, C_, H_, W_).contiguous()
512
+
513
+ # broadcasting and fusion
514
+ inner_outs = [proj_feats[-1]]
515
+ for idx in range(self.n_levels - 1, 0, -1):
516
+ feat_high = inner_outs[0]
517
+ feat_low = proj_feats[idx - 1]
518
+ feat_high = self.lateral_convs[self.n_levels - 1 - idx](feat_high)
519
+ inner_outs[0] = feat_high
520
+ upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest')
521
+ inner_out = self.fpn_blocks[self.n_levels-1-idx](torch.concat([upsample_feat, feat_low], dim=1))
522
+ inner_outs.insert(0, inner_out)
523
+
524
+ outs = [inner_outs[0]]
525
+ for idx in range(self.n_levels - 1):
526
+ feat_low = outs[-1]
527
+ feat_high = inner_outs[idx + 1]
528
+ downsample_feat = self.downsample_convs[idx](feat_low)
529
+ out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_high], dim=1))
530
+ outs.append(out)
531
+ return outs
532
+
533
+ def build_hybrid_encoder_with_asymmetric_conv(args):
534
+ return HybridEncoderAsymConv(
535
+ in_channels=args.in_channels_encoder,
536
+ feat_strides=args.feat_strides,
537
+ n_levels=args.num_feature_levels,
538
+ hidden_dim=args.hidden_dim,
539
+ nhead=args.nheads,
540
+ dim_feedforward = args.dim_feedforward,
541
+ dropout=args.dropout,
542
+ enc_act='gelu',
543
+ expansion=args.expansion,
544
+ depth_mult=args.depth_mult,
545
+ act='silu',
546
+ temperatureH=args.pe_temperatureH,
547
+ temperatureW=args.pe_temperatureW,
548
+ eval_spatial_size= args.eval_spatial_size,
549
+ )
linea/models/linea/linea.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # DINO
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Conditional DETR model and criterion classes.
7
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
8
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
9
+ # ------------------------------------------------------------------------
10
+ # Modified from DETR (https://github.com/facebookresearch/detr)
11
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
12
+ # ------------------------------------------------------------------------
13
+ # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
14
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
15
+ # ------------------------------------------------------------------------
16
+ import copy
17
+ import math
18
+ from typing import List
19
+ import torch
20
+ from torch import nn
21
+ from torchvision.transforms.functional import resize
22
+
23
+ import numpy as np
24
+
25
+ from .utils import sigmoid_focal_loss, MLP
26
+
27
+ from ..registry import MODULE_BUILD_FUNCS
28
+
29
+ from .hgnetv2 import build_hgnetv2
30
+ from .hybrid_encoder_asymmetric_conv import build_hybrid_encoder_with_asymmetric_conv
31
+ from .decoder import build_decoder
32
+
33
+ from .linea_utils import *
34
+
35
+ class LINEA(nn.Module):
36
+ """ This is the Cross-Attention Detector module that performs object detection """
37
+ def __init__(self,
38
+ backbone,
39
+ encoder,
40
+ decoder,
41
+ # multiscale = None,
42
+ use_lmap = False
43
+ ):
44
+ """ Initializes the model.
45
+ Parameters:
46
+ backbone: torch module of the backbone to be used. See backbone.py
47
+ transformer: torch module of the transformer architecture. See transformer.py
48
+ num_classes: number of object classes
49
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
50
+ Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
51
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
52
+ """
53
+ super().__init__()
54
+ self.backbone = backbone
55
+ self.encoder = encoder
56
+ self.decoder = decoder
57
+
58
+ # for auxiliary branch
59
+ if use_lmap:
60
+ self.aux_branch = nn.ModuleList()
61
+ hidden_dim = encoder.hidden_dim
62
+ for i in range(3):
63
+ n = 2 ** i
64
+ self.aux_branch.append(nn.Conv2d(hidden_dim, 1, 1))
65
+
66
+ def forward(self, samples, targets:List=None):
67
+ """ The forward expects a NestedTensor, which consists of:
68
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
69
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
70
+
71
+ It returns a dict with the following elements:
72
+ - "pred_logits": the classification logits (including no-object) for all queries.
73
+ Shape= [batch_size x num_queries x num_classes]
74
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
75
+ (center_x, center_y, width, height). These values are normalized in [0, 1],
76
+ relative to the size of each individual image (disregarding possible padding).
77
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
78
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
79
+ dictionnaries containing the two above keys for each decoder layer.
80
+ """
81
+ features = self.backbone(samples)
82
+
83
+ features = self.encoder(features)
84
+
85
+ out = self.decoder(features, targets)
86
+
87
+ if self.training and hasattr(self, 'aux_branch'):
88
+ lmaps = []
89
+ for feat, convs in zip(features, self.aux_branch):
90
+ lmap = convs(feat)
91
+ lmaps.append(lmap)
92
+ # lmaps = torch.cat(lmaps, dim=1)
93
+ out['aux_lmap'] = lmaps
94
+
95
+ return out
96
+
97
+ def deploy(self, ):
98
+ self.eval()
99
+ for m in self.modules():
100
+ if hasattr(m, 'convert_to_deploy'):
101
+ m.convert_to_deploy()
102
+ return self
103
+
104
+
105
+ class PostProcess(nn.Module):
106
+ """ This module converts the model's output into the format expected by the coco api"""
107
+ def __init__(self) -> None:
108
+ super().__init__()
109
+ self.deploy_mode = False
110
+
111
+ @torch.no_grad()
112
+ def forward(self, outputs, target_sizes):
113
+ """ Perform the computation
114
+ Parameters:
115
+ outputs: raw outputs of the model
116
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
117
+ For evaluation, this must be the original image size (before any data augmentation)
118
+ For visualization, this should be the image size after data augment, but before padding
119
+ """
120
+ out_logits, out_line = outputs['pred_logits'], outputs['pred_lines']
121
+
122
+ scores = out_logits[..., 0].sigmoid()
123
+
124
+ # convert to [x0, y0, x1, y1] format
125
+ lines = out_line * target_sizes.repeat(1, 2).unsqueeze(1)
126
+
127
+ if self.deploy_mode:
128
+ return lines, scores
129
+
130
+ results = [{'lines': l, 'scores': s} for s, l in zip(scores, lines)]
131
+
132
+ return results
133
+
134
+ def deploy(self, ):
135
+ self.eval()
136
+ self.deploy_mode = True
137
+ return self
138
+
139
+ @MODULE_BUILD_FUNCS.registe_with_name(module_name='LINEA')
140
+ def build_linea(args):
141
+ num_classes = args.num_classes
142
+
143
+ backbone = build_hgnetv2(args)
144
+ encoder = build_hybrid_encoder_with_asymmetric_conv(args)
145
+ decoder = build_decoder(args)
146
+
147
+ model = LINEA(
148
+ backbone,
149
+ encoder,
150
+ decoder,
151
+ use_lmap=args.use_lmap
152
+ )
153
+
154
+ postprocessors = PostProcess()
155
+
156
+ return model, postprocessors
linea/models/linea/linea_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def weighting_function(reg_max, up, reg_scale, deploy=False):
5
+ """
6
+ Generates the non-uniform Weighting Function W(n) for bounding box regression.
7
+
8
+ Args:
9
+ reg_max (int): Max number of the discrete bins.
10
+ up (Tensor): Controls upper bounds of the sequence,
11
+ where maximum offset is ±up * H / W.
12
+ reg_scale (float): Controls the curvature of the Weighting Function.
13
+ Larger values result in flatter weights near the central axis W(reg_max/2)=0
14
+ and steeper weights at both ends.
15
+ deploy (bool): If True, uses deployment mode settings.
16
+
17
+ Returns:
18
+ Tensor: Sequence of Weighting Function.
19
+ """
20
+ if deploy:
21
+ upper_bound1 = (abs(up[0]) * abs(reg_scale)).item()
22
+ upper_bound2 = (abs(up[0]) * abs(reg_scale) * 2).item()
23
+ step = (upper_bound1 + 1) ** (2 / (reg_max - 2))
24
+ left_values = [-(step) ** i + 1 for i in range(reg_max // 2 - 1, 0, -1)]
25
+ right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)]
26
+ values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2]
27
+ return torch.tensor(values, dtype=up.dtype, device=up.device)
28
+ else:
29
+ upper_bound1 = abs(up[0]) * abs(reg_scale)
30
+ upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
31
+ step = (upper_bound1 + 1) ** (2 / (reg_max - 2))
32
+ left_values = [-(step) ** i + 1 for i in range(reg_max // 2 - 1, 0, -1)]
33
+ right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)]
34
+ values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2]
35
+ return torch.cat(values, 0)
36
+
37
+
38
+ def translate_gt(gt, reg_max, reg_scale, up):
39
+ """
40
+ Decodes bounding box ground truth (GT) values into distribution-based GT representations.
41
+
42
+ This function maps continuous GT values into discrete distribution bins, which can be used
43
+ for regression tasks in object detection models. It calculates the indices of the closest
44
+ bins to each GT value and assigns interpolation weights to these bins based on their proximity
45
+ to the GT value.
46
+
47
+ Args:
48
+ gt (Tensor): Ground truth bounding box values, shape (N, ).
49
+ reg_max (int): Maximum number of discrete bins for the distribution.
50
+ reg_scale (float): Controls the curvature of the Weighting Function.
51
+ up (Tensor): Controls the upper bounds of the Weighting Function.
52
+
53
+ Returns:
54
+ Tuple[Tensor, Tensor, Tensor]:
55
+ - indices (Tensor): Index of the left bin closest to each GT value, shape (N, ).
56
+ - weight_right (Tensor): Weight assigned to the right bin, shape (N, ).
57
+ - weight_left (Tensor): Weight assigned to the left bin, shape (N, ).
58
+ """
59
+ gt = gt.reshape(-1)
60
+ function_values = weighting_function(reg_max, up, reg_scale)
61
+
62
+ # Find the closest left-side indices for each value
63
+ diffs = function_values.unsqueeze(0) - gt.unsqueeze(1)
64
+ mask = diffs <= 0
65
+ closest_left_indices = torch.sum(mask, dim=1) - 1
66
+
67
+ # Calculate the weights for the interpolation
68
+ indices = closest_left_indices.float()
69
+
70
+ weight_right = torch.zeros_like(indices)
71
+ weight_left = torch.zeros_like(indices)
72
+
73
+ valid_idx_mask = (indices >= 0) & (indices < reg_max)
74
+ valid_indices = indices[valid_idx_mask].long()
75
+
76
+ # Obtain distances
77
+ left_values = function_values[valid_indices]
78
+ right_values = function_values[valid_indices + 1]
79
+
80
+ left_diffs = torch.abs(gt[valid_idx_mask] - left_values)
81
+ right_diffs = torch.abs(right_values - gt[valid_idx_mask])
82
+
83
+ # Valid weights
84
+ weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs)
85
+ weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask]
86
+
87
+ # Invalid weights (out of range)
88
+ invalid_idx_mask_neg = (indices < 0)
89
+ weight_right[invalid_idx_mask_neg] = 0.0
90
+ weight_left[invalid_idx_mask_neg] = 1.0
91
+ indices[invalid_idx_mask_neg] = 0.0
92
+
93
+ invalid_idx_mask_pos = (indices >= reg_max)
94
+ weight_right[invalid_idx_mask_pos] = 1.0
95
+ weight_left[invalid_idx_mask_pos] = 0.0
96
+ indices[invalid_idx_mask_pos] = reg_max - 0.1
97
+
98
+ return indices, weight_right, weight_left
99
+
100
+
101
+ def bbox2distance(points, bbox, reg_max, reg_scale, up, eps=0.1):
102
+ """
103
+ Converts bounding box coordinates to distances from a reference point.
104
+
105
+ Args:
106
+ points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center.
107
+ bbox (Tensor): (n, 4) bounding boxes in "xyxy" format.
108
+ reg_max (float): Maximum bin value.
109
+ reg_scale (float): Controling curvarture of W(n).
110
+ up (Tensor): Controling upper bounds of W(n).
111
+ eps (float): Small value to ensure target < reg_max.
112
+
113
+ Returns:
114
+ Tensor: Decoded distances.
115
+ """
116
+ reg_scale = abs(reg_scale)
117
+
118
+ Dx = torch.abs(points[..., 0] - points[..., 2])
119
+ Dy = torch.abs(points[..., 1] - points[..., 3])
120
+
121
+ left = (points[:, 0] - bbox[:, 0]) / (Dx / reg_scale + 1e-16) - 0.5 * reg_scale
122
+ top = (points[:, 1] - bbox[:, 1]) / (Dy / reg_scale + 1e-16) - 0.5 * reg_scale
123
+ right = (points[:, 2] - bbox[:, 2]) / (Dx / reg_scale + 1e-16) - 0.5 * reg_scale
124
+ bottom = (points[:, 3] - bbox[:, 3]) / (Dy / reg_scale + 1e-16) - 0.5 * reg_scale
125
+ four_lens = torch.stack([left, top, right, bottom], -1)
126
+ four_lens, weight_right, weight_left = translate_gt(four_lens, reg_max, reg_scale, up)
127
+ if reg_max is not None:
128
+ four_lens = four_lens.clamp(min=0, max=reg_max-eps)
129
+ return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach()
130
+
131
+
132
+ def distance2bbox(points, distance, reg_scale):
133
+ """
134
+ Decodes edge-distances into bounding box coordinates.
135
+
136
+ Args:
137
+ points (Tensor): (B, N, 4) or (N, 4) format, representing [x, y, w, h],
138
+ where (x, y) is the center and (w, h) are width and height.
139
+ distance (Tensor): (B, N, 4) or (N, 4), representing distances from the
140
+ point to the left, top, right, and bottom boundaries.
141
+
142
+ reg_scale (float): Controls the curvature of the Weighting Function.
143
+
144
+ Returns:
145
+ Tensor: Bounding boxes in (N, 4) or (B, N, 4) format [cx, cy, w, h].
146
+ """
147
+ reg_scale = abs(reg_scale)
148
+
149
+ Dx = torch.abs(points[..., 0] - points[..., 2])
150
+ Dy = torch.abs(points[..., 1] - points[..., 3])
151
+
152
+ x1 = points[..., 0] + (0.5 * reg_scale + distance[..., 0]) * (Dx / reg_scale)
153
+ y1 = points[..., 1] + (0.5 * reg_scale + distance[..., 1]) * (Dy / reg_scale)
154
+ x2 = points[..., 2] + (0.5 * reg_scale + distance[..., 2]) * (Dx / reg_scale)
155
+ y2 = points[..., 3] + (0.5 * reg_scale + distance[..., 3]) * (Dy / reg_scale)
156
+
157
+ bboxes = torch.stack([x1, y1, x2, y2], -1)
158
+
159
+ return bboxes
160
+
161
+ def inverse_sigmoid(x, eps=1e-3):
162
+ x = x.clamp(min=0, max=1)
163
+ x1 = x.clamp(min=eps)
164
+ x2 = (1 - x).clamp(min=eps)
165
+ return torch.log(x1/x2)
linea/models/linea/matcher.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # DINO
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modules to compute the matching cost and solve the corresponding LSAP.
7
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
8
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
9
+ # ------------------------------------------------------------------------
10
+ # Modified from DETR (https://github.com/facebookresearch/detr)
11
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
12
+ # ------------------------------------------------------------------------
13
+ # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
14
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
15
+ # ------------------------------------------------------------------------
16
+
17
+
18
+ import torch, os
19
+ from torch import nn
20
+ import torch.nn.functional as F
21
+ from scipy.optimize import linear_sum_assignment
22
+
23
+
24
+ class HungarianMatcher(nn.Module):
25
+ """This class computes an assignment between the targets and the predictions of the network
26
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
27
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
28
+ while the others are un-matched (and thus treated as non-objects).
29
+ """
30
+
31
+ def __init__(self, cost_class: float = 1, cost_bbox: float = 1, focal_alpha = 0.25):
32
+ """Creates the matcher
33
+ Params:
34
+ cost_class: This is the relative weight of the classification error in the matching cost
35
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
36
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
37
+ """
38
+ super().__init__()
39
+ self.cost_class = cost_class
40
+ self.cost_line = cost_bbox
41
+ assert cost_class != 0 or cost_bbox != 0, "all costs cant be 0"
42
+
43
+ self.focal_alpha = focal_alpha
44
+
45
+ @torch.no_grad()
46
+ def forward(self, outputs, targets):
47
+ """ Performs the matching
48
+ Params:
49
+ outputs: This is a dict that contains at least these entries:
50
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
51
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
52
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
53
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
54
+ objects in the target) containing the class labels
55
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
56
+ Returns:
57
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
58
+ - index_i is the indices of the selected predictions (in order)
59
+ - index_j is the indices of the corresponding selected targets (in order)
60
+ For each batch element, it holds:
61
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
62
+ """
63
+
64
+ bs, num_queries = outputs["pred_logits"].shape[:2]
65
+
66
+ # We flatten to compute the cost matrices in a batch
67
+ out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
68
+ out_line = outputs["pred_lines"].flatten(0, 1) # [batch_size * num_queries, 4]
69
+
70
+ # Also concat the target labels and lines
71
+ tgt_ids = torch.cat([v["labels"] for v in targets])
72
+ tgt_line = torch.cat([v["lines"] for v in targets])
73
+
74
+ # Compute the classification cost.
75
+ alpha = self.focal_alpha
76
+ gamma = 2.0
77
+ neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
78
+ pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
79
+ cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
80
+
81
+ # Compute the L1 cost between boxes
82
+ cost_line = torch.cdist(out_line, tgt_line, p=1)
83
+
84
+ # Final cost matrix
85
+ C = self.cost_line * cost_line + self.cost_class * cost_class
86
+ C = C.view(bs, num_queries, -1).cpu()
87
+
88
+ sizes = [len(v["lines"]) for v in targets]
89
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
90
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
91
+
92
+
93
+ class SimpleMinsumMatcher(nn.Module):
94
+ """This class computes an assignment between the targets and the predictions of the network
95
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
96
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
97
+ while the others are un-matched (and thus treated as non-objects).
98
+ """
99
+
100
+ def __init__(self, cost_class: float = 1, cost_bbox: float = 1, focal_alpha = 0.25):
101
+ """Creates the matcher
102
+ Params:
103
+ cost_class: This is the relative weight of the classification error in the matching cost
104
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
105
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
106
+ """
107
+ super().__init__()
108
+ self.cost_class = cost_class
109
+ self.cost_line = cost_bbox
110
+ assert cost_class != 0 or cost_bbox != 0, "all costs cant be 0"
111
+
112
+ self.focal_alpha = focal_alpha
113
+
114
+ @torch.no_grad()
115
+ def forward(self, outputs, targets):
116
+ """ Performs the matching
117
+ Params:
118
+ outputs: This is a dict that contains at least these entries:
119
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
120
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
121
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
122
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
123
+ objects in the target) containing the class labels
124
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
125
+ Returns:
126
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
127
+ - index_i is the indices of the selected predictions (in order)
128
+ - index_j is the indices of the corresponding selected targets (in order)
129
+ For each batch element, it holds:
130
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
131
+ """
132
+
133
+ bs, num_queries = outputs["pred_logits"].shape[:2]
134
+
135
+ # We flatten to compute the cost matrices in a batch
136
+ out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
137
+ out_line = outputs["pred_lines"].flatten(0, 1) # [batch_size * num_queries, 4]
138
+
139
+ # Also concat the target labels and boxes
140
+ tgt_ids = torch.cat([v["labels"] for v in targets])
141
+ tgt_line = torch.cat([v["lines"] for v in targets])
142
+
143
+ # Compute the classification cost.
144
+ alpha = self.focal_alpha
145
+ gamma = 2.0
146
+ neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
147
+ pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
148
+ cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
149
+
150
+ # Compute the L1 cost between boxes
151
+ cost_line = torch.cdist(out_line, tgt_line, p=1)
152
+
153
+ # Final cost matrix
154
+ C = self.cost_line * cost_line + self.cost_class * cost_class
155
+ C = C.view(bs, num_queries, -1)
156
+
157
+ sizes = [len(v["lines"]) for v in targets]
158
+ indices = []
159
+ device = C.device
160
+ for i, (c, _size) in enumerate(zip(C.split(sizes, -1), sizes)):
161
+ weight_mat = c[i]
162
+ idx_i = weight_mat.min(0)[1]
163
+ idx_j = torch.arange(_size).to(device)
164
+ indices.append((idx_i, idx_j))
165
+
166
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
167
+
168
+
169
+ def build_matcher(args):
170
+ assert args.matcher_type in ['HungarianMatcher', 'SimpleMinsumMatcher'], "Unknown args.matcher_type: {}".format(args.matcher_type)
171
+ if args.matcher_type == 'HungarianMatcher':
172
+ return HungarianMatcher(
173
+ cost_class=args.set_cost_class, cost_bbox=args.set_cost_lines, focal_alpha=args.focal_alpha
174
+ )
175
+ elif args.matcher_type == 'SimpleMinsumMatcher':
176
+ return SimpleMinsumMatcher(
177
+ cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, focal_alpha=args.focal_alpha
178
+ )
179
+ else:
180
+ raise NotImplementedError("Unknown args.matcher_type: {}".format(args.matcher_type))
linea/models/linea/new_dn_components.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # DINO
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # DN-DETR
7
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
8
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
9
+
10
+
11
+ import torch
12
+ from .linea_utils import inverse_sigmoid
13
+ import torch.nn.functional as F
14
+
15
+ def prepare_for_cdn(dn_args, training, num_queries, num_classes, hidden_dim, label_enc):
16
+ """
17
+ A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
18
+ forward function and use learnable tgt embedding, so we change this function a little bit.
19
+ :param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
20
+ :param training: if it is training or inference
21
+ :param num_queries: number of queires
22
+ :param num_classes: number of classes
23
+ :param hidden_dim: transformer hidden dim
24
+ :param label_enc: encode labels in dn
25
+ :return:
26
+ """
27
+ if training:
28
+ targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
29
+ # positive and negative dn queries
30
+ dn_number = dn_number * 2
31
+ known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
32
+ batch_size = len(known)
33
+ known_num = [sum(k) for k in known]
34
+
35
+ if int(max(known_num)) == 0:
36
+ dn_number = 1
37
+ else:
38
+ if dn_number >= 100:
39
+ dn_number = dn_number // (int(max(known_num) * 2))
40
+ elif dn_number < 1:
41
+ dn_number = 1
42
+ if dn_number == 0:
43
+ dn_number = 1
44
+
45
+ unmask_bbox = unmask_label = torch.cat(known)
46
+ labels = torch.cat([t['labels'] for t in targets])
47
+ lines = torch.cat([t['lines'] for t in targets])
48
+ batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
49
+
50
+ known_indice = torch.nonzero(unmask_label + unmask_bbox)
51
+ known_indice = known_indice.view(-1)
52
+
53
+ known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
54
+ known_labels = labels.repeat(2 * dn_number, 1).view(-1)
55
+ known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
56
+ known_lines = lines.repeat(2 * dn_number, 1)
57
+
58
+ known_labels_expaned = known_labels.clone()
59
+ known_lines_expand = known_lines.clone()
60
+
61
+ if label_noise_ratio > 0:
62
+ p = torch.rand_like(known_labels_expaned.float())
63
+ chosen_indice = torch.nonzero(p < (label_noise_ratio * 0.5)).view(-1) # half of bbox prob
64
+ new_label = torch.randint_like(chosen_indice, 0, num_classes) # randomly put a new one here
65
+ known_labels_expaned.scatter_(0, chosen_indice, new_label)
66
+
67
+ single_pad = int(max(known_num))
68
+
69
+ pad_size = int(single_pad * 2 * dn_number)
70
+ positive_idx = torch.tensor(range(len(lines))).long().cuda().unsqueeze(0).repeat(dn_number, 1)
71
+ positive_idx += (torch.tensor(range(dn_number)) * len(lines) * 2).long().cuda().unsqueeze(1)
72
+ positive_idx = positive_idx.flatten()
73
+ negative_idx = positive_idx + len(lines)
74
+
75
+ known_lines_ = known_lines.clone().unflatten(-1, (2, 2))
76
+ offsets = F.normalize(2 * torch.rand_like(known_lines_) - 1, dim=-1)
77
+ rand_part = torch.rand(size=(known_lines.shape[0], 2, 1), device=known_lines.device, dtype=offsets.dtype)
78
+
79
+ rand_part[positive_idx] *= 0.005
80
+ rand_part[negative_idx] *= 0.0645
81
+ rand_part[negative_idx] += 0.0055
82
+
83
+ known_lines_ = known_lines_ + offsets * rand_part
84
+ known_lines_ = known_lines_.flatten(-2)
85
+
86
+ known_lines_expand = known_lines_.clamp(min=0.0, max=1.0)
87
+
88
+ # # order: top point > bottom point
89
+ # # if same y coordinate, right point > left point
90
+
91
+ # idx = torch.logical_or(known_lines_expand[..., 0] > known_lines_expand[..., 2],
92
+ # torch.logical_or(
93
+ # known_lines_expand[..., 0] == known_lines_expand[..., 2],
94
+ # known_lines_expand[..., 1] < known_lines_expand[..., 3]
95
+ # )
96
+ # )
97
+
98
+ # known_lines_expand[idx] = known_lines_expand[idx][:, [2, 3, 0, 1]]
99
+
100
+ m = known_labels_expaned.long().to('cuda')
101
+ input_label_embed = label_enc(m)
102
+ input_lines_embed = inverse_sigmoid(known_lines_expand)
103
+
104
+ padding_label = torch.zeros(pad_size, hidden_dim).cuda()
105
+ padding_lines = torch.zeros(pad_size, 4).cuda()
106
+
107
+ input_query_label = padding_label.repeat(batch_size, 1, 1)
108
+ input_query_lines = padding_lines.repeat(batch_size, 1, 1)
109
+
110
+ map_known_indice = torch.tensor([]).to('cuda')
111
+ if len(known_num):
112
+ map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
113
+ map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
114
+
115
+ if len(known_bid):
116
+ input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed
117
+ input_query_lines[(known_bid.long(), map_known_indice)] = input_lines_embed
118
+
119
+ tgt_size = pad_size + num_queries
120
+ attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
121
+ # match query cannot see the reconstruct
122
+ attn_mask[pad_size:, :pad_size] = True
123
+ # reconstruct cannot see each other
124
+ for i in range(dn_number):
125
+ if i == 0:
126
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
127
+ if i == dn_number - 1:
128
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True
129
+ else:
130
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
131
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True
132
+
133
+ dn_meta = {
134
+ 'pad_size': pad_size,
135
+ 'num_dn_group': dn_number,
136
+ }
137
+ else:
138
+
139
+ input_query_label = None
140
+ input_query_lines = None
141
+ attn_mask = None
142
+ dn_meta = None
143
+
144
+ return input_query_label, input_query_lines, attn_mask, dn_meta
145
+
146
+
147
+ def dn_post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
148
+ """
149
+ post process of dn after output from the transformer
150
+ put the dn part in the dn_meta
151
+ """
152
+ if dn_meta and dn_meta['pad_size'] > 0:
153
+ output_known_class = outputs_class[:, :, :dn_meta['pad_size'], :]
154
+ output_known_coord = outputs_coord[:, :, :dn_meta['pad_size'], :]
155
+ outputs_class = outputs_class[:, :, dn_meta['pad_size']:, :]
156
+ outputs_coord = outputs_coord[:, :, dn_meta['pad_size']:, :]
157
+ out = {'pred_logits': output_known_class[-1], 'pred_lines': output_known_coord[-1]}
158
+ if aux_loss:
159
+ out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord)
160
+ dn_meta['output_known_lbs_lines'] = out
161
+ return outputs_class, outputs_coord
162
+
163
+
linea/models/linea/position_encoding.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # DINO
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Conditional DETR
7
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
8
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
9
+ # ------------------------------------------------------------------------
10
+ # Copied from DETR (https://github.com/facebookresearch/detr)
11
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
12
+ # ------------------------------------------------------------------------
13
+
14
+ """
15
+ Various positional encodings for the transformer.
16
+ """
17
+ import math
18
+ import torch
19
+ from torch import nn
20
+
21
+ from util.misc import NestedTensor
22
+
23
+
24
+ class PositionEmbeddingSine(nn.Module):
25
+ """
26
+ This is a more standard version of the position embedding, very similar to the one
27
+ used by the Attention is all you need paper, generalized to work on images.
28
+ """
29
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
30
+ super().__init__()
31
+ self.num_pos_feats = num_pos_feats
32
+ self.temperature = temperature
33
+ self.normalize = normalize
34
+ if scale is not None and normalize is False:
35
+ raise ValueError("normalize should be True if scale is passed")
36
+ if scale is None:
37
+ scale = 2 * math.pi
38
+ self.scale = scale
39
+
40
+ def forward(self, tensor_list: NestedTensor):
41
+ x = tensor_list.tensors
42
+ mask = tensor_list.mask
43
+ assert mask is not None
44
+ not_mask = ~mask
45
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
46
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
47
+ if self.normalize:
48
+ eps = 1e-6
49
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
50
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
51
+
52
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
53
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
54
+
55
+ pos_x = x_embed[:, :, :, None] / dim_t
56
+ pos_y = y_embed[:, :, :, None] / dim_t
57
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
58
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
59
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
60
+ return pos
61
+
62
+ class PositionEmbeddingSineHW(nn.Module):
63
+ """
64
+ This is a more standard version of the position embedding, very similar to the one
65
+ used by the Attention is all you need paper, generalized to work on images.
66
+ """
67
+ def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
68
+ super().__init__()
69
+ self.num_pos_feats = num_pos_feats
70
+ self.temperatureH = temperatureH
71
+ self.temperatureW = temperatureW
72
+ self.normalize = normalize
73
+ if scale is not None and normalize is False:
74
+ raise ValueError("normalize should be True if scale is passed")
75
+ if scale is None:
76
+ scale = 2 * math.pi
77
+ self.scale = scale
78
+
79
+ def forward(self, tensor_list: NestedTensor):
80
+ x = tensor_list.tensors
81
+ mask = tensor_list.mask
82
+ assert mask is not None
83
+ not_mask = ~mask
84
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
85
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
86
+
87
+
88
+
89
+ if self.normalize:
90
+ eps = 1e-6
91
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
92
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
93
+
94
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
95
+ dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
96
+ pos_x = x_embed[:, :, :, None] / dim_tx
97
+
98
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
99
+ dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
100
+ pos_y = y_embed[:, :, :, None] / dim_ty
101
+
102
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
103
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
104
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
105
+ return pos
106
+
107
+ class PositionEmbeddingLearned(nn.Module):
108
+ """
109
+ Absolute pos embedding, learned.
110
+ """
111
+ def __init__(self, num_pos_feats=256):
112
+ super().__init__()
113
+ self.row_embed = nn.Embedding(50, num_pos_feats)
114
+ self.col_embed = nn.Embedding(50, num_pos_feats)
115
+ self.reset_parameters()
116
+
117
+ def reset_parameters(self):
118
+ nn.init.uniform_(self.row_embed.weight)
119
+ nn.init.uniform_(self.col_embed.weight)
120
+
121
+ def forward(self, tensor_list: NestedTensor):
122
+ x = tensor_list.tensors
123
+ h, w = x.shape[-2:]
124
+ i = torch.arange(w, device=x.device)
125
+ j = torch.arange(h, device=x.device)
126
+ x_emb = self.col_embed(i)
127
+ y_emb = self.row_embed(j)
128
+ pos = torch.cat([
129
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
130
+ y_emb.unsqueeze(1).repeat(1, w, 1),
131
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
132
+ return pos
133
+
134
+
135
+ def build_position_encoding(args):
136
+ N_steps = args.hidden_dim // 2
137
+ if args.position_embedding in ('v2', 'sine'):
138
+ # TODO find a better way of exposing other arguments
139
+ position_embedding = PositionEmbeddingSineHW(
140
+ N_steps,
141
+ temperatureH=args.pe_temperatureH,
142
+ temperatureW=args.pe_temperatureW,
143
+ normalize=True
144
+ )
145
+ elif args.position_embedding in ('v3', 'learned'):
146
+ position_embedding = PositionEmbeddingLearned(N_steps)
147
+ else:
148
+ raise ValueError(f"not supported {args.position_embedding}")
149
+
150
+ return position_embedding
linea/models/linea/utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # DINO
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+ import torch
8
+ from torch import nn, Tensor
9
+
10
+ import math
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+
15
+ def gen_encoder_output_proposals(memory:Tensor, spatial_shapes:Tensor):
16
+ """
17
+ Input:
18
+ - memory: bs, \sum{hw}, d_model
19
+ - memory_padding_mask: bs, \sum{hw}
20
+ - spatial_shapes: nlevel, 2
21
+ - learnedwh: 2
22
+ Output:
23
+ - output_memory: bs, \sum{hw}, d_model
24
+ - output_proposals: bs, \sum{hw}, 4
25
+ """
26
+ N_, S_, C_ = memory.shape
27
+ base_scale = 4.0
28
+ proposals = []
29
+
30
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
31
+
32
+ grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
33
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
34
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
35
+
36
+ scale = torch.tensor([W_, H_], dtype=torch.float32, device=memory.device).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
37
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
38
+
39
+ # wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
40
+
41
+ # proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
42
+ proposal = torch.cat((grid, grid), -1).view(N_, -1, 4)
43
+ proposals.append(proposal)
44
+
45
+ output_proposals = torch.cat(proposals, 1)
46
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
47
+ output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
48
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
49
+
50
+ output_memory = memory
51
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
52
+
53
+ return output_memory, output_proposals
54
+
55
+
56
+ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
57
+ """
58
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
59
+ Args:
60
+ inputs: A float tensor of arbitrary shape.
61
+ The predictions for each example.
62
+ targets: A float tensor with the same shape as inputs. Stores the binary
63
+ classification label for each element in inputs
64
+ (0 for the negative class and 1 for the positive class).
65
+ alpha: (optional) Weighting factor in range (0,1) to balance
66
+ positive vs negative examples. Default = -1 (no weighting).
67
+ gamma: Exponent of the modulating factor (1 - p_t) to
68
+ balance easy vs hard examples.
69
+ Returns:
70
+ Loss tensor
71
+ """
72
+ prob = inputs.sigmoid()
73
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
74
+ p_t = prob * targets + (1 - prob) * (1 - targets)
75
+ loss = ce_loss * ((1 - p_t) ** gamma)
76
+
77
+ if alpha >= 0:
78
+ alpha_t = alpha * targets + (1-alpha) * (1 - targets)
79
+ loss = alpha_t * loss
80
+
81
+ return loss.mean(1).sum() / num_boxes
82
+
83
+
84
+ class MLP(nn.Module):
85
+ """ Very simple multi-layer perceptron (also called FFN)"""
86
+
87
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
88
+ super().__init__()
89
+ self.num_layers = num_layers
90
+ h = [hidden_dim] * (num_layers - 1)
91
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
92
+
93
+ def forward(self, x):
94
+ for i, layer in enumerate(self.layers):
95
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
96
+ return x
97
+
98
+
99
+ def _get_activation_fn(activation, d_model=256, batch_dim=0):
100
+ """Return an activation function given a string"""
101
+ if activation == "relu":
102
+ return F.relu
103
+ if activation == "gelu":
104
+ return F.gelu
105
+ if activation == "glu":
106
+ return F.glu
107
+ if activation == "prelu":
108
+ return nn.PReLU()
109
+ if activation == "selu":
110
+ return F.selu
111
+
112
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
113
+
114
+
115
+ def gen_sineembed_for_position(pos_tensor, hidden_dim):
116
+ # n_query, bs, _ = pos_tensor.size()
117
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
118
+ hidden_dim_ = hidden_dim // 2
119
+ scale = 2 * math.pi
120
+ dim_t = torch.arange(hidden_dim_, dtype=torch.float32, device=pos_tensor.device)
121
+ dim_t = 10000 ** (2 * (dim_t // 2) / hidden_dim_)
122
+ x_embed = pos_tensor[:, :, 0] * scale
123
+ y_embed = pos_tensor[:, :, 1] * scale
124
+ pos_x = x_embed[:, :, None] / dim_t
125
+ pos_y = y_embed[:, :, None] / dim_t
126
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
127
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
128
+
129
+ w_embed = pos_tensor[:, :, 2] * scale
130
+ pos_w = w_embed[:, :, None] / dim_t
131
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
132
+
133
+ h_embed = pos_tensor[:, :, 3] * scale
134
+ pos_h = h_embed[:, :, None] / dim_t
135
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
136
+
137
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
138
+ return pos
139
+
linea/models/registry.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author: Yihao Chen
3
+ # @Date: 2021-08-16 16:03:17
4
+ # @Last Modified by: Shilong Liu
5
+ # @Last Modified time: 2022-01-23 15:26
6
+ # modified from mmcv
7
+
8
+ import inspect
9
+ from functools import partial
10
+
11
+
12
+ class Registry(object):
13
+
14
+ def __init__(self, name):
15
+ self._name = name
16
+ self._module_dict = dict()
17
+
18
+ def __repr__(self):
19
+ format_str = self.__class__.__name__ + '(name={}, items={})'.format(
20
+ self._name, list(self._module_dict.keys()))
21
+ return format_str
22
+
23
+ def __len__(self):
24
+ return len(self._module_dict)
25
+
26
+ @property
27
+ def name(self):
28
+ return self._name
29
+
30
+ @property
31
+ def module_dict(self):
32
+ return self._module_dict
33
+
34
+ def get(self, key):
35
+ return self._module_dict.get(key, None)
36
+
37
+ def registe_with_name(self, module_name=None, force=False):
38
+ return partial(self.register, module_name=module_name, force=force)
39
+
40
+ def register(self, module_build_function, module_name=None, force=False):
41
+ """Register a module build function.
42
+ Args:
43
+ module (:obj:`nn.Module`): Module to be registered.
44
+ """
45
+ if not inspect.isfunction(module_build_function):
46
+ raise TypeError('module_build_function must be a function, but got {}'.format(
47
+ type(module_build_function)))
48
+ if module_name is None:
49
+ module_name = module_build_function.__name__
50
+ if not force and module_name in self._module_dict:
51
+ raise KeyError('{} is already registered in {}'.format(
52
+ module_name, self.name))
53
+ self._module_dict[module_name] = module_build_function
54
+
55
+ return module_build_function
56
+
57
+ MODULE_BUILD_FUNCS = Registry('model build functions')
58
+
linea/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.1
2
+ torchvision>=0.15.2
3
+ scipy
4
+ calflops
5
+ transformers
6
+ tensorboardx
7
+ addict
8
+ yapf
9
+ pycocotools
linea/util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
linea/util/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (151 Bytes). View file
 
linea/util/__pycache__/misc.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
linea/util/__pycache__/slconfig.cpython-311.pyc ADDED
Binary file (24.6 kB). View file
 
linea/util/get_param_dicts.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ import re
6
+
7
+
8
+ def get_optim_params(cfg: list, model: nn.Module):
9
+ """
10
+ E.g.:
11
+ ^(?=.*a)(?=.*b).*$ means including a and b
12
+ ^(?=.*(?:a|b)).*$ means including a or b
13
+ ^(?=.*a)(?!.*b).*$ means including a, but not b
14
+ """
15
+
16
+ param_groups = []
17
+ visited = []
18
+ for pg in cfg:
19
+ pattern = pg['params']
20
+ params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0}
21
+ pg['params'] = params.values()
22
+ param_groups.append(pg)
23
+ visited.extend(list(params.keys()))
24
+
25
+ names = [k for k, v in model.named_parameters() if v.requires_grad]
26
+
27
+ if len(visited) < len(names):
28
+ unseen = set(names) - set(visited)
29
+ params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
30
+ param_groups.append({'params': params.values()})
31
+ visited.extend(list(params.keys()))
32
+
33
+ assert len(visited) == len(names), ''
34
+
35
+ return param_groups
linea/util/misc.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+
5
+ Mostly copy-paste from torchvision references.
6
+ """
7
+ import os
8
+ import time
9
+ from collections import defaultdict, deque
10
+ import datetime
11
+ from typing import Optional, List
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torch import Tensor
16
+
17
+
18
+ class SmoothedValue(object):
19
+ """Track a series of values and provide access to smoothed values over a
20
+ window or the global series average.
21
+ """
22
+
23
+ def __init__(self, window_size=20, fmt=None):
24
+ if fmt is None:
25
+ fmt = "{median:.4f} ({global_avg:.4f})"
26
+ self.deque = deque(maxlen=window_size)
27
+ self.total = 0.0
28
+ self.count = 0
29
+ self.fmt = fmt
30
+
31
+ def update(self, value, n=1):
32
+ self.deque.append(value)
33
+ self.count += n
34
+ self.total += value * n
35
+
36
+ def synchronize_between_processes(self):
37
+ """
38
+ Warning: does not synchronize the deque!
39
+ """
40
+ if not is_dist_avail_and_initialized():
41
+ return
42
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
43
+ dist.barrier()
44
+ dist.all_reduce(t)
45
+ t = t.tolist()
46
+ self.count = int(t[0])
47
+ self.total = t[1]
48
+
49
+ @property
50
+ def median(self):
51
+ d = torch.tensor(list(self.deque))
52
+ if d.shape[0] == 0:
53
+ return 0
54
+ return d.median().item()
55
+
56
+ @property
57
+ def avg(self):
58
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
59
+ return d.mean().item()
60
+
61
+ @property
62
+ def global_avg(self):
63
+ return self.total / self.count
64
+
65
+ @property
66
+ def max(self):
67
+ return max(self.deque)
68
+
69
+ @property
70
+ def value(self):
71
+ return self.deque[-1]
72
+
73
+ def __str__(self):
74
+ return self.fmt.format(
75
+ median=self.median,
76
+ avg=self.avg,
77
+ global_avg=self.global_avg,
78
+ max=self.max,
79
+ value=self.value)
80
+
81
+
82
+ def reduce_dict(input_dict, average=True):
83
+ """
84
+ Args:
85
+ input_dict (dict): all the values will be reduced
86
+ average (bool): whether to do average or sum
87
+ Reduce the values in the dictionary from all processes so that all processes
88
+ have the averaged results. Returns a dict with the same fields as
89
+ input_dict, after reduction.
90
+ """
91
+ world_size = get_world_size()
92
+ if world_size < 2:
93
+ return input_dict
94
+ with torch.no_grad():
95
+ names = []
96
+ values = []
97
+ # sort the keys so that they are consistent across processes
98
+ for k in sorted(input_dict.keys()):
99
+ names.append(k)
100
+ values.append(input_dict[k])
101
+ values = torch.stack(values, dim=0)
102
+ dist.all_reduce(values)
103
+ if average:
104
+ values /= world_size
105
+ reduced_dict = {k: v for k, v in zip(names, values)}
106
+ return reduced_dict
107
+
108
+
109
+ class MetricLogger(object):
110
+ def __init__(self, delimiter="\t"):
111
+ self.meters = defaultdict(SmoothedValue)
112
+ self.delimiter = delimiter
113
+
114
+ def update(self, **kwargs):
115
+ for k, v in kwargs.items():
116
+ if isinstance(v, torch.Tensor):
117
+ v = v.item()
118
+ assert isinstance(v, (float, int))
119
+ self.meters[k].update(v)
120
+
121
+ def __getattr__(self, attr):
122
+ if attr in self.meters:
123
+ return self.meters[attr]
124
+ if attr in self.__dict__:
125
+ return self.__dict__[attr]
126
+ raise AttributeError("'{}' object has no attribute '{}'".format(
127
+ type(self).__name__, attr))
128
+
129
+ def __str__(self):
130
+ loss_str = []
131
+ for name, meter in self.meters.items():
132
+ # print(name, str(meter))
133
+ # import ipdb;ipdb.set_trace()
134
+ if meter.count > 0:
135
+ loss_str.append(
136
+ "{}: {}".format(name, str(meter))
137
+ )
138
+ return self.delimiter.join(loss_str)
139
+
140
+ def synchronize_between_processes(self):
141
+ for meter in self.meters.values():
142
+ meter.synchronize_between_processes()
143
+
144
+ def add_meter(self, name, meter):
145
+ self.meters[name] = meter
146
+
147
+ def log_every(self, iterable, print_freq, header=None, logger=None):
148
+ if logger is None:
149
+ print_func = print
150
+ else:
151
+ print_func = logger.info
152
+
153
+ i = 0
154
+ if not header:
155
+ header = ''
156
+ start_time = time.time()
157
+ end = time.time()
158
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
159
+ data_time = SmoothedValue(fmt='{avg:.4f}')
160
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
161
+ if torch.cuda.is_available():
162
+ log_msg = self.delimiter.join([
163
+ header,
164
+ '[{0' + space_fmt + '}/{1}]',
165
+ 'eta: {eta}',
166
+ '{meters}',
167
+ 'time: {time}',
168
+ 'data: {data}',
169
+ 'max mem: {memory:.0f}'
170
+ ])
171
+ else:
172
+ log_msg = self.delimiter.join([
173
+ header,
174
+ '[{0' + space_fmt + '}/{1}]',
175
+ 'eta: {eta}',
176
+ '{meters}',
177
+ 'time: {time}',
178
+ 'data: {data}'
179
+ ])
180
+ MB = 1024.0 * 1024.0
181
+ for obj in iterable:
182
+ data_time.update(time.time() - end)
183
+ yield obj
184
+
185
+ iter_time.update(time.time() - end)
186
+ if i % print_freq == 0 or i == len(iterable) - 1:
187
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
188
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
189
+ if torch.cuda.is_available():
190
+ print_func(log_msg.format(
191
+ i, len(iterable), eta=eta_string,
192
+ meters=str(self),
193
+ time=str(iter_time), data=str(data_time),
194
+ memory=torch.cuda.max_memory_allocated() / MB))
195
+ else:
196
+ print_func(log_msg.format(
197
+ i, len(iterable), eta=eta_string,
198
+ meters=str(self),
199
+ time=str(iter_time), data=str(data_time)))
200
+ i += 1
201
+ end = time.time()
202
+ total_time = time.time() - start_time
203
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
204
+ print_func('{} Total time: {} ({:.4f} s / it)'.format(
205
+ header, total_time_str, total_time / len(iterable)))
206
+
207
+
208
+ def setup_for_distributed(is_master):
209
+ """
210
+ This function disables printing when not in master process
211
+ """
212
+ import builtins as __builtin__
213
+ builtin_print = __builtin__.print
214
+
215
+ def print(*args, **kwargs):
216
+ force = kwargs.pop('force', False)
217
+ if is_master or force:
218
+ builtin_print(*args, **kwargs)
219
+
220
+ __builtin__.print = print
221
+
222
+
223
+ def is_dist_avail_and_initialized():
224
+ if not dist.is_available():
225
+ return False
226
+ if not dist.is_initialized():
227
+ return False
228
+ return True
229
+
230
+
231
+ def get_world_size():
232
+ if not is_dist_avail_and_initialized():
233
+ return 1
234
+ return dist.get_world_size()
235
+
236
+
237
+ def get_rank():
238
+ if not is_dist_avail_and_initialized():
239
+ return 0
240
+ return dist.get_rank()
241
+
242
+
243
+ def is_main_process():
244
+ return get_rank() == 0
245
+
246
+
247
+ def save_on_master(*args, **kwargs):
248
+ if is_main_process():
249
+ torch.save(*args, **kwargs)
250
+
251
+
252
+ def init_distributed_mode(args):
253
+ try:
254
+ # https://pytorch.org/docs/stable/elastic/run.html
255
+ RANK = int(os.getenv('RANK', -1))
256
+ args.gpu = LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))
257
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
258
+
259
+ torch.distributed.init_process_group(init_method='env://')
260
+ torch.distributed.barrier()
261
+
262
+ rank = torch.distributed.get_rank()
263
+ torch.cuda.set_device(rank)
264
+ torch.cuda.empty_cache()
265
+ args.distributed = True
266
+ setup_for_distributed(get_rank() == 0)
267
+ print('Initialized distributed mode...')
268
+ except:
269
+ print('Not using distributed mode')
270
+ args.distributed = False
271
+ args.world_size = 1
272
+ args.rank = 0
273
+ args.local_rank = 0
274
+ return
275
+
linea/util/profiler.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from calflops import calculate_flops
3
+ from typing import Tuple
4
+
5
+ def stats(
6
+ model, args,
7
+ input_shape: Tuple=(1, 3, 640, 640), ) -> Tuple[int, dict]:
8
+
9
+ base_size = args.eval_spatial_size[0]
10
+ input_shape = (1, 3, base_size, base_size)
11
+
12
+ model_for_info = copy.deepcopy(model).deploy()
13
+
14
+ flops, macs, _ = calculate_flops(model=model_for_info,
15
+ input_shape=input_shape,
16
+ output_as_string=True,
17
+ output_precision=4,
18
+ print_detailed=False)
19
+ params = sum(p.numel() for p in model_for_info.parameters())
20
+ del model_for_info
21
+ return {'flops': flops, 'macs': macs, 'params': params}
linea/util/slconfig.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================================
2
+ # Modified from mmcv
3
+ # ==========================================================
4
+ import os, sys
5
+ import os.path as osp
6
+ import ast
7
+ import tempfile
8
+ import shutil
9
+ from importlib import import_module
10
+
11
+ from argparse import Action
12
+
13
+ from addict import Dict
14
+ from yapf.yapflib.yapf_api import FormatCode
15
+
16
+ import platform
17
+ MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans
18
+
19
+ BASE_KEY = '_base_'
20
+ DELETE_KEY = '_delete_'
21
+ RESERVED_KEYS = ['filename', 'text', 'pretty_text', 'get', 'dump', 'merge_from_dict']
22
+
23
+
24
+ def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
25
+ if not osp.isfile(filename):
26
+ raise FileNotFoundError(msg_tmpl.format(filename))
27
+
28
+ class ConfigDict(Dict):
29
+
30
+ def __missing__(self, name):
31
+ raise KeyError(name)
32
+
33
+ def __getattr__(self, name):
34
+ try:
35
+ value = super(ConfigDict, self).__getattr__(name)
36
+ except KeyError:
37
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no "
38
+ f"attribute '{name}'")
39
+ except Exception as e:
40
+ ex = e
41
+ else:
42
+ return value
43
+ raise ex
44
+
45
+
46
+ class SLConfig(object):
47
+ """
48
+ config files.
49
+ only support .py file as config now.
50
+
51
+ ref: mmcv.utils.config
52
+
53
+ Example:
54
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
55
+ >>> cfg.a
56
+ 1
57
+ >>> cfg.b
58
+ {'b1': [0, 1]}
59
+ >>> cfg.b.b1
60
+ [0, 1]
61
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
62
+ >>> cfg.filename
63
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
64
+ >>> cfg.item4
65
+ 'test'
66
+ >>> cfg
67
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
68
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
69
+ """
70
+ @staticmethod
71
+ def _validate_py_syntax(filename):
72
+ with open(filename) as f:
73
+ content = f.read()
74
+ try:
75
+ ast.parse(content)
76
+ except SyntaxError:
77
+ raise SyntaxError('There are syntax errors in config '
78
+ f'file {filename}')
79
+
80
+ @staticmethod
81
+ def _file2dict(filename):
82
+ filename = osp.abspath(osp.expanduser(filename))
83
+ check_file_exist(filename)
84
+ if filename.lower().endswith('.py'):
85
+ with tempfile.TemporaryDirectory() as temp_config_dir:
86
+ temp_config_file = tempfile.NamedTemporaryFile(
87
+ dir=temp_config_dir, suffix='.py')
88
+ temp_config_name = osp.basename(temp_config_file.name)
89
+ if WINDOWS:
90
+ temp_config_file.close()
91
+ shutil.copyfile(filename,
92
+ osp.join(temp_config_dir, temp_config_name))
93
+ temp_module_name = osp.splitext(temp_config_name)[0]
94
+ sys.path.insert(0, temp_config_dir)
95
+ SLConfig._validate_py_syntax(filename)
96
+ mod = import_module(temp_module_name)
97
+ sys.path.pop(0)
98
+ cfg_dict = {
99
+ name: value
100
+ for name, value in mod.__dict__.items()
101
+ if not name.startswith('__')
102
+ }
103
+ # delete imported module
104
+ del sys.modules[temp_module_name]
105
+ # close temp file
106
+ temp_config_file.close()
107
+ elif filename.lower().endswith(('.yml', '.yaml', '.json')):
108
+ from .slio import slload
109
+ cfg_dict = slload(filename)
110
+ else:
111
+ raise IOError('Only py/yml/yaml/json type are supported now!')
112
+
113
+ cfg_text = filename + '\n'
114
+ with open(filename, 'r') as f:
115
+ cfg_text += f.read()
116
+
117
+ # parse the base file
118
+ if BASE_KEY in cfg_dict:
119
+ cfg_dir = osp.dirname(filename)
120
+ base_filename = cfg_dict.pop(BASE_KEY)
121
+ base_filename = base_filename if isinstance(
122
+ base_filename, list) else [base_filename]
123
+
124
+ cfg_dict_list = list()
125
+ cfg_text_list = list()
126
+ for f in base_filename:
127
+ _cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f))
128
+ cfg_dict_list.append(_cfg_dict)
129
+ cfg_text_list.append(_cfg_text)
130
+
131
+ base_cfg_dict = dict()
132
+ for c in cfg_dict_list:
133
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
134
+ raise KeyError('Duplicate key is not allowed among bases')
135
+ # TODO Allow the duplicate key while warnning user
136
+ base_cfg_dict.update(c)
137
+
138
+ base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict)
139
+ cfg_dict = base_cfg_dict
140
+
141
+ # merge cfg_text
142
+ cfg_text_list.append(cfg_text)
143
+ cfg_text = '\n'.join(cfg_text_list)
144
+
145
+ return cfg_dict, cfg_text
146
+
147
+ @staticmethod
148
+ def _merge_a_into_b(a, b):
149
+ """merge dict `a` into dict `b` (non-inplace).
150
+ values in `a` will overwrite `b`.
151
+ copy first to avoid inplace modification
152
+
153
+ Args:
154
+ a ([type]): [description]
155
+ b ([type]): [description]
156
+
157
+ Returns:
158
+ [dict]: [description]
159
+ """
160
+
161
+ if not isinstance(a, dict):
162
+ return a
163
+
164
+ b = b.copy()
165
+ for k, v in a.items():
166
+ if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
167
+
168
+ if not isinstance(b[k], dict) and not isinstance(b[k], list):
169
+ # if :
170
+
171
+ raise TypeError(
172
+ f'{k}={v} in child config cannot inherit from base '
173
+ f'because {k} is a dict in the child config but is of '
174
+ f'type {type(b[k])} in base config. You may set '
175
+ f'`{DELETE_KEY}=True` to ignore the base config')
176
+ b[k] = SLConfig._merge_a_into_b(v, b[k])
177
+ elif isinstance(b, list):
178
+ try:
179
+ _ = int(k)
180
+ except:
181
+ raise TypeError(
182
+ f'b is a list, '
183
+ f'index {k} should be an int when input but {type(k)}'
184
+ )
185
+ b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)])
186
+ else:
187
+ b[k] = v
188
+
189
+ return b
190
+
191
+ @staticmethod
192
+ def fromfile(filename):
193
+ cfg_dict, cfg_text = SLConfig._file2dict(filename)
194
+ return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename)
195
+
196
+
197
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
198
+ if cfg_dict is None:
199
+ cfg_dict = dict()
200
+ elif not isinstance(cfg_dict, dict):
201
+ raise TypeError('cfg_dict must be a dict, but '
202
+ f'got {type(cfg_dict)}')
203
+ for key in cfg_dict:
204
+ if key in RESERVED_KEYS:
205
+ raise KeyError(f'{key} is reserved for config file')
206
+
207
+ super(SLConfig, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
208
+ super(SLConfig, self).__setattr__('_filename', filename)
209
+ if cfg_text:
210
+ text = cfg_text
211
+ elif filename:
212
+ with open(filename, 'r') as f:
213
+ text = f.read()
214
+ else:
215
+ text = ''
216
+ super(SLConfig, self).__setattr__('_text', text)
217
+
218
+
219
+ @property
220
+ def filename(self):
221
+ return self._filename
222
+
223
+ @property
224
+ def text(self):
225
+ return self._text
226
+
227
+ @property
228
+ def pretty_text(self):
229
+
230
+ indent = 4
231
+
232
+ def _indent(s_, num_spaces):
233
+ s = s_.split('\n')
234
+ if len(s) == 1:
235
+ return s_
236
+ first = s.pop(0)
237
+ s = [(num_spaces * ' ') + line for line in s]
238
+ s = '\n'.join(s)
239
+ s = first + '\n' + s
240
+ return s
241
+
242
+ def _format_basic_types(k, v, use_mapping=False):
243
+ if isinstance(v, str):
244
+ v_str = f"'{v}'"
245
+ else:
246
+ v_str = str(v)
247
+
248
+ if use_mapping:
249
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
250
+ attr_str = f'{k_str}: {v_str}'
251
+ else:
252
+ attr_str = f'{str(k)}={v_str}'
253
+ attr_str = _indent(attr_str, indent)
254
+
255
+ return attr_str
256
+
257
+ def _format_list(k, v, use_mapping=False):
258
+ # check if all items in the list are dict
259
+ if all(isinstance(_, dict) for _ in v):
260
+ v_str = '[\n'
261
+ v_str += '\n'.join(
262
+ f'dict({_indent(_format_dict(v_), indent)}),'
263
+ for v_ in v).rstrip(',')
264
+ if use_mapping:
265
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
266
+ attr_str = f'{k_str}: {v_str}'
267
+ else:
268
+ attr_str = f'{str(k)}={v_str}'
269
+ attr_str = _indent(attr_str, indent) + ']'
270
+ else:
271
+ attr_str = _format_basic_types(k, v, use_mapping)
272
+ return attr_str
273
+
274
+ def _contain_invalid_identifier(dict_str):
275
+ contain_invalid_identifier = False
276
+ for key_name in dict_str:
277
+ contain_invalid_identifier |= \
278
+ (not str(key_name).isidentifier())
279
+ return contain_invalid_identifier
280
+
281
+ def _format_dict(input_dict, outest_level=False):
282
+ r = ''
283
+ s = []
284
+
285
+ use_mapping = _contain_invalid_identifier(input_dict)
286
+ if use_mapping:
287
+ r += '{'
288
+ for idx, (k, v) in enumerate(input_dict.items()):
289
+ is_last = idx >= len(input_dict) - 1
290
+ end = '' if outest_level or is_last else ','
291
+ if isinstance(v, dict):
292
+ v_str = '\n' + _format_dict(v)
293
+ if use_mapping:
294
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
295
+ attr_str = f'{k_str}: dict({v_str}'
296
+ else:
297
+ attr_str = f'{str(k)}=dict({v_str}'
298
+ attr_str = _indent(attr_str, indent) + ')' + end
299
+ elif isinstance(v, list):
300
+ attr_str = _format_list(k, v, use_mapping) + end
301
+ else:
302
+ attr_str = _format_basic_types(k, v, use_mapping) + end
303
+
304
+ s.append(attr_str)
305
+ r += '\n'.join(s)
306
+ if use_mapping:
307
+ r += '}'
308
+ return r
309
+
310
+ cfg_dict = self._cfg_dict.to_dict()
311
+ text = _format_dict(cfg_dict, outest_level=True)
312
+ # copied from setup.cfg
313
+ yapf_style = dict(
314
+ based_on_style='pep8',
315
+ blank_line_before_nested_class_or_def=True,
316
+ split_before_expression_after_opening_paren=True)
317
+ text, _ = FormatCode(text, style_config=yapf_style)#, verify=True)
318
+
319
+ return text
320
+
321
+
322
+ def __repr__(self):
323
+ return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
324
+
325
+ def __len__(self):
326
+ return len(self._cfg_dict)
327
+
328
+ def __getattr__(self, name):
329
+ # # debug
330
+ # print('+'*15)
331
+ # print('name=%s' % name)
332
+ # print("addr:", id(self))
333
+ # # print('type(self):', type(self))
334
+ # print(self.__dict__)
335
+ # print('+'*15)
336
+ # if self.__dict__ == {}:
337
+ # raise ValueError
338
+
339
+ return getattr(self._cfg_dict, name)
340
+
341
+ def __getitem__(self, name):
342
+ return self._cfg_dict.__getitem__(name)
343
+
344
+ def __setattr__(self, name, value):
345
+ if isinstance(value, dict):
346
+ value = ConfigDict(value)
347
+ self._cfg_dict.__setattr__(name, value)
348
+
349
+ def __setitem__(self, name, value):
350
+ if isinstance(value, dict):
351
+ value = ConfigDict(value)
352
+ self._cfg_dict.__setitem__(name, value)
353
+
354
+ def __iter__(self):
355
+ return iter(self._cfg_dict)
356
+
357
+ def dump(self, file=None):
358
+
359
+ if file is None:
360
+ return self.pretty_text
361
+ else:
362
+ with open(file, 'w') as f:
363
+ f.write(self.pretty_text)
364
+
365
+ def merge_from_dict(self, options):
366
+ """Merge list into cfg_dict
367
+
368
+ Merge the dict parsed by MultipleKVAction into this cfg.
369
+
370
+ Examples:
371
+ >>> options = {'model.backbone.depth': 50,
372
+ ... 'model.backbone.with_cp':True}
373
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
374
+ >>> cfg.merge_from_dict(options)
375
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
376
+ >>> assert cfg_dict == dict(
377
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
378
+
379
+ Args:
380
+ options (dict): dict of configs to merge from.
381
+ """
382
+ option_cfg_dict = {}
383
+ for full_key, v in options.items():
384
+ d = option_cfg_dict
385
+ key_list = full_key.split('.')
386
+ for subkey in key_list[:-1]:
387
+ d.setdefault(subkey, ConfigDict())
388
+ d = d[subkey]
389
+ subkey = key_list[-1]
390
+ d[subkey] = v
391
+
392
+ cfg_dict = super(SLConfig, self).__getattribute__('_cfg_dict')
393
+ super(SLConfig, self).__setattr__(
394
+ '_cfg_dict', SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict))
395
+
396
+ # for multiprocess
397
+ def __setstate__(self, state):
398
+ self.__init__(state)
399
+
400
+
401
+ def copy(self):
402
+ return SLConfig(self._cfg_dict.copy())
403
+
404
+ def deepcopy(self):
405
+ return SLConfig(self._cfg_dict.deepcopy())
406
+
407
+
408
+ class DictAction(Action):
409
+ """
410
+ argparse action to split an argument into KEY=VALUE form
411
+ on the first = and append to a dictionary. List options should
412
+ be passed as comma separated values, i.e KEY=V1,V2,V3
413
+ """
414
+
415
+ @staticmethod
416
+ def _parse_int_float_bool(val):
417
+ try:
418
+ return int(val)
419
+ except ValueError:
420
+ pass
421
+ try:
422
+ return float(val)
423
+ except ValueError:
424
+ pass
425
+ if val.lower() in ['true', 'false']:
426
+ return True if val.lower() == 'true' else False
427
+ if val.lower() in ['none', 'null']:
428
+ return None
429
+ return val
430
+
431
+ def __call__(self, parser, namespace, values, option_string=None):
432
+ options = {}
433
+ for kv in values:
434
+ key, val = kv.split('=', maxsplit=1)
435
+ val = [self._parse_int_float_bool(v) for v in val.split(',')]
436
+ if len(val) == 1:
437
+ val = val[0]
438
+ options[key] = val
439
+ setattr(namespace, self.dest, options)
440
+