""" U-TAE Implementation Author: Vivien Sainte Fare Garnot (github/VSainteuf) License: MIT """ import torch import torch.nn as nn from src.backbones.convlstm import ConvLSTM, BConvLSTM from src.backbones.ltae import LTAE2d, LTAE2dtiny # function to normalize gradient magnitudes, # evoke via e.g. scale_gradients(out) at every forward pass def scale_gradients(params): def hook_norm(grad): # get norm of parameter p's gradients #grad_norm = p.grad.detach().data.norm(2) # get the gradient's L2 norm grad_norm = grad.detach().data.norm(2) # return normalized gradient return grad/(grad_norm+1e-9) # see https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html params.register_hook(hook_norm) class UNet(nn.Module): def __init__( self, input_dim, encoder_widths=[64, 64, 64, 128], decoder_widths=[32, 32, 64, 128], out_conv=[13], out_nonlin_mean=False, out_nonlin_var='relu', str_conv_k=4, str_conv_s=2, str_conv_p=1, encoder_norm="group", norm_skip="batch", norm_up="batch", decoder_norm="batch", encoder=False, return_maps=False, pad_value=0, padding_mode="reflect", ): """ U-Net architecture for spatial pre-training of UTAE on mono-temporal data, excluding LTAE temporal encoder. Args: input_dim (int): Number of channels in the input images. encoder_widths (List[int]): List giving the number of channels of the successive encoder_widths of the convolutional encoder. This argument also defines the number of encoder_widths (i.e. the number of downsampling steps +1) in the architecture. The number of channels are given from top to bottom, i.e. from the highest to the lowest resolution. decoder_widths (List[int], optional): Same as encoder_widths but for the decoder. The order in which the number of channels should be given is also from top to bottom. If this argument is not specified the decoder will have the same configuration as the encoder. out_conv (List[int]): Number of channels of the successive convolutions for the str_conv_k (int): Kernel size of the strided up and down convolutions. str_conv_s (int): Stride of the strided up and down convolutions. str_conv_p (int): Padding of the strided up and down convolutions. agg_mode (str): Aggregation mode for the skip connections. Can either be: - att_group (default) : Attention weighted temporal average, using the same channel grouping strategy as in the LTAE. The attention masks are bilinearly resampled to the resolution of the skipped feature maps. - att_mean : Attention weighted temporal average, using the average attention scores across heads for each date. - mean : Temporal average excluding padded dates. encoder_norm (str): Type of normalisation layer to use in the encoding branch. Can either be: - group : GroupNorm (default) - batch : BatchNorm - instance : InstanceNorm - none: apply no normalization norm_skip (str): similar to encoder_norm, just controlling the normalization after convolving skipped maps norm_up (str): similar to encoder_norm, just controlling the normalization after transposed convolution decoder_norm (str): similar to encoder_norm n_head (int): Number of heads in LTAE. d_model (int): Parameter of LTAE d_k (int): Key-Query space dimension encoder (bool): If true, the feature maps instead of the class scores are returned (default False) return_maps (bool): If true, the feature maps instead of the class scores are returned (default False) pad_value (float): Value used by the dataloader for temporal padding. padding_mode (str): Spatial padding strategy for convolutional layers (passed to nn.Conv2d). positional_encoding (bool): If False, no positional encoding is used (default True). """ super(UNet, self).__init__() self.n_stages = len(encoder_widths) self.return_maps = return_maps self.encoder_widths = encoder_widths self.decoder_widths = decoder_widths self.enc_dim = ( decoder_widths[0] if decoder_widths is not None else encoder_widths[0] ) self.stack_dim = ( sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths) ) self.pad_value = pad_value self.encoder = encoder if encoder: self.return_maps = True if decoder_widths is not None: assert len(encoder_widths) == len(decoder_widths) assert encoder_widths[-1] == decoder_widths[-1] else: decoder_widths = encoder_widths # ENCODER self.in_conv = ConvBlock( nkernels=[input_dim] + [encoder_widths[0]], k=1, s=1, p=0, pad_value=pad_value, norm=encoder_norm, padding_mode=padding_mode, ) self.down_blocks = nn.ModuleList( DownConvBlock( d_in=encoder_widths[i], d_out=encoder_widths[i + 1], k=str_conv_k, s=str_conv_s, p=str_conv_p, pad_value=pad_value, norm=encoder_norm, padding_mode=padding_mode, ) for i in range(self.n_stages - 1) ) # DECODER self.up_blocks = nn.ModuleList( UpConvBlock( d_in=decoder_widths[i], d_out=decoder_widths[i - 1], d_skip=encoder_widths[i - 1], k=str_conv_k, s=str_conv_s, p=str_conv_p, norm_skip=norm_skip, #'batch' norm_up=norm_up, # 'batch' norm=decoder_norm, #"batch", padding_mode=padding_mode, ) for i in range(self.n_stages - 1, 0, -1) ) # note: not including normalization layer and ReLU nonlinearity into the final ConvBlock, # if inserting >1 layers into out_conv then consider treating normalizations separately self.out_dims = out_conv[-1] self.out_conv = ConvBlock(nkernels=[decoder_widths[0]] + out_conv, k=1, s=1, p=0, padding_mode=padding_mode, norm='none', last_relu=False) if out_nonlin_mean: self.out_mean = nn.Sigmoid() # this is for predicting mean values in [0, 1] else: self.out_mean = nn.Identity() # just keep the mean estimates, without applying a nonlinearity if out_nonlin_var=='relu': self.out_var = nn.ReLU() # this is for predicting var values > 0 elif out_nonlin_var=='softplus': self.out_var = nn.Softplus(beta=1, threshold=20) # a smooth approximation to the ReLU function elif out_nonlin_var=='elu': self.out_var = lambda vars: nn.ELU()(vars) + 1 + 1e-8 else: # just keep the variance estimates, self.out_var = nn.Identity() # just keep the variance estimates, without applying a nonlinearity def forward(self, input, batch_positions=None, return_att=False): # SPATIAL ENCODER # collect feature maps in list 'feature_maps' out = self.in_conv.smart_forward(input) feature_maps = [out] for i in range(self.n_stages - 1): out = self.down_blocks[i].smart_forward(feature_maps[-1]) feature_maps.append(out) # SPATIAL DECODER if self.return_maps: maps = [out] out = out[:,0,...] # note: we index to reduce the temporal dummy dimension of size 1 for i in range(self.n_stages - 1): # skip-connect features between paired encoder/decoder blocks skip = feature_maps[-(i + 2)] # upconv the features, concatenating current 'out' and paired 'skip' out = self.up_blocks[i](out, skip[:,0,...]) # note: we index to reduce the temporal dummy dimension of size 1 if self.return_maps: maps.append(out) if self.encoder: return out, maps else: out = self.out_conv(out) # append a singelton temporal dimension such that outputs are [B x T=1 x C x H x W] out = out.unsqueeze(1) # optionally apply an output nonlinearity out_mean = self.out_mean(out[:,:,:13,...]) # mean predictions out_std = self.out_var(out[:,:,13:,...]) # var predictions > 0 out = torch.cat((out_mean, out_std), dim=2) # stack mean and var predictions if return_att: return out, None if self.return_maps: return out, maps else: return out class UTAE(nn.Module): def __init__( self, input_dim, encoder_widths=[64, 64, 64, 128], decoder_widths=[32, 32, 64, 128], out_conv=[13], out_nonlin_mean=False, out_nonlin_var='relu', str_conv_k=4, str_conv_s=2, str_conv_p=1, agg_mode="att_group", encoder_norm="group", norm_skip='batch', norm_up="batch", decoder_norm="batch", n_head=16, d_model=256, d_k=4, encoder=False, return_maps=False, pad_value=0, padding_mode="reflect", positional_encoding=True, scale_by=1 ): """ U-TAE architecture for spatio-temporal encoding of satellite image time series. Args: input_dim (int): Number of channels in the input images. encoder_widths (List[int]): List giving the number of channels of the successive encoder_widths of the convolutional encoder. This argument also defines the number of encoder_widths (i.e. the number of downsampling steps +1) in the architecture. The number of channels are given from top to bottom, i.e. from the highest to the lowest resolution. decoder_widths (List[int], optional): Same as encoder_widths but for the decoder. The order in which the number of channels should be given is also from top to bottom. If this argument is not specified the decoder will have the same configuration as the encoder. out_conv (List[int]): Number of channels of the successive convolutions for the str_conv_k (int): Kernel size of the strided up and down convolutions. str_conv_s (int): Stride of the strided up and down convolutions. str_conv_p (int): Padding of the strided up and down convolutions. agg_mode (str): Aggregation mode for the skip connections. Can either be: - att_group (default) : Attention weighted temporal average, using the same channel grouping strategy as in the LTAE. The attention masks are bilinearly resampled to the resolution of the skipped feature maps. - att_mean : Attention weighted temporal average, using the average attention scores across heads for each date. - mean : Temporal average excluding padded dates. encoder_norm (str): Type of normalisation layer to use in the encoding branch. Can either be: - group : GroupNorm (default) - batch : BatchNorm - instance : InstanceNorm - none: apply no normalization norm_skip (str): similar to encoder_norm, just controlling the normalization after convolving skipped maps norm_up (str): similar to encoder_norm, just controlling the normalization after transposed convolution decoder_norm (str): similar to encoder_norm n_head (int): Number of heads in LTAE. d_model (int): Parameter of LTAE d_k (int): Key-Query space dimension encoder (bool): If true, the feature maps instead of the class scores are returned (default False) return_maps (bool): If true, the feature maps instead of the class scores are returned (default False) pad_value (float): Value used by the dataloader for temporal padding. padding_mode (str): Spatial padding strategy for convolutional layers (passed to nn.Conv2d). positional_encoding (bool): If False, no positional encoding is used (default True). """ super(UTAE, self).__init__() self.n_stages = len(encoder_widths) self.return_maps = return_maps self.encoder_widths = encoder_widths self.decoder_widths = decoder_widths self.enc_dim = ( decoder_widths[0] if decoder_widths is not None else encoder_widths[0] ) self.stack_dim = ( sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths) ) self.pad_value = pad_value self.encoder = encoder self.scale_by = scale_by if encoder: self.return_maps = True if decoder_widths is not None: assert len(encoder_widths) == len(decoder_widths) assert encoder_widths[-1] == decoder_widths[-1] else: decoder_widths = encoder_widths # ENCODER self.in_conv = ConvBlock( nkernels=[input_dim] + [encoder_widths[0]], k=1, s=1, p=0, pad_value=pad_value, norm=encoder_norm, padding_mode=padding_mode, ) self.down_blocks = nn.ModuleList( DownConvBlock( d_in=encoder_widths[i], d_out=encoder_widths[i + 1], k=str_conv_k, s=str_conv_s, p=str_conv_p, pad_value=pad_value, norm=encoder_norm, padding_mode=padding_mode, ) for i in range(self.n_stages - 1) ) # DECODER self.up_blocks = nn.ModuleList( UpConvBlock( d_in=decoder_widths[i], d_out=decoder_widths[i - 1], d_skip=encoder_widths[i - 1], k=str_conv_k, s=str_conv_s, p=str_conv_p, norm_skip=norm_skip, # 'batch' norm_up=norm_up, # 'batch' norm=decoder_norm, #"batch", padding_mode=padding_mode, ) for i in range(self.n_stages - 1, 0, -1) ) # LTAE self.temporal_encoder = LTAE2d( in_channels=encoder_widths[-1], d_model=d_model, n_head=n_head, mlp=[d_model, encoder_widths[-1]], return_att=True, d_k=d_k, positional_encoding=positional_encoding, ) self.temporal_aggregator = Temporal_Aggregator(mode=agg_mode) # note: not including normalization layer and ReLU nonlinearity into the final ConvBlock # if inserting >1 layers into out_conv then consider treating normalizations separately self.out_dims = out_conv[-1] self.out_conv = ConvBlock(nkernels=[decoder_widths[0]] + out_conv, k=1, s=1, p=0, padding_mode=padding_mode, norm='none', last_relu=False) if out_nonlin_mean: self.out_mean = lambda vars: self.scale_by * nn.Sigmoid()(vars) # this is for predicting mean values in [0, 1] else: self.out_mean = lambda vars: nn.Identity()(vars) # just keep the mean estimates, without applying a nonlinearity if out_nonlin_var=='relu': self.out_var = nn.ReLU() # this is for predicting var values > 0 elif out_nonlin_var=='softplus': self.out_var = nn.Softplus(beta=1, threshold=20) # a smooth approximation to the ReLU function elif out_nonlin_var=='elu': self.out_var = lambda vars: nn.ELU()(vars) + 1 + 1e-8 else: # just keep the variance estimates, self.out_var = nn.Identity() # just keep the variance estimates, without applying a nonlinearity def forward(self, input, batch_positions=None, return_att=False): pad_mask = ( (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) ) # BxT pad mask # SPATIAL ENCODER # collect feature maps in list 'feature_maps' out = self.in_conv.smart_forward(input) feature_maps = [out] for i in range(self.n_stages - 1): out = self.down_blocks[i].smart_forward(feature_maps[-1]) feature_maps.append(out) # TEMPORAL ENCODER # feature_maps[-1].shape is torch.Size([B, T, 128, 32, 32]) # -> every attention pixel has an 8x8 receptive field # att.shape is torch.Size([h, B, T, 32, 32]) # out.shape is torch.Size([B, 128, 32, 32]), in self-attention class it's Size([B*32*32*h=32768, 1, 16] out, att = self.temporal_encoder( feature_maps[-1], batch_positions=batch_positions, pad_mask=pad_mask ) # SPATIAL DECODER if self.return_maps: maps = [out] for i in range(self.n_stages - 1): skip = self.temporal_aggregator( feature_maps[-(i + 2)], pad_mask=pad_mask, attn_mask=att ) out = self.up_blocks[i](out, skip) if self.return_maps: maps.append(out) if self.encoder: return out, maps else: out = self.out_conv(out) # append a singelton temporal dimension such that outputs are [B x T=1 x C x H x W] out = out.unsqueeze(1) # optionally apply an output nonlinearity out_mean = self.out_mean(out[:,:,:13,...]) # mean predictions out_std = self.out_var(out[:,:,13:,...]) # var predictions > 0 out = torch.cat((out_mean, out_std), dim=2) # stack mean and var predictions if return_att: return out, att if self.return_maps: return out, maps else: return out class TemporallySharedBlock(nn.Module): """ Helper module for convolutional encoding blocks that are shared across a sequence. This module adds the self.smart_forward() method the the block. smart_forward will combine the batch and temporal dimension of an input tensor if it is 5-D and apply the shared convolutions to all the (batch x temp) positions. """ def __init__(self, pad_value=None): super(TemporallySharedBlock, self).__init__() self.out_shape = None self.pad_value = pad_value def smart_forward(self, input): if len(input.shape) == 4: return self.forward(input) else: b, t, c, h, w = input.shape if self.pad_value is not None: dummy = torch.zeros(input.shape, device=input.device).float() self.out_shape = self.forward(dummy.view(b * t, c, h, w)).shape out = input.view(b * t, c, h, w) if self.pad_value is not None: pad_mask = (out == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) if pad_mask.any(): temp = ( torch.ones( self.out_shape, device=input.device, requires_grad=False ) * self.pad_value ) temp[~pad_mask] = self.forward(out[~pad_mask]) out = temp else: out = self.forward(out) else: out = self.forward(out) _, c, h, w = out.shape out = out.view(b, t, c, h, w) return out class ConvLayer(nn.Module): def __init__( self, nkernels, norm="batch", k=3, s=1, p=1, n_groups=4, last_relu=True, padding_mode="reflect", ): super(ConvLayer, self).__init__() layers = [] if norm == "batch": nl = nn.BatchNorm2d elif norm == "instance": nl = nn.InstanceNorm2d elif norm == "group": nl = lambda num_feats: nn.GroupNorm( num_channels=num_feats, num_groups=n_groups, ) else: nl = None for i in range(len(nkernels) - 1): layers.append( nn.Conv2d( in_channels=nkernels[i], out_channels=nkernels[i + 1], kernel_size=k, padding=p, stride=s, padding_mode=padding_mode, ) ) if nl is not None: layers.append(nl(nkernels[i + 1])) if last_relu: # append a ReLU after the current CONV layer layers.append(nn.ReLU()) elif i < len(nkernels) - 2: # only append ReLU if not last layer layers.append(nn.ReLU()) self.conv = nn.Sequential(*layers) def forward(self, input): return self.conv(input) class ConvBlock(TemporallySharedBlock): def __init__( self, nkernels, pad_value=None, norm="batch", last_relu=True, k=3, s=1, p=1, padding_mode="reflect", ): super(ConvBlock, self).__init__(pad_value=pad_value) self.conv = ConvLayer( nkernels=nkernels, norm=norm, last_relu=last_relu, k=k, s=s, p=p, padding_mode=padding_mode, ) def forward(self, input): return self.conv(input) class DownConvBlock(TemporallySharedBlock): def __init__( self, d_in, d_out, k, s, p, pad_value=None, norm="batch", padding_mode="reflect", ): super(DownConvBlock, self).__init__(pad_value=pad_value) self.down = ConvLayer( nkernels=[d_in, d_in], norm=norm, k=k, s=s, p=p, padding_mode=padding_mode, ) self.conv1 = ConvLayer( nkernels=[d_in, d_out], norm=norm, padding_mode=padding_mode, ) self.conv2 = ConvLayer( nkernels=[d_out, d_out], norm=norm, padding_mode=padding_mode, last_relu=False # note: removing last ReLU in DownConvBlock because it adds onto residual connection ) def forward(self, input): out = self.down(input) out = self.conv1(out) out = out + self.conv2(out) return out def get_norm_layer(out_channels, num_feats, n_groups=4, layer_type='BatchNorm'): if layer_type == 'batch': return nn.BatchNorm2d(out_channels) elif layer_type == 'instance': return nn.InstanceNorm2d(out_channels) elif layer_type == 'group': return nn.GroupNorm(num_channels=num_feats, num_groups=n_groups) class UpConvBlock(nn.Module): def __init__(self, d_in, d_out, k, s, p, norm_skip="batch", norm_up ="batch", norm="batch", n_groups=4, d_skip=None, padding_mode="reflect"): super(UpConvBlock, self).__init__() d = d_out if d_skip is None else d_skip # apply another CONV and norm to the skipped paired map """" self.skip_conv = nn.Sequential( nn.Conv2d(in_channels=d, out_channels=d, kernel_size=1), nn.BatchNorm2d(d), nn.ReLU(), ) """ if norm_skip in ['group', 'batch', 'instance']: self.skip_conv = nn.Sequential( nn.Conv2d(in_channels=d, out_channels=d, kernel_size=1), get_norm_layer(d, d, n_groups, norm_skip), #nn.BatchNorm2d(d), nn.ReLU()) else: self.skip_conv = nn.Sequential( nn.Conv2d(in_channels=d, out_channels=d, kernel_size=1), nn.ReLU()) # transposed CONV layer to perform upsampling """ self.up = nn.Sequential( nn.ConvTranspose2d( in_channels=d_in, out_channels=d_out, kernel_size=k, stride=s, padding=p ), nn.BatchNorm2d(d_out), nn.ReLU(), ) """ if norm_up in ['group', 'batch', 'instance']: self.up = nn.Sequential( nn.ConvTranspose2d(in_channels=d_in, out_channels=d_out, kernel_size=k, stride=s, padding=p), get_norm_layer(d_out, d_out, n_groups, norm_up), #nn.BatchNorm2d(d_out), nn.ReLU()) else: self.up = nn.Sequential( nn.ConvTranspose2d(in_channels=d_in, out_channels=d_out, kernel_size=k, stride=s, padding=p), nn.ReLU()) self.conv1 = ConvLayer( nkernels=[d_out + d, d_out], norm=norm, padding_mode=padding_mode, # removing downsampling relu in UpConvBlock because of MobileNet2 ) self.conv2 = ConvLayer( nkernels=[d_out, d_out], norm=norm, padding_mode=padding_mode, last_relu=False # removing last relu in UpConvBlock because it adds onto residual connection ) def forward(self, input, skip): out = self.up(input) # transposed CONV on previous layer # apply another CONV and norm to the skipped input --> paired encoder map out = torch.cat([out, self.skip_conv(skip)], dim=1) # concat '' with paired encoder map out = self.conv1(out) # CONV again out = out + self.conv2(out) # conv with residual return out class Temporal_Aggregator(nn.Module): def __init__(self, mode="mean"): super(Temporal_Aggregator, self).__init__() self.mode = mode def forward(self, x, pad_mask=None, attn_mask=None): if pad_mask is not None and pad_mask.any(): if self.mode == "att_group": n_heads, b, t, h, w = attn_mask.shape attn = attn_mask.view(n_heads * b, t, h, w) if x.shape[-2] > w: attn = nn.Upsample( size=x.shape[-2:], mode="bilinear", align_corners=False )(attn) else: attn = nn.AvgPool2d(kernel_size=w // x.shape[-2])(attn) attn = attn.view(n_heads, b, t, *x.shape[-2:]) attn = attn * (~pad_mask).float()[None, :, :, None, None] out = torch.stack(x.chunk(n_heads, dim=2)) # hxBxTxC/hxHxW out = attn[:, :, :, None, :, :] * out out = out.sum(dim=2) # sum on temporal dim -> hxBxC/hxHxW out = torch.cat([group for group in out], dim=1) # -> BxCxHxW return out elif self.mode == "att_mean": attn = attn_mask.mean(dim=0) # average over heads -> BxTxHxW attn = nn.Upsample( size=x.shape[-2:], mode="bilinear", align_corners=False )(attn) attn = attn * (~pad_mask).float()[:, :, None, None] out = (x * attn[:, :, None, :, :]).sum(dim=1) return out elif self.mode == "mean": out = x * (~pad_mask).float()[:, :, None, None, None] out = out.sum(dim=1) / (~pad_mask).sum(dim=1)[:, None, None, None] return out else: if self.mode == "att_group": n_heads, b, t, h, w = attn_mask.shape attn = attn_mask.view(n_heads * b, t, h, w) if x.shape[-2] > w: attn = nn.Upsample( size=x.shape[-2:], mode="bilinear", align_corners=False )(attn) else: attn = nn.AvgPool2d(kernel_size=w // x.shape[-2])(attn) attn = attn.view(n_heads, b, t, *x.shape[-2:]) out = torch.stack(x.chunk(n_heads, dim=2)) # hxBxTxC/hxHxW out = attn[:, :, :, None, :, :] * out out = out.sum(dim=2) # sum on temporal dim -> hxBxC/hxHxW out = torch.cat([group for group in out], dim=1) # -> BxCxHxW return out elif self.mode == "att_mean": attn = attn_mask.mean(dim=0) # average over heads -> BxTxHxW attn = nn.Upsample( size=x.shape[-2:], mode="bilinear", align_corners=False )(attn) out = (x * attn[:, :, None, :, :]).sum(dim=1) return out elif self.mode == "mean": return x.mean(dim=1) class RecUNet(nn.Module): """Recurrent U-Net architecture. Similar to the U-TAE architecture but the L-TAE is replaced by a recurrent network and temporal averages are computed for the skip connections.""" def __init__( self, input_dim, encoder_widths=[64, 64, 64, 128], decoder_widths=[32, 32, 64, 128], out_conv=[13], str_conv_k=4, str_conv_s=2, str_conv_p=1, temporal="lstm", input_size=128, encoder_norm="group", hidden_dim=128, encoder=False, padding_mode="reflect", pad_value=0, ): super(RecUNet, self).__init__() self.n_stages = len(encoder_widths) self.temporal = temporal self.encoder_widths = encoder_widths self.decoder_widths = decoder_widths self.enc_dim = ( decoder_widths[0] if decoder_widths is not None else encoder_widths[0] ) self.stack_dim = ( sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths) ) self.pad_value = pad_value self.encoder = encoder if encoder: self.return_maps = True else: self.return_maps = False if decoder_widths is not None: assert len(encoder_widths) == len(decoder_widths) assert encoder_widths[-1] == decoder_widths[-1] else: decoder_widths = encoder_widths self.in_conv = ConvBlock( nkernels=[input_dim] + [encoder_widths[0], encoder_widths[0]], pad_value=pad_value, norm=encoder_norm, ) self.down_blocks = nn.ModuleList( DownConvBlock( d_in=encoder_widths[i], d_out=encoder_widths[i + 1], k=str_conv_k, s=str_conv_s, p=str_conv_p, pad_value=pad_value, norm=encoder_norm, padding_mode=padding_mode, ) for i in range(self.n_stages - 1) ) self.up_blocks = nn.ModuleList( UpConvBlock( d_in=decoder_widths[i], d_out=decoder_widths[i - 1], d_skip=encoder_widths[i - 1], k=str_conv_k, s=str_conv_s, p=str_conv_p, norm=encoder_norm, padding_mode=padding_mode, ) for i in range(self.n_stages - 1, 0, -1) ) self.temporal_aggregator = Temporal_Aggregator(mode="mean") if temporal == "mean": self.temporal_encoder = Temporal_Aggregator(mode="mean") elif temporal == "lstm": size = int(input_size / str_conv_s ** (self.n_stages - 1)) self.temporal_encoder = ConvLSTM( input_dim=encoder_widths[-1], input_size=(size, size), hidden_dim=hidden_dim, kernel_size=(3, 3), ) self.out_convlstm = nn.Conv2d( in_channels=hidden_dim, out_channels=encoder_widths[-1], kernel_size=3, padding=1, ) elif temporal == "blstm": size = int(input_size / str_conv_s ** (self.n_stages - 1)) self.temporal_encoder = BConvLSTM( input_dim=encoder_widths[-1], input_size=(size, size), hidden_dim=hidden_dim, kernel_size=(3, 3), ) self.out_convlstm = nn.Conv2d( in_channels=2 * hidden_dim, out_channels=encoder_widths[-1], kernel_size=3, padding=1, ) elif temporal == "mono": self.temporal_encoder = None self.out_conv = ConvBlock(nkernels=[decoder_widths[0]] + out_conv, k=1, s=1, p=0, padding_mode=padding_mode) def forward(self, input, batch_positions=None): pad_mask = ( (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) ) # BxT pad mask out = self.in_conv.smart_forward(input) feature_maps = [out] # ENCODER for i in range(self.n_stages - 1): out = self.down_blocks[i].smart_forward(feature_maps[-1]) feature_maps.append(out) # Temporal encoder if self.temporal == "mean": out = self.temporal_encoder(feature_maps[-1], pad_mask=pad_mask) elif self.temporal == "lstm": _, out = self.temporal_encoder(feature_maps[-1], pad_mask=pad_mask) out = out[0][1] # take last cell state as embedding out = self.out_convlstm(out) elif self.temporal == "blstm": out = self.temporal_encoder(feature_maps[-1], pad_mask=pad_mask) out = self.out_convlstm(out) elif self.temporal == "mono": out = feature_maps[-1] if self.return_maps: maps = [out] for i in range(self.n_stages - 1): if self.temporal != "mono": skip = self.temporal_aggregator( feature_maps[-(i + 2)], pad_mask=pad_mask ) else: skip = feature_maps[-(i + 2)] out = self.up_blocks[i](out, skip) if self.return_maps: maps.append(out) if self.encoder: return out, maps else: out = self.out_conv(out) if self.return_maps: return out, maps else: return out