|
import torch as th |
|
import glob |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, v_path="v3"): |
|
self.imgs_path = v_path |
|
file_list = glob.glob(self.imgs_path + "*") |
|
print(file_list) |
|
self.data = [] |
|
for class_path in file_list: |
|
class_name = class_path.split("/")[-1] |
|
for img_path in glob.glob(class_path + "/*.jpg"): |
|
self.data.append([img_path, class_name]) |
|
print(self.data) |
|
self.class_map = {"Fake" : 0, "Real": 1} |
|
self.img_dim = (416, 416) |
|
def __len__(self): |
|
return len(self.data) |
|
def __getitem__(self, idx): |
|
img_path, class_name = self.data[idx] |
|
img = cv2.imread(img_path) |
|
img = cv2.resize(img, self.img_dim) |
|
class_id = self.class_map[class_name] |
|
img_tensor = torch.from_numpy(img) |
|
img_tensor = img_tensor.permute(2, 0, 1) |
|
class_id = torch.tensor([class_id]) |
|
return img_tensor, class_id |
|
|