Fix typo
Browse files
README.md
CHANGED
@@ -26,18 +26,18 @@ fine-tuned versions on a task that interests you.
|
|
26 |
|
27 |
### How to use
|
28 |
|
29 |
-
Since this model is a distilled ViT model, you can plug it into DeiTModel or
|
30 |
|
31 |
Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
|
32 |
|
33 |
```python
|
34 |
-
from transformers import AutoFeatureExtractor,
|
35 |
from PIL import Image
|
36 |
import requests
|
37 |
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
38 |
image = Image.open(requests.get(url, stream=True).raw)
|
39 |
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224')
|
40 |
-
model =
|
41 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
42 |
outputs = model(**inputs)
|
43 |
logits = outputs.logits
|
|
|
26 |
|
27 |
### How to use
|
28 |
|
29 |
+
Since this model is a distilled ViT model, you can plug it into DeiTModel, DeiTForImageClassification or DeiTForImageClassificationWithTeacher. Note that the model expects the data to be prepared using DeiTFeatureExtractor. Here we use AutoFeatureExtractor, which will automatically use the appropriate feature extractor given the model name.
|
30 |
|
31 |
Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
|
32 |
|
33 |
```python
|
34 |
+
from transformers import AutoFeatureExtractor, DeiTForImageClassificationWithTeacher
|
35 |
from PIL import Image
|
36 |
import requests
|
37 |
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
38 |
image = Image.open(requests.get(url, stream=True).raw)
|
39 |
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224')
|
40 |
+
model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')
|
41 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
42 |
outputs = model(**inputs)
|
43 |
logits = outputs.logits
|