|
import torch |
|
import warnings |
|
import numpy as np |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
|
|
from PIL import Image |
|
|
|
from model import ( |
|
load_classification_model, |
|
load_t5_model, |
|
predict_image, |
|
generate_comment_turkce, |
|
calculate_performance_metrics |
|
) |
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
cnn_model_f = load_classification_model(device, model_type="f", num_classes=6) |
|
cnn_model_c = load_classification_model(device, model_type="c", num_classes=6) |
|
cnn_model_q = load_classification_model(device, model_type="q", num_classes=6) |
|
|
|
t5_tokenizer, t5_model = load_t5_model(device) |
|
|
|
perf_metrics_f = calculate_performance_metrics(cnn_model_f, device) |
|
perf_metrics_c = calculate_performance_metrics(cnn_model_c, device) |
|
perf_metrics_q = calculate_performance_metrics(cnn_model_q, device) |
|
|
|
|
|
class_names_en = [ |
|
"Alzheimer Disease", |
|
"Mild Alzheimer Risk", |
|
"Moderate Alzheimer Risk", |
|
"Very Mild Alzheimer Risk", |
|
"No Risk", |
|
"Parkinson Disease" |
|
] |
|
en2tr = { |
|
"Alzheimer Disease": "Alzheimer Hastalığı", |
|
"Mild Alzheimer Risk": "Hafif Alzheimer Riski", |
|
"Moderate Alzheimer Risk": "Orta Düzey Alzheimer Riski", |
|
"Very Mild Alzheimer Risk": "Çok Hafif Alzheimer Riski", |
|
"No Risk": "Risk Yok", |
|
"Parkinson Disease": "Parkinson Hastalığı" |
|
} |
|
|
|
|
|
def gradio_predict(image, model_type, question): |
|
if model_type == "f": |
|
cnn_model = cnn_model_f |
|
elif model_type == "c": |
|
cnn_model = cnn_model_c |
|
else: |
|
cnn_model = cnn_model_q |
|
|
|
idx, conf, inp_tensor, all_probs = predict_image(cnn_model, image, device) |
|
pred_en = class_names_en[idx] |
|
pred_tr = en2tr[pred_en] |
|
|
|
if not question or question.strip() == "": |
|
comment = generate_comment_turkce(t5_tokenizer, t5_model, pred_tr, device) |
|
else: |
|
input_text = f"Sınıf: {pred_tr}. Soru: {question}" |
|
inputs = t5_tokenizer( |
|
input_text, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=True, |
|
max_length=64 |
|
).to(device) |
|
out_ids = t5_model.generate( |
|
**inputs, |
|
max_length=64, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95, |
|
no_repeat_ngram_size=2, |
|
early_stopping=True |
|
) |
|
comment = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True) |
|
|
|
inp_np = inp_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() |
|
mean = np.array([0.485, 0.456, 0.406]) |
|
std = np.array([0.229, 0.224, 0.225]) |
|
img_show = inp_np * std + mean |
|
img_show = np.clip(img_show, 0, 1) |
|
|
|
tahmin_metni = f"Tahmin: {pred_en} — {conf:.2f}%" |
|
|
|
return img_show, tahmin_metni, comment |
|
|
|
css = "img {border-radius: 8px;}" |
|
|
|
demo = gr.Interface( |
|
fn=gradio_predict, |
|
inputs=[ |
|
gr.Image(type="pil", label="MRI Görüntüsü Yükleyin"), |
|
gr.Radio(choices=["f", "c", "q"], label="Vbai Model Tipi Seçin"), |
|
gr.Textbox(label="Görselle ilgili soru sorun. (İsteğe bağlı)", placeholder="Örnek: Bu hasta tedaviye ihtiyaç duyar mı?") |
|
], |
|
outputs=[ |
|
gr.Image(type="numpy", label="İşlenmiş Görsel"), |
|
gr.Textbox(label="Tahmin ve Güven"), |
|
gr.Textbox(label="Tbai Yorumu") |
|
], |
|
title="🧠 Vbai-DPA 2.3 + Tbai-DPA 1.0-BETA Yorumu", |
|
description=( |
|
"1) MRI görüntünüzü yükleyin.\n" |
|
"2) Vbai için f, c veya q modellerinden birini seçin.\n" |
|
"3) İsterseniz görselle ilgili kısa bir soru girin.\n" |
|
"4) Tahmin ve Tbai tabanlı yorum ekranda gösterilecek.\n" |
|
"5) NOT: Sonuçlar yapay zeka tarafından tahmin edildiğinden, doktorunuza ulaşmanız gerekmektedir!" |
|
), |
|
theme="default", |
|
css=css |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|