File size: 46,983 Bytes
22b5634 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 |
# coding=utf-8
# @Author : Saurabhchand Bhati
# @Affiliation : Massachusetts Institute of Technology
# VMamba backbone is from https://github.com/MzeroMiko/VMamba/blob/main/vmamba.py
# VMambaLayer, VMambaModel, VMambaForImageClassification are implemnted based on VMamba
# SS2Dv0, SS2Dv1, SS2S are merged into one class and initiliazation is limited to v05_noz,
# patch embeddings is limited to v2 and downsample is limited to v3.
# MIT License
# Copyright (c) 2024 MzeroMiko, Saurabhchand Bhati
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""VMamba: Visual State Space Model configuration model"""
import math
import torch
import warnings
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, trunc_normal_
from functools import partial
from typing import Optional, Callable, Any, Union
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from transformers.modeling_outputs import ImageClassifierOutput
from transformers.utils import logging
from transformers.modeling_utils import PreTrainedModel
from .configuration_vmamba import VMambaConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "VMambaConfig"
WITH_TRITON = True
# WITH_TRITON = False
try:
import triton
import triton.language as tl
except:
WITH_TRITON = False
warnings.warn("Triton not installed, fall back to pytorch implements.")
# to make sure cached_property can be loaded for triton
if WITH_TRITON:
try:
from functools import cached_property
except:
warnings.warn("if you are using py37, add this line to functools.py: "
"cached_property = lambda func: property(lru_cache()(func))")
# torch implementation ========================================
def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if in_channel_first:
B, C, H, W = x.shape
if scans == 0:
y = x.new_empty((B, 4, C, H * W))
y[:, 0, :, :] = x.flatten(2, 3)
y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
elif scans == 1:
y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
elif scans == 2:
y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
y = torch.cat([y, y.flip(dims=[-1])], dim=1)
elif scans == 3:
y = x.new_empty((B, 4, C, H * W))
y[:, 0, :, :] = x.flatten(2, 3)
y[:, 1, :, :] = torch.rot90(x, 1, dims=(2, 3)).flatten(2, 3)
y[:, 2, :, :] = torch.rot90(x, 2, dims=(2, 3)).flatten(2, 3)
y[:, 3, :, :] = torch.rot90(x, 3, dims=(2, 3)).flatten(2, 3)
else:
B, H, W, C = x.shape
if scans == 0:
y = x.new_empty((B, H * W, 4, C))
y[:, :, 0, :] = x.flatten(1, 2)
y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
elif scans == 1:
y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1)
elif scans == 2:
y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
y = torch.cat([y, y.flip(dims=[1])], dim=2)
elif scans == 3:
y = x.new_empty((B, H * W, 4, C))
y[:, :, 0, :] = x.flatten(1, 2)
y[:, :, 1, :] = torch.rot90(x, 1, dims=(1, 2)).flatten(1, 2)
y[:, :, 2, :] = torch.rot90(x, 2, dims=(1, 2)).flatten(1, 2)
y[:, :, 3, :] = torch.rot90(x, 3, dims=(1, 2)).flatten(1, 2)
if in_channel_first and (not out_channel_first):
y = y.permute(0, 3, 1, 2).contiguous()
elif (not in_channel_first) and out_channel_first:
y = y.permute(0, 2, 3, 1).contiguous()
return y
def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if out_channel_first:
B, K, D, H, W = y.shape
y = y.view(B, K, D, -1)
if scans == 0:
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
elif scans == 1:
y = y.sum(1)
elif scans == 2:
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
y = y.sum(1)
elif scans == 3:
oy = y[:, 0, :, :].contiguous().view(B, D, -1)
oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3)
oy = oy + torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3)
oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3)
y = oy
else:
B, H, W, K, D = y.shape
y = y.view(B, -1, K, D)
if scans == 0:
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
elif scans == 1:
y = y.sum(2)
elif scans == 2:
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
y = y.sum(2)
elif scans == 3:
oy = y[:, :, 0, :].contiguous().view(B, -1, D)
oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2)
oy = oy + torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2)
oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2)
y = oy
if in_channel_first and (not out_channel_first):
y = y.permute(0, 2, 1).contiguous()
elif (not in_channel_first) and out_channel_first:
y = y.permute(0, 2, 1).contiguous()
return y
def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if in_channel_first:
B, _, C, H, W = x.shape
if scans == 0:
y = torch.stack([
x[:, 0].flatten(2, 3),
x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
], dim=1)
elif scans == 1:
y = x.flatten(2, 3)
elif scans == 2:
y = torch.stack([
x[:, 0].flatten(2, 3),
x[:, 1].flatten(2, 3),
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
], dim=1)
elif scans == 3:
y = torch.stack([
x[:, 0, :, :, :].flatten(2, 3),
torch.rot90(x[:, 1, :, :, :], 1, dims=(2, 3)).flatten(2, 3),
torch.rot90(x[:, 2, :, :, :], 2, dims=(2, 3)).flatten(2, 3),
torch.rot90(x[:, 3, :, :, :], 3, dims=(2, 3)).flatten(2, 3),
], dim=1)
else:
B, H, W, _, C = x.shape
if scans == 0:
y = torch.stack([
x[:, :, :, 0].flatten(1, 2),
x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
], dim=2)
elif scans == 1:
y = x.flatten(1, 2)
elif scans == 2:
y = torch.stack([
x[:, 0].flatten(1, 2),
x[:, 1].flatten(1, 2),
torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
], dim=2)
elif scans == 3:
y = torch.stack([
x[:, :, :, 0, :].flatten(1, 2),
torch.rot90(x[:, :, :, 1, :], 1, dims=(1, 2)).flatten(1, 2),
torch.rot90(x[:, :, :, 2, :], 2, dims=(1, 2)).flatten(1, 2),
torch.rot90(x[:, :, :, 3, :], 3, dims=(1, 2)).flatten(1, 2),
], dim=1)
if in_channel_first and (not out_channel_first):
y = y.permute(0, 3, 1, 2).contiguous()
elif (not in_channel_first) and out_channel_first:
y = y.permute(0, 2, 3, 1).contiguous()
return y
def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if out_channel_first:
B, K, D, H, W = y.shape
y = y.view(B, K, D, -1)
if scans == 0:
y = torch.stack([
y[:, 0],
y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
torch.flip(y[:, 2], dims=[-1]),
torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
], dim=1)
elif scans == 1:
y = y
elif scans == 2:
y = torch.stack([
y[:, 0],
y[:, 1],
torch.flip(y[:, 2], dims=[-1]),
torch.flip(y[:, 3], dims=[-1]),
], dim=1)
elif scans == 3:
y = torch.stack([
y[:, 0, :, :].contiguous().view(B, D, -1),
torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3),
torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3),
torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3),
], dim=1)
else:
B, H, W, K, D = y.shape
y = y.view(B, -1, K, D)
if scans == 0:
y = torch.stack([
y[:, :, 0],
y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
torch.flip(y[:, :, 2], dims=[1]),
torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
], dim=2)
elif scans == 1:
y = y
elif scans == 2:
y = torch.stack([
y[:, :, 0],
y[:, :, 1],
torch.flip(y[:, :, 2], dims=[1]),
torch.flip(y[:, :, 3], dims=[1]),
], dim=2)
elif scans == 3:
y = torch.stack([
y[:, :, 0, :].contiguous().view(B, -1, D),
torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2),
torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2),
torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2),
], dim=2)
if out_channel_first and (not in_channel_first):
y = y.permute(0, 3, 1, 2).contiguous()
elif (not out_channel_first) and in_channel_first:
y = y.permute(0, 2, 3, 1).contiguous()
return y
class CrossScanF(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
# y: (B, 4, C, H * W) | (B, H * W, 4, C)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
if one_by_one:
B, K, C, H, W = x.shape
if not in_channel_first:
B, H, W, K, C = x.shape
else:
B, C, H, W = x.shape
if not in_channel_first:
B, H, W, C = x.shape
ctx.shape = (B, C, H, W)
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
y = _fn(x, in_channel_first, out_channel_first, scans)
return y
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
y = _fn(ys, in_channel_first, out_channel_first, scans)
if one_by_one:
y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1)
else:
y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
return y, None, None, None, None
class CrossMergeF(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
# y: (B, 4, C, H * W) | (B, H * W, 4, C)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
B, K, C, H, W = ys.shape
if not out_channel_first:
B, H, W, K, C = ys.shape
ctx.shape = (B, C, H, W)
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
y = _fn(ys, in_channel_first, out_channel_first, scans)
return y
@staticmethod
def backward(ctx, x: torch.Tensor):
# B, D, L = x.shape
# out: (b, k, d, h, w)
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
if not one_by_one:
if in_channel_first:
x = x.view(B, C, H, W)
else:
x = x.view(B, H, W, C)
else:
if in_channel_first:
x = x.view(B, 4, C, H, W)
else:
x = x.view(B, H, W, 4, C)
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
x = _fn(x, in_channel_first, out_channel_first, scans)
x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C)
return x, None, None, None, None
# triton implements ========================================
@triton.jit
def triton_cross_scan_flex(
x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C)
x_layout: tl.constexpr,
y_layout: tl.constexpr,
operation: tl.constexpr,
onebyone: tl.constexpr,
scans: tl.constexpr,
BC: tl.constexpr,
BH: tl.constexpr,
BW: tl.constexpr,
DC: tl.constexpr,
DH: tl.constexpr,
DW: tl.constexpr,
NH: tl.constexpr,
NW: tl.constexpr,
):
# x_layout = 0
# y_layout = 1 # 0 BCHW, 1 BHWC
# operation = 0 # 0 scan, 1 merge
# onebyone = 0 # 0 false, 1 true
# scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h, i_w = (i_hw // NW), (i_hw % NW)
_mask_h = (i_h * BH + tl.arange(0, BH)) < DH
_mask_w = (i_w * BW + tl.arange(0, BW)) < DW
_mask_hw = _mask_h[:, None] & _mask_w[None, :]
_for_C = min(DC - i_c * BC, BC)
pos_h = (i_h * BH + tl.arange(0, BH)[:, None])
pos_w = (i_w * BW + tl.arange(0, BW)[None, :])
neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None])
neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :])
if scans == 0:
# none; trans; flip; trans + flip;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = pos_w * DH + pos_h # trans
HWRoute2 = neg_h * DW + neg_w # flip
HWRoute3 = neg_w * DH + neg_h # trans + flip
elif scans == 1:
# none; none; none; none;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = HWRoute0
HWRoute2 = HWRoute0
HWRoute3 = HWRoute0
elif scans == 2:
# none; none; flip; flip;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = HWRoute0
HWRoute2 = neg_h * DW + neg_w # flip
HWRoute3 = HWRoute2
elif scans == 3:
# none; rot90; rot180==flip; rot270;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = neg_w * DH + pos_h
HWRoute2 = neg_h * DW + neg_w
HWRoute3 = pos_w * DH + neg_h
_tmp1 = DC * DH * DW
y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
if y_layout == 0:
p_y1 = y_ptr_base + HWRoute0
p_y2 = y_ptr_base + _tmp1 + HWRoute1
p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
else:
p_y1 = y_ptr_base + HWRoute0 * 4 * DC
p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
if onebyone == 0:
x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
if x_layout == 0:
p_x = x_ptr_base + HWRoute0
else:
p_x = x_ptr_base + HWRoute0 * DC
if operation == 0:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
_x = tl.load(p_x + _idx_x, mask=_mask_hw)
tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
elif operation == 1:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
_y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
_y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
_y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
_y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
else:
x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
if x_layout == 0:
p_x1 = x_ptr_base + HWRoute0
p_x2 = p_x1 + _tmp1
p_x3 = p_x2 + _tmp1
p_x4 = p_x3 + _tmp1
else:
p_x1 = x_ptr_base + HWRoute0 * 4 * DC
p_x2 = p_x1 + DC
p_x3 = p_x2 + DC
p_x4 = p_x3 + DC
if operation == 0:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
else:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
class CrossScanTritonF(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
if one_by_one:
if in_channel_first:
B, _, C, H, W = x.shape
else:
B, H, W, _, C = x.shape
else:
if in_channel_first:
B, C, H, W = x.shape
else:
B, H, W, C = x.shape
B, C, H, W = int(B), int(C), int(H), int(W)
BC, BH, BW = 1, 32, 32
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
ctx.shape = (B, C, H, W)
ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x.contiguous(), y,
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return y
@staticmethod
def backward(ctx, y: torch.Tensor):
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
BC, BH, BW, NC, NH, NW = ctx.triton_shape
if one_by_one:
x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))
else:
x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x, y.contiguous(),
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return x, None, None, None, None
class CrossMergeTritonF(torch.autograd.Function):
@staticmethod
def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
if out_channel_first:
B, _, C, H, W = y.shape
else:
B, H, W, _, C = y.shape
B, C, H, W = int(B), int(C), int(H), int(W)
BC, BH, BW = 1, 32, 32
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
ctx.shape = (B, C, H, W)
ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
if one_by_one:
x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))
else:
x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x, y.contiguous(),
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return x
@staticmethod
def backward(ctx, x: torch.Tensor):
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
BC, BH, BW, NC, NH, NW = ctx.triton_shape
y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x.contiguous(), y,
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return y, None, None, None, None, None
# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
# y: (B, 4, C, L) | (B, L, 4, C)
# scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
if x.is_cuda:
with torch.cuda.device(x.device):
return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
else:
return CrossScanF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
# y: (B, 4, C, L) | (B, L, 4, C)
# x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C)
# scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
if y.is_cuda:
with torch.cuda.device(y.device):
return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
else:
return CrossMergeF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
##########################################################
# csms6s.py
##########################################################
WITH_SELECTIVESCAN_MAMBA = True
try:
import selective_scan_cuda
except ImportError:
WITH_SELECTIVESCAN_MAMBA = False
def selective_scan_torch(
u: torch.Tensor, # (B, K * C, L)
delta: torch.Tensor, # (B, K * C, L)
A: torch.Tensor, # (K * C, N)
B: torch.Tensor, # (B, K, N, L)
C: torch.Tensor, # (B, K, N, L)
D: torch.Tensor = None, # (K * C)
delta_bias: torch.Tensor = None, # (K * C)
delta_softplus=True,
oflex=True,
*args,
**kwargs
):
dtype_in = u.dtype
Batch, K, N, L = B.shape
KCdim = u.shape[1]
Cdim = int(KCdim / K)
assert u.shape == (Batch, KCdim, L)
assert delta.shape == (Batch, KCdim, L)
assert A.shape == (KCdim, N)
assert C.shape == B.shape
if delta_bias is not None:
delta = delta + delta_bias[..., None]
if delta_softplus:
delta = torch.nn.functional.softplus(delta)
u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float()
B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if True:
x = A.new_zeros((Batch, KCdim, N))
ys = []
for i in range(L):
x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :]
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
ys.append(y)
y = torch.stack(ys, dim=2) # (B, C, L)
out = y if D is None else y + u * D.unsqueeze(-1)
return out if oflex else out.to(dtype=dtype_in)
class SelectiveScanCuda(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None):
ctx.delta_softplus = delta_softplus
# backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend
# backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend
backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend
ctx.backend = backend
if backend == "oflex":
out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
elif backend == "mamba":
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout, *args):
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
backend = ctx.backend
if dout.stride(-1) != 1:
dout = dout.contiguous()
if backend == "oflex":
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
)
elif backend == "mamba":
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
False
)
return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None
def selective_scan_fn(
u: torch.Tensor, # (B, K * C, L)
delta: torch.Tensor, # (B, K * C, L)
A: torch.Tensor, # (K * C, N)
B: torch.Tensor, # (B, K, N, L)
C: torch.Tensor, # (B, K, N, L)
D: torch.Tensor = None, # (K * C)
delta_bias: torch.Tensor = None, # (K * C)
delta_softplus=True,
oflex=True,
backend=None,
):
fn = selective_scan_torch if backend == "torch" or (not WITH_SELECTIVESCAN_MAMBA) else SelectiveScanCuda.apply
return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend)
##########################################################
############## HuggingFace modeling file #################
##########################################################
class VMambaLinear2d(nn.Linear):
def __init__(self, *args, groups=1, **kwargs):
nn.Linear.__init__(self, *args, **kwargs)
self.groups = groups
def forward(self, x: torch.Tensor):
if len(x.shape) == 4:
return F.conv2d(x, self.weight[:, :, None, None], self.bias, groups=self.groups)
elif len(x.shape) == 3:
return F.conv1d(x, self.weight[:, :, None], self.bias, groups=self.groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
self_state_dict = self.state_dict()
load_state_dict_keys = list(state_dict.keys())
if prefix + "weight" in load_state_dict_keys:
state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view_as(self_state_dict["weight"])
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class VMambaLayerNorm2d(nn.LayerNorm):
def __init__(self, *args, **kwargs):
nn.LayerNorm.__init__(self, *args, **kwargs)
def forward(self, x: torch.Tensor):
x = x.permute(0, 2, 3, 1)
x = nn.LayerNorm.forward(self, x)
x = x.permute(0, 3, 1, 2)
return x
class VMambaPatchEmbeddings(nn.Module):
"""
This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size,
seq_length, hidden_size)` to be consumed by a State-space model.
"""
def __init__(self, num_channels=3,patch_size=4,embed_dim=96):
super().__init__()
stride = patch_size // 2
kernel_size = stride + 1
padding = 1
self.projection = nn.Sequential(
nn.Conv2d(num_channels, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding),
VMambaLayerNorm2d(embed_dim // 2),
nn.GELU(),
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
VMambaLayerNorm2d(embed_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.projection(x)
return x
class VMambaDowsample(nn.Module):
"""
This class downsamples the input tensor using a convolutional layer followed by a layer normalization.
"""
def __init__(self, dim, out_dim, use_norm=True):
super().__init__()
self.down = nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1)
self.norm = VMambaLayerNorm2d(out_dim) if use_norm else nn.Identity()
def forward(self, x):
x = self.down(x)
x = self.norm(x)
return x
class VMambaMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = VMambaLinear2d(in_features, hidden_features)
self.act = act_layer()
self.fc2 = VMambaLinear2d(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SS2D(nn.Module):
def __init__(
self,
# basic dims ===========
d_model=96,
d_state=16,
ssm_ratio=2.0,
dt_rank="auto",
act_layer=nn.SiLU,
# dwconv ===============
d_conv=3,
conv_bias=True,
# ======================
dropout=0.0,
bias=False,
# dt init ==============
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
# forward_type="v05_noz" is always used
# ======================
**kwargs,
):
super().__init__()
self.k_group = 4
self.d_model = int(d_model)
self.d_state = int(d_state)
self.d_inner = int(ssm_ratio * d_model)
self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank)
self.forward_core = partial(self.forward_corev2, force_fp32=False, no_einsum=True)
self.with_dconv = d_conv > 1
# In projection
self.in_proj = VMambaLinear2d(self.d_model, self.d_inner, bias=bias)
self.act: nn.Module = act_layer()
# Convolution
if self.with_dconv:
self.conv2d = nn.Conv2d(
in_channels=self.d_inner,
out_channels=self.d_inner,
groups=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
)
# x_proj and dt_proj
self.x_proj = VMambaLinear2d(self.d_inner, self.k_group * (self.dt_rank + self.d_state * 2), groups=self.k_group, bias=False)
self.dt_projs = VMambaLinear2d(self.dt_rank, self.k_group * self.d_inner, groups=self.k_group, bias=False)
# out projection
self.out_proj = VMambaLinear2d(self.d_inner, self.d_model, bias=bias)
self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
# Initialization
self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = self.init_dt_A_D(
self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group,
)
self.dt_projs.weight.data = self.dt_projs_weight.data.view(self.dt_projs.weight.shape)
# self.dt_projs.bias.data = self.dt_projs_bias.data.view(self.dt_projs.bias.shape)
del self.dt_projs_weight
# del self.dt_projs_bias
# Define out_norm directly with "LN2D"
self.out_norm = VMambaLayerNorm2d(self.d_inner)
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
dt_init_std = dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
dt = torch.exp(
torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
A_log = torch.log(A)
if copies > 0:
A_log = A_log[None].repeat(copies, 1, 1).contiguous()
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=-1, device=None, merge=True):
D = torch.ones(d_inner, device=device)
if copies > 0:
D = D[None].repeat(copies, 1).contiguous()
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D)
D._no_weight_decay = True
return D
@classmethod
def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
dt_projs = [
cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
for _ in range(k_group)
]
dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0))
dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0))
del dt_projs
A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True)
Ds = cls.D_init(d_inner, copies=k_group, merge=True)
return A_logs, Ds, dt_projs_weight, dt_projs_bias
def forward_corev2(
self,
x: torch.Tensor,
force_fp32=False,
no_einsum=True,
):
B, D, H, W = x.shape
N = self.d_state
L = H * W
xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True)
x_dbl = self.x_proj(xs.view(B, -1, L))
dts, Bs, Cs = torch.split(x_dbl.view(B, self.k_group, -1, L), [self.dt_rank, N, N], dim=2)
dts = dts.contiguous().view(B, -1, L)
dts = self.dt_projs(dts)
xs = xs.view(B, -1, L)
dts = dts.contiguous().view(B, -1, L)
As = -self.A_logs.to(torch.float32).exp()
Ds = self.Ds.to(torch.float32)
Bs = Bs.contiguous().view(B, self.k_group, N, L)
Cs = Cs.contiguous().view(B, self.k_group, N, L)
delta_bias = self.dt_projs_bias.view(-1).to(torch.float32)
ys = selective_scan_fn(
xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus=True, backend="mamba"
).view(B, self.k_group, -1, H, W)
y = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True)
y = y.view(B, -1, H, W)
y = self.out_norm(y)
return y.to(x.dtype)
def forward(self, x: torch.Tensor):
x = self.in_proj(x)
x = self.conv2d(x)
x = self.act(x)
y = self.forward_core(x)
out = self.dropout(self.out_proj(y))
return out
class VSSBlock(nn.Module):
def __init__(
self,
hidden_dim: int = 0,
drop_path: float = 0,
ssm_d_state: int = 1,
ssm_ratio=1.0,
ssm_dt_rank: Any = "auto",
ssm_act_layer=nn.SiLU,
ssm_conv: int = 3,
ssm_conv_bias=False,
ssm_drop_rate: float = 0,
mlp_ratio=4.0,
mlp_act_layer=nn.GELU,
mlp_drop_rate: float = 0.0,
use_checkpoint: bool = False,
post_norm: bool = False,
**kwargs,
):
super().__init__()
self.ssm_branch = ssm_ratio > 0
self.mlp_branch = mlp_ratio > 0
self.use_checkpoint = use_checkpoint
self.post_norm = post_norm
if self.ssm_branch:
self.norm = VMambaLayerNorm2d(hidden_dim)
self.op = SS2D(
d_model=hidden_dim,
d_state=ssm_d_state,
ssm_ratio=ssm_ratio,
dt_rank=ssm_dt_rank,
act_layer=ssm_act_layer,
d_conv=ssm_conv,
conv_bias=ssm_conv_bias,
dropout=ssm_drop_rate,
)
self.drop_path = DropPath(drop_path)
if self.mlp_branch:
self.norm2 = VMambaLayerNorm2d(hidden_dim)
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
self.mlp = VMambaMlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate)
def _forward(self, input: torch.Tensor):
x = input
if self.ssm_branch:
if self.post_norm:
x = x + self.drop_path(self.norm(self.op(x)))
else:
x = x + self.drop_path(self.op(self.norm(x)))
if self.mlp_branch:
if self.post_norm:
x = x + self.drop_path(self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def forward(self, input: torch.Tensor):
if self.use_checkpoint:
return checkpoint.checkpoint(self._forward, input)
else:
return self._forward(input)
class VMambaLayer(nn.Module):
def __init__(
self,
input_dim,
depth,
drop_path=0.0,
norm_layer=VMambaLayerNorm2d,
downsample=nn.Identity(),
use_checkpoint=False,
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList()
for i in range(depth):
self.blocks.append(
VSSBlock(hidden_dim=input_dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,use_checkpoint=use_checkpoint,**kwargs,
)
)
self.downsample = downsample
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.downsample(x)
return x
class VMambaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = VMambaConfig
base_model_prefix = "vmamba"
supports_gradient_checkpointing = False
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
class VMambaModel(VMambaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
dims = config.dims
if isinstance(dims, int):
dims = [int(dims * 2**i_layer) for i_layer in range(self.num_layers)]
self.dims = dims
self.patch_embeddings = VMambaPatchEmbeddings(patch_size=config.patch_size,
embed_dim=dims[0])
self.num_layers = len(config.depths)
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.num_features = dims[-1]
self.layers = nn.ModuleList()
for i in range(self.num_layers):
layer = VMambaLayer(
input_dim=self.dims[i],
depth=config.depths[i],
drop_path=dpr[sum(config.depths[:i]):sum(config.depths[:i+1])],
downsample=VMambaDowsample(self.dims[i], self.dims[i+1]) if i < self.num_layers - 1 else nn.Identity(),
use_checkpoint=config.use_checkpoint,
)
self.layers.append(layer)
self.norm = VMambaLayerNorm2d(self.num_features)
self.avgpool = nn.AdaptiveAvgPool2d(1)
def get_input_embeddings(self) -> VMambaPatchEmbeddings:
return self.patch_embeddings
def forward(self, input_values: torch.Tensor):
x = self.patch_embeddings(input_values)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x = self.avgpool(x).flatten(1)
return x
class VMambaForImageClassification(VMambaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_classes = config.num_classes
self.vmamba = VMambaModel(config)
self.head = nn.Linear(self.vmamba.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
):
outputs = self.vmamba(
pixel_values,
)
logits = self.head(outputs)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.loss_type == "ce":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "bce":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if return_dict:
output = (logits,) + (outputs,)
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs,
)
__all__ = [
"VMambaModel",
"VMambaPreTrainedModel",
"VMambaForImageClassification",
]
|