eyupipler commited on
Commit
f2bace7
·
verified ·
1 Parent(s): 5f89fef

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +126 -37
model.py CHANGED
@@ -1,50 +1,139 @@
1
- # model.py
2
-
3
  import torch
4
- import torchvision.transforms as transforms
 
5
  from PIL import Image
 
 
6
  from transformers import T5ForConditionalGeneration, T5Tokenizer
7
- import torch.nn.functional as F
8
 
9
  CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3"
10
- T5_MODEL_REPO = "Neurazum/Tbai-DPA-1.0"
11
 
12
- classification_model = torch.hub.load_state_dict_from_url(
13
- f"https://huggingface.co/Neurazum/Vbai-DPA-2.3/blob/main/Vbai-DPA%202.3c.pt",
14
- map_location=torch.device('cpu')
15
- )
16
- classification_model = torch.jit.load("Vbai-DPA 2.3c.pt", map_location="cpu")
17
- classification_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_REPO)
20
- t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_REPO)
21
- t5_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  transform = transforms.Compose([
24
  transforms.Resize((224, 224)),
25
- transforms.ToTensor()
 
 
26
  ])
27
 
28
- class_names = [
29
- 'Alzheimer Disease',
30
- 'Mild Alzheimer Risk',
31
- 'Moderate Alzheimer Risk',
32
- 'Very Mild Alzheimer Risk',
33
- 'No Risk',
34
- 'Parkinson Disease'
35
- ]
36
-
37
- def predict(image: Image.Image, question: str = ""):
38
- img_tensor = transform(image).unsqueeze(0)
39
  with torch.no_grad():
40
- output = classification_model(img_tensor)
41
- probs = F.softmax(output, dim=1)[0]
42
- confidence, pred_idx = torch.max(probs, dim=0)
43
- prediction = class_names[pred_idx.item()]
44
-
45
- input_text = f"Input: {prediction}. Question: {question if question else 'Durum hakkında tıbbi yorum yap'}"
46
- t5_input = t5_tokenizer.encode(input_text, return_tensors="pt")
47
- t5_output = t5_model.generate(t5_input, max_length=50)
48
- comment = t5_tokenizer.decode(t5_output[0], skip_special_tokens=True)
49
-
50
- return prediction, confidence.item(), comment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
  from PIL import Image
5
+ from torchvision import transforms
6
+ from thop import profile
7
  from transformers import T5ForConditionalGeneration, T5Tokenizer
8
+ from huggingface_hub import hf_hub_download
9
 
10
  CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3"
 
11
 
12
+ CLASSIFICATION_MODEL_FILENAME_F = "Vbai-DPA 2.3f.pt"
13
+ CLASSIFICATION_MODEL_FILENAME_C = "Vbai-DPA 2.3c.pt"
14
+ CLASSIFICATION_MODEL_FILENAME_Q = "Vbai-DPA 2.3q.pt"
15
+
16
+ T5_MODEL_REPO = "Neurazum/Tbai-DPA 1.0"
17
+
18
+ class SimpleCNN(nn.Module):
19
+ def __init__(self, model_type="f", num_classes=6):
20
+ super(SimpleCNN, self).__init__()
21
+ self.num_classes = num_classes
22
+
23
+ if model_type == "f":
24
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
25
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
26
+ self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
27
+ self.fc1 = nn.Linear(64 * 28 * 28, 256)
28
+ elif model_type == "c":
29
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
30
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
31
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
32
+ self.fc1 = nn.Linear(128 * 28 * 28, 512)
33
+ elif model_type == "q":
34
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
35
+ self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
36
+ self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
37
+ self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
38
+ self.fc1 = nn.Linear(512 * 14 * 14, 1024)
39
+
40
+ self.dropout = nn.Dropout(0.5)
41
+ self.fc2 = nn.Linear(self.fc1.out_features, num_classes)
42
+ self.relu = nn.ReLU()
43
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
44
 
45
+ def forward(self, x):
46
+ x = self.pool(self.relu(self.conv1(x)))
47
+ x = self.pool(self.relu(self.conv2(x)))
48
+ x = self.pool(self.relu(self.conv3(x)))
49
+ if hasattr(self, "conv4"):
50
+ x = self.pool(self.relu(self.conv4(x)))
51
+ x = x.view(x.size(0), -1)
52
+ x = self.relu(self.fc1(x))
53
+ x = self.dropout(x)
54
+ x = self.fc2(x)
55
+ return x
56
+
57
+ def load_classification_model(device, model_type="f", num_classes=6):
58
+ if model_type == "f":
59
+ filename = CLASSIFICATION_MODEL_FILENAME_F
60
+ elif model_type == "c":
61
+ filename = CLASSIFICATION_MODEL_FILENAME_C
62
+ elif model_type == "q":
63
+ filename = CLASSIFICATION_MODEL_FILENAME_Q
64
+ else:
65
+ raise ValueError(f"Geçersiz model_type: {model_type}")
66
+
67
+ local_pt = hf_hub_download(
68
+ repo_id=CLASSIFICATION_MODEL_REPO,
69
+ filename=filename,
70
+ use_auth_token=False
71
+ )
72
+
73
+ model = SimpleCNN(model_type=model_type, num_classes=num_classes).to(device)
74
+ try:
75
+ state_dict = torch.load(local_pt, map_location=device)
76
+ model.load_state_dict(state_dict)
77
+ except RuntimeError:
78
+ model = torch.jit.load(local_pt, map_location=device)
79
+ model.eval()
80
+ return model
81
+
82
+ def load_t5_model(device):
83
+ tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_REPO)
84
+ model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_REPO).to(device)
85
+ model.eval()
86
+ return tokenizer, model
87
 
88
  transform = transforms.Compose([
89
  transforms.Resize((224, 224)),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize([0.485, 0.456, 0.406],
92
+ [0.229, 0.224, 0.225])
93
  ])
94
 
95
+ def predict_image(model, image: Image.Image, device):
96
+ img_tensor = transform(image).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
97
  with torch.no_grad():
98
+ logits = model(img_tensor)
99
+ probs = F.softmax(logits, dim=1)[0]
100
+ conf, idx = torch.max(probs, dim=0)
101
+ return idx.item(), conf.item() * 100, img_tensor, probs.cpu().numpy()
102
+
103
+ def generate_comment_turkce(tokenizer, model, sinif_adi: str, device, max_length=64):
104
+ input_text = f"Sınıf: {sinif_adi}"
105
+ inputs = tokenizer(
106
+ input_text,
107
+ return_tensors="pt",
108
+ padding="longest",
109
+ truncation=True,
110
+ max_length=32
111
+ ).to(device)
112
+
113
+ out_ids = model.generate(
114
+ **inputs,
115
+ max_length=max_length,
116
+ do_sample=True,
117
+ top_k=50,
118
+ top_p=0.95,
119
+ no_repeat_ngram_size=2,
120
+ early_stopping=True
121
+ )
122
+ comment = tokenizer.decode(out_ids[0], skip_special_tokens=True)
123
+ return comment
124
+
125
+ def calculate_performance_metrics(model, device):
126
+ model = model.to(device)
127
+ test_input = torch.randn((1, 3, 224, 224)).to(device)
128
+ flops, params = profile(model, inputs=(test_input,), verbose=False)
129
+ start = time.time()
130
+ _ = model(test_input)
131
+ cpu_time = (time.time() - start) * 1000
132
+ return {
133
+ "size_pixels": 224,
134
+ "speed_cpu_b1": cpu_time,
135
+ "speed_cpu_b32": cpu_time / 10,
136
+ "speed_v100_b1": cpu_time / 2,
137
+ "params_million": params / 1e6,
138
+ "flops_billion": flops / 1e9
139
+ }