|
|
|
|
|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import torch.nn.functional as F |
|
|
|
CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3" |
|
T5_MODEL_REPO = "Neurazum/Tbai-DPA-1.0" |
|
|
|
classification_model = torch.hub.load_state_dict_from_url( |
|
f"https://huggingface.co/{CLASSIFICATION_MODEL_REPO}/resolve/main/vbai_model.pt", |
|
map_location=torch.device('cpu') |
|
) |
|
classification_model = torch.jit.load("Vbai-DPA 2.3c.pt", map_location="cpu") |
|
classification_model.eval() |
|
|
|
t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_REPO) |
|
t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_REPO) |
|
t5_model.eval() |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
class_names = [ |
|
'Alzheimer Disease', |
|
'Mild Alzheimer Risk', |
|
'Moderate Alzheimer Risk', |
|
'Very Mild Alzheimer Risk', |
|
'No Risk', |
|
'Parkinson Disease' |
|
] |
|
|
|
def predict(image: Image.Image, question: str = ""): |
|
img_tensor = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
output = classification_model(img_tensor) |
|
probs = F.softmax(output, dim=1)[0] |
|
confidence, pred_idx = torch.max(probs, dim=0) |
|
prediction = class_names[pred_idx.item()] |
|
|
|
input_text = f"Input: {prediction}. Question: {question if question else 'Durum hakkında tıbbi yorum yap'}" |
|
t5_input = t5_tokenizer.encode(input_text, return_tensors="pt") |
|
t5_output = t5_model.generate(t5_input, max_length=50) |
|
comment = t5_tokenizer.decode(t5_output[0], skip_special_tokens=True) |
|
|
|
return prediction, confidence.item(), comment |
|
|