class ViTBinaryClassifier(nn.Module): def __init__(self, pretrained=True): super(ViTBinaryClassifier, self).__init__() self.backbone = timm.create_model("vit_medium_patch16_224", pretrained=pretrained) in_features = self.backbone.head.in_features self.backbone.head = nn.Identity() self.classifier = nn.Sequential( nn.Linear(in_features, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 1), nn.Sigmoid() ) def forward(self, x): features = self.backbone(x) out = self.classifier(features) return out