timm
English
Pingsz commited on
Commit
eaaf4f2
·
verified ·
1 Parent(s): d235e29

Upload 5 files

Browse files
Kaggle Notebook.ipynb ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {
7
+ "execution": {
8
+ "iopub.execute_input": "2025-04-04T12:54:49.334286Z",
9
+ "iopub.status.busy": "2025-04-04T12:54:49.333816Z",
10
+ "iopub.status.idle": "2025-04-04T12:54:49.973500Z",
11
+ "shell.execute_reply": "2025-04-04T12:54:49.972489Z",
12
+ "shell.execute_reply.started": "2025-04-04T12:54:49.334258Z"
13
+ },
14
+ "trusted": true
15
+ },
16
+ "outputs": [
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "Path to dataset files: /kaggle/input/cifake-real-and-ai-generated-synthetic-images\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "import kagglehub\n",
27
+ "path = kagglehub.dataset_download(\"birdy654/cifake-real-and-ai-generated-synthetic-images\")\n",
28
+ "print(\"Path to dataset files:\", path)"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 3,
34
+ "metadata": {
35
+ "execution": {
36
+ "iopub.execute_input": "2025-04-04T12:54:51.497912Z",
37
+ "iopub.status.busy": "2025-04-04T12:54:51.497571Z",
38
+ "iopub.status.idle": "2025-04-04T12:54:51.503484Z",
39
+ "shell.execute_reply": "2025-04-04T12:54:51.502691Z",
40
+ "shell.execute_reply.started": "2025-04-04T12:54:51.497884Z"
41
+ },
42
+ "trusted": true
43
+ },
44
+ "outputs": [
45
+ {
46
+ "data": {
47
+ "text/plain": [
48
+ "'/kaggle/input/cifake-real-and-ai-generated-synthetic-images'"
49
+ ]
50
+ },
51
+ "execution_count": 3,
52
+ "metadata": {},
53
+ "output_type": "execute_result"
54
+ }
55
+ ],
56
+ "source": [
57
+ "path"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 4,
63
+ "metadata": {
64
+ "execution": {
65
+ "iopub.execute_input": "2025-04-04T12:54:52.950084Z",
66
+ "iopub.status.busy": "2025-04-04T12:54:52.949766Z",
67
+ "iopub.status.idle": "2025-04-04T12:54:53.085522Z",
68
+ "shell.execute_reply": "2025-04-04T12:54:53.084690Z",
69
+ "shell.execute_reply.started": "2025-04-04T12:54:52.950059Z"
70
+ },
71
+ "trusted": true
72
+ },
73
+ "outputs": [
74
+ {
75
+ "name": "stdout",
76
+ "output_type": "stream",
77
+ "text": [
78
+ "\u001b[0m\u001b[01;34mtest\u001b[0m/ \u001b[01;34mtrain\u001b[0m/\n"
79
+ ]
80
+ }
81
+ ],
82
+ "source": [
83
+ "ls '/kaggle/input/cifake-real-and-ai-generated-synthetic-images'"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {
90
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
91
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
92
+ "execution": {
93
+ "iopub.execute_input": "2025-04-04T12:54:54.425136Z",
94
+ "iopub.status.busy": "2025-04-04T12:54:54.424819Z",
95
+ "iopub.status.idle": "2025-04-04T12:54:54.436446Z",
96
+ "shell.execute_reply": "2025-04-04T12:54:54.435720Z",
97
+ "shell.execute_reply.started": "2025-04-04T12:54:54.425108Z"
98
+ },
99
+ "trusted": true
100
+ },
101
+ "outputs": [],
102
+ "source": [
103
+ "import os\n",
104
+ "import glob\n",
105
+ "\n",
106
+ "data_dir = str(path)\n",
107
+ "\n",
108
+ "training_dir = os.path.join(data_dir,\"train\")\n",
109
+ "if not os.path.isdir(training_dir):\n",
110
+ " os.mkdir(training_dir)\n",
111
+ "\n",
112
+ "dog_training_dir = os.path.join(training_dir,\"REAL\")\n",
113
+ "if not os.path.isdir(dog_training_dir):\n",
114
+ " os.mkdir(dog_training_dir)\n",
115
+ "\n",
116
+ "\n",
117
+ "cat_training_dir = os.path.join(training_dir,\"FAKE\")\n",
118
+ "if not os.path.isdir(cat_training_dir):\n",
119
+ " os.mkdir(cat_training_dir)\n",
120
+ "\n",
121
+ "\n",
122
+ "validation_dir = os.path.join(data_dir,\"test\")\n",
123
+ "if not os.path.isdir(validation_dir):\n",
124
+ " os.mkdir(validation_dir)\n",
125
+ "\n",
126
+ "dog_validation_dir = os.path.join(validation_dir,\"REAL\")\n",
127
+ "if not os.path.isdir(dog_validation_dir):\n",
128
+ " os.mkdir(dog_validation_dir)\n",
129
+ "\n",
130
+ "\n",
131
+ "cat_validation_dir = os.path.join(validation_dir,\"FAKE\")\n",
132
+ "if not os.path.isdir(cat_validation_dir):\n",
133
+ " os.mkdir(cat_validation_dir)"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 6,
139
+ "metadata": {
140
+ "execution": {
141
+ "iopub.execute_input": "2025-04-04T12:54:57.162476Z",
142
+ "iopub.status.busy": "2025-04-04T12:54:57.162155Z",
143
+ "iopub.status.idle": "2025-04-04T12:54:57.169644Z",
144
+ "shell.execute_reply": "2025-04-04T12:54:57.168755Z",
145
+ "shell.execute_reply.started": "2025-04-04T12:54:57.162453Z"
146
+ },
147
+ "trusted": true
148
+ },
149
+ "outputs": [],
150
+ "source": [
151
+ "import shutil\n",
152
+ "\n",
153
+ "split_size = 0.80\n",
154
+ "cat_imgs_size = len(glob.glob(\"/content/data/train/FAKE*\"))\n",
155
+ "dog_imgs_size = len(glob.glob(\"/content/data/train/REAL*\"))\n",
156
+ "\n",
157
+ "for i,img in enumerate(glob.glob(\"/content/data/train/FAKE*\")):\n",
158
+ " if i < (cat_imgs_size * split_size):\n",
159
+ " shutil.move(img,cat_training_dir)\n",
160
+ " else:\n",
161
+ " shutil.move(img,cat_validation_dir)\n",
162
+ "\n",
163
+ "for i,img in enumerate(glob.glob(\"/content/data/train/REAL*\")):\n",
164
+ " if i < (dog_imgs_size * split_size):\n",
165
+ " shutil.move(img,dog_training_dir)\n",
166
+ " else:\n",
167
+ " shutil.move(img,dog_validation_dir)"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "metadata": {
174
+ "execution": {
175
+ "iopub.execute_input": "2025-04-04T12:55:00.489855Z",
176
+ "iopub.status.busy": "2025-04-04T12:55:00.489560Z",
177
+ "iopub.status.idle": "2025-04-04T12:56:01.137810Z",
178
+ "shell.execute_reply": "2025-04-04T12:56:01.137109Z",
179
+ "shell.execute_reply.started": "2025-04-04T12:55:00.489835Z"
180
+ },
181
+ "trusted": true
182
+ },
183
+ "outputs": [],
184
+ "source": [
185
+ "import torch\n",
186
+ "import torchvision\n",
187
+ "from torchvision import datasets, transforms\n",
188
+ "\n",
189
+ "traindir = path+\"/train\"\n",
190
+ "testdir = path+\"/test\"\n",
191
+ "\n",
192
+ "train_transforms = transforms.Compose([transforms.Resize((224,224)),\n",
193
+ " transforms.ToTensor(), \n",
194
+ " torchvision.transforms.Normalize(\n",
195
+ " mean=[0.485, 0.456, 0.406],\n",
196
+ " std=[0.229, 0.224, 0.225],\n",
197
+ " ),\n",
198
+ " ])\n",
199
+ "test_transforms = transforms.Compose([transforms.Resize((224,224)),\n",
200
+ " transforms.ToTensor(),\n",
201
+ " torchvision.transforms.Normalize(\n",
202
+ " mean=[0.485, 0.456, 0.406],\n",
203
+ " std=[0.229, 0.224, 0.225],\n",
204
+ " ),\n",
205
+ " ])\n",
206
+ "\n",
207
+ "train_data = datasets.ImageFolder(traindir,transform=train_transforms)\n",
208
+ "test_data = datasets.ImageFolder(testdir,transform=test_transforms)\n",
209
+ "\n",
210
+ "trainloader = torch.utils.data.DataLoader(train_data, shuffle = True, batch_size=16)\n",
211
+ "testloader = torch.utils.data.DataLoader(test_data, shuffle = True, batch_size=16)\n"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {
218
+ "execution": {
219
+ "iopub.execute_input": "2025-04-04T12:56:11.815228Z",
220
+ "iopub.status.busy": "2025-04-04T12:56:11.814667Z",
221
+ "iopub.status.idle": "2025-04-04T12:56:11.820038Z",
222
+ "shell.execute_reply": "2025-04-04T12:56:11.818970Z",
223
+ "shell.execute_reply.started": "2025-04-04T12:56:11.815175Z"
224
+ },
225
+ "trusted": true
226
+ },
227
+ "outputs": [],
228
+ "source": [
229
+ "def make_train_step(model, optimizer, loss_fn):\n",
230
+ " def train_step(x,y):\n",
231
+ " yhat = model(x)\n",
232
+ " model.train()\n",
233
+ " loss = loss_fn(yhat,y)\n",
234
+ "\n",
235
+ " loss.backward()\n",
236
+ " optimizer.step()\n",
237
+ " optimizer.zero_grad()\n",
238
+ " #optimizer.cleargrads()\n",
239
+ "\n",
240
+ " return loss\n",
241
+ " return train_step"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": null,
247
+ "metadata": {
248
+ "execution": {
249
+ "iopub.execute_input": "2025-04-04T12:56:24.422670Z",
250
+ "iopub.status.busy": "2025-04-04T12:56:24.422329Z",
251
+ "iopub.status.idle": "2025-04-04T12:56:24.664399Z",
252
+ "shell.execute_reply": "2025-04-04T12:56:24.663683Z",
253
+ "shell.execute_reply.started": "2025-04-04T12:56:24.422644Z"
254
+ },
255
+ "trusted": true
256
+ },
257
+ "outputs": [],
258
+ "source": [
259
+ "from torchvision import datasets, models, transforms\n",
260
+ "import torch.nn as nn\n",
261
+ "\n",
262
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
263
+ "model = models.resnet18(pretrained=True)\n",
264
+ "\n",
265
+ "for params in model.parameters():\n",
266
+ " params.requires_grad_ = False\n",
267
+ "\n",
268
+ "nr_filters = model.fc.in_features \n",
269
+ "model.fc = nn.Linear(nr_filters, 1)\n",
270
+ "\n",
271
+ "model = model.to(device)"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {
278
+ "execution": {
279
+ "iopub.execute_input": "2025-04-04T12:56:26.628077Z",
280
+ "iopub.status.busy": "2025-04-04T12:56:26.627763Z",
281
+ "iopub.status.idle": "2025-04-04T12:56:26.632616Z",
282
+ "shell.execute_reply": "2025-04-04T12:56:26.631736Z",
283
+ "shell.execute_reply.started": "2025-04-04T12:56:26.628054Z"
284
+ },
285
+ "trusted": true
286
+ },
287
+ "outputs": [],
288
+ "source": [
289
+ "from torch.nn.modules.loss import BCEWithLogitsLoss\n",
290
+ "from torch.optim import lr_scheduler\n",
291
+ "\n",
292
+ "loss_fn = BCEWithLogitsLoss()\n",
293
+ "optimizer = torch.optim.Adam(model.fc.parameters()) \n",
294
+ "\n",
295
+ "train_step = make_train_step(model, optimizer, loss_fn)"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "metadata": {
302
+ "execution": {
303
+ "iopub.execute_input": "2025-04-04T12:56:29.126146Z",
304
+ "iopub.status.busy": "2025-04-04T12:56:29.125852Z",
305
+ "iopub.status.idle": "2025-04-04T12:57:41.905138Z",
306
+ "shell.execute_reply": "2025-04-04T12:57:41.904311Z",
307
+ "shell.execute_reply.started": "2025-04-04T12:56:29.126124Z"
308
+ },
309
+ "trusted": true
310
+ },
311
+ "outputs": [
312
+ {
313
+ "ename": "KeyboardInterrupt",
314
+ "evalue": "",
315
+ "output_type": "error",
316
+ "traceback": [
317
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
318
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
319
+ "\u001b[0;32m<ipython-input-12-fecf91831327>\u001b[0m in \u001b[0;36m<cell line: 19>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Set model to train mode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mx_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mx_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx_batch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
320
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tqdm/std.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1179\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1181\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1182\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1183\u001b[0m \u001b[0;31m# Update and possibly print the progressbar.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
321
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 699\u001b[0m \u001b[0;31m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 700\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[call-arg]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 701\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 702\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 703\u001b[0m if (\n",
322
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 755\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 756\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 757\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 758\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 759\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
323
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getitems__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
324
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getitems__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
325
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torchvision/datasets/folder.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 243\u001b[0m \"\"\"\n\u001b[1;32m 244\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msamples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0msample\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0msample\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
326
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torchvision/datasets/folder.py\u001b[0m in \u001b[0;36mdefault_loader\u001b[0;34m(path)\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0maccimage_loader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 283\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 284\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mpil_loader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 285\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
327
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torchvision/datasets/folder.py\u001b[0m in \u001b[0;36mpil_loader\u001b[0;34m(path)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;31m# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"RGB\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
328
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/PIL/Image.py\u001b[0m in \u001b[0;36mopen\u001b[0;34m(fp, mode, formats)\u001b[0m\n\u001b[1;32m 3478\u001b[0m \u001b[0mexclusive_fp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3479\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3480\u001b[0;31m \u001b[0mprefix\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3481\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3482\u001b[0m \u001b[0mpreinit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
329
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
330
+ ]
331
+ }
332
+ ],
333
+ "source": [
334
+ "%%capture\n",
335
+ "!pip install tqdm\n",
336
+ "\n",
337
+ "from tqdm import tqdm\n",
338
+ "import torch\n",
339
+ "\n",
340
+ "losses = []\n",
341
+ "val_losses = []\n",
342
+ "\n",
343
+ "epoch_train_losses = []\n",
344
+ "epoch_test_losses = []\n",
345
+ "\n",
346
+ "n_epochs = 10\n",
347
+ "early_stopping_tolerance = 3\n",
348
+ "early_stopping_threshold = 0.03\n",
349
+ "early_stopping_counter = 0 \n",
350
+ "\n",
351
+ "best_loss = float(\"inf\") \n",
352
+ "\n",
353
+ "for epoch in range(n_epochs):\n",
354
+ " optimizer.zero_grad()\n",
355
+ "\n",
356
+ " epoch_loss = 0\n",
357
+ " model.train() \n",
358
+ "\n",
359
+ " for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):\n",
360
+ " x_batch, y_batch = data\n",
361
+ " x_batch = x_batch.to(device)\n",
362
+ " y_batch = y_batch.unsqueeze(1).float().to(device)\n",
363
+ "\n",
364
+ " loss = train_step(x_batch, y_batch)\n",
365
+ " epoch_loss += loss / len(trainloader)\n",
366
+ " losses.append(loss)\n",
367
+ "\n",
368
+ " epoch_train_losses.append(epoch_loss)\n",
369
+ " print(f\"\\nEpoch: {epoch+1}, train loss: {epoch_loss:.4f}\")\n",
370
+ "\n",
371
+ " model.eval()\n",
372
+ " with torch.no_grad():\n",
373
+ " cum_loss = 0\n",
374
+ " for x_batch, y_batch in testloader:\n",
375
+ " x_batch = x_batch.to(device)\n",
376
+ " y_batch = y_batch.unsqueeze(1).float().to(device)\n",
377
+ "\n",
378
+ " yhat = model(x_batch)\n",
379
+ " val_loss = loss_fn(yhat, y_batch)\n",
380
+ " cum_loss += val_loss.item() / len(testloader)\n",
381
+ " val_losses.append(val_loss.item())\n",
382
+ "\n",
383
+ " epoch_test_losses.append(cum_loss)\n",
384
+ " print(f\"Epoch: {epoch+1}, val loss: {cum_loss:.4f}\")\n",
385
+ "\n",
386
+ " if cum_loss < best_loss:\n",
387
+ " best_loss = cum_loss\n",
388
+ " best_model_wts = model.state_dict()\n",
389
+ " early_stopping_counter = 0\n",
390
+ " else:\n",
391
+ " early_stopping_counter += 1\n",
392
+ "\n",
393
+ " if early_stopping_counter == early_stopping_tolerance or best_loss <= early_stopping_threshold:\n",
394
+ " print(\"\\nTerminating: early stopping\")\n",
395
+ " break\n",
396
+ "\n",
397
+ "model.load_state_dict(best_model_wts)\n"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "metadata": {
404
+ "trusted": true
405
+ },
406
+ "outputs": [],
407
+ "source": []
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": null,
412
+ "metadata": {
413
+ "trusted": true
414
+ },
415
+ "outputs": [],
416
+ "source": []
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": 13,
421
+ "metadata": {
422
+ "execution": {
423
+ "iopub.execute_input": "2025-04-04T13:00:51.514501Z",
424
+ "iopub.status.busy": "2025-04-04T13:00:51.514093Z",
425
+ "iopub.status.idle": "2025-04-04T13:00:51.883630Z",
426
+ "shell.execute_reply": "2025-04-04T13:00:51.882714Z",
427
+ "shell.execute_reply.started": "2025-04-04T13:00:51.514470Z"
428
+ },
429
+ "trusted": true
430
+ },
431
+ "outputs": [
432
+ {
433
+ "name": "stdout",
434
+ "output_type": "stream",
435
+ "text": [
436
+ "Fri Apr 4 13:00:51 2025 \n",
437
+ "+-----------------------------------------------------------------------------------------+\n",
438
+ "| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |\n",
439
+ "|-----------------------------------------+------------------------+----------------------+\n",
440
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
441
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
442
+ "| | | MIG M. |\n",
443
+ "|=========================================+========================+======================|\n",
444
+ "| 0 Tesla P100-PCIE-16GB Off | 00000000:00:04.0 Off | 0 |\n",
445
+ "| N/A 36C P0 32W / 250W | 929MiB / 16384MiB | 0% Default |\n",
446
+ "| | | N/A |\n",
447
+ "+-----------------------------------------+------------------------+----------------------+\n",
448
+ " \n",
449
+ "+-----------------------------------------------------------------------------------------+\n",
450
+ "| Processes: |\n",
451
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
452
+ "| ID ID Usage |\n",
453
+ "|=========================================================================================|\n",
454
+ "+-----------------------------------------------------------------------------------------+\n"
455
+ ]
456
+ }
457
+ ],
458
+ "source": [
459
+ "!nvidia-smi"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "execution_count": null,
465
+ "metadata": {
466
+ "execution": {
467
+ "iopub.execute_input": "2025-04-04T13:52:03.047692Z",
468
+ "iopub.status.busy": "2025-04-04T13:52:03.047345Z",
469
+ "iopub.status.idle": "2025-04-04T14:32:23.869395Z",
470
+ "shell.execute_reply": "2025-04-04T14:32:23.868306Z",
471
+ "shell.execute_reply.started": "2025-04-04T13:52:03.047664Z"
472
+ },
473
+ "trusted": true
474
+ },
475
+ "outputs": [
476
+ {
477
+ "name": "stdout",
478
+ "output_type": "stream",
479
+ "text": [
480
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.67.1)\n"
481
+ ]
482
+ },
483
+ {
484
+ "name": "stderr",
485
+ "output_type": "stream",
486
+ "text": [
487
+ "100%|██████████| 6250/6250 [07:21<00:00, 14.15it/s]\n"
488
+ ]
489
+ },
490
+ {
491
+ "name": "stdout",
492
+ "output_type": "stream",
493
+ "text": [
494
+ "\n",
495
+ "Epoch: 1, train loss: 0.3295\n",
496
+ "Epoch: 1, val loss: 0.2714\n"
497
+ ]
498
+ },
499
+ {
500
+ "name": "stderr",
501
+ "output_type": "stream",
502
+ "text": [
503
+ "100%|██████████| 6250/6250 [07:49<00:00, 13.32it/s]\n"
504
+ ]
505
+ },
506
+ {
507
+ "name": "stdout",
508
+ "output_type": "stream",
509
+ "text": [
510
+ "\n",
511
+ "Epoch: 2, train loss: 0.3302\n",
512
+ "Epoch: 2, val loss: 0.2683\n"
513
+ ]
514
+ },
515
+ {
516
+ "name": "stderr",
517
+ "output_type": "stream",
518
+ "text": [
519
+ "100%|██████████| 6250/6250 [06:44<00:00, 15.47it/s]\n"
520
+ ]
521
+ },
522
+ {
523
+ "name": "stdout",
524
+ "output_type": "stream",
525
+ "text": [
526
+ "\n",
527
+ "Epoch: 3, train loss: 0.3320\n",
528
+ "Epoch: 3, val loss: 0.2689\n"
529
+ ]
530
+ },
531
+ {
532
+ "name": "stderr",
533
+ "output_type": "stream",
534
+ "text": [
535
+ "100%|██████████| 6250/6250 [06:30<00:00, 15.99it/s]\n"
536
+ ]
537
+ },
538
+ {
539
+ "name": "stdout",
540
+ "output_type": "stream",
541
+ "text": [
542
+ "\n",
543
+ "Epoch: 4, train loss: 0.3316\n",
544
+ "Epoch: 4, val loss: 0.2745\n"
545
+ ]
546
+ },
547
+ {
548
+ "name": "stderr",
549
+ "output_type": "stream",
550
+ "text": [
551
+ "100%|████████���█| 6250/6250 [06:30<00:00, 16.01it/s]\n"
552
+ ]
553
+ },
554
+ {
555
+ "name": "stdout",
556
+ "output_type": "stream",
557
+ "text": [
558
+ "\n",
559
+ "Epoch: 5, train loss: 0.3331\n",
560
+ "Epoch: 5, val loss: 0.2716\n",
561
+ "\n",
562
+ "Terminating: early stopping\n"
563
+ ]
564
+ },
565
+ {
566
+ "data": {
567
+ "text/plain": [
568
+ "<All keys matched successfully>"
569
+ ]
570
+ },
571
+ "execution_count": 17,
572
+ "metadata": {},
573
+ "output_type": "execute_result"
574
+ }
575
+ ],
576
+ "source": [
577
+ "\n",
578
+ "!pip install tqdm\n",
579
+ "\n",
580
+ "from tqdm import tqdm\n",
581
+ "import torch\n",
582
+ "\n",
583
+ "losses = []\n",
584
+ "val_losses = []\n",
585
+ "\n",
586
+ "epoch_train_losses = []\n",
587
+ "epoch_test_losses = []\n",
588
+ "\n",
589
+ "n_epochs = 10\n",
590
+ "early_stopping_tolerance = 3\n",
591
+ "early_stopping_threshold = 0.03\n",
592
+ "early_stopping_counter = 0\n",
593
+ "\n",
594
+ "best_loss = float(\"inf\")\n",
595
+ "\n",
596
+ "for epoch in range(n_epochs):\n",
597
+ " optimizer.zero_grad() \n",
598
+ "\n",
599
+ " epoch_loss = 0\n",
600
+ " model.train()\n",
601
+ "\n",
602
+ " for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):\n",
603
+ " x_batch, y_batch = data\n",
604
+ " x_batch = x_batch.to(device)\n",
605
+ " y_batch = y_batch.unsqueeze(1).float().to(device)\n",
606
+ "\n",
607
+ " loss = train_step(x_batch, y_batch) \n",
608
+ " loss_value = loss.item() \n",
609
+ "\n",
610
+ " epoch_loss += loss_value / len(trainloader)\n",
611
+ " losses.append(loss_value) \n",
612
+ " epoch_train_losses.append(epoch_loss)\n",
613
+ " print(f\"\\nEpoch: {epoch+1}, train loss: {epoch_loss:.4f}\")\n",
614
+ "\n",
615
+ " model.eval()\n",
616
+ " with torch.no_grad():\n",
617
+ " cum_loss = 0\n",
618
+ " for x_batch, y_batch in testloader:\n",
619
+ " x_batch = x_batch.to(device)\n",
620
+ " y_batch = y_batch.unsqueeze(1).float().to(device)\n",
621
+ "\n",
622
+ " yhat = model(x_batch)\n",
623
+ " val_loss = loss_fn(yhat, y_batch)\n",
624
+ " val_loss_value = val_loss.item() \n",
625
+ " cum_loss += val_loss_value / len(testloader)\n",
626
+ " val_losses.append(val_loss_value) \n",
627
+ "\n",
628
+ " epoch_test_losses.append(cum_loss)\n",
629
+ " print(f\"Epoch: {epoch+1}, val loss: {cum_loss:.4f}\")\n",
630
+ "\n",
631
+ " if cum_loss < best_loss:\n",
632
+ " best_loss = cum_loss\n",
633
+ " best_model_wts = model.state_dict()\n",
634
+ " early_stopping_counter = 0\n",
635
+ " else:\n",
636
+ " early_stopping_counter += 1\n",
637
+ "\n",
638
+ " if early_stopping_counter == early_stopping_tolerance or best_loss <= early_stopping_threshold:\n",
639
+ " print(\"\\nTerminating: early stopping\")\n",
640
+ " break\n",
641
+ "\n",
642
+ "model.load_state_dict(best_model_wts)\n"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "execution_count": 19,
648
+ "metadata": {
649
+ "execution": {
650
+ "iopub.execute_input": "2025-04-04T14:34:04.284009Z",
651
+ "iopub.status.busy": "2025-04-04T14:34:04.283650Z",
652
+ "iopub.status.idle": "2025-04-04T14:34:04.364630Z",
653
+ "shell.execute_reply": "2025-04-04T14:34:04.363652Z",
654
+ "shell.execute_reply.started": "2025-04-04T14:34:04.283981Z"
655
+ },
656
+ "trusted": true
657
+ },
658
+ "outputs": [],
659
+ "source": [
660
+ "\n",
661
+ "torch.save(model.state_dict(), \"my_model.pth\")\n"
662
+ ]
663
+ },
664
+ {
665
+ "cell_type": "code",
666
+ "execution_count": null,
667
+ "metadata": {
668
+ "execution": {
669
+ "iopub.execute_input": "2025-04-04T14:34:16.117938Z",
670
+ "iopub.status.busy": "2025-04-04T14:34:16.117620Z",
671
+ "iopub.status.idle": "2025-04-04T14:34:16.821580Z",
672
+ "shell.execute_reply": "2025-04-04T14:34:16.820869Z",
673
+ "shell.execute_reply.started": "2025-04-04T14:34:16.117913Z"
674
+ },
675
+ "trusted": true
676
+ },
677
+ "outputs": [],
678
+ "source": [
679
+ "from safetensors.torch import save_file\n",
680
+ "\n",
681
+ "save_file(model.state_dict(), \"my_model.safetensors\")\n",
682
+ "\n",
683
+ "import h5py\n",
684
+ "\n",
685
+ "state_dict = model.state_dict()\n",
686
+ "\n",
687
+ "with h5py.File(\"my_model.h5\", \"w\") as f:\n",
688
+ " for key, tensor in state_dict.items():\n",
689
+ " f.create_dataset(key, data=tensor.cpu().numpy())\n"
690
+ ]
691
+ },
692
+ {
693
+ "cell_type": "code",
694
+ "execution_count": 23,
695
+ "metadata": {
696
+ "execution": {
697
+ "iopub.execute_input": "2025-04-04T14:40:54.022447Z",
698
+ "iopub.status.busy": "2025-04-04T14:40:54.021956Z",
699
+ "iopub.status.idle": "2025-04-04T14:40:54.031801Z",
700
+ "shell.execute_reply": "2025-04-04T14:40:54.030996Z",
701
+ "shell.execute_reply.started": "2025-04-04T14:40:54.022406Z"
702
+ },
703
+ "trusted": true
704
+ },
705
+ "outputs": [],
706
+ "source": [
707
+ "#inference\n",
708
+ "import os\n",
709
+ "import torch\n",
710
+ "from torchvision import models, transforms\n",
711
+ "from torch.utils.data import Dataset, DataLoader\n",
712
+ "from PIL import Image\n",
713
+ "import pandas as pd\n",
714
+ "\n",
715
+ "class InferenceDataset(Dataset):\n",
716
+ " def __init__(self, folder, transform):\n",
717
+ " self.paths = [os.path.join(folder, f) for f in os.listdir(folder)\n",
718
+ " if f.lower().endswith((\"png\", \"jpg\", \"jpeg\"))]\n",
719
+ " self.transform = transform\n",
720
+ "\n",
721
+ " def __len__(self):\n",
722
+ " return len(self.paths)\n",
723
+ "\n",
724
+ " def __getitem__(self, idx):\n",
725
+ " img = Image.open(self.paths[idx]).convert(\"RGB\")\n",
726
+ " return self.transform(img), self.paths[idx]\n",
727
+ "\n",
728
+ "def run_inference(image_folder, output_csv=\"predictions.csv\"):\n",
729
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
730
+ "\n",
731
+ " model = models.resnet18(pretrained=True)\n",
732
+ " for p in model.parameters():\n",
733
+ " p.requires_grad = False\n",
734
+ " model.fc = torch.nn.Linear(model.fc.in_features, 1)\n",
735
+ " model = model.to(device)\n",
736
+ " model.eval()\n",
737
+ "\n",
738
+ " transform = transforms.Compose([\n",
739
+ " transforms.Resize((224, 224)),\n",
740
+ " transforms.ToTensor(),\n",
741
+ " transforms.Normalize([0.485, 0.456, 0.406],\n",
742
+ " [0.229, 0.224, 0.225])\n",
743
+ " ])\n",
744
+ "\n",
745
+ " dataset = InferenceDataset(image_folder, transform)\n",
746
+ " loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
747
+ "\n",
748
+ " results = []\n",
749
+ " with torch.no_grad():\n",
750
+ " for img, path in loader:\n",
751
+ " img = img.to(device)\n",
752
+ " pred = torch.sigmoid(model(img)).item()\n",
753
+ " label = \"REAL\" if pred >= 0.5 else \"FAKE\"\n",
754
+ " results.append({\"image_path\": path[0], \"prediction\": label, \"score\": pred})\n",
755
+ "\n",
756
+ " pd.DataFrame(results).to_csv(output_csv, index=False)\n"
757
+ ]
758
+ },
759
+ {
760
+ "cell_type": "code",
761
+ "execution_count": 24,
762
+ "metadata": {
763
+ "execution": {
764
+ "iopub.execute_input": "2025-04-04T14:41:16.510441Z",
765
+ "iopub.status.busy": "2025-04-04T14:41:16.510058Z",
766
+ "iopub.status.idle": "2025-04-04T14:41:17.143128Z",
767
+ "shell.execute_reply": "2025-04-04T14:41:17.142398Z",
768
+ "shell.execute_reply.started": "2025-04-04T14:41:16.510411Z"
769
+ },
770
+ "trusted": true
771
+ },
772
+ "outputs": [],
773
+ "source": [
774
+ "final_path = \"/kaggle/input/finald/Test datasets/Test_dataset_2\"\n",
775
+ "run_inference(final_path, \"outputdata1.csv\")\n",
776
+ "run_inference(final_path, \"outputdata2.csv\")"
777
+ ]
778
+ }
779
+ ],
780
+ "metadata": {
781
+ "kaggle": {
782
+ "accelerator": "gpu",
783
+ "dataSources": [
784
+ {
785
+ "datasetId": 3041726,
786
+ "sourceId": 5256696,
787
+ "sourceType": "datasetVersion"
788
+ },
789
+ {
790
+ "datasetId": 7049439,
791
+ "sourceId": 11276085,
792
+ "sourceType": "datasetVersion"
793
+ }
794
+ ],
795
+ "dockerImageVersionId": 30919,
796
+ "isGpuEnabled": true,
797
+ "isInternetEnabled": true,
798
+ "language": "python",
799
+ "sourceType": "notebook"
800
+ },
801
+ "kernelspec": {
802
+ "display_name": "Python 3",
803
+ "language": "python",
804
+ "name": "python3"
805
+ },
806
+ "language_info": {
807
+ "codemirror_mode": {
808
+ "name": "ipython",
809
+ "version": 3
810
+ },
811
+ "file_extension": ".py",
812
+ "mimetype": "text/x-python",
813
+ "name": "python",
814
+ "nbconvert_exporter": "python",
815
+ "pygments_lexer": "ipython3",
816
+ "version": "3.10.12"
817
+ }
818
+ },
819
+ "nbformat": 4,
820
+ "nbformat_minor": 4
821
+ }
Kaggle Output Data 2.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_path,prediction,score
2
+ /kaggle/input/finald/Test datasets/Test_dataset_2/11.png,REAL,0.5577288866043091
3
+ /kaggle/input/finald/Test datasets/Test_dataset_2/4.png,FAKE,0.3429737687110901
4
+ /kaggle/input/finald/Test datasets/Test_dataset_2/9.png,FAKE,0.4281144440174103
5
+ /kaggle/input/finald/Test datasets/Test_dataset_2/14.png,FAKE,0.42051035165786743
6
+ /kaggle/input/finald/Test datasets/Test_dataset_2/1.png,FAKE,0.4234199821949005
7
+ /kaggle/input/finald/Test datasets/Test_dataset_2/20.png,FAKE,0.4581557512283325
8
+ /kaggle/input/finald/Test datasets/Test_dataset_2/2.png,REAL,0.674893319606781
9
+ /kaggle/input/finald/Test datasets/Test_dataset_2/10.png,FAKE,0.31975361704826355
10
+ /kaggle/input/finald/Test datasets/Test_dataset_2/18.png,FAKE,0.38838866353034973
11
+ /kaggle/input/finald/Test datasets/Test_dataset_2/12.png,REAL,0.5615072846412659
12
+ /kaggle/input/finald/Test datasets/Test_dataset_2/7.png,FAKE,0.48348483443260193
13
+ /kaggle/input/finald/Test datasets/Test_dataset_2/17.png,FAKE,0.46091413497924805
14
+ /kaggle/input/finald/Test datasets/Test_dataset_2/5.png,REAL,0.5818506479263306
15
+ /kaggle/input/finald/Test datasets/Test_dataset_2/3.png,REAL,0.5805583000183105
16
+ /kaggle/input/finald/Test datasets/Test_dataset_2/16.png,FAKE,0.4133031666278839
17
+ /kaggle/input/finald/Test datasets/Test_dataset_2/8.png,REAL,0.5182304978370667
18
+ /kaggle/input/finald/Test datasets/Test_dataset_2/6.png,FAKE,0.44089677929878235
19
+ /kaggle/input/finald/Test datasets/Test_dataset_2/15.png,REAL,0.5523123741149902
20
+ /kaggle/input/finald/Test datasets/Test_dataset_2/13.png,FAKE,0.4346674084663391
21
+ /kaggle/input/finald/Test datasets/Test_dataset_2/19.png,FAKE,0.450478732585907
My Model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efe70c5a580d3ef3438c5038bde8374276500f45a73b8e60888bc289ecf55fff
3
+ size 44787068
My Model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:991fc24e8abd20955af8cfc521a6d05bf84dfd57be3ca1024a092df70e10d35f
3
+ size 44757412
Readme.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # ImageDetectionModel : A Ai Image detection model to find if a image is orginal or Ai generated
2
+ ------------->Models are in the folder
3
+ -> H5
4
+ ->safetensor
5
+ ->pth
6
+ ==============================
7
+
8
+ For inference , run through the
9
+ inference.py file