|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import AutoModel
|
|
from torchvision import models
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
class ClipWithFrozenSBert(nn.Module):
|
|
def __init__(self, embed_dim=512):
|
|
super().__init__()
|
|
|
|
|
|
vit = models.vit_l_16(pretrained=True)
|
|
|
|
|
|
vit.heads = nn.Identity()
|
|
|
|
self.image_encoder = vit
|
|
image_encoder_embed_dim = vit.hidden_dim
|
|
|
|
|
|
self.text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
|
|
self.text_encoder.heads = nn.Identity()
|
|
text_encoder_embed_dim = self.text_encoder.get_sentence_embedding_dimension()
|
|
|
|
|
|
self.image_proj = nn.Linear(image_encoder_embed_dim, embed_dim)
|
|
self.text_proj = nn.Linear(text_encoder_embed_dim, embed_dim)
|
|
|
|
|
|
for p in self.image_encoder.parameters():
|
|
p.requires_grad = True
|
|
for p in self.text_encoder.parameters():
|
|
p.requires_grad = False
|
|
|
|
def forward(self, images, texts):
|
|
|
|
image_features = self.image_encoder(images)
|
|
|
|
with torch.no_grad():
|
|
text_features = self.text_encoder.encode(texts, convert_to_tensor=True).to(images.device)
|
|
|
|
|
|
image_embeds = self.image_proj(image_features)
|
|
text_embeds = self.text_proj(text_features)
|
|
|
|
|
|
image_embeds = F.normalize(image_embeds, dim=-1)
|
|
text_embeds = F.normalize(text_embeds, dim=-1)
|
|
|
|
return image_embeds, text_embeds |