File size: 1,067 Bytes
68286c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch as th
import glob
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, v_path="v3"):
self.imgs_path = v_path
file_list = glob.glob(self.imgs_path + "*")
print(file_list)
self.data = []
for class_path in file_list:
class_name = class_path.split("/")[-1]
for img_path in glob.glob(class_path + "/*.jpg"):
self.data.append([img_path, class_name])
print(self.data)
self.class_map = {"Fake" : 0, "Real": 1}
self.img_dim = (416, 416)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path, class_name = self.data[idx]
img = cv2.imread(img_path)
img = cv2.resize(img, self.img_dim)
class_id = self.class_map[class_name]
img_tensor = torch.from_numpy(img)
img_tensor = img_tensor.permute(2, 0, 1)
class_id = torch.tensor([class_id])
return img_tensor, class_id
|