File size: 3,924 Bytes
dabc064
 
 
16bb15b
dabc064
f3e3032
16bb15b
f3e3032
dabc064
 
 
 
 
 
 
 
 
 
 
72c85cf
dabc064
 
 
 
 
 
 
 
 
 
 
 
72c85cf
dabc064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72c85cf
dabc064
 
 
 
 
f3e3032
dabc064
e24f336
dabc064
 
 
e24f336
f3e3032
dabc064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16bb15b
e0843a5
e67dd2a
 
16bb15b
 
dabc064
16bb15b
e67dd2a
e461067
 
16bb15b
 
e461067
dabc064
e461067
16bb15b
e461067
dabc064
e67dd2a
e461067
e67dd2a
e205224
e461067
dabc064
 
 
16bb15b
 
 
e67dd2a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import warnings
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt

from PIL import Image

from model import (
    load_classification_model,
    load_t5_model,
    predict_image,
    generate_comment_turkce,
    calculate_performance_metrics
)

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cnn_model_f = load_classification_model(device, model_type="f", num_classes=6)
cnn_model_c = load_classification_model(device, model_type="c", num_classes=6)
cnn_model_q = load_classification_model(device, model_type="q", num_classes=6)

t5_tokenizer, t5_model = load_t5_model(device)

perf_metrics_f = calculate_performance_metrics(cnn_model_f, device)
perf_metrics_c = calculate_performance_metrics(cnn_model_c, device)
perf_metrics_q = calculate_performance_metrics(cnn_model_q, device)


class_names_en = [
    "Alzheimer Disease",
    "Mild Alzheimer Risk",
    "Moderate Alzheimer Risk",
    "Very Mild Alzheimer Risk",
    "No Risk",
    "Parkinson Disease"
]
en2tr = {
    "Alzheimer Disease": "Alzheimer Hastalığı",
    "Mild Alzheimer Risk": "Hafif Alzheimer Riski",
    "Moderate Alzheimer Risk": "Orta Düzey Alzheimer Riski",
    "Very Mild Alzheimer Risk": "Çok Hafif Alzheimer Riski",
    "No Risk": "Risk Yok",
    "Parkinson Disease": "Parkinson Hastalığı"
}


def gradio_predict(image, model_type, question):
    if model_type == "f":
        cnn_model = cnn_model_f
    elif model_type == "c":
        cnn_model = cnn_model_c
    else:
        cnn_model = cnn_model_q
        
    idx, conf, inp_tensor, all_probs = predict_image(cnn_model, image, device)
    pred_en = class_names_en[idx]
    pred_tr = en2tr[pred_en]
    
    if not question or question.strip() == "":
        comment = generate_comment_turkce(t5_tokenizer, t5_model, pred_tr, device)
    else:
        input_text = f"Sınıf: {pred_tr}. Soru: {question}"
        inputs = t5_tokenizer(
            input_text,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=64
        ).to(device)
        out_ids = t5_model.generate(
            **inputs,
            max_length=64,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            no_repeat_ngram_size=2,
            early_stopping=True
        )
        comment = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)

    inp_np = inp_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_show = inp_np * std + mean
    img_show = np.clip(img_show, 0, 1)

    tahmin_metni = f"Tahmin: {pred_en}{conf:.2f}%"

    return img_show, tahmin_metni, comment

css = "img {border-radius: 8px;}"

demo = gr.Interface(
    fn=gradio_predict,
    inputs=[
        gr.Image(type="pil", label="MRI Görüntüsü Yükleyin"),
        gr.Radio(choices=["f", "c", "q"], label="Vbai Model Tipi Seçin"),
        gr.Textbox(label="Görselle ilgili soru sorun. (İsteğe bağlı)", placeholder="Örnek: Bu hasta tedaviye ihtiyaç duyar mı?")
    ],
    outputs=[
        gr.Image(type="numpy", label="İşlenmiş Görsel"),
        gr.Textbox(label="Tahmin ve Güven"),
        gr.Textbox(label="Tbai Yorumu")
    ],
    title="🧠 Vbai-DPA 2.3 + Tbai-DPA 1.0-BETA Yorumu",
    description=(
        "1) MRI görüntünüzü yükleyin.\n"
        "2) Vbai için f, c veya q modellerinden birini seçin.\n"
        "3) İsterseniz görselle ilgili kısa bir soru girin.\n"
        "4) Tahmin ve Tbai tabanlı yorum ekranda gösterilecek.\n"
        "5) NOT: Sonuçlar yapay zeka tarafından tahmin edildiğinden, doktorunuza ulaşmanız gerekmektedir!"
    ),
    theme="default",
    css=css
)

if __name__ == "__main__":
    demo.launch()