VMamba_ImageNet_83.6 / modeling_vmamba.py
saurabhati's picture
Upload VMambaForImageClassification
22b5634 verified
raw
history blame
47 kB
# coding=utf-8
# @Author : Saurabhchand Bhati
# @Affiliation : Massachusetts Institute of Technology
# VMamba backbone is from https://github.com/MzeroMiko/VMamba/blob/main/vmamba.py
# VMambaLayer, VMambaModel, VMambaForImageClassification are implemnted based on VMamba
# SS2Dv0, SS2Dv1, SS2S are merged into one class and initiliazation is limited to v05_noz,
# patch embeddings is limited to v2 and downsample is limited to v3.
# MIT License
# Copyright (c) 2024 MzeroMiko, Saurabhchand Bhati
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""VMamba: Visual State Space Model configuration model"""
import math
import torch
import warnings
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, trunc_normal_
from functools import partial
from typing import Optional, Callable, Any, Union
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from transformers.modeling_outputs import ImageClassifierOutput
from transformers.utils import logging
from transformers.modeling_utils import PreTrainedModel
from .configuration_vmamba import VMambaConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "VMambaConfig"
WITH_TRITON = True
# WITH_TRITON = False
try:
import triton
import triton.language as tl
except:
WITH_TRITON = False
warnings.warn("Triton not installed, fall back to pytorch implements.")
# to make sure cached_property can be loaded for triton
if WITH_TRITON:
try:
from functools import cached_property
except:
warnings.warn("if you are using py37, add this line to functools.py: "
"cached_property = lambda func: property(lru_cache()(func))")
# torch implementation ========================================
def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if in_channel_first:
B, C, H, W = x.shape
if scans == 0:
y = x.new_empty((B, 4, C, H * W))
y[:, 0, :, :] = x.flatten(2, 3)
y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
elif scans == 1:
y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
elif scans == 2:
y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
y = torch.cat([y, y.flip(dims=[-1])], dim=1)
elif scans == 3:
y = x.new_empty((B, 4, C, H * W))
y[:, 0, :, :] = x.flatten(2, 3)
y[:, 1, :, :] = torch.rot90(x, 1, dims=(2, 3)).flatten(2, 3)
y[:, 2, :, :] = torch.rot90(x, 2, dims=(2, 3)).flatten(2, 3)
y[:, 3, :, :] = torch.rot90(x, 3, dims=(2, 3)).flatten(2, 3)
else:
B, H, W, C = x.shape
if scans == 0:
y = x.new_empty((B, H * W, 4, C))
y[:, :, 0, :] = x.flatten(1, 2)
y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
elif scans == 1:
y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1)
elif scans == 2:
y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
y = torch.cat([y, y.flip(dims=[1])], dim=2)
elif scans == 3:
y = x.new_empty((B, H * W, 4, C))
y[:, :, 0, :] = x.flatten(1, 2)
y[:, :, 1, :] = torch.rot90(x, 1, dims=(1, 2)).flatten(1, 2)
y[:, :, 2, :] = torch.rot90(x, 2, dims=(1, 2)).flatten(1, 2)
y[:, :, 3, :] = torch.rot90(x, 3, dims=(1, 2)).flatten(1, 2)
if in_channel_first and (not out_channel_first):
y = y.permute(0, 3, 1, 2).contiguous()
elif (not in_channel_first) and out_channel_first:
y = y.permute(0, 2, 3, 1).contiguous()
return y
def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if out_channel_first:
B, K, D, H, W = y.shape
y = y.view(B, K, D, -1)
if scans == 0:
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
elif scans == 1:
y = y.sum(1)
elif scans == 2:
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
y = y.sum(1)
elif scans == 3:
oy = y[:, 0, :, :].contiguous().view(B, D, -1)
oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3)
oy = oy + torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3)
oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3)
y = oy
else:
B, H, W, K, D = y.shape
y = y.view(B, -1, K, D)
if scans == 0:
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
elif scans == 1:
y = y.sum(2)
elif scans == 2:
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
y = y.sum(2)
elif scans == 3:
oy = y[:, :, 0, :].contiguous().view(B, -1, D)
oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2)
oy = oy + torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2)
oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2)
y = oy
if in_channel_first and (not out_channel_first):
y = y.permute(0, 2, 1).contiguous()
elif (not in_channel_first) and out_channel_first:
y = y.permute(0, 2, 1).contiguous()
return y
def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if in_channel_first:
B, _, C, H, W = x.shape
if scans == 0:
y = torch.stack([
x[:, 0].flatten(2, 3),
x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
], dim=1)
elif scans == 1:
y = x.flatten(2, 3)
elif scans == 2:
y = torch.stack([
x[:, 0].flatten(2, 3),
x[:, 1].flatten(2, 3),
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
], dim=1)
elif scans == 3:
y = torch.stack([
x[:, 0, :, :, :].flatten(2, 3),
torch.rot90(x[:, 1, :, :, :], 1, dims=(2, 3)).flatten(2, 3),
torch.rot90(x[:, 2, :, :, :], 2, dims=(2, 3)).flatten(2, 3),
torch.rot90(x[:, 3, :, :, :], 3, dims=(2, 3)).flatten(2, 3),
], dim=1)
else:
B, H, W, _, C = x.shape
if scans == 0:
y = torch.stack([
x[:, :, :, 0].flatten(1, 2),
x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
], dim=2)
elif scans == 1:
y = x.flatten(1, 2)
elif scans == 2:
y = torch.stack([
x[:, 0].flatten(1, 2),
x[:, 1].flatten(1, 2),
torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
], dim=2)
elif scans == 3:
y = torch.stack([
x[:, :, :, 0, :].flatten(1, 2),
torch.rot90(x[:, :, :, 1, :], 1, dims=(1, 2)).flatten(1, 2),
torch.rot90(x[:, :, :, 2, :], 2, dims=(1, 2)).flatten(1, 2),
torch.rot90(x[:, :, :, 3, :], 3, dims=(1, 2)).flatten(1, 2),
], dim=1)
if in_channel_first and (not out_channel_first):
y = y.permute(0, 3, 1, 2).contiguous()
elif (not in_channel_first) and out_channel_first:
y = y.permute(0, 2, 3, 1).contiguous()
return y
def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if out_channel_first:
B, K, D, H, W = y.shape
y = y.view(B, K, D, -1)
if scans == 0:
y = torch.stack([
y[:, 0],
y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
torch.flip(y[:, 2], dims=[-1]),
torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
], dim=1)
elif scans == 1:
y = y
elif scans == 2:
y = torch.stack([
y[:, 0],
y[:, 1],
torch.flip(y[:, 2], dims=[-1]),
torch.flip(y[:, 3], dims=[-1]),
], dim=1)
elif scans == 3:
y = torch.stack([
y[:, 0, :, :].contiguous().view(B, D, -1),
torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3),
torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3),
torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3),
], dim=1)
else:
B, H, W, K, D = y.shape
y = y.view(B, -1, K, D)
if scans == 0:
y = torch.stack([
y[:, :, 0],
y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
torch.flip(y[:, :, 2], dims=[1]),
torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
], dim=2)
elif scans == 1:
y = y
elif scans == 2:
y = torch.stack([
y[:, :, 0],
y[:, :, 1],
torch.flip(y[:, :, 2], dims=[1]),
torch.flip(y[:, :, 3], dims=[1]),
], dim=2)
elif scans == 3:
y = torch.stack([
y[:, :, 0, :].contiguous().view(B, -1, D),
torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2),
torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2),
torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2),
], dim=2)
if out_channel_first and (not in_channel_first):
y = y.permute(0, 3, 1, 2).contiguous()
elif (not out_channel_first) and in_channel_first:
y = y.permute(0, 2, 3, 1).contiguous()
return y
class CrossScanF(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
# y: (B, 4, C, H * W) | (B, H * W, 4, C)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
if one_by_one:
B, K, C, H, W = x.shape
if not in_channel_first:
B, H, W, K, C = x.shape
else:
B, C, H, W = x.shape
if not in_channel_first:
B, H, W, C = x.shape
ctx.shape = (B, C, H, W)
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
y = _fn(x, in_channel_first, out_channel_first, scans)
return y
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
y = _fn(ys, in_channel_first, out_channel_first, scans)
if one_by_one:
y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1)
else:
y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
return y, None, None, None, None
class CrossMergeF(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
# y: (B, 4, C, H * W) | (B, H * W, 4, C)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
B, K, C, H, W = ys.shape
if not out_channel_first:
B, H, W, K, C = ys.shape
ctx.shape = (B, C, H, W)
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
y = _fn(ys, in_channel_first, out_channel_first, scans)
return y
@staticmethod
def backward(ctx, x: torch.Tensor):
# B, D, L = x.shape
# out: (b, k, d, h, w)
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
if not one_by_one:
if in_channel_first:
x = x.view(B, C, H, W)
else:
x = x.view(B, H, W, C)
else:
if in_channel_first:
x = x.view(B, 4, C, H, W)
else:
x = x.view(B, H, W, 4, C)
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
x = _fn(x, in_channel_first, out_channel_first, scans)
x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C)
return x, None, None, None, None
# triton implements ========================================
@triton.jit
def triton_cross_scan_flex(
x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C)
x_layout: tl.constexpr,
y_layout: tl.constexpr,
operation: tl.constexpr,
onebyone: tl.constexpr,
scans: tl.constexpr,
BC: tl.constexpr,
BH: tl.constexpr,
BW: tl.constexpr,
DC: tl.constexpr,
DH: tl.constexpr,
DW: tl.constexpr,
NH: tl.constexpr,
NW: tl.constexpr,
):
# x_layout = 0
# y_layout = 1 # 0 BCHW, 1 BHWC
# operation = 0 # 0 scan, 1 merge
# onebyone = 0 # 0 false, 1 true
# scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h, i_w = (i_hw // NW), (i_hw % NW)
_mask_h = (i_h * BH + tl.arange(0, BH)) < DH
_mask_w = (i_w * BW + tl.arange(0, BW)) < DW
_mask_hw = _mask_h[:, None] & _mask_w[None, :]
_for_C = min(DC - i_c * BC, BC)
pos_h = (i_h * BH + tl.arange(0, BH)[:, None])
pos_w = (i_w * BW + tl.arange(0, BW)[None, :])
neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None])
neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :])
if scans == 0:
# none; trans; flip; trans + flip;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = pos_w * DH + pos_h # trans
HWRoute2 = neg_h * DW + neg_w # flip
HWRoute3 = neg_w * DH + neg_h # trans + flip
elif scans == 1:
# none; none; none; none;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = HWRoute0
HWRoute2 = HWRoute0
HWRoute3 = HWRoute0
elif scans == 2:
# none; none; flip; flip;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = HWRoute0
HWRoute2 = neg_h * DW + neg_w # flip
HWRoute3 = HWRoute2
elif scans == 3:
# none; rot90; rot180==flip; rot270;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = neg_w * DH + pos_h
HWRoute2 = neg_h * DW + neg_w
HWRoute3 = pos_w * DH + neg_h
_tmp1 = DC * DH * DW
y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
if y_layout == 0:
p_y1 = y_ptr_base + HWRoute0
p_y2 = y_ptr_base + _tmp1 + HWRoute1
p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
else:
p_y1 = y_ptr_base + HWRoute0 * 4 * DC
p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
if onebyone == 0:
x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
if x_layout == 0:
p_x = x_ptr_base + HWRoute0
else:
p_x = x_ptr_base + HWRoute0 * DC
if operation == 0:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
_x = tl.load(p_x + _idx_x, mask=_mask_hw)
tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
elif operation == 1:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
_y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
_y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
_y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
_y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
else:
x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
if x_layout == 0:
p_x1 = x_ptr_base + HWRoute0
p_x2 = p_x1 + _tmp1
p_x3 = p_x2 + _tmp1
p_x4 = p_x3 + _tmp1
else:
p_x1 = x_ptr_base + HWRoute0 * 4 * DC
p_x2 = p_x1 + DC
p_x3 = p_x2 + DC
p_x4 = p_x3 + DC
if operation == 0:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
else:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
class CrossScanTritonF(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
if one_by_one:
if in_channel_first:
B, _, C, H, W = x.shape
else:
B, H, W, _, C = x.shape
else:
if in_channel_first:
B, C, H, W = x.shape
else:
B, H, W, C = x.shape
B, C, H, W = int(B), int(C), int(H), int(W)
BC, BH, BW = 1, 32, 32
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
ctx.shape = (B, C, H, W)
ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x.contiguous(), y,
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return y
@staticmethod
def backward(ctx, y: torch.Tensor):
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
BC, BH, BW, NC, NH, NW = ctx.triton_shape
if one_by_one:
x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))
else:
x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x, y.contiguous(),
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return x, None, None, None, None
class CrossMergeTritonF(torch.autograd.Function):
@staticmethod
def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
if out_channel_first:
B, _, C, H, W = y.shape
else:
B, H, W, _, C = y.shape
B, C, H, W = int(B), int(C), int(H), int(W)
BC, BH, BW = 1, 32, 32
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
ctx.shape = (B, C, H, W)
ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
if one_by_one:
x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))
else:
x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x, y.contiguous(),
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return x
@staticmethod
def backward(ctx, x: torch.Tensor):
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
BC, BH, BW, NC, NH, NW = ctx.triton_shape
y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x.contiguous(), y,
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return y, None, None, None, None, None
# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
# y: (B, 4, C, L) | (B, L, 4, C)
# scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
if x.is_cuda:
with torch.cuda.device(x.device):
return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
else:
return CrossScanF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
# y: (B, 4, C, L) | (B, L, 4, C)
# x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C)
# scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
if y.is_cuda:
with torch.cuda.device(y.device):
return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
else:
return CrossMergeF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
##########################################################
# csms6s.py
##########################################################
WITH_SELECTIVESCAN_MAMBA = True
try:
import selective_scan_cuda
except ImportError:
WITH_SELECTIVESCAN_MAMBA = False
def selective_scan_torch(
u: torch.Tensor, # (B, K * C, L)
delta: torch.Tensor, # (B, K * C, L)
A: torch.Tensor, # (K * C, N)
B: torch.Tensor, # (B, K, N, L)
C: torch.Tensor, # (B, K, N, L)
D: torch.Tensor = None, # (K * C)
delta_bias: torch.Tensor = None, # (K * C)
delta_softplus=True,
oflex=True,
*args,
**kwargs
):
dtype_in = u.dtype
Batch, K, N, L = B.shape
KCdim = u.shape[1]
Cdim = int(KCdim / K)
assert u.shape == (Batch, KCdim, L)
assert delta.shape == (Batch, KCdim, L)
assert A.shape == (KCdim, N)
assert C.shape == B.shape
if delta_bias is not None:
delta = delta + delta_bias[..., None]
if delta_softplus:
delta = torch.nn.functional.softplus(delta)
u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float()
B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if True:
x = A.new_zeros((Batch, KCdim, N))
ys = []
for i in range(L):
x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :]
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
ys.append(y)
y = torch.stack(ys, dim=2) # (B, C, L)
out = y if D is None else y + u * D.unsqueeze(-1)
return out if oflex else out.to(dtype=dtype_in)
class SelectiveScanCuda(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None):
ctx.delta_softplus = delta_softplus
# backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend
# backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend
backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend
ctx.backend = backend
if backend == "oflex":
out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
elif backend == "mamba":
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout, *args):
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
backend = ctx.backend
if dout.stride(-1) != 1:
dout = dout.contiguous()
if backend == "oflex":
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
)
elif backend == "mamba":
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
False
)
return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None
def selective_scan_fn(
u: torch.Tensor, # (B, K * C, L)
delta: torch.Tensor, # (B, K * C, L)
A: torch.Tensor, # (K * C, N)
B: torch.Tensor, # (B, K, N, L)
C: torch.Tensor, # (B, K, N, L)
D: torch.Tensor = None, # (K * C)
delta_bias: torch.Tensor = None, # (K * C)
delta_softplus=True,
oflex=True,
backend=None,
):
fn = selective_scan_torch if backend == "torch" or (not WITH_SELECTIVESCAN_MAMBA) else SelectiveScanCuda.apply
return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend)
##########################################################
############## HuggingFace modeling file #################
##########################################################
class VMambaLinear2d(nn.Linear):
def __init__(self, *args, groups=1, **kwargs):
nn.Linear.__init__(self, *args, **kwargs)
self.groups = groups
def forward(self, x: torch.Tensor):
if len(x.shape) == 4:
return F.conv2d(x, self.weight[:, :, None, None], self.bias, groups=self.groups)
elif len(x.shape) == 3:
return F.conv1d(x, self.weight[:, :, None], self.bias, groups=self.groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
self_state_dict = self.state_dict()
load_state_dict_keys = list(state_dict.keys())
if prefix + "weight" in load_state_dict_keys:
state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view_as(self_state_dict["weight"])
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class VMambaLayerNorm2d(nn.LayerNorm):
def __init__(self, *args, **kwargs):
nn.LayerNorm.__init__(self, *args, **kwargs)
def forward(self, x: torch.Tensor):
x = x.permute(0, 2, 3, 1)
x = nn.LayerNorm.forward(self, x)
x = x.permute(0, 3, 1, 2)
return x
class VMambaPatchEmbeddings(nn.Module):
"""
This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size,
seq_length, hidden_size)` to be consumed by a State-space model.
"""
def __init__(self, num_channels=3,patch_size=4,embed_dim=96):
super().__init__()
stride = patch_size // 2
kernel_size = stride + 1
padding = 1
self.projection = nn.Sequential(
nn.Conv2d(num_channels, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding),
VMambaLayerNorm2d(embed_dim // 2),
nn.GELU(),
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
VMambaLayerNorm2d(embed_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.projection(x)
return x
class VMambaDowsample(nn.Module):
"""
This class downsamples the input tensor using a convolutional layer followed by a layer normalization.
"""
def __init__(self, dim, out_dim, use_norm=True):
super().__init__()
self.down = nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1)
self.norm = VMambaLayerNorm2d(out_dim) if use_norm else nn.Identity()
def forward(self, x):
x = self.down(x)
x = self.norm(x)
return x
class VMambaMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = VMambaLinear2d(in_features, hidden_features)
self.act = act_layer()
self.fc2 = VMambaLinear2d(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SS2D(nn.Module):
def __init__(
self,
# basic dims ===========
d_model=96,
d_state=16,
ssm_ratio=2.0,
dt_rank="auto",
act_layer=nn.SiLU,
# dwconv ===============
d_conv=3,
conv_bias=True,
# ======================
dropout=0.0,
bias=False,
# dt init ==============
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
# forward_type="v05_noz" is always used
# ======================
**kwargs,
):
super().__init__()
self.k_group = 4
self.d_model = int(d_model)
self.d_state = int(d_state)
self.d_inner = int(ssm_ratio * d_model)
self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank)
self.forward_core = partial(self.forward_corev2, force_fp32=False, no_einsum=True)
self.with_dconv = d_conv > 1
# In projection
self.in_proj = VMambaLinear2d(self.d_model, self.d_inner, bias=bias)
self.act: nn.Module = act_layer()
# Convolution
if self.with_dconv:
self.conv2d = nn.Conv2d(
in_channels=self.d_inner,
out_channels=self.d_inner,
groups=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
)
# x_proj and dt_proj
self.x_proj = VMambaLinear2d(self.d_inner, self.k_group * (self.dt_rank + self.d_state * 2), groups=self.k_group, bias=False)
self.dt_projs = VMambaLinear2d(self.dt_rank, self.k_group * self.d_inner, groups=self.k_group, bias=False)
# out projection
self.out_proj = VMambaLinear2d(self.d_inner, self.d_model, bias=bias)
self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
# Initialization
self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = self.init_dt_A_D(
self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group,
)
self.dt_projs.weight.data = self.dt_projs_weight.data.view(self.dt_projs.weight.shape)
# self.dt_projs.bias.data = self.dt_projs_bias.data.view(self.dt_projs.bias.shape)
del self.dt_projs_weight
# del self.dt_projs_bias
# Define out_norm directly with "LN2D"
self.out_norm = VMambaLayerNorm2d(self.d_inner)
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
dt_init_std = dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
dt = torch.exp(
torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
A_log = torch.log(A)
if copies > 0:
A_log = A_log[None].repeat(copies, 1, 1).contiguous()
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=-1, device=None, merge=True):
D = torch.ones(d_inner, device=device)
if copies > 0:
D = D[None].repeat(copies, 1).contiguous()
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D)
D._no_weight_decay = True
return D
@classmethod
def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
dt_projs = [
cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
for _ in range(k_group)
]
dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0))
dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0))
del dt_projs
A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True)
Ds = cls.D_init(d_inner, copies=k_group, merge=True)
return A_logs, Ds, dt_projs_weight, dt_projs_bias
def forward_corev2(
self,
x: torch.Tensor,
force_fp32=False,
no_einsum=True,
):
B, D, H, W = x.shape
N = self.d_state
L = H * W
xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True)
x_dbl = self.x_proj(xs.view(B, -1, L))
dts, Bs, Cs = torch.split(x_dbl.view(B, self.k_group, -1, L), [self.dt_rank, N, N], dim=2)
dts = dts.contiguous().view(B, -1, L)
dts = self.dt_projs(dts)
xs = xs.view(B, -1, L)
dts = dts.contiguous().view(B, -1, L)
As = -self.A_logs.to(torch.float32).exp()
Ds = self.Ds.to(torch.float32)
Bs = Bs.contiguous().view(B, self.k_group, N, L)
Cs = Cs.contiguous().view(B, self.k_group, N, L)
delta_bias = self.dt_projs_bias.view(-1).to(torch.float32)
ys = selective_scan_fn(
xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus=True, backend="mamba"
).view(B, self.k_group, -1, H, W)
y = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True)
y = y.view(B, -1, H, W)
y = self.out_norm(y)
return y.to(x.dtype)
def forward(self, x: torch.Tensor):
x = self.in_proj(x)
x = self.conv2d(x)
x = self.act(x)
y = self.forward_core(x)
out = self.dropout(self.out_proj(y))
return out
class VSSBlock(nn.Module):
def __init__(
self,
hidden_dim: int = 0,
drop_path: float = 0,
ssm_d_state: int = 1,
ssm_ratio=1.0,
ssm_dt_rank: Any = "auto",
ssm_act_layer=nn.SiLU,
ssm_conv: int = 3,
ssm_conv_bias=False,
ssm_drop_rate: float = 0,
mlp_ratio=4.0,
mlp_act_layer=nn.GELU,
mlp_drop_rate: float = 0.0,
use_checkpoint: bool = False,
post_norm: bool = False,
**kwargs,
):
super().__init__()
self.ssm_branch = ssm_ratio > 0
self.mlp_branch = mlp_ratio > 0
self.use_checkpoint = use_checkpoint
self.post_norm = post_norm
if self.ssm_branch:
self.norm = VMambaLayerNorm2d(hidden_dim)
self.op = SS2D(
d_model=hidden_dim,
d_state=ssm_d_state,
ssm_ratio=ssm_ratio,
dt_rank=ssm_dt_rank,
act_layer=ssm_act_layer,
d_conv=ssm_conv,
conv_bias=ssm_conv_bias,
dropout=ssm_drop_rate,
)
self.drop_path = DropPath(drop_path)
if self.mlp_branch:
self.norm2 = VMambaLayerNorm2d(hidden_dim)
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
self.mlp = VMambaMlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate)
def _forward(self, input: torch.Tensor):
x = input
if self.ssm_branch:
if self.post_norm:
x = x + self.drop_path(self.norm(self.op(x)))
else:
x = x + self.drop_path(self.op(self.norm(x)))
if self.mlp_branch:
if self.post_norm:
x = x + self.drop_path(self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def forward(self, input: torch.Tensor):
if self.use_checkpoint:
return checkpoint.checkpoint(self._forward, input)
else:
return self._forward(input)
class VMambaLayer(nn.Module):
def __init__(
self,
input_dim,
depth,
drop_path=0.0,
norm_layer=VMambaLayerNorm2d,
downsample=nn.Identity(),
use_checkpoint=False,
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList()
for i in range(depth):
self.blocks.append(
VSSBlock(hidden_dim=input_dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,use_checkpoint=use_checkpoint,**kwargs,
)
)
self.downsample = downsample
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.downsample(x)
return x
class VMambaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = VMambaConfig
base_model_prefix = "vmamba"
supports_gradient_checkpointing = False
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
class VMambaModel(VMambaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
dims = config.dims
if isinstance(dims, int):
dims = [int(dims * 2**i_layer) for i_layer in range(self.num_layers)]
self.dims = dims
self.patch_embeddings = VMambaPatchEmbeddings(patch_size=config.patch_size,
embed_dim=dims[0])
self.num_layers = len(config.depths)
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.num_features = dims[-1]
self.layers = nn.ModuleList()
for i in range(self.num_layers):
layer = VMambaLayer(
input_dim=self.dims[i],
depth=config.depths[i],
drop_path=dpr[sum(config.depths[:i]):sum(config.depths[:i+1])],
downsample=VMambaDowsample(self.dims[i], self.dims[i+1]) if i < self.num_layers - 1 else nn.Identity(),
use_checkpoint=config.use_checkpoint,
)
self.layers.append(layer)
self.norm = VMambaLayerNorm2d(self.num_features)
self.avgpool = nn.AdaptiveAvgPool2d(1)
def get_input_embeddings(self) -> VMambaPatchEmbeddings:
return self.patch_embeddings
def forward(self, input_values: torch.Tensor):
x = self.patch_embeddings(input_values)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x = self.avgpool(x).flatten(1)
return x
class VMambaForImageClassification(VMambaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_classes = config.num_classes
self.vmamba = VMambaModel(config)
self.head = nn.Linear(self.vmamba.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
):
outputs = self.vmamba(
pixel_values,
)
logits = self.head(outputs)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.loss_type == "ce":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "bce":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if return_dict:
output = (logits,) + (outputs,)
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs,
)
__all__ = [
"VMambaModel",
"VMambaPreTrainedModel",
"VMambaForImageClassification",
]