File size: 1,708 Bytes
549beca
 
e2c4c7d
549beca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# model.py

import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch.nn.functional as F

CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3"
T5_MODEL_REPO = "Neurazum/Tbai-DPA-1.0"

classification_model = torch.hub.load_state_dict_from_url(
    f"https://huggingface.co/{CLASSIFICATION_MODEL_REPO}/resolve/main/vbai_model.pt",
    map_location=torch.device('cpu')
)
classification_model = torch.jit.load("Vbai-DPA 2.3c.pt", map_location="cpu")
classification_model.eval()

t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_REPO)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_REPO)
t5_model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

class_names = [
        'Alzheimer Disease',
        'Mild Alzheimer Risk',
        'Moderate Alzheimer Risk',
        'Very Mild Alzheimer Risk',
        'No Risk',
        'Parkinson Disease'
]

def predict(image: Image.Image, question: str = ""):
    img_tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        output = classification_model(img_tensor)
        probs = F.softmax(output, dim=1)[0]
        confidence, pred_idx = torch.max(probs, dim=0)
        prediction = class_names[pred_idx.item()]
    
    input_text = f"Input: {prediction}. Question: {question if question else 'Durum hakkında tıbbi yorum yap'}"
    t5_input = t5_tokenizer.encode(input_text, return_tensors="pt")
    t5_output = t5_model.generate(t5_input, max_length=50)
    comment = t5_tokenizer.decode(t5_output[0], skip_special_tokens=True)

    return prediction, confidence.item(), comment