import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel from torchvision import models from sentence_transformers import SentenceTransformer class ClipWithFrozenSBert(nn.Module): def __init__(self, embed_dim=512): super().__init__() # Load ViT base vit = models.vit_l_16(pretrained=True) # Remove classification head vit.heads = nn.Identity() self.image_encoder = vit image_encoder_embed_dim = vit.hidden_dim # 768 for vit_b_16 # Load text encoder self.text_encoder = SentenceTransformer('all-MiniLM-L6-v2') self.text_encoder.heads = nn.Identity() text_encoder_embed_dim = self.text_encoder.get_sentence_embedding_dimension() # Projection layers self.image_proj = nn.Linear(image_encoder_embed_dim, embed_dim) self.text_proj = nn.Linear(text_encoder_embed_dim, embed_dim) # Fine-tune image encoder, freeze text encoder for p in self.image_encoder.parameters(): p.requires_grad = True for p in self.text_encoder.parameters(): p.requires_grad = False def forward(self, images, texts): # Extract image features image_features = self.image_encoder(images) # Extract text features with torch.no_grad(): text_features = self.text_encoder.encode(texts, convert_to_tensor=True).to(images.device) # Project to shared embedding space image_embeds = self.image_proj(image_features) text_embeds = self.text_proj(text_features) # Normalize embeddings image_embeds = F.normalize(image_embeds, dim=-1) text_embeds = F.normalize(text_embeds, dim=-1) return image_embeds, text_embeds