TYH71 commited on
Commit
2bbb2d2
·
1 Parent(s): 33396c7

fix: .gitignore issue

Browse files
Files changed (3) hide show
  1. .gitignore +1 -1
  2. src/model/__init__.py +0 -0
  3. src/model/clip.py +73 -0
.gitignore CHANGED
@@ -159,4 +159,4 @@ cython_debug/
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
 
162
- model/
 
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
 
162
+ model_dir/
src/model/__init__.py ADDED
File without changes
src/model/clip.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLIP model for zero-shot classification; running on CPU machine"""
2
+ from typing import List, Dict
3
+ from PIL import Image
4
+ import torch
5
+ import open_clip
6
+ from open_clip import tokenizer
7
+
8
+ # modules
9
+ from src.core.singleton import SingletonMeta
10
+ from src.core.logger import logger
11
+
12
+ class CLIP_Model(metaclass=SingletonMeta):
13
+ def __init__(self,
14
+ model_name: str = "ViT-B/32",
15
+ pretrained: str = "laion2b_s34b_b79k",
16
+ jit: bool = False
17
+ ):
18
+ logger.debug("creating CLIP Model Object")
19
+ self.config = dict(
20
+ model_name=model_name,
21
+ pretrained=pretrained,
22
+ precision="bf16",
23
+ device=torch.device("cpu"),
24
+ jit=jit,
25
+ cache_dir="model_dir/"
26
+ )
27
+ self.model, self.preprocess = open_clip.create_model_from_pretrained(**self.config)
28
+ self.model.eval()
29
+ logger.info(f"{self.config.get('model_name')} {self.config.get('pretrained')} initialized")
30
+
31
+ def __call__(self, image: Image.Image, text: List[str]) -> Dict[str, float]:
32
+ """inference pipeline for CLIP model"""
33
+ with torch.inference_mode(), torch.cpu.amp.autocast():
34
+ # compute image features
35
+ image_input = self.preprocess_image(image)
36
+ image_features = self.get_image_features(image_input)
37
+ logger.info("image features computed")
38
+
39
+ # compute text features
40
+ text_input = self.preprocess_text(text)
41
+ text_features = self.get_text_features(text_input)
42
+ logger.info("text features computed")
43
+
44
+ # zero-shot classification
45
+ text_probs = self.matmul_and_softmax(image_features, text_features)
46
+ logger.debug("text_probs: %s", text_probs)
47
+ return dict(zip(text, text_probs))
48
+
49
+ def preprocess_image(self, image: Image.Image) -> torch.Tensor:
50
+ """function to preprocess the input image"""
51
+ return self.preprocess(image).unsqueeze(0)
52
+
53
+ @staticmethod
54
+ def preprocess_text(text: List[str]) -> torch.Tensor:
55
+ """function to preprocess the input text"""
56
+ return tokenizer.tokenize(text)
57
+
58
+ def get_image_features(self, image_input: torch.Tensor) -> torch.Tensor:
59
+ """function to get the image features"""
60
+ image_features = self.model.encode_image(image_input)
61
+ image_features /= image_features.norm(dim=-1, keepdim=True) # normalize vector prior
62
+ return image_features
63
+
64
+ def get_text_features(self, text_input: torch.Tensor) -> torch.Tensor:
65
+ """function to get the text features"""
66
+ text_features = self.model.encode_text(text_input)
67
+ text_features /= text_features.norm(dim=-1, keepdim=True) # normalize vector prior
68
+ return text_features
69
+
70
+ @staticmethod
71
+ def matmul_and_softmax(image_features: torch.Tensor, text_features: torch.Tensor) -> List[float]:
72
+ """compute matmul and softmax"""
73
+ return (100.0 * image_features @ text_features.T).softmax(dim=-1).squeeze(0).tolist()