File size: 3,924 Bytes
dabc064 16bb15b dabc064 f3e3032 16bb15b f3e3032 dabc064 72c85cf dabc064 72c85cf dabc064 72c85cf dabc064 f3e3032 dabc064 e24f336 dabc064 e24f336 f3e3032 dabc064 16bb15b e0843a5 e67dd2a 16bb15b dabc064 16bb15b e67dd2a e461067 16bb15b e461067 dabc064 e461067 16bb15b e461067 dabc064 e67dd2a e461067 e67dd2a e205224 e461067 dabc064 16bb15b e67dd2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import torch
import warnings
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
from model import (
load_classification_model,
load_t5_model,
predict_image,
generate_comment_turkce,
calculate_performance_metrics
)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn_model_f = load_classification_model(device, model_type="f", num_classes=6)
cnn_model_c = load_classification_model(device, model_type="c", num_classes=6)
cnn_model_q = load_classification_model(device, model_type="q", num_classes=6)
t5_tokenizer, t5_model = load_t5_model(device)
perf_metrics_f = calculate_performance_metrics(cnn_model_f, device)
perf_metrics_c = calculate_performance_metrics(cnn_model_c, device)
perf_metrics_q = calculate_performance_metrics(cnn_model_q, device)
class_names_en = [
"Alzheimer Disease",
"Mild Alzheimer Risk",
"Moderate Alzheimer Risk",
"Very Mild Alzheimer Risk",
"No Risk",
"Parkinson Disease"
]
en2tr = {
"Alzheimer Disease": "Alzheimer Hastalığı",
"Mild Alzheimer Risk": "Hafif Alzheimer Riski",
"Moderate Alzheimer Risk": "Orta Düzey Alzheimer Riski",
"Very Mild Alzheimer Risk": "Çok Hafif Alzheimer Riski",
"No Risk": "Risk Yok",
"Parkinson Disease": "Parkinson Hastalığı"
}
def gradio_predict(image, model_type, question):
if model_type == "f":
cnn_model = cnn_model_f
elif model_type == "c":
cnn_model = cnn_model_c
else:
cnn_model = cnn_model_q
idx, conf, inp_tensor, all_probs = predict_image(cnn_model, image, device)
pred_en = class_names_en[idx]
pred_tr = en2tr[pred_en]
if not question or question.strip() == "":
comment = generate_comment_turkce(t5_tokenizer, t5_model, pred_tr, device)
else:
input_text = f"Sınıf: {pred_tr}. Soru: {question}"
inputs = t5_tokenizer(
input_text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=64
).to(device)
out_ids = t5_model.generate(
**inputs,
max_length=64,
do_sample=True,
top_k=50,
top_p=0.95,
no_repeat_ngram_size=2,
early_stopping=True
)
comment = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
inp_np = inp_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_show = inp_np * std + mean
img_show = np.clip(img_show, 0, 1)
tahmin_metni = f"Tahmin: {pred_en} — {conf:.2f}%"
return img_show, tahmin_metni, comment
css = "img {border-radius: 8px;}"
demo = gr.Interface(
fn=gradio_predict,
inputs=[
gr.Image(type="pil", label="MRI Görüntüsü Yükleyin"),
gr.Radio(choices=["f", "c", "q"], label="Vbai Model Tipi Seçin"),
gr.Textbox(label="Görselle ilgili soru sorun. (İsteğe bağlı)", placeholder="Örnek: Bu hasta tedaviye ihtiyaç duyar mı?")
],
outputs=[
gr.Image(type="numpy", label="İşlenmiş Görsel"),
gr.Textbox(label="Tahmin ve Güven"),
gr.Textbox(label="Tbai Yorumu")
],
title="🧠 Vbai-DPA 2.3 + Tbai-DPA 1.0-BETA Yorumu",
description=(
"1) MRI görüntünüzü yükleyin.\n"
"2) Vbai için f, c veya q modellerinden birini seçin.\n"
"3) İsterseniz görselle ilgili kısa bir soru girin.\n"
"4) Tahmin ve Tbai tabanlı yorum ekranda gösterilecek.\n"
"5) NOT: Sonuçlar yapay zeka tarafından tahmin edildiğinden, doktorunuza ulaşmanız gerekmektedir!"
),
theme="default",
css=css
)
if __name__ == "__main__":
demo.launch()
|