Image Feature Extraction
Transformers
Safetensors
dinov2
fepegar commited on
Commit
d59bc91
·
verified ·
1 Parent(s): ed6e71c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +68 -4
README.md CHANGED
@@ -58,6 +58,8 @@ Underlying biases of the training datasets may not be well characterized.
58
 
59
  ## Getting started
60
 
 
 
61
  Let us first write an auxiliary function to download a chest X-ray.
62
 
63
  ```python
@@ -73,6 +75,8 @@ Let us first write an auxiliary function to download a chest X-ray.
73
  ...
74
  ```
75
 
 
 
76
  Now let us download the model and encode an image.
77
 
78
  ```python
@@ -82,13 +86,17 @@ Now let us download the model and encode an image.
82
  >>>
83
  >>> # Download the model
84
  >>> repo = "microsoft/rad-dino"
85
- >>> model = AutoModel.from_pretrained(repo)
86
  >>>
87
  >>> # The processor takes a PIL image, performs resizing, center-cropping, and
88
  >>> # intensity normalization using stats from MIMIC-CXR, and returns a
89
  >>> # dictionary with a PyTorch tensor ready for the encoder
90
  >>> processor = AutoImageProcessor.from_pretrained(repo)
91
- >>>
 
 
 
 
92
  >>> # Download and preprocess a chest X-ray
93
  >>> image = download_sample_image()
94
  >>> image.size # (width, height)
@@ -97,7 +105,7 @@ Now let us download the model and encode an image.
97
  >>>
98
  >>> # Encode the image!
99
  >>> with torch.inference_mode():
100
- >>> outputs = model(**inputs)
101
  >>>
102
  >>> # Look at the CLS embeddings
103
  >>> cls_embeddings = outputs.pooler_output
@@ -124,6 +132,62 @@ We will use [`einops`](https://einops.rocks/) (install with `pip install einops`
124
  torch.Size([1, 768, 37, 37])
125
  ```
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  ## Training details
128
 
129
  ### Training data
@@ -225,4 +289,4 @@ We used [SimpleITK](https://simpleitk.org/) and [Pydicom](https://pydicom.github
225
 
226
  ## Model card contact
227
 
228
- Fernando Pérez-García ([`[email protected]`](mailto:[email protected])).
 
58
 
59
  ## Getting started
60
 
61
+ ### Get some data
62
+
63
  Let us first write an auxiliary function to download a chest X-ray.
64
 
65
  ```python
 
75
  ...
76
  ```
77
 
78
+ ### Load the model
79
+
80
  Now let us download the model and encode an image.
81
 
82
  ```python
 
86
  >>>
87
  >>> # Download the model
88
  >>> repo = "microsoft/rad-dino"
89
+ >>> rad_dino = AutoModel.from_pretrained(repo)
90
  >>>
91
  >>> # The processor takes a PIL image, performs resizing, center-cropping, and
92
  >>> # intensity normalization using stats from MIMIC-CXR, and returns a
93
  >>> # dictionary with a PyTorch tensor ready for the encoder
94
  >>> processor = AutoImageProcessor.from_pretrained(repo)
95
+ ```
96
+
97
+ ### Encode an image
98
+
99
+ ```python
100
  >>> # Download and preprocess a chest X-ray
101
  >>> image = download_sample_image()
102
  >>> image.size # (width, height)
 
105
  >>>
106
  >>> # Encode the image!
107
  >>> with torch.inference_mode():
108
+ >>> outputs = rad_dino(**inputs)
109
  >>>
110
  >>> # Look at the CLS embeddings
111
  >>> cls_embeddings = outputs.pooler_output
 
132
  torch.Size([1, 768, 37, 37])
133
  ```
134
 
135
+ ### Weights for fine-tuning
136
+
137
+ We have released a checkpoint compatible with
138
+ [the original DINOv2 code](https://github.com/facebookresearch/dinov2) to help
139
+ researchers fine-tune our model.
140
+
141
+ First, let us write code to load a
142
+ [`safetensors` checkpoint](https://huggingface.co/docs/safetensors).
143
+
144
+ ```python
145
+ >>> import safetensors
146
+ >>> def safetensors_to_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]:
147
+ ... state_dict = {}
148
+ ... with safe_open(checkpoint_path, framework="pt") as ckpt_file:
149
+ ... for key in ckpt_file.keys():
150
+ ... state_dict[key] = ckpt_file.get_tensor(key)
151
+ ... return state_dict
152
+ ...
153
+ ```
154
+
155
+ We can now use the hub model and load the RAD-DINO weights.
156
+ Let's clone the DINOv2 repository so we can import the code for the head.
157
+
158
+ ```shell
159
+ git clone https://github.com/facebookresearch/dinov2.git
160
+ cd dinov2
161
+ ```
162
+
163
+ ```python
164
+ >>> import torch
165
+ >>> rad_dino_gh = torch.hub.load(".", "dinov2_vitb14")
166
+ >>> backbone_state_dict = safetensors_to_state_dict("backbone_compatible.safetensors")
167
+ >>> rad_dino_gh.load_state_dict(backbone_state_dict, strict=True)
168
+ <All keys matched successfully>
169
+ ```
170
+
171
+ The weights of the head are also released:
172
+
173
+ ```python
174
+ >>> from dinov2.layers import DINOHead
175
+ >>> rad_dino_head_gh = DINOHead(
176
+ ... in_dim=768,
177
+ ... out_dim=65536,
178
+ ... hidden_dim=2048,
179
+ ... bottleneck_dim=256,
180
+ ... nlayers=3,
181
+ ... )
182
+ >>> head_state_dict = safetensors_to_state_dict("dino_head.safetensors")
183
+ >>> rad_dino_head_gh.load_state_dict(head_state_dict, strict=True)
184
+ <All keys matched successfully>
185
+ ```
186
+
187
+ ### Configs and augmentation
188
+
189
+ The configuration files [`ssl_default_config.yaml`](./ssl_default_config.yaml) and [`vitb14_cxr.yaml`](./vitb14_cxr.yaml), and the [`augmentations`](./augmentations.py) module are also available in the repository to help researchers reproduce the training procedure with our hyperparameters.
190
+
191
  ## Training details
192
 
193
  ### Training data
 
289
 
290
  ## Model card contact
291
 
292
+ Fernando Pérez-García ([`[email protected]`](mailto:[email protected])).