teguhteja commited on
Commit
bac1e0b
·
verified ·
1 Parent(s): 261c924

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +100 -0
README.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - uoft-cs/cifar10
5
+ metrics:
6
+ - accuracy
7
+ base_model:
8
+ - microsoft/resnet-18
9
+ ---
10
+ # ResNet18 Fine-Tuned on CIFAR-10
11
+
12
+ This model is a fine-tuned version of **ResNet18** (originally pretrained on ImageNet) on the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). It achieves the following results on the validation/test set:
13
+
14
+ - **Validation Accuracy**: `88.60%`
15
+
16
+ ---
17
+
18
+ ## Model description
19
+
20
+ - **Architecture**: ResNet18 with the final fully-connected layer replaced by a 10-class output layer for CIFAR-10 (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck).
21
+ - **Pretrained Weights**: ImageNet1K
22
+ - **Fine-Tuning**: The model was fine-tuned on CIFAR-10 images resized to 128×128 pixels.
23
+ - **Data Augmentation**: Random horizontal flip, random rotation, normalization to mean=0.5 and std=0.5.
24
+
25
+ > **Intended uses & limitations**
26
+ > - **Intended use**: Educational/demo purposes or as a starting point for further fine-tuning on similar image classification tasks.
27
+ > - **Not intended for**: Production-critical tasks without further evaluation, as CIFAR-10 is relatively small-scale, and the model may not generalize to non-CIFAR data without additional fine-tuning.
28
+
29
+ ---
30
+
31
+ ## Training procedure
32
+
33
+ **Hyperparameters** (approximate):
34
+ - **optimizer**: Adam
35
+ - **learning_rate**: 1e-3
36
+ - **batch_size**: 32
37
+ - **num_epochs**: 15
38
+
39
+ **GPU/CPU**:
40
+ - This model was trained on a single GPU (`torch.device("cuda")`) if available, otherwise CPU.
41
+
42
+ **Training logs** (for each epoch on the training set):
43
+
44
+ | Epoch | Training Loss | Training Accuracy | Validation Accuracy |
45
+ |------:|--------------:|------------------:|--------------------:|
46
+ | 1 | 0.7013 | 76.52% | - |
47
+ | 2 | 0.4248 | 85.64% | - |
48
+ | 3 | 0.3185 | 89.07% | - |
49
+ | 4 | 0.2341 | 92.06% | - |
50
+ | 5 | 0.1762 | 93.86% | - |
51
+ | 6 | 0.1302 | 95.55% | - |
52
+ | 7 | 0.1085 | 96.31% | - |
53
+ | 8 | 0.0925 | 96.82% | - |
54
+ | 9 | 0.0765 | 97.37% | - |
55
+ | 10 | 0.0683 | 97.68% | - |
56
+ | 11 | 0.0655 | 97.83% | - |
57
+ | 12 | 0.0548 | 98.18% | - |
58
+ | 13 | 0.0513 | 98.27% | - |
59
+ | 14 | 0.0461 | 98.49% | - |
60
+ | 15 | 0.0470 | 98.41% | **88.60%** |
61
+
62
+ > **Note**: Validation accuracy was computed at the end of training (final epoch).
63
+
64
+ ---
65
+
66
+ ## Usage
67
+
68
+ Below is a sample usage snippet in Python. **Replace** `username/model_repo_name` with the actual model repo id on Hugging Face.
69
+
70
+ ```python
71
+ import torch
72
+ import torch.nn as nn
73
+ from torchvision import models, transforms
74
+ from huggingface_hub import hf_hub_download
75
+ from PIL import Image
76
+
77
+ # Download the weights from the Hugging Face Hub
78
+ ckpt_path = hf_hub_download(repo_id="username/model_repo_name", filename="cnn_model.pth")
79
+
80
+ # Define the same model architecture
81
+ model = models.resnet18(pretrained=False)
82
+ model.fc = nn.Linear(model.fc.in_features, 10) # for CIFAR-10
83
+ model.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
84
+ model.eval()
85
+
86
+ # Define transforms
87
+ transform = transforms.Compose([
88
+ transforms.Resize((128, 128)),
89
+ transforms.ToTensor(),
90
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
91
+ ])
92
+
93
+ # Example inference
94
+ image = Image.open("your_image.jpg").convert("RGB")
95
+ input_tensor = transform(image).unsqueeze(0) # add batch dimension
96
+ with torch.no_grad():
97
+ logits = model(input_tensor)
98
+ predicted_class = logits.argmax(dim=1).item()
99
+
100
+ print("Predicted class ID:", predicted_class)