import math import warnings from typing import Union, Optional, Callable, Tuple, List, Sequence import torch from einops.layers.torch import Rearrange from torch import Tensor, nn, Size from torch.nn import Conv3d, ModuleList from torch.nn import functional as F Shape = Union[Size, List[int], Tuple[int, ...]] ModuleFactory = Union[Callable[[], nn.Module], Callable[[int], nn.Module]] class PatchEmbedding3d(nn.Module): def __init__(self, input_size: Shape, patch_size: Union[int, Shape], embedding: int, strides: Optional[Union[int, Shape]] = None, build_normalization: Optional[ModuleFactory] = None ): super().__init__() # channel, time, height, width c, t, h, w = input_size # patch_time, patch_height, patch_width pt, ph, pw = (patch_size, patch_size, patch_size) if type(patch_size) is int else patch_size # configure the strides for conv3d if strides is None: # no specified means no overlap and gap between patches strides = (pt, ph, pw) elif type(strides) is int: # transform the side length of strides to 3D strides = (strides, strides, strides) self.projection = Conv3d(c, embedding, kernel_size=(pt, ph, pw), stride=strides) self.has_norm = build_normalization is not None if self.has_norm: self.normalization = build_normalization() self.rearrange = Rearrange("b d nt nh nw -> b (nt nh nw) d") def forward(self, x: Tensor) -> Tensor: x = self.projection(x) x = self.rearrange(x) if self.has_norm: x = self.normalization(x) return x class Linear(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True, build_activation: Optional[ModuleFactory] = None, build_normalization: Optional[ModuleFactory] = None, normalization_after_activation: bool = False, dropout_rate: float = 0. ): super().__init__() self.linear = nn.Linear(in_features, out_features, bias) self.has_act = build_activation is not None if self.has_act: self.activation = build_activation() else: self.activation = None self.has_norm = build_normalization is not None if self.has_norm: self.normalization = build_normalization() self.norm_after_act = normalization_after_activation else: self.normalization = None self.has_dropout = dropout_rate > 0 if self.has_dropout: self.dropout = nn.Dropout(dropout_rate) def forward(self, x: Tensor) -> Tensor: x = self.linear(x) if self.has_act and self.has_norm: if self.norm_after_act: x = self.activation(x) x = self.normalization(x) else: x = self.normalization(x) x = self.activation(x) elif self.has_act and not self.has_norm: x = self.activation(x) elif not self.has_act and self.has_norm: x = self.normalization(x) if self.has_dropout: x = self.dropout(x) return x class MLP(nn.Module): def __init__(self, neurons: Sequence[int], build_activation: Optional[ModuleFactory] = None, dropout_rate: float = 0. ): super().__init__() n_features = neurons[1:] self.layers: ModuleList[Linear] = ModuleList( [Linear(neurons[i], neurons[i + 1], True, build_activation, None, False, dropout_rate ) for i in range(len(n_features) - 1) ] + [ Linear(neurons[-2], neurons[-1], True) ] ) def forward(self, x: Tensor) -> Tensor: for layer in self.layers: x = layer(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., attn_head_dim=None ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = MLP( neurons=[dim, mlp_hidden_dim, dim], build_activation=act_layer, dropout_rate=drop ) if init_values > 0: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x): if self.gamma_1 is None: x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) else: x = x + (self.gamma_1 * self.attn(self.norm1(x))) x = x + (self.gamma_2 * self.mlp(self.norm2(x))) return x def no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor