nmed2024 / app.py
xf3227's picture
ok
d217981
raw
history blame
2.6 kB
import streamlit as st
import json
import random
@st.cache_resource
def load_model():
import adrd
ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt'
# ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_densenet_emb_encoder_2_AUPR.pt'
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
return model
@st.cache_resource
def load_nacc_data():
from data.dataset_csv import CSVDataset
dat = CSVDataset(
dat_file = "./data/nacc_test_with_np_cli.csv",
cnf_file = "./data/default_conf_new.toml"
)
return dat
model = load_model()
dat_tst = load_nacc_data()
def predict_proba(data_dict):
pred_dict = model.predict_proba([data_dict])[1][0]
return pred_dict
# load NACC testing data
from data.dataset_csv import CSVDataset
dat_tst = CSVDataset(
dat_file = "./data/nacc_test_with_np_cli.csv",
cnf_file = "./data/default_conf_new.toml"
)
# initialize session state for the text input if it's not already set
if 'input_text' not in st.session_state:
st.session_state.input_text = ""
# Create a form for user input
with st.form("json_input_form"):
st.write("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"Random NACC Example\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"Predict\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
json_input = st.text_area(
"Please enter JSON-formatted input features:",
value = st.session_state.input_text,
height = 250
)
# create three columns
left_col, middle_col, right_col = st.columns([3, 4, 1])
with left_col:
sample_button = st.form_submit_button("Random NACC Case")
with right_col:
submit_button = st.form_submit_button("Predict")
if sample_button:
idx = random.randint(0, len(dat_tst) - 1)
example = dat_tst[idx][0]
st.session_state.input_text = json.dumps(example)
# reset input text after form processing to show updated text in the input box
if 'input_text' in st.session_state:
st.experimental_rerun()
elif submit_button:
try:
# Parse the JSON input into a Python dictionary
data_dict = json.loads(json_input)
pred_dict = predict_proba(data_dict)
st.write("Predicted probabilities:")
st.json(pred_dict)
except json.JSONDecodeError as e:
# Handle JSON parsing errors
st.error(f"An error occurred: {e}")