arHang's picture
Upload folder using huggingface_hub
0f1aae9 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
from torchvision import models
from sentence_transformers import SentenceTransformer
class ClipWithFrozenSBert(nn.Module):
def __init__(self, embed_dim=512):
super().__init__()
# Load ViT base
vit = models.vit_l_16(pretrained=True)
# Remove classification head
vit.heads = nn.Identity()
self.image_encoder = vit
image_encoder_embed_dim = vit.hidden_dim # 768 for vit_b_16
# Load text encoder
self.text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
self.text_encoder.heads = nn.Identity()
text_encoder_embed_dim = self.text_encoder.get_sentence_embedding_dimension()
# Projection layers
self.image_proj = nn.Linear(image_encoder_embed_dim, embed_dim)
self.text_proj = nn.Linear(text_encoder_embed_dim, embed_dim)
# Fine-tune image encoder, freeze text encoder
for p in self.image_encoder.parameters():
p.requires_grad = True
for p in self.text_encoder.parameters():
p.requires_grad = False
def forward(self, images, texts):
# Extract image features
image_features = self.image_encoder(images)
# Extract text features
with torch.no_grad():
text_features = self.text_encoder.encode(texts, convert_to_tensor=True).to(images.device)
# Project to shared embedding space
image_embeds = self.image_proj(image_features)
text_embeds = self.text_proj(text_features)
# Normalize embeddings
image_embeds = F.normalize(image_embeds, dim=-1)
text_embeds = F.normalize(text_embeds, dim=-1)
return image_embeds, text_embeds