timm
English
iit_fakerv1.0 / fakeimg /encoder.py
Pingsz's picture
Upload 15 files
68286c7 verified
raw
history blame
1.07 kB
import torch as th
import glob
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, v_path="v3"):
self.imgs_path = v_path
file_list = glob.glob(self.imgs_path + "*")
print(file_list)
self.data = []
for class_path in file_list:
class_name = class_path.split("/")[-1]
for img_path in glob.glob(class_path + "/*.jpg"):
self.data.append([img_path, class_name])
print(self.data)
self.class_map = {"Fake" : 0, "Real": 1}
self.img_dim = (416, 416)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path, class_name = self.data[idx]
img = cv2.imread(img_path)
img = cv2.resize(img, self.img_dim)
class_id = self.class_map[class_name]
img_tensor = torch.from_numpy(img)
img_tensor = img_tensor.permute(2, 0, 1)
class_id = torch.tensor([class_id])
return img_tensor, class_id