|
import streamlit as st |
|
import json |
|
import random |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
import adrd |
|
ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt' |
|
|
|
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu') |
|
return model |
|
|
|
@st.cache_resource |
|
def load_nacc_data(): |
|
from data.dataset_csv import CSVDataset |
|
dat = CSVDataset( |
|
dat_file = "./data/nacc_test_with_np_cli.csv", |
|
cnf_file = "./data/default_conf_new.toml" |
|
) |
|
return dat |
|
|
|
model = load_model() |
|
dat_tst = load_nacc_data() |
|
|
|
def predict_proba(data_dict): |
|
pred_dict = model.predict_proba([data_dict])[1][0] |
|
return pred_dict |
|
|
|
|
|
from data.dataset_csv import CSVDataset |
|
dat_tst = CSVDataset( |
|
dat_file = "./data/nacc_test_with_np_cli.csv", |
|
cnf_file = "./data/default_conf_new.toml" |
|
) |
|
|
|
|
|
if 'input_text' not in st.session_state: |
|
st.session_state.input_text = "" |
|
|
|
|
|
with st.form("json_input_form"): |
|
st.write("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"Random NACC Example\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"Predict\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.") |
|
json_input = st.text_area( |
|
"Please enter JSON-formatted input features:", |
|
value = st.session_state.input_text, |
|
height = 250 |
|
) |
|
|
|
|
|
left_col, middle_col, right_col = st.columns([3, 4, 1]) |
|
|
|
with left_col: |
|
sample_button = st.form_submit_button("Random NACC Case") |
|
|
|
with right_col: |
|
submit_button = st.form_submit_button("Predict") |
|
|
|
if sample_button: |
|
idx = random.randint(0, len(dat_tst) - 1) |
|
example = dat_tst[idx][0] |
|
st.session_state.input_text = json.dumps(example) |
|
|
|
|
|
if 'input_text' in st.session_state: |
|
st.experimental_rerun() |
|
|
|
elif submit_button: |
|
try: |
|
|
|
data_dict = json.loads(json_input) |
|
pred_dict = predict_proba(data_dict) |
|
st.write("Predicted probabilities:") |
|
st.json(pred_dict) |
|
except json.JSONDecodeError as e: |
|
|
|
st.error(f"An error occurred: {e}") |
|
|