|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""VMamba: Visual State Space Model configuration model""" |
|
|
|
import math |
|
import torch |
|
import warnings |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint as checkpoint |
|
from timm.models.layers import DropPath, trunc_normal_ |
|
from functools import partial |
|
from typing import Optional, Callable, Any, Union |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss |
|
from transformers.modeling_outputs import ImageClassifierOutput |
|
|
|
from transformers.utils import logging |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
from .configuration_vmamba import VMambaConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
_CONFIG_FOR_DOC = "VMambaConfig" |
|
|
|
WITH_TRITON = True |
|
|
|
try: |
|
import triton |
|
import triton.language as tl |
|
except: |
|
WITH_TRITON = False |
|
warnings.warn("Triton not installed, fall back to pytorch implements.") |
|
|
|
|
|
if WITH_TRITON: |
|
try: |
|
from functools import cached_property |
|
except: |
|
warnings.warn("if you are using py37, add this line to functools.py: " |
|
"cached_property = lambda func: property(lru_cache()(func))") |
|
|
|
|
|
def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
|
if in_channel_first: |
|
B, C, H, W = x.shape |
|
if scans == 0: |
|
y = x.new_empty((B, 4, C, H * W)) |
|
y[:, 0, :, :] = x.flatten(2, 3) |
|
y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3) |
|
y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1]) |
|
elif scans == 1: |
|
y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) |
|
elif scans == 2: |
|
y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) |
|
y = torch.cat([y, y.flip(dims=[-1])], dim=1) |
|
elif scans == 3: |
|
y = x.new_empty((B, 4, C, H * W)) |
|
y[:, 0, :, :] = x.flatten(2, 3) |
|
y[:, 1, :, :] = torch.rot90(x, 1, dims=(2, 3)).flatten(2, 3) |
|
y[:, 2, :, :] = torch.rot90(x, 2, dims=(2, 3)).flatten(2, 3) |
|
y[:, 3, :, :] = torch.rot90(x, 3, dims=(2, 3)).flatten(2, 3) |
|
else: |
|
B, H, W, C = x.shape |
|
if scans == 0: |
|
y = x.new_empty((B, H * W, 4, C)) |
|
y[:, :, 0, :] = x.flatten(1, 2) |
|
y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2) |
|
y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1]) |
|
elif scans == 1: |
|
y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1) |
|
elif scans == 2: |
|
y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1) |
|
y = torch.cat([y, y.flip(dims=[1])], dim=2) |
|
elif scans == 3: |
|
y = x.new_empty((B, H * W, 4, C)) |
|
y[:, :, 0, :] = x.flatten(1, 2) |
|
y[:, :, 1, :] = torch.rot90(x, 1, dims=(1, 2)).flatten(1, 2) |
|
y[:, :, 2, :] = torch.rot90(x, 2, dims=(1, 2)).flatten(1, 2) |
|
y[:, :, 3, :] = torch.rot90(x, 3, dims=(1, 2)).flatten(1, 2) |
|
|
|
if in_channel_first and (not out_channel_first): |
|
y = y.permute(0, 3, 1, 2).contiguous() |
|
elif (not in_channel_first) and out_channel_first: |
|
y = y.permute(0, 2, 3, 1).contiguous() |
|
|
|
return y |
|
|
|
|
|
def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
|
if out_channel_first: |
|
B, K, D, H, W = y.shape |
|
y = y.view(B, K, D, -1) |
|
if scans == 0: |
|
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
|
y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) |
|
elif scans == 1: |
|
y = y.sum(1) |
|
elif scans == 2: |
|
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
|
y = y.sum(1) |
|
elif scans == 3: |
|
oy = y[:, 0, :, :].contiguous().view(B, D, -1) |
|
oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3) |
|
oy = oy + torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3) |
|
oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3) |
|
y = oy |
|
else: |
|
B, H, W, K, D = y.shape |
|
y = y.view(B, -1, K, D) |
|
if scans == 0: |
|
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) |
|
y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D) |
|
elif scans == 1: |
|
y = y.sum(2) |
|
elif scans == 2: |
|
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) |
|
y = y.sum(2) |
|
elif scans == 3: |
|
oy = y[:, :, 0, :].contiguous().view(B, -1, D) |
|
oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2) |
|
oy = oy + torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2) |
|
oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2) |
|
y = oy |
|
|
|
if in_channel_first and (not out_channel_first): |
|
y = y.permute(0, 2, 1).contiguous() |
|
elif (not in_channel_first) and out_channel_first: |
|
y = y.permute(0, 2, 1).contiguous() |
|
|
|
return y |
|
|
|
|
|
def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
|
if in_channel_first: |
|
B, _, C, H, W = x.shape |
|
if scans == 0: |
|
y = torch.stack([ |
|
x[:, 0].flatten(2, 3), |
|
x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3), |
|
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), |
|
torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), |
|
], dim=1) |
|
elif scans == 1: |
|
y = x.flatten(2, 3) |
|
elif scans == 2: |
|
y = torch.stack([ |
|
x[:, 0].flatten(2, 3), |
|
x[:, 1].flatten(2, 3), |
|
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), |
|
torch.flip(x[:, 3].flatten(2, 3), dims=[-1]), |
|
], dim=1) |
|
elif scans == 3: |
|
y = torch.stack([ |
|
x[:, 0, :, :, :].flatten(2, 3), |
|
torch.rot90(x[:, 1, :, :, :], 1, dims=(2, 3)).flatten(2, 3), |
|
torch.rot90(x[:, 2, :, :, :], 2, dims=(2, 3)).flatten(2, 3), |
|
torch.rot90(x[:, 3, :, :, :], 3, dims=(2, 3)).flatten(2, 3), |
|
], dim=1) |
|
|
|
else: |
|
B, H, W, _, C = x.shape |
|
if scans == 0: |
|
y = torch.stack([ |
|
x[:, :, :, 0].flatten(1, 2), |
|
x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2), |
|
torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]), |
|
torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), |
|
], dim=2) |
|
elif scans == 1: |
|
y = x.flatten(1, 2) |
|
elif scans == 2: |
|
y = torch.stack([ |
|
x[:, 0].flatten(1, 2), |
|
x[:, 1].flatten(1, 2), |
|
torch.flip(x[:, 2].flatten(1, 2), dims=[-1]), |
|
torch.flip(x[:, 3].flatten(1, 2), dims=[-1]), |
|
], dim=2) |
|
elif scans == 3: |
|
y = torch.stack([ |
|
x[:, :, :, 0, :].flatten(1, 2), |
|
torch.rot90(x[:, :, :, 1, :], 1, dims=(1, 2)).flatten(1, 2), |
|
torch.rot90(x[:, :, :, 2, :], 2, dims=(1, 2)).flatten(1, 2), |
|
torch.rot90(x[:, :, :, 3, :], 3, dims=(1, 2)).flatten(1, 2), |
|
], dim=1) |
|
|
|
if in_channel_first and (not out_channel_first): |
|
y = y.permute(0, 3, 1, 2).contiguous() |
|
elif (not in_channel_first) and out_channel_first: |
|
y = y.permute(0, 2, 3, 1).contiguous() |
|
|
|
return y |
|
|
|
|
|
def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
|
if out_channel_first: |
|
B, K, D, H, W = y.shape |
|
y = y.view(B, K, D, -1) |
|
if scans == 0: |
|
y = torch.stack([ |
|
y[:, 0], |
|
y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), |
|
torch.flip(y[:, 2], dims=[-1]), |
|
torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), |
|
], dim=1) |
|
elif scans == 1: |
|
y = y |
|
elif scans == 2: |
|
y = torch.stack([ |
|
y[:, 0], |
|
y[:, 1], |
|
torch.flip(y[:, 2], dims=[-1]), |
|
torch.flip(y[:, 3], dims=[-1]), |
|
], dim=1) |
|
elif scans == 3: |
|
y = torch.stack([ |
|
y[:, 0, :, :].contiguous().view(B, D, -1), |
|
torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3), |
|
torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3), |
|
torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3), |
|
], dim=1) |
|
else: |
|
B, H, W, K, D = y.shape |
|
y = y.view(B, -1, K, D) |
|
if scans == 0: |
|
y = torch.stack([ |
|
y[:, :, 0], |
|
y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), |
|
torch.flip(y[:, :, 2], dims=[1]), |
|
torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), |
|
], dim=2) |
|
elif scans == 1: |
|
y = y |
|
elif scans == 2: |
|
y = torch.stack([ |
|
y[:, :, 0], |
|
y[:, :, 1], |
|
torch.flip(y[:, :, 2], dims=[1]), |
|
torch.flip(y[:, :, 3], dims=[1]), |
|
], dim=2) |
|
elif scans == 3: |
|
y = torch.stack([ |
|
y[:, :, 0, :].contiguous().view(B, -1, D), |
|
torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2), |
|
torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2), |
|
torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2), |
|
], dim=2) |
|
|
|
if out_channel_first and (not in_channel_first): |
|
y = y.permute(0, 3, 1, 2).contiguous() |
|
elif (not out_channel_first) and in_channel_first: |
|
y = y.permute(0, 2, 3, 1).contiguous() |
|
|
|
return y |
|
|
|
|
|
class CrossScanF(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
|
|
|
|
|
ctx.in_channel_first = in_channel_first |
|
ctx.out_channel_first = out_channel_first |
|
ctx.one_by_one = one_by_one |
|
ctx.scans = scans |
|
|
|
if one_by_one: |
|
B, K, C, H, W = x.shape |
|
if not in_channel_first: |
|
B, H, W, K, C = x.shape |
|
else: |
|
B, C, H, W = x.shape |
|
if not in_channel_first: |
|
B, H, W, C = x.shape |
|
ctx.shape = (B, C, H, W) |
|
|
|
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd |
|
y = _fn(x, in_channel_first, out_channel_first, scans) |
|
|
|
return y |
|
|
|
@staticmethod |
|
def backward(ctx, ys: torch.Tensor): |
|
|
|
in_channel_first = ctx.in_channel_first |
|
out_channel_first = ctx.out_channel_first |
|
one_by_one = ctx.one_by_one |
|
scans = ctx.scans |
|
B, C, H, W = ctx.shape |
|
|
|
ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C) |
|
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd |
|
y = _fn(ys, in_channel_first, out_channel_first, scans) |
|
|
|
if one_by_one: |
|
y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1) |
|
else: |
|
y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1) |
|
|
|
return y, None, None, None, None |
|
|
|
|
|
class CrossMergeF(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
|
|
|
|
|
ctx.in_channel_first = in_channel_first |
|
ctx.out_channel_first = out_channel_first |
|
ctx.one_by_one = one_by_one |
|
ctx.scans = scans |
|
|
|
B, K, C, H, W = ys.shape |
|
if not out_channel_first: |
|
B, H, W, K, C = ys.shape |
|
ctx.shape = (B, C, H, W) |
|
|
|
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd |
|
y = _fn(ys, in_channel_first, out_channel_first, scans) |
|
|
|
return y |
|
|
|
@staticmethod |
|
def backward(ctx, x: torch.Tensor): |
|
|
|
|
|
in_channel_first = ctx.in_channel_first |
|
out_channel_first = ctx.out_channel_first |
|
one_by_one = ctx.one_by_one |
|
scans = ctx.scans |
|
B, C, H, W = ctx.shape |
|
|
|
if not one_by_one: |
|
if in_channel_first: |
|
x = x.view(B, C, H, W) |
|
else: |
|
x = x.view(B, H, W, C) |
|
else: |
|
if in_channel_first: |
|
x = x.view(B, 4, C, H, W) |
|
else: |
|
x = x.view(B, H, W, 4, C) |
|
|
|
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd |
|
x = _fn(x, in_channel_first, out_channel_first, scans) |
|
x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C) |
|
|
|
return x, None, None, None, None |
|
|
|
|
|
|
|
|
|
@triton.jit |
|
def triton_cross_scan_flex( |
|
x: tl.tensor, |
|
y: tl.tensor, |
|
x_layout: tl.constexpr, |
|
y_layout: tl.constexpr, |
|
operation: tl.constexpr, |
|
onebyone: tl.constexpr, |
|
scans: tl.constexpr, |
|
BC: tl.constexpr, |
|
BH: tl.constexpr, |
|
BW: tl.constexpr, |
|
DC: tl.constexpr, |
|
DH: tl.constexpr, |
|
DW: tl.constexpr, |
|
NH: tl.constexpr, |
|
NW: tl.constexpr, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
i_h, i_w = (i_hw // NW), (i_hw % NW) |
|
_mask_h = (i_h * BH + tl.arange(0, BH)) < DH |
|
_mask_w = (i_w * BW + tl.arange(0, BW)) < DW |
|
_mask_hw = _mask_h[:, None] & _mask_w[None, :] |
|
_for_C = min(DC - i_c * BC, BC) |
|
|
|
pos_h = (i_h * BH + tl.arange(0, BH)[:, None]) |
|
pos_w = (i_w * BW + tl.arange(0, BW)[None, :]) |
|
neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None]) |
|
neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :]) |
|
if scans == 0: |
|
|
|
HWRoute0 = pos_h * DW + pos_w |
|
HWRoute1 = pos_w * DH + pos_h |
|
HWRoute2 = neg_h * DW + neg_w |
|
HWRoute3 = neg_w * DH + neg_h |
|
elif scans == 1: |
|
|
|
HWRoute0 = pos_h * DW + pos_w |
|
HWRoute1 = HWRoute0 |
|
HWRoute2 = HWRoute0 |
|
HWRoute3 = HWRoute0 |
|
elif scans == 2: |
|
|
|
HWRoute0 = pos_h * DW + pos_w |
|
HWRoute1 = HWRoute0 |
|
HWRoute2 = neg_h * DW + neg_w |
|
HWRoute3 = HWRoute2 |
|
elif scans == 3: |
|
|
|
HWRoute0 = pos_h * DW + pos_w |
|
HWRoute1 = neg_w * DH + pos_h |
|
HWRoute2 = neg_h * DW + neg_w |
|
HWRoute3 = pos_w * DH + neg_h |
|
|
|
_tmp1 = DC * DH * DW |
|
|
|
y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) |
|
if y_layout == 0: |
|
p_y1 = y_ptr_base + HWRoute0 |
|
p_y2 = y_ptr_base + _tmp1 + HWRoute1 |
|
p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 |
|
p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 |
|
else: |
|
p_y1 = y_ptr_base + HWRoute0 * 4 * DC |
|
p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC |
|
p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC |
|
p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC |
|
|
|
if onebyone == 0: |
|
x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) |
|
if x_layout == 0: |
|
p_x = x_ptr_base + HWRoute0 |
|
else: |
|
p_x = x_ptr_base + HWRoute0 * DC |
|
|
|
if operation == 0: |
|
for idxc in range(_for_C): |
|
_idx_x = idxc * DH * DW if x_layout == 0 else idxc |
|
_idx_y = idxc * DH * DW if y_layout == 0 else idxc |
|
_x = tl.load(p_x + _idx_x, mask=_mask_hw) |
|
tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) |
|
tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) |
|
tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) |
|
tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) |
|
elif operation == 1: |
|
for idxc in range(_for_C): |
|
_idx_x = idxc * DH * DW if x_layout == 0 else idxc |
|
_idx_y = idxc * DH * DW if y_layout == 0 else idxc |
|
_y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) |
|
_y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) |
|
_y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) |
|
_y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) |
|
tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) |
|
|
|
else: |
|
x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) |
|
if x_layout == 0: |
|
p_x1 = x_ptr_base + HWRoute0 |
|
p_x2 = p_x1 + _tmp1 |
|
p_x3 = p_x2 + _tmp1 |
|
p_x4 = p_x3 + _tmp1 |
|
else: |
|
p_x1 = x_ptr_base + HWRoute0 * 4 * DC |
|
p_x2 = p_x1 + DC |
|
p_x3 = p_x2 + DC |
|
p_x4 = p_x3 + DC |
|
|
|
if operation == 0: |
|
for idxc in range(_for_C): |
|
_idx_x = idxc * DH * DW if x_layout == 0 else idxc |
|
_idx_y = idxc * DH * DW if y_layout == 0 else idxc |
|
tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
|
tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
|
tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
|
tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
|
else: |
|
for idxc in range(_for_C): |
|
_idx_x = idxc * DH * DW if x_layout == 0 else idxc |
|
_idx_y = idxc * DH * DW if y_layout == 0 else idxc |
|
tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) |
|
tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) |
|
tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) |
|
tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) |
|
|
|
|
|
class CrossScanTritonF(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
|
if one_by_one: |
|
if in_channel_first: |
|
B, _, C, H, W = x.shape |
|
else: |
|
B, H, W, _, C = x.shape |
|
else: |
|
if in_channel_first: |
|
B, C, H, W = x.shape |
|
else: |
|
B, H, W, C = x.shape |
|
B, C, H, W = int(B), int(C), int(H), int(W) |
|
BC, BH, BW = 1, 32, 32 |
|
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) |
|
|
|
ctx.in_channel_first = in_channel_first |
|
ctx.out_channel_first = out_channel_first |
|
ctx.one_by_one = one_by_one |
|
ctx.scans = scans |
|
ctx.shape = (B, C, H, W) |
|
ctx.triton_shape = (BC, BH, BW, NC, NH, NW) |
|
|
|
y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C)) |
|
triton_cross_scan_flex[(NH * NW, NC, B)]( |
|
x.contiguous(), y, |
|
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, |
|
BC, BH, BW, C, H, W, NH, NW |
|
) |
|
return y |
|
|
|
@staticmethod |
|
def backward(ctx, y: torch.Tensor): |
|
in_channel_first = ctx.in_channel_first |
|
out_channel_first = ctx.out_channel_first |
|
one_by_one = ctx.one_by_one |
|
scans = ctx.scans |
|
B, C, H, W = ctx.shape |
|
BC, BH, BW, NC, NH, NW = ctx.triton_shape |
|
if one_by_one: |
|
x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C)) |
|
else: |
|
x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) |
|
|
|
triton_cross_scan_flex[(NH * NW, NC, B)]( |
|
x, y.contiguous(), |
|
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, |
|
BC, BH, BW, C, H, W, NH, NW |
|
) |
|
return x, None, None, None, None |
|
|
|
|
|
class CrossMergeTritonF(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
|
if out_channel_first: |
|
B, _, C, H, W = y.shape |
|
else: |
|
B, H, W, _, C = y.shape |
|
B, C, H, W = int(B), int(C), int(H), int(W) |
|
BC, BH, BW = 1, 32, 32 |
|
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) |
|
ctx.in_channel_first = in_channel_first |
|
ctx.out_channel_first = out_channel_first |
|
ctx.one_by_one = one_by_one |
|
ctx.scans = scans |
|
ctx.shape = (B, C, H, W) |
|
ctx.triton_shape = (BC, BH, BW, NC, NH, NW) |
|
if one_by_one: |
|
x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C)) |
|
else: |
|
x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) |
|
triton_cross_scan_flex[(NH * NW, NC, B)]( |
|
x, y.contiguous(), |
|
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, |
|
BC, BH, BW, C, H, W, NH, NW |
|
) |
|
return x |
|
|
|
@staticmethod |
|
def backward(ctx, x: torch.Tensor): |
|
in_channel_first = ctx.in_channel_first |
|
out_channel_first = ctx.out_channel_first |
|
one_by_one = ctx.one_by_one |
|
scans = ctx.scans |
|
B, C, H, W = ctx.shape |
|
BC, BH, BW, NC, NH, NW = ctx.triton_shape |
|
y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C)) |
|
triton_cross_scan_flex[(NH * NW, NC, B)]( |
|
x.contiguous(), y, |
|
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, |
|
BC, BH, BW, C, H, W, NH, NW |
|
) |
|
return y, None, None, None, None, None |
|
|
|
|
|
|
|
def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): |
|
|
|
|
|
|
|
CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF |
|
if x.is_cuda: |
|
with torch.cuda.device(x.device): |
|
return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) |
|
else: |
|
return CrossScanF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) |
|
|
|
|
|
|
|
def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): |
|
|
|
|
|
|
|
CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF |
|
if y.is_cuda: |
|
with torch.cuda.device(y.device): |
|
return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) |
|
else: |
|
return CrossMergeF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) |
|
|
|
|
|
|
|
|
|
|
|
|
|
WITH_SELECTIVESCAN_MAMBA = True |
|
try: |
|
import selective_scan_cuda |
|
except ImportError: |
|
WITH_SELECTIVESCAN_MAMBA = False |
|
|
|
|
|
def selective_scan_torch( |
|
u: torch.Tensor, |
|
delta: torch.Tensor, |
|
A: torch.Tensor, |
|
B: torch.Tensor, |
|
C: torch.Tensor, |
|
D: torch.Tensor = None, |
|
delta_bias: torch.Tensor = None, |
|
delta_softplus=True, |
|
oflex=True, |
|
*args, |
|
**kwargs |
|
): |
|
dtype_in = u.dtype |
|
Batch, K, N, L = B.shape |
|
KCdim = u.shape[1] |
|
Cdim = int(KCdim / K) |
|
assert u.shape == (Batch, KCdim, L) |
|
assert delta.shape == (Batch, KCdim, L) |
|
assert A.shape == (KCdim, N) |
|
assert C.shape == B.shape |
|
|
|
if delta_bias is not None: |
|
delta = delta + delta_bias[..., None] |
|
if delta_softplus: |
|
delta = torch.nn.functional.softplus(delta) |
|
|
|
u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float() |
|
B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L) |
|
C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L) |
|
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) |
|
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) |
|
|
|
if True: |
|
x = A.new_zeros((Batch, KCdim, N)) |
|
ys = [] |
|
for i in range(L): |
|
x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :] |
|
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) |
|
ys.append(y) |
|
y = torch.stack(ys, dim=2) |
|
|
|
out = y if D is None else y + u * D.unsqueeze(-1) |
|
return out if oflex else out.to(dtype=dtype_in) |
|
|
|
|
|
class SelectiveScanCuda(torch.autograd.Function): |
|
@staticmethod |
|
@torch.cuda.amp.custom_fwd |
|
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None): |
|
ctx.delta_softplus = delta_softplus |
|
|
|
|
|
backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend |
|
ctx.backend = backend |
|
if backend == "oflex": |
|
out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) |
|
elif backend == "mamba": |
|
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) |
|
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) |
|
return out |
|
|
|
@staticmethod |
|
@torch.cuda.amp.custom_bwd |
|
def backward(ctx, dout, *args): |
|
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors |
|
backend = ctx.backend |
|
if dout.stride(-1) != 1: |
|
dout = dout.contiguous() |
|
if backend == "oflex": |
|
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( |
|
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 |
|
) |
|
elif backend == "mamba": |
|
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( |
|
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, |
|
False |
|
) |
|
return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None |
|
|
|
|
|
def selective_scan_fn( |
|
u: torch.Tensor, |
|
delta: torch.Tensor, |
|
A: torch.Tensor, |
|
B: torch.Tensor, |
|
C: torch.Tensor, |
|
D: torch.Tensor = None, |
|
delta_bias: torch.Tensor = None, |
|
delta_softplus=True, |
|
oflex=True, |
|
backend=None, |
|
): |
|
fn = selective_scan_torch if backend == "torch" or (not WITH_SELECTIVESCAN_MAMBA) else SelectiveScanCuda.apply |
|
return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend) |
|
|
|
|
|
|
|
|
|
|
|
class VMambaLinear2d(nn.Linear): |
|
def __init__(self, *args, groups=1, **kwargs): |
|
nn.Linear.__init__(self, *args, **kwargs) |
|
self.groups = groups |
|
|
|
def forward(self, x: torch.Tensor): |
|
if len(x.shape) == 4: |
|
return F.conv2d(x, self.weight[:, :, None, None], self.bias, groups=self.groups) |
|
elif len(x.shape) == 3: |
|
return F.conv1d(x, self.weight[:, :, None], self.bias, groups=self.groups) |
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
|
self_state_dict = self.state_dict() |
|
load_state_dict_keys = list(state_dict.keys()) |
|
if prefix + "weight" in load_state_dict_keys: |
|
state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view_as(self_state_dict["weight"]) |
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
|
|
|
|
|
class VMambaLayerNorm2d(nn.LayerNorm): |
|
def __init__(self, *args, **kwargs): |
|
nn.LayerNorm.__init__(self, *args, **kwargs) |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = x.permute(0, 2, 3, 1) |
|
x = nn.LayerNorm.forward(self, x) |
|
x = x.permute(0, 3, 1, 2) |
|
return x |
|
|
|
|
|
class VMambaPatchEmbeddings(nn.Module): |
|
""" |
|
This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, |
|
seq_length, hidden_size)` to be consumed by a State-space model. |
|
""" |
|
|
|
def __init__(self, num_channels=3,patch_size=4,embed_dim=96): |
|
super().__init__() |
|
|
|
stride = patch_size // 2 |
|
kernel_size = stride + 1 |
|
padding = 1 |
|
|
|
self.projection = nn.Sequential( |
|
nn.Conv2d(num_channels, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding), |
|
VMambaLayerNorm2d(embed_dim // 2), |
|
nn.GELU(), |
|
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding), |
|
VMambaLayerNorm2d(embed_dim), |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.projection(x) |
|
return x |
|
|
|
|
|
class VMambaDowsample(nn.Module): |
|
""" |
|
This class downsamples the input tensor using a convolutional layer followed by a layer normalization. |
|
""" |
|
def __init__(self, dim, out_dim, use_norm=True): |
|
super().__init__() |
|
self.down = nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1) |
|
self.norm = VMambaLayerNorm2d(out_dim) if use_norm else nn.Identity() |
|
|
|
def forward(self, x): |
|
x = self.down(x) |
|
x = self.norm(x) |
|
return x |
|
|
|
|
|
class VMambaMlp(nn.Module): |
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
self.fc1 = VMambaLinear2d(in_features, hidden_features) |
|
self.act = act_layer() |
|
self.fc2 = VMambaLinear2d(hidden_features, out_features) |
|
self.drop = nn.Dropout(drop) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop(x) |
|
x = self.fc2(x) |
|
x = self.drop(x) |
|
return x |
|
|
|
|
|
class SS2D(nn.Module): |
|
def __init__( |
|
self, |
|
|
|
d_model=96, |
|
d_state=16, |
|
ssm_ratio=2.0, |
|
dt_rank="auto", |
|
act_layer=nn.SiLU, |
|
|
|
d_conv=3, |
|
conv_bias=True, |
|
|
|
dropout=0.0, |
|
bias=False, |
|
|
|
dt_min=0.001, |
|
dt_max=0.1, |
|
dt_init="random", |
|
dt_scale=1.0, |
|
dt_init_floor=1e-4, |
|
|
|
|
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.k_group = 4 |
|
self.d_model = int(d_model) |
|
self.d_state = int(d_state) |
|
self.d_inner = int(ssm_ratio * d_model) |
|
self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank) |
|
self.forward_core = partial(self.forward_corev2, force_fp32=False, no_einsum=True) |
|
self.with_dconv = d_conv > 1 |
|
|
|
|
|
self.in_proj = VMambaLinear2d(self.d_model, self.d_inner, bias=bias) |
|
self.act: nn.Module = act_layer() |
|
|
|
|
|
if self.with_dconv: |
|
self.conv2d = nn.Conv2d( |
|
in_channels=self.d_inner, |
|
out_channels=self.d_inner, |
|
groups=self.d_inner, |
|
bias=conv_bias, |
|
kernel_size=d_conv, |
|
padding=(d_conv - 1) // 2, |
|
) |
|
|
|
|
|
self.x_proj = VMambaLinear2d(self.d_inner, self.k_group * (self.dt_rank + self.d_state * 2), groups=self.k_group, bias=False) |
|
self.dt_projs = VMambaLinear2d(self.dt_rank, self.k_group * self.d_inner, groups=self.k_group, bias=False) |
|
|
|
|
|
self.out_proj = VMambaLinear2d(self.d_inner, self.d_model, bias=bias) |
|
self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() |
|
|
|
|
|
self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = self.init_dt_A_D( |
|
self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group, |
|
) |
|
self.dt_projs.weight.data = self.dt_projs_weight.data.view(self.dt_projs.weight.shape) |
|
|
|
del self.dt_projs_weight |
|
|
|
|
|
self.out_norm = VMambaLayerNorm2d(self.d_inner) |
|
|
|
@staticmethod |
|
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4): |
|
dt_proj = nn.Linear(dt_rank, d_inner, bias=True) |
|
|
|
dt_init_std = dt_rank**-0.5 * dt_scale |
|
if dt_init == "constant": |
|
nn.init.constant_(dt_proj.weight, dt_init_std) |
|
elif dt_init == "random": |
|
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) |
|
else: |
|
raise NotImplementedError |
|
|
|
dt = torch.exp( |
|
torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) |
|
+ math.log(dt_min) |
|
).clamp(min=dt_init_floor) |
|
|
|
inv_dt = dt + torch.log(-torch.expm1(-dt)) |
|
with torch.no_grad(): |
|
dt_proj.bias.copy_(inv_dt) |
|
|
|
return dt_proj |
|
|
|
@staticmethod |
|
def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): |
|
A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous() |
|
A_log = torch.log(A) |
|
if copies > 0: |
|
A_log = A_log[None].repeat(copies, 1, 1).contiguous() |
|
if merge: |
|
A_log = A_log.flatten(0, 1) |
|
A_log = nn.Parameter(A_log) |
|
A_log._no_weight_decay = True |
|
return A_log |
|
|
|
@staticmethod |
|
def D_init(d_inner, copies=-1, device=None, merge=True): |
|
D = torch.ones(d_inner, device=device) |
|
if copies > 0: |
|
D = D[None].repeat(copies, 1).contiguous() |
|
if merge: |
|
D = D.flatten(0, 1) |
|
D = nn.Parameter(D) |
|
D._no_weight_decay = True |
|
return D |
|
|
|
@classmethod |
|
def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4): |
|
dt_projs = [ |
|
cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor) |
|
for _ in range(k_group) |
|
] |
|
dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) |
|
dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) |
|
del dt_projs |
|
|
|
A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) |
|
Ds = cls.D_init(d_inner, copies=k_group, merge=True) |
|
return A_logs, Ds, dt_projs_weight, dt_projs_bias |
|
|
|
def forward_corev2( |
|
self, |
|
x: torch.Tensor, |
|
force_fp32=False, |
|
no_einsum=True, |
|
): |
|
B, D, H, W = x.shape |
|
N = self.d_state |
|
L = H * W |
|
|
|
xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True) |
|
x_dbl = self.x_proj(xs.view(B, -1, L)) |
|
dts, Bs, Cs = torch.split(x_dbl.view(B, self.k_group, -1, L), [self.dt_rank, N, N], dim=2) |
|
dts = dts.contiguous().view(B, -1, L) |
|
dts = self.dt_projs(dts) |
|
|
|
xs = xs.view(B, -1, L) |
|
dts = dts.contiguous().view(B, -1, L) |
|
As = -self.A_logs.to(torch.float32).exp() |
|
Ds = self.Ds.to(torch.float32) |
|
Bs = Bs.contiguous().view(B, self.k_group, N, L) |
|
Cs = Cs.contiguous().view(B, self.k_group, N, L) |
|
delta_bias = self.dt_projs_bias.view(-1).to(torch.float32) |
|
|
|
ys = selective_scan_fn( |
|
xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus=True, backend="mamba" |
|
).view(B, self.k_group, -1, H, W) |
|
|
|
y = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True) |
|
y = y.view(B, -1, H, W) |
|
y = self.out_norm(y) |
|
return y.to(x.dtype) |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = self.in_proj(x) |
|
x = self.conv2d(x) |
|
|
|
x = self.act(x) |
|
y = self.forward_core(x) |
|
|
|
out = self.dropout(self.out_proj(y)) |
|
return out |
|
|
|
|
|
class VSSBlock(nn.Module): |
|
def __init__( |
|
self, |
|
hidden_dim: int = 0, |
|
drop_path: float = 0, |
|
ssm_d_state: int = 1, |
|
ssm_ratio=1.0, |
|
ssm_dt_rank: Any = "auto", |
|
ssm_act_layer=nn.SiLU, |
|
ssm_conv: int = 3, |
|
ssm_conv_bias=False, |
|
ssm_drop_rate: float = 0, |
|
mlp_ratio=4.0, |
|
mlp_act_layer=nn.GELU, |
|
mlp_drop_rate: float = 0.0, |
|
use_checkpoint: bool = False, |
|
post_norm: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.ssm_branch = ssm_ratio > 0 |
|
self.mlp_branch = mlp_ratio > 0 |
|
self.use_checkpoint = use_checkpoint |
|
self.post_norm = post_norm |
|
|
|
if self.ssm_branch: |
|
self.norm = VMambaLayerNorm2d(hidden_dim) |
|
self.op = SS2D( |
|
d_model=hidden_dim, |
|
d_state=ssm_d_state, |
|
ssm_ratio=ssm_ratio, |
|
dt_rank=ssm_dt_rank, |
|
act_layer=ssm_act_layer, |
|
d_conv=ssm_conv, |
|
conv_bias=ssm_conv_bias, |
|
dropout=ssm_drop_rate, |
|
) |
|
|
|
self.drop_path = DropPath(drop_path) |
|
|
|
if self.mlp_branch: |
|
self.norm2 = VMambaLayerNorm2d(hidden_dim) |
|
mlp_hidden_dim = int(hidden_dim * mlp_ratio) |
|
self.mlp = VMambaMlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate) |
|
|
|
def _forward(self, input: torch.Tensor): |
|
x = input |
|
if self.ssm_branch: |
|
if self.post_norm: |
|
x = x + self.drop_path(self.norm(self.op(x))) |
|
else: |
|
x = x + self.drop_path(self.op(self.norm(x))) |
|
if self.mlp_branch: |
|
if self.post_norm: |
|
x = x + self.drop_path(self.norm2(self.mlp(x))) |
|
else: |
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
return x |
|
|
|
def forward(self, input: torch.Tensor): |
|
if self.use_checkpoint: |
|
return checkpoint.checkpoint(self._forward, input) |
|
else: |
|
return self._forward(input) |
|
|
|
class VMambaLayer(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
input_dim, |
|
depth, |
|
drop_path=0.0, |
|
norm_layer=VMambaLayerNorm2d, |
|
downsample=nn.Identity(), |
|
use_checkpoint=False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.use_checkpoint = use_checkpoint |
|
|
|
self.blocks = nn.ModuleList() |
|
for i in range(depth): |
|
self.blocks.append( |
|
VSSBlock(hidden_dim=input_dim, |
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, |
|
norm_layer=norm_layer,use_checkpoint=use_checkpoint,**kwargs, |
|
) |
|
) |
|
|
|
self.downsample = downsample |
|
|
|
def forward(self, x): |
|
for block in self.blocks: |
|
x = block(x) |
|
|
|
x = self.downsample(x) |
|
return x |
|
|
|
class VMambaPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and |
|
a simple interface for downloading and loading pretrained models. |
|
""" |
|
|
|
config_class = VMambaConfig |
|
base_model_prefix = "vmamba" |
|
supports_gradient_checkpointing = False |
|
|
|
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
trunc_normal_(module.weight, std=0.02) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
elif isinstance(module, nn.LayerNorm): |
|
nn.init.constant_(module.bias, 0) |
|
nn.init.constant_(module.weight, 1.0) |
|
|
|
|
|
class VMambaModel(VMambaPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
dims = config.dims |
|
if isinstance(dims, int): |
|
dims = [int(dims * 2**i_layer) for i_layer in range(self.num_layers)] |
|
|
|
self.dims = dims |
|
self.patch_embeddings = VMambaPatchEmbeddings(patch_size=config.patch_size, |
|
embed_dim=dims[0]) |
|
|
|
self.num_layers = len(config.depths) |
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] |
|
self.num_features = dims[-1] |
|
|
|
self.layers = nn.ModuleList() |
|
for i in range(self.num_layers): |
|
layer = VMambaLayer( |
|
input_dim=self.dims[i], |
|
depth=config.depths[i], |
|
drop_path=dpr[sum(config.depths[:i]):sum(config.depths[:i+1])], |
|
downsample=VMambaDowsample(self.dims[i], self.dims[i+1]) if i < self.num_layers - 1 else nn.Identity(), |
|
use_checkpoint=config.use_checkpoint, |
|
) |
|
self.layers.append(layer) |
|
|
|
self.norm = VMambaLayerNorm2d(self.num_features) |
|
self.avgpool = nn.AdaptiveAvgPool2d(1) |
|
|
|
def get_input_embeddings(self) -> VMambaPatchEmbeddings: |
|
return self.patch_embeddings |
|
|
|
def forward(self, input_values: torch.Tensor): |
|
x = self.patch_embeddings(input_values) |
|
for layer in self.layers: |
|
x = layer(x) |
|
x = self.norm(x) |
|
x = self.avgpool(x).flatten(1) |
|
return x |
|
|
|
|
|
class VMambaForImageClassification(VMambaPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.num_classes = config.num_classes |
|
self.vmamba = VMambaModel(config) |
|
self.head = nn.Linear(self.vmamba.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity() |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
|
|
outputs = self.vmamba( |
|
pixel_values, |
|
) |
|
|
|
logits = self.head(outputs) |
|
|
|
loss = None |
|
if labels is not None: |
|
labels = labels.to(logits.device) |
|
if self.config.loss_type == "ce": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "bce": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
|
|
if return_dict: |
|
output = (logits,) + (outputs,) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return ImageClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs, |
|
) |
|
|
|
__all__ = [ |
|
"VMambaModel", |
|
"VMambaPreTrainedModel", |
|
"VMambaForImageClassification", |
|
] |
|
|