|
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
LEARNING_RATE = 0.001 |
|
WEIGHT_DECAY = 0.0001 |
|
BATCH_SIZE = 32 |
|
NUM_EPOCHS = 100 |
|
IMAGE_SIZE = 72 |
|
PATCH_SIZE = 6 |
|
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2 |
|
PROJECTION_DIM = 64 |
|
NUM_HEADS = 4 |
|
TRANSFORMER_LAYERS = 8 |
|
MLP_HEAD_UNITS = [2048, 1024] |
|
|
|
class CreatePatchesLayer(torch.nn.Module): |
|
|
|
def __init__( |
|
self, |
|
patch_size: int, |
|
strides: int, |
|
) -> None: |
|
super().__init__() |
|
self.unfold_layer = torch.nn.Unfold( |
|
kernel_size=patch_size, stride=strides |
|
) |
|
|
|
def forward(self, images: torch.Tensor) -> torch.Tensor: |
|
patched_images = self.unfold_layer(images) |
|
return patched_images.permute((0, 2, 1)) |
|
|
|
batch_of_images = next(iter(trainloader))[0][0].unsqueeze(dim=0) |
|
|
|
plt.figure(figsize=(4, 4)) |
|
image = torch.permute(batch_of_images[0], (1, 2, 0)).numpy() |
|
plt.imshow(image) |
|
plt.axis("off") |
|
plt.savefig("img.png", bbox_inches="tight", pad_inches=0) |
|
plt.clf() |
|
|
|
patch_layer = CreatePatchesLayer(patch_size=PATCH_SIZE, strides=PATCH_SIZE) |
|
patched_image = patch_layer(batch_of_images) |
|
patched_image = patched_image.squeeze() |
|
|
|
plt.figure(figsize=(4, 4)) |
|
for idx, patch in enumerate(patched_image): |
|
ax = plt.subplot(NUM_PATCHES, NUM_PATCHES, idx + 1) |
|
patch_img = torch.reshape(patch, (3, PATCH_SIZE, PATCH_SIZE)) |
|
patch_img = torch.permute(patch_img, (1, 2, 0)) |
|
plt.imshow(patch_img.numpy()) |
|
plt.axis("off") |
|
plt.savefig("patched_img.png", bbox_inches="tight", pad_inches=0) |