Spaces:
Runtime error
Runtime error
TYH71
commited on
Commit
·
2bbb2d2
1
Parent(s):
33396c7
fix: .gitignore issue
Browse files- .gitignore +1 -1
- src/model/__init__.py +0 -0
- src/model/clip.py +73 -0
.gitignore
CHANGED
@@ -159,4 +159,4 @@ cython_debug/
|
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
#.idea/
|
161 |
|
162 |
-
|
|
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
#.idea/
|
161 |
|
162 |
+
model_dir/
|
src/model/__init__.py
ADDED
File without changes
|
src/model/clip.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""CLIP model for zero-shot classification; running on CPU machine"""
|
2 |
+
from typing import List, Dict
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import open_clip
|
6 |
+
from open_clip import tokenizer
|
7 |
+
|
8 |
+
# modules
|
9 |
+
from src.core.singleton import SingletonMeta
|
10 |
+
from src.core.logger import logger
|
11 |
+
|
12 |
+
class CLIP_Model(metaclass=SingletonMeta):
|
13 |
+
def __init__(self,
|
14 |
+
model_name: str = "ViT-B/32",
|
15 |
+
pretrained: str = "laion2b_s34b_b79k",
|
16 |
+
jit: bool = False
|
17 |
+
):
|
18 |
+
logger.debug("creating CLIP Model Object")
|
19 |
+
self.config = dict(
|
20 |
+
model_name=model_name,
|
21 |
+
pretrained=pretrained,
|
22 |
+
precision="bf16",
|
23 |
+
device=torch.device("cpu"),
|
24 |
+
jit=jit,
|
25 |
+
cache_dir="model_dir/"
|
26 |
+
)
|
27 |
+
self.model, self.preprocess = open_clip.create_model_from_pretrained(**self.config)
|
28 |
+
self.model.eval()
|
29 |
+
logger.info(f"{self.config.get('model_name')} {self.config.get('pretrained')} initialized")
|
30 |
+
|
31 |
+
def __call__(self, image: Image.Image, text: List[str]) -> Dict[str, float]:
|
32 |
+
"""inference pipeline for CLIP model"""
|
33 |
+
with torch.inference_mode(), torch.cpu.amp.autocast():
|
34 |
+
# compute image features
|
35 |
+
image_input = self.preprocess_image(image)
|
36 |
+
image_features = self.get_image_features(image_input)
|
37 |
+
logger.info("image features computed")
|
38 |
+
|
39 |
+
# compute text features
|
40 |
+
text_input = self.preprocess_text(text)
|
41 |
+
text_features = self.get_text_features(text_input)
|
42 |
+
logger.info("text features computed")
|
43 |
+
|
44 |
+
# zero-shot classification
|
45 |
+
text_probs = self.matmul_and_softmax(image_features, text_features)
|
46 |
+
logger.debug("text_probs: %s", text_probs)
|
47 |
+
return dict(zip(text, text_probs))
|
48 |
+
|
49 |
+
def preprocess_image(self, image: Image.Image) -> torch.Tensor:
|
50 |
+
"""function to preprocess the input image"""
|
51 |
+
return self.preprocess(image).unsqueeze(0)
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def preprocess_text(text: List[str]) -> torch.Tensor:
|
55 |
+
"""function to preprocess the input text"""
|
56 |
+
return tokenizer.tokenize(text)
|
57 |
+
|
58 |
+
def get_image_features(self, image_input: torch.Tensor) -> torch.Tensor:
|
59 |
+
"""function to get the image features"""
|
60 |
+
image_features = self.model.encode_image(image_input)
|
61 |
+
image_features /= image_features.norm(dim=-1, keepdim=True) # normalize vector prior
|
62 |
+
return image_features
|
63 |
+
|
64 |
+
def get_text_features(self, text_input: torch.Tensor) -> torch.Tensor:
|
65 |
+
"""function to get the text features"""
|
66 |
+
text_features = self.model.encode_text(text_input)
|
67 |
+
text_features /= text_features.norm(dim=-1, keepdim=True) # normalize vector prior
|
68 |
+
return text_features
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def matmul_and_softmax(image_features: torch.Tensor, text_features: torch.Tensor) -> List[float]:
|
72 |
+
"""compute matmul and softmax"""
|
73 |
+
return (100.0 * image_features @ text_features.T).softmax(dim=-1).squeeze(0).tolist()
|