Update model.py
Browse files
model.py
CHANGED
@@ -1,52 +1,50 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
-
import
|
3 |
-
from
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
)
|
47 |
-
|
48 |
-
|
49 |
-
state = torch.load(weights_path, map_location=torch_device)
|
50 |
-
model.load_state_dict(state)
|
51 |
-
model.eval()
|
52 |
-
return model
|
|
|
1 |
+
# model.py
|
2 |
+
|
3 |
import torch
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3"
|
10 |
+
T5_MODEL_REPO = "Neurazum/Tbai-DPA-1.0"
|
11 |
+
|
12 |
+
classification_model = torch.hub.load_state_dict_from_url(
|
13 |
+
f"https://huggingface.co/{CLASSIFICATION_MODEL_REPO}/resolve/main/vbai_model.pt",
|
14 |
+
map_location=torch.device('cpu')
|
15 |
+
)
|
16 |
+
classification_model = torch.jit.load("Vbai-DPA 2.3c.pt", map_location="cpu")
|
17 |
+
classification_model.eval()
|
18 |
+
|
19 |
+
t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_REPO)
|
20 |
+
t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_REPO)
|
21 |
+
t5_model.eval()
|
22 |
+
|
23 |
+
transform = transforms.Compose([
|
24 |
+
transforms.Resize((224, 224)),
|
25 |
+
transforms.ToTensor()
|
26 |
+
])
|
27 |
+
|
28 |
+
class_names = [
|
29 |
+
'Alzheimer Disease',
|
30 |
+
'Mild Alzheimer Risk',
|
31 |
+
'Moderate Alzheimer Risk',
|
32 |
+
'Very Mild Alzheimer Risk',
|
33 |
+
'No Risk',
|
34 |
+
'Parkinson Disease'
|
35 |
+
]
|
36 |
+
|
37 |
+
def predict(image: Image.Image, question: str = ""):
|
38 |
+
img_tensor = transform(image).unsqueeze(0)
|
39 |
+
with torch.no_grad():
|
40 |
+
output = classification_model(img_tensor)
|
41 |
+
probs = F.softmax(output, dim=1)[0]
|
42 |
+
confidence, pred_idx = torch.max(probs, dim=0)
|
43 |
+
prediction = class_names[pred_idx.item()]
|
44 |
+
|
45 |
+
input_text = f"Input: {prediction}. Question: {question if question else 'Durum hakkında tıbbi yorum yap'}"
|
46 |
+
t5_input = t5_tokenizer.encode(input_text, return_tensors="pt")
|
47 |
+
t5_output = t5_model.generate(t5_input, max_length=50)
|
48 |
+
comment = t5_tokenizer.decode(t5_output[0], skip_special_tokens=True)
|
49 |
+
|
50 |
+
return prediction, confidence.item(), comment
|
|
|
|
|
|
|
|