arHang commited on
Commit
0f1aae9
·
verified ·
1 Parent(s): aeb6c2d

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +1 -3
  2. model.py +52 -0
  3. model_batch_32_vit_L_clip.pth +3 -0
  4. model_index.json +5 -0
README.md CHANGED
@@ -1,3 +1 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
1
+ Later update
 
 
model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoModel
5
+ from torchvision import models
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+
9
+ class ClipWithFrozenSBert(nn.Module):
10
+ def __init__(self, embed_dim=512):
11
+ super().__init__()
12
+
13
+ # Load ViT base
14
+ vit = models.vit_l_16(pretrained=True)
15
+
16
+ # Remove classification head
17
+ vit.heads = nn.Identity()
18
+
19
+ self.image_encoder = vit
20
+ image_encoder_embed_dim = vit.hidden_dim # 768 for vit_b_16
21
+
22
+ # Load text encoder
23
+ self.text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
24
+ self.text_encoder.heads = nn.Identity()
25
+ text_encoder_embed_dim = self.text_encoder.get_sentence_embedding_dimension()
26
+
27
+ # Projection layers
28
+ self.image_proj = nn.Linear(image_encoder_embed_dim, embed_dim)
29
+ self.text_proj = nn.Linear(text_encoder_embed_dim, embed_dim)
30
+
31
+ # Fine-tune image encoder, freeze text encoder
32
+ for p in self.image_encoder.parameters():
33
+ p.requires_grad = True
34
+ for p in self.text_encoder.parameters():
35
+ p.requires_grad = False
36
+
37
+ def forward(self, images, texts):
38
+ # Extract image features
39
+ image_features = self.image_encoder(images)
40
+ # Extract text features
41
+ with torch.no_grad():
42
+ text_features = self.text_encoder.encode(texts, convert_to_tensor=True).to(images.device)
43
+
44
+ # Project to shared embedding space
45
+ image_embeds = self.image_proj(image_features)
46
+ text_embeds = self.text_proj(text_features)
47
+
48
+ # Normalize embeddings
49
+ image_embeds = F.normalize(image_embeds, dim=-1)
50
+ text_embeds = F.normalize(text_embeds, dim=-1)
51
+
52
+ return image_embeds, text_embeds
model_batch_32_vit_L_clip.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:148875b7f47f3f76ad3032ceeb75de46facad511483c80812edc5581965702f8
3
+ size 1307128266
model_index.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoModel": "model.ClipWithFrozenSBert"
4
+ }
5
+ }