eyupipler's picture
Update model.py
549beca verified
raw
history blame
1.71 kB
# model.py
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch.nn.functional as F
CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3"
T5_MODEL_REPO = "Neurazum/Tbai-DPA-1.0"
classification_model = torch.hub.load_state_dict_from_url(
f"https://huggingface.co/{CLASSIFICATION_MODEL_REPO}/resolve/main/vbai_model.pt",
map_location=torch.device('cpu')
)
classification_model = torch.jit.load("Vbai-DPA 2.3c.pt", map_location="cpu")
classification_model.eval()
t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_REPO)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_REPO)
t5_model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
class_names = [
'Alzheimer Disease',
'Mild Alzheimer Risk',
'Moderate Alzheimer Risk',
'Very Mild Alzheimer Risk',
'No Risk',
'Parkinson Disease'
]
def predict(image: Image.Image, question: str = ""):
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = classification_model(img_tensor)
probs = F.softmax(output, dim=1)[0]
confidence, pred_idx = torch.max(probs, dim=0)
prediction = class_names[pred_idx.item()]
input_text = f"Input: {prediction}. Question: {question if question else 'Durum hakkında tıbbi yorum yap'}"
t5_input = t5_tokenizer.encode(input_text, return_tensors="pt")
t5_output = t5_model.generate(t5_input, max_length=50)
comment = t5_tokenizer.decode(t5_output[0], skip_special_tokens=True)
return prediction, confidence.item(), comment