eyupipler commited on
Commit
549beca
·
verified ·
1 Parent(s): d1ab300

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +49 -51
model.py CHANGED
@@ -1,52 +1,50 @@
 
 
1
  import torch
2
- import torch.nn as nn
3
- from huggingface_hub import hf_hub_download
4
-
5
-
6
- class SimpleCNN(nn.Module):
7
- def __init__(self, num_classes=6):
8
- super(SimpleCNN, self).__init__()
9
- self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
10
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
11
- self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
12
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
13
- self.relu = nn.ReLU()
14
- self.dropout = nn.Dropout(0.5)
15
- self._initialize_fc(num_classes)
16
-
17
- def _initialize_fc(self, num_classes):
18
- dummy_input = torch.zeros(1, 3, 448, 448)
19
- x = self.pool(self.relu(self.conv1(dummy_input)))
20
- x = self.pool(self.relu(self.conv2(x)))
21
- x = self.pool(self.relu(self.conv3(x)))
22
- x = x.view(x.size(0), -1)
23
- flattened_size = x.shape[1]
24
- self.fc1 = nn.Linear(flattened_size, 512)
25
- self.fc2 = nn.Linear(512, num_classes)
26
-
27
- def forward(self, x):
28
- x = self.pool(self.relu(self.conv1(x)))
29
- x = self.pool(self.relu(self.conv2(x)))
30
- x = self.pool(self.relu(self.conv3(x)))
31
- x = x.view(x.size(0), -1)
32
- x = self.dropout(self.relu(self.fc1(x)))
33
- x = self.fc2(x)
34
- return x
35
-
36
- def load_model(device: str = 'cpu'):
37
- """
38
- Downloads and loads the pretrained SimpleCNN model for the 'c' version.
39
- """
40
- torch_device = torch.device(device)
41
-
42
- weights_path = hf_hub_download(
43
- repo_id="Neurazum/Vbai-DPA-2.3",
44
- filename="Vbai-DPA 2.3c.pt",
45
- repo_type="model"
46
- )
47
-
48
- model = SimpleCNN(num_classes=6).to(torch_device)
49
- state = torch.load(weights_path, map_location=torch_device)
50
- model.load_state_dict(state)
51
- model.eval()
52
- return model
 
1
+ # model.py
2
+
3
  import torch
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
7
+ import torch.nn.functional as F
8
+
9
+ CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3"
10
+ T5_MODEL_REPO = "Neurazum/Tbai-DPA-1.0"
11
+
12
+ classification_model = torch.hub.load_state_dict_from_url(
13
+ f"https://huggingface.co/{CLASSIFICATION_MODEL_REPO}/resolve/main/vbai_model.pt",
14
+ map_location=torch.device('cpu')
15
+ )
16
+ classification_model = torch.jit.load("Vbai-DPA 2.3c.pt", map_location="cpu")
17
+ classification_model.eval()
18
+
19
+ t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_REPO)
20
+ t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_REPO)
21
+ t5_model.eval()
22
+
23
+ transform = transforms.Compose([
24
+ transforms.Resize((224, 224)),
25
+ transforms.ToTensor()
26
+ ])
27
+
28
+ class_names = [
29
+ 'Alzheimer Disease',
30
+ 'Mild Alzheimer Risk',
31
+ 'Moderate Alzheimer Risk',
32
+ 'Very Mild Alzheimer Risk',
33
+ 'No Risk',
34
+ 'Parkinson Disease'
35
+ ]
36
+
37
+ def predict(image: Image.Image, question: str = ""):
38
+ img_tensor = transform(image).unsqueeze(0)
39
+ with torch.no_grad():
40
+ output = classification_model(img_tensor)
41
+ probs = F.softmax(output, dim=1)[0]
42
+ confidence, pred_idx = torch.max(probs, dim=0)
43
+ prediction = class_names[pred_idx.item()]
44
+
45
+ input_text = f"Input: {prediction}. Question: {question if question else 'Durum hakkında tıbbi yorum yap'}"
46
+ t5_input = t5_tokenizer.encode(input_text, return_tensors="pt")
47
+ t5_output = t5_model.generate(t5_input, max_length=50)
48
+ comment = t5_tokenizer.decode(t5_output[0], skip_special_tokens=True)
49
+
50
+ return prediction, confidence.item(), comment