Update model.py
Browse files
model.py
CHANGED
@@ -1,50 +1,139 @@
|
|
1 |
-
# model.py
|
2 |
-
|
3 |
import torch
|
4 |
-
import
|
|
|
5 |
from PIL import Image
|
|
|
|
|
6 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
7 |
-
|
8 |
|
9 |
CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3"
|
10 |
-
T5_MODEL_REPO = "Neurazum/Tbai-DPA-1.0"
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
transform = transforms.Compose([
|
24 |
transforms.Resize((224, 224)),
|
25 |
-
transforms.ToTensor()
|
|
|
|
|
26 |
])
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
'Mild Alzheimer Risk',
|
31 |
-
'Moderate Alzheimer Risk',
|
32 |
-
'Very Mild Alzheimer Risk',
|
33 |
-
'No Risk',
|
34 |
-
'Parkinson Disease'
|
35 |
-
]
|
36 |
-
|
37 |
-
def predict(image: Image.Image, question: str = ""):
|
38 |
-
img_tensor = transform(image).unsqueeze(0)
|
39 |
with torch.no_grad():
|
40 |
-
|
41 |
-
probs = F.softmax(
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
from PIL import Image
|
5 |
+
from torchvision import transforms
|
6 |
+
from thop import profile
|
7 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
|
10 |
CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3"
|
|
|
11 |
|
12 |
+
CLASSIFICATION_MODEL_FILENAME_F = "Vbai-DPA 2.3f.pt"
|
13 |
+
CLASSIFICATION_MODEL_FILENAME_C = "Vbai-DPA 2.3c.pt"
|
14 |
+
CLASSIFICATION_MODEL_FILENAME_Q = "Vbai-DPA 2.3q.pt"
|
15 |
+
|
16 |
+
T5_MODEL_REPO = "Neurazum/Tbai-DPA 1.0"
|
17 |
+
|
18 |
+
class SimpleCNN(nn.Module):
|
19 |
+
def __init__(self, model_type="f", num_classes=6):
|
20 |
+
super(SimpleCNN, self).__init__()
|
21 |
+
self.num_classes = num_classes
|
22 |
+
|
23 |
+
if model_type == "f":
|
24 |
+
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
25 |
+
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
26 |
+
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
27 |
+
self.fc1 = nn.Linear(64 * 28 * 28, 256)
|
28 |
+
elif model_type == "c":
|
29 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
|
30 |
+
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
31 |
+
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
32 |
+
self.fc1 = nn.Linear(128 * 28 * 28, 512)
|
33 |
+
elif model_type == "q":
|
34 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
35 |
+
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
36 |
+
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
37 |
+
self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
38 |
+
self.fc1 = nn.Linear(512 * 14 * 14, 1024)
|
39 |
+
|
40 |
+
self.dropout = nn.Dropout(0.5)
|
41 |
+
self.fc2 = nn.Linear(self.fc1.out_features, num_classes)
|
42 |
+
self.relu = nn.ReLU()
|
43 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
44 |
|
45 |
+
def forward(self, x):
|
46 |
+
x = self.pool(self.relu(self.conv1(x)))
|
47 |
+
x = self.pool(self.relu(self.conv2(x)))
|
48 |
+
x = self.pool(self.relu(self.conv3(x)))
|
49 |
+
if hasattr(self, "conv4"):
|
50 |
+
x = self.pool(self.relu(self.conv4(x)))
|
51 |
+
x = x.view(x.size(0), -1)
|
52 |
+
x = self.relu(self.fc1(x))
|
53 |
+
x = self.dropout(x)
|
54 |
+
x = self.fc2(x)
|
55 |
+
return x
|
56 |
+
|
57 |
+
def load_classification_model(device, model_type="f", num_classes=6):
|
58 |
+
if model_type == "f":
|
59 |
+
filename = CLASSIFICATION_MODEL_FILENAME_F
|
60 |
+
elif model_type == "c":
|
61 |
+
filename = CLASSIFICATION_MODEL_FILENAME_C
|
62 |
+
elif model_type == "q":
|
63 |
+
filename = CLASSIFICATION_MODEL_FILENAME_Q
|
64 |
+
else:
|
65 |
+
raise ValueError(f"Geçersiz model_type: {model_type}")
|
66 |
+
|
67 |
+
local_pt = hf_hub_download(
|
68 |
+
repo_id=CLASSIFICATION_MODEL_REPO,
|
69 |
+
filename=filename,
|
70 |
+
use_auth_token=False
|
71 |
+
)
|
72 |
+
|
73 |
+
model = SimpleCNN(model_type=model_type, num_classes=num_classes).to(device)
|
74 |
+
try:
|
75 |
+
state_dict = torch.load(local_pt, map_location=device)
|
76 |
+
model.load_state_dict(state_dict)
|
77 |
+
except RuntimeError:
|
78 |
+
model = torch.jit.load(local_pt, map_location=device)
|
79 |
+
model.eval()
|
80 |
+
return model
|
81 |
+
|
82 |
+
def load_t5_model(device):
|
83 |
+
tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_REPO)
|
84 |
+
model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_REPO).to(device)
|
85 |
+
model.eval()
|
86 |
+
return tokenizer, model
|
87 |
|
88 |
transform = transforms.Compose([
|
89 |
transforms.Resize((224, 224)),
|
90 |
+
transforms.ToTensor(),
|
91 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
92 |
+
[0.229, 0.224, 0.225])
|
93 |
])
|
94 |
|
95 |
+
def predict_image(model, image: Image.Image, device):
|
96 |
+
img_tensor = transform(image).unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
with torch.no_grad():
|
98 |
+
logits = model(img_tensor)
|
99 |
+
probs = F.softmax(logits, dim=1)[0]
|
100 |
+
conf, idx = torch.max(probs, dim=0)
|
101 |
+
return idx.item(), conf.item() * 100, img_tensor, probs.cpu().numpy()
|
102 |
+
|
103 |
+
def generate_comment_turkce(tokenizer, model, sinif_adi: str, device, max_length=64):
|
104 |
+
input_text = f"Sınıf: {sinif_adi}"
|
105 |
+
inputs = tokenizer(
|
106 |
+
input_text,
|
107 |
+
return_tensors="pt",
|
108 |
+
padding="longest",
|
109 |
+
truncation=True,
|
110 |
+
max_length=32
|
111 |
+
).to(device)
|
112 |
+
|
113 |
+
out_ids = model.generate(
|
114 |
+
**inputs,
|
115 |
+
max_length=max_length,
|
116 |
+
do_sample=True,
|
117 |
+
top_k=50,
|
118 |
+
top_p=0.95,
|
119 |
+
no_repeat_ngram_size=2,
|
120 |
+
early_stopping=True
|
121 |
+
)
|
122 |
+
comment = tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
123 |
+
return comment
|
124 |
+
|
125 |
+
def calculate_performance_metrics(model, device):
|
126 |
+
model = model.to(device)
|
127 |
+
test_input = torch.randn((1, 3, 224, 224)).to(device)
|
128 |
+
flops, params = profile(model, inputs=(test_input,), verbose=False)
|
129 |
+
start = time.time()
|
130 |
+
_ = model(test_input)
|
131 |
+
cpu_time = (time.time() - start) * 1000
|
132 |
+
return {
|
133 |
+
"size_pixels": 224,
|
134 |
+
"speed_cpu_b1": cpu_time,
|
135 |
+
"speed_cpu_b32": cpu_time / 10,
|
136 |
+
"speed_v100_b1": cpu_time / 2,
|
137 |
+
"params_million": params / 1e6,
|
138 |
+
"flops_billion": flops / 1e9
|
139 |
+
}
|