Update README.md
Browse files
README.md
CHANGED
@@ -75,28 +75,25 @@ This will install an editable version of repo, allowing you to make changes to t
|
|
75 |
## Image and Text Feature extraction with a Trained Model
|
76 |
```python
|
77 |
import torch
|
78 |
-
from core.vision_encoder.factory import create_model_and_transforms, get_tokenizer
|
79 |
from PIL import Image
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
model
|
85 |
-
model_name,
|
86 |
-
pretrained=pretrained,
|
87 |
-
)
|
88 |
model = model.cuda()
|
89 |
-
tokenizer = get_tokenizer(model_name)
|
90 |
|
91 |
-
|
|
|
|
|
|
|
92 |
text = tokenizer(["a diagram", "a dog", "a cat"]).cuda()
|
93 |
|
94 |
with torch.no_grad(), torch.autocast("cuda"):
|
95 |
-
image_features = model
|
96 |
-
|
97 |
-
image_features /= image_features.norm(dim=-1, keepdim=True)
|
98 |
-
text_features /= text_features.norm(dim=-1, keepdim=True)
|
99 |
-
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
100 |
|
101 |
print("Label probs:", text_probs) # prints: [[0.0, 0.0, 1.0]]
|
102 |
```
|
|
|
75 |
## Image and Text Feature extraction with a Trained Model
|
76 |
```python
|
77 |
import torch
|
|
|
78 |
from PIL import Image
|
79 |
+
import core.vision_encoder.pe as pe
|
80 |
+
import core.vision_encoder.transforms as transforms
|
81 |
|
82 |
+
print("CLIP configs:", pe.CLIP.available_configs())
|
83 |
+
# CLIP configs: ['PE-Core-G14-448', 'PE-Core-L14-336', 'PE-Core-B16-224']
|
84 |
|
85 |
+
model = pe.CLIP.from_config("PE-Core-G14-448", pretrained=True) # Downloads from HF
|
|
|
|
|
|
|
86 |
model = model.cuda()
|
|
|
87 |
|
88 |
+
preprocess = transforms.get_image_transform(model.image_size)
|
89 |
+
tokenizer = transforms.get_text_tokenizer(model.context_length)
|
90 |
+
|
91 |
+
image = preprocess(Image.open("docs/assets/cat.png")).unsqueeze(0).cuda()
|
92 |
text = tokenizer(["a diagram", "a dog", "a cat"]).cuda()
|
93 |
|
94 |
with torch.no_grad(), torch.autocast("cuda"):
|
95 |
+
image_features, text_features, logit_scale = model(image, text)
|
96 |
+
text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1)
|
|
|
|
|
|
|
97 |
|
98 |
print("Label probs:", text_probs) # prints: [[0.0, 0.0, 1.0]]
|
99 |
```
|