Update app.py
Browse files
app.py
CHANGED
@@ -1,25 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
2 |
from PIL import Image
|
3 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
|
9 |
demo = gr.Interface(
|
10 |
-
fn=
|
11 |
inputs=[
|
12 |
gr.Image(type="pil", label="MRI Görüntüsü Yükle"),
|
13 |
-
gr.
|
|
|
14 |
],
|
15 |
outputs=[
|
16 |
-
gr.
|
|
|
|
|
17 |
gr.Textbox(label="Tbai Yorumu")
|
18 |
],
|
19 |
-
title="Vbai-DPA 2.
|
20 |
-
description=
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
)
|
23 |
|
24 |
if __name__ == "__main__":
|
25 |
-
demo.launch()
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import warnings
|
5 |
+
import numpy as np
|
6 |
import gradio as gr
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
from PIL import Image
|
9 |
+
from sklearn.metrics import average_precision_score
|
10 |
+
from model import (
|
11 |
+
load_classification_model,
|
12 |
+
load_t5_model,
|
13 |
+
predict_image,
|
14 |
+
generate_comment_turkce,
|
15 |
+
calculate_performance_metrics
|
16 |
+
)
|
17 |
+
|
18 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
19 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
20 |
+
|
21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
|
23 |
+
cnn_model_f = load_classification_model(device, model_type="f", num_classes=6)
|
24 |
+
cnn_model_c = load_classification_model(device, model_type="c", num_classes=6)
|
25 |
+
cnn_model_q = load_classification_model(device, model_type="q", num_classes=6)
|
26 |
+
|
27 |
+
t5_tokenizer, t5_model = load_t5_model(device)
|
28 |
+
|
29 |
+
perf_metrics_f = calculate_performance_metrics(cnn_model_f, device)
|
30 |
+
perf_metrics_c = calculate_performance_metrics(cnn_model_c, device)
|
31 |
+
perf_metrics_q = calculate_performance_metrics(cnn_model_q, device)
|
32 |
+
|
33 |
+
class_names_en = [
|
34 |
+
"Alzheimer Disease",
|
35 |
+
"Mild Alzheimer Risk",
|
36 |
+
"Moderate Alzheimer Risk",
|
37 |
+
"Very Mild Alzheimer Risk",
|
38 |
+
"No Risk",
|
39 |
+
"Parkinson Disease"
|
40 |
+
]
|
41 |
+
en2tr = {
|
42 |
+
"Alzheimer Disease": "Alzheimer Hastalığı",
|
43 |
+
"Mild Alzheimer Risk": "Hafif Alzheimer Riski",
|
44 |
+
"Moderate Alzheimer Risk": "Orta Düzey Alzheimer Riski",
|
45 |
+
"Very Mild Alzheimer Risk": "Çok Hafif Alzheimer Riski",
|
46 |
+
"No Risk": "Risk Yok",
|
47 |
+
"Parkinson Disease": "Parkinson Hastalığı"
|
48 |
+
}
|
49 |
+
|
50 |
+
def gradio_predict(image, model_type, question):
|
51 |
+
if model_type == "f":
|
52 |
+
cnn_model = cnn_model_f
|
53 |
+
perf_metrics = perf_metrics_f
|
54 |
+
elif model_type == "c":
|
55 |
+
cnn_model = cnn_model_c
|
56 |
+
perf_metrics = perf_metrics_c
|
57 |
+
elif model_type == "q":
|
58 |
+
cnn_model = cnn_model_q
|
59 |
+
perf_metrics = perf_metrics_q
|
60 |
+
else:
|
61 |
+
return None, "Hata: Geçersiz model tipi", "", ""
|
62 |
+
|
63 |
+
idx, conf, inp_tensor, all_probs = predict_image(cnn_model, image, device)
|
64 |
+
pred_en = class_names_en[idx]
|
65 |
+
pred_tr = en2tr[pred_en]
|
66 |
+
|
67 |
+
if question is None or question.strip() == "":
|
68 |
+
comment = generate_comment_turkce(t5_tokenizer, t5_model, pred_tr, device)
|
69 |
+
else:
|
70 |
+
input_text = f"Sınıf: {pred_tr}. Soru: {question}"
|
71 |
+
inputs = t5_tokenizer(
|
72 |
+
input_text,
|
73 |
+
return_tensors="pt",
|
74 |
+
padding="longest",
|
75 |
+
truncation=True,
|
76 |
+
max_length=64
|
77 |
+
).to(device)
|
78 |
+
out_ids = t5_model.generate(
|
79 |
+
**inputs,
|
80 |
+
max_length=64,
|
81 |
+
do_sample=True,
|
82 |
+
top_k=50,
|
83 |
+
top_p=0.95,
|
84 |
+
no_repeat_ngram_size=2,
|
85 |
+
early_stopping=True
|
86 |
+
)
|
87 |
+
comment = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
88 |
+
|
89 |
+
num_classes = len(class_names_en)
|
90 |
+
true_onehot = np.zeros(num_classes, dtype=int)
|
91 |
+
true_onehot[idx] = 1
|
92 |
+
ap_scores = []
|
93 |
+
for cls_i in range(num_classes):
|
94 |
+
true_binary = (true_onehot == 1).astype(int) if cls_i == idx else (true_onehot == 0).astype(int)
|
95 |
+
score = all_probs[cls_i]
|
96 |
+
ap = average_precision_score(true_binary, [score])
|
97 |
+
ap_scores.append(ap)
|
98 |
+
map_score = np.mean(ap_scores)
|
99 |
+
|
100 |
+
inp_np = inp_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
101 |
+
mean = np.array([0.485, 0.456, 0.406])
|
102 |
+
std = np.array([0.229, 0.224, 0.225])
|
103 |
+
img_show = inp_np * std + mean
|
104 |
+
img_show = np.clip(img_show, 0, 1)
|
105 |
+
|
106 |
+
tahmin_metni = f"Tahmin: {pred_en} — {conf:.2f}%"
|
107 |
+
ap_map_metni = f"AP (sınıf {pred_en}): {ap_scores[idx]:.4f} | mAP: {map_score:.4f}"
|
108 |
+
|
109 |
+
return img_show, tahmin_metni, ap_map_metni, comment
|
110 |
|
111 |
+
css = """
|
112 |
+
img {border-radius: 8px;}
|
113 |
+
"""
|
114 |
|
115 |
demo = gr.Interface(
|
116 |
+
fn=gradio_predict,
|
117 |
inputs=[
|
118 |
gr.Image(type="pil", label="MRI Görüntüsü Yükle"),
|
119 |
+
gr.Radio(choices=["f", "c", "q"], label="Vbai Model Tipi Seçin"),
|
120 |
+
gr.Textbox(label="Görselle ilgili soru (isteğe bağlı)", placeholder="Örnek: Bu hasta tedaviye ihtiyaç duyar mı?")
|
121 |
],
|
122 |
outputs=[
|
123 |
+
gr.Image(type="numpy", label="Ön İşlenmiş Görsel"),
|
124 |
+
gr.Textbox(label="Tahmin ve Güven"),
|
125 |
+
gr.Textbox(label="AP ve mAP"),
|
126 |
gr.Textbox(label="Tbai Yorumu")
|
127 |
],
|
128 |
+
title="🧠 Vbai-DPA 2.3 | Seçilebilir Vbai (f/c/q) + Tbai Yorumu",
|
129 |
+
description=(
|
130 |
+
"1) MRI görüntüsünü yükleyin.\n"
|
131 |
+
"2) Vbai için f, c veya q modellerinden birini seçin.\n"
|
132 |
+
"3) Opsiyonel olarak görselle ilgili kısa bir soru girin.\n"
|
133 |
+
"4) Tahmin sonucu, gerçek AP/mAP ve Tbai tabanlı Türkçe yorum ekranda görülecek.\n"
|
134 |
+
"5) NOT: Sonuçlar yapay zeka tarafından üretildiğinden, doktorunuzdan en doğru bilgiye ulaşmanız gereklidir!"
|
135 |
+
),
|
136 |
+
theme="default",
|
137 |
+
css=css
|
138 |
)
|
139 |
|
140 |
if __name__ == "__main__":
|
141 |
+
demo.launch()
|