# PyTorch for deep learning operations import torch import torch.nn as nn # PyTorch data loading and utilities import torch.multiprocessing # COCO dataset tools from transformers import BertModel, BertTokenizer, AutoModel, AutoImageProcessor from configs import CFG from text_image import OneEncoder as TextImageEncoder class AlignmentLayer(nn.Module): def __init__(self, input_dim=768, projection_dim=CFG.projection_dim, dropout_rate=CFG.dropout_rate, *args, **kwargs): super(AlignmentLayer, self).__init__(*args, **kwargs) # Attributes self.input_dim = input_dim self.projection_dim = projection_dim self.dropout_rate = dropout_rate # Layers self.linear_layer1 = nn.Linear(self.input_dim, self.projection_dim) self.gelu = nn.GELU() self.linear_layer2 = nn.Linear(self.projection_dim, self.projection_dim) self.dropout = nn.Dropout(self.dropout_rate) self.normalization_layer = nn.LayerNorm(self.projection_dim) def forward(self, inputs): x = inputs x = self.linear_layer1(x) x = self.gelu(x) x = self.linear_layer2(x) x = self.dropout(x) x = self.normalization_layer(x) return x def __call__(self, inputs): return self.forward(inputs) class RadioEncoder(nn.Module): def __init__(self, model_name=CFG.radio_name, projection_dim=CFG.projection_dim, trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs): super(RadioEncoder, self).__init__(*args, **kwargs) # Attributes self.model_name = model_name self.projection_dim = projection_dim self.dropout_rate = dropout_rate self.trainable = trainable # Models self.pretrained_encoder = AutoModel.from_pretrained(self.model_name) self.alignment_layer = AlignmentLayer( input_dim=self.pretrained_encoder.config.hidden_size, projection_dim=self.projection_dim, dropout_rate=self.dropout_rate) # Freeze Wav2VecModel for parameter in self.pretrained_encoder.parameters(): parameter.requires_grad = self.trainable def forward(self, inputs): x = self.pretrained_encoder(inputs).last_hidden_state x = self.alignment_layer(x) return x def __call__(self, inputs): return self.forward(inputs) class ModalityTokenEncoder(nn.Module): def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs): super(ModalityTokenEncoder, self).__init__(*args, **kwargs) # Attributes self.projection_dim = projection_dim self.device = device self.token_size = token_size # Models radio_variance = torch.rand(1) * 0.5 + 0.1 self.radio_token = nn.Parameter(torch.normal(mean=0, std=radio_variance.item(), size=(self.token_size, self.projection_dim)).to(self.device)) def forward(self): return self.radio_token def __call__(self): return self.forward() class OneEncoder(nn.Module): def __init__(self, device='cpu', modality_token_encoder=ModalityTokenEncoder(), checkpoint="bilalfaye/OneEncoder-text-image", radio_processor=AutoImageProcessor.from_pretrained("microsoft/rad-dino"), sample_rate=CFG.sample_rate, radio_encoder=RadioEncoder(), *args, **kwargs): super(OneEncoder, self).__init__(*args, **kwargs) self.device = device self.checkpoint = checkpoint self.modality_token_encoder = modality_token_encoder self.modality_token_encoder.device = self.device self.text_image_encoder = TextImageEncoder(device=self.device) self.text_image_encoder.from_pretrained(self.checkpoint) self.radio_processor = radio_processor self.sample_rate = sample_rate self.radio_encoder = radio_encoder self.temperature = nn.Parameter(torch.tensor(0.07).to(self.device)) # Freeze for parameter in self.text_image_encoder.parameters(): parameter.requires_grad = False def encode_radio(self, pil_radios=None, radios=None): """ :param pil_radios: list of pillow images :param radios: preprocessed image :return: tensor """ if pil_radios is not None: tensors = self.radio_processor(pil_radios, return_tensors="pt")["pixel_values"].to(self.device) else: tensors = radios.to(self.device) features = self.radio_encoder(tensors) radio_token = self.modality_token_encoder() outputs = self.text_image_encoder.universal_projection_encoder([features, radio_token]).last_hidden_state return outputs