eyupipler's picture
Update app.py
e205224 verified
raw
history blame
3.92 kB
import torch
import warnings
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
from model import (
load_classification_model,
load_t5_model,
predict_image,
generate_comment_turkce,
calculate_performance_metrics
)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn_model_f = load_classification_model(device, model_type="f", num_classes=6)
cnn_model_c = load_classification_model(device, model_type="c", num_classes=6)
cnn_model_q = load_classification_model(device, model_type="q", num_classes=6)
t5_tokenizer, t5_model = load_t5_model(device)
perf_metrics_f = calculate_performance_metrics(cnn_model_f, device)
perf_metrics_c = calculate_performance_metrics(cnn_model_c, device)
perf_metrics_q = calculate_performance_metrics(cnn_model_q, device)
class_names_en = [
"Alzheimer Disease",
"Mild Alzheimer Risk",
"Moderate Alzheimer Risk",
"Very Mild Alzheimer Risk",
"No Risk",
"Parkinson Disease"
]
en2tr = {
"Alzheimer Disease": "Alzheimer Hastalığı",
"Mild Alzheimer Risk": "Hafif Alzheimer Riski",
"Moderate Alzheimer Risk": "Orta Düzey Alzheimer Riski",
"Very Mild Alzheimer Risk": "Çok Hafif Alzheimer Riski",
"No Risk": "Risk Yok",
"Parkinson Disease": "Parkinson Hastalığı"
}
def gradio_predict(image, model_type, question):
if model_type == "f":
cnn_model = cnn_model_f
elif model_type == "c":
cnn_model = cnn_model_c
else:
cnn_model = cnn_model_q
idx, conf, inp_tensor, all_probs = predict_image(cnn_model, image, device)
pred_en = class_names_en[idx]
pred_tr = en2tr[pred_en]
if not question or question.strip() == "":
comment = generate_comment_turkce(t5_tokenizer, t5_model, pred_tr, device)
else:
input_text = f"Sınıf: {pred_tr}. Soru: {question}"
inputs = t5_tokenizer(
input_text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=64
).to(device)
out_ids = t5_model.generate(
**inputs,
max_length=64,
do_sample=True,
top_k=50,
top_p=0.95,
no_repeat_ngram_size=2,
early_stopping=True
)
comment = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
inp_np = inp_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_show = inp_np * std + mean
img_show = np.clip(img_show, 0, 1)
tahmin_metni = f"Tahmin: {pred_en}{conf:.2f}%"
return img_show, tahmin_metni, comment
css = "img {border-radius: 8px;}"
demo = gr.Interface(
fn=gradio_predict,
inputs=[
gr.Image(type="pil", label="MRI Görüntüsü Yükleyin"),
gr.Radio(choices=["f", "c", "q"], label="Vbai Model Tipi Seçin"),
gr.Textbox(label="Görselle ilgili soru sorun. (İsteğe bağlı)", placeholder="Örnek: Bu hasta tedaviye ihtiyaç duyar mı?")
],
outputs=[
gr.Image(type="numpy", label="İşlenmiş Görsel"),
gr.Textbox(label="Tahmin ve Güven"),
gr.Textbox(label="Tbai Yorumu")
],
title="🧠 Vbai-DPA 2.3 + Tbai-DPA 1.0-BETA Yorumu",
description=(
"1) MRI görüntünüzü yükleyin.\n"
"2) Vbai için f, c veya q modellerinden birini seçin.\n"
"3) İsterseniz görselle ilgili kısa bir soru girin.\n"
"4) Tahmin ve Tbai tabanlı yorum ekranda gösterilecek.\n"
"5) NOT: Sonuçlar yapay zeka tarafından tahmin edildiğinden, doktorunuza ulaşmanız gerekmektedir!"
),
theme="default",
css=css
)
if __name__ == "__main__":
demo.launch()