Upload folder using huggingface_hub
Browse files- README.md +1 -3
- model.py +52 -0
- model_batch_32_vit_L_clip.pth +3 -0
- model_index.json +5 -0
README.md
CHANGED
@@ -1,3 +1 @@
|
|
1 |
-
|
2 |
-
license: apache-2.0
|
3 |
-
---
|
|
|
1 |
+
Later update
|
|
|
|
model.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import AutoModel
|
5 |
+
from torchvision import models
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
|
8 |
+
|
9 |
+
class ClipWithFrozenSBert(nn.Module):
|
10 |
+
def __init__(self, embed_dim=512):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
# Load ViT base
|
14 |
+
vit = models.vit_l_16(pretrained=True)
|
15 |
+
|
16 |
+
# Remove classification head
|
17 |
+
vit.heads = nn.Identity()
|
18 |
+
|
19 |
+
self.image_encoder = vit
|
20 |
+
image_encoder_embed_dim = vit.hidden_dim # 768 for vit_b_16
|
21 |
+
|
22 |
+
# Load text encoder
|
23 |
+
self.text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
|
24 |
+
self.text_encoder.heads = nn.Identity()
|
25 |
+
text_encoder_embed_dim = self.text_encoder.get_sentence_embedding_dimension()
|
26 |
+
|
27 |
+
# Projection layers
|
28 |
+
self.image_proj = nn.Linear(image_encoder_embed_dim, embed_dim)
|
29 |
+
self.text_proj = nn.Linear(text_encoder_embed_dim, embed_dim)
|
30 |
+
|
31 |
+
# Fine-tune image encoder, freeze text encoder
|
32 |
+
for p in self.image_encoder.parameters():
|
33 |
+
p.requires_grad = True
|
34 |
+
for p in self.text_encoder.parameters():
|
35 |
+
p.requires_grad = False
|
36 |
+
|
37 |
+
def forward(self, images, texts):
|
38 |
+
# Extract image features
|
39 |
+
image_features = self.image_encoder(images)
|
40 |
+
# Extract text features
|
41 |
+
with torch.no_grad():
|
42 |
+
text_features = self.text_encoder.encode(texts, convert_to_tensor=True).to(images.device)
|
43 |
+
|
44 |
+
# Project to shared embedding space
|
45 |
+
image_embeds = self.image_proj(image_features)
|
46 |
+
text_embeds = self.text_proj(text_features)
|
47 |
+
|
48 |
+
# Normalize embeddings
|
49 |
+
image_embeds = F.normalize(image_embeds, dim=-1)
|
50 |
+
text_embeds = F.normalize(text_embeds, dim=-1)
|
51 |
+
|
52 |
+
return image_embeds, text_embeds
|
model_batch_32_vit_L_clip.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:148875b7f47f3f76ad3032ceeb75de46facad511483c80812edc5581965702f8
|
3 |
+
size 1307128266
|
model_index.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoModel": "model.ClipWithFrozenSBert"
|
4 |
+
}
|
5 |
+
}
|