import streamlit as st import json import random @st.cache_resource def load_model(): import adrd ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt' # ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_densenet_emb_encoder_2_AUPR.pt' model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu') return model @st.cache_resource def load_nacc_data(): from data.dataset_csv import CSVDataset dat = CSVDataset( dat_file = "./data/nacc_test_with_np_cli.csv", cnf_file = "./data/default_conf_new.toml" ) return dat model = load_model() dat_tst = load_nacc_data() def predict_proba(data_dict): pred_dict = model.predict_proba([data_dict])[1][0] return pred_dict # load NACC testing data from data.dataset_csv import CSVDataset dat_tst = CSVDataset( dat_file = "./data/nacc_test_with_np_cli.csv", cnf_file = "./data/default_conf_new.toml" ) # initialize session state for the text input if it's not already set if 'input_text' not in st.session_state: st.session_state.input_text = "" # Create a form for user input with st.form("json_input_form"): st.write("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"Random NACC Example\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"Predict\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.") json_input = st.text_area( "Please enter JSON-formatted input features:", value = st.session_state.input_text, height = 250 ) # create three columns left_col, middle_col, right_col = st.columns([3, 4, 1]) with left_col: sample_button = st.form_submit_button("Random NACC Case") with right_col: submit_button = st.form_submit_button("Predict") if sample_button: idx = random.randint(0, len(dat_tst) - 1) example = dat_tst[idx][0] st.session_state.input_text = json.dumps(example) # reset input text after form processing to show updated text in the input box if 'input_text' in st.session_state: st.experimental_rerun() elif submit_button: try: # Parse the JSON input into a Python dictionary data_dict = json.loads(json_input) pred_dict = predict_proba(data_dict) st.write("Predicted probabilities:") st.json(pred_dict) except json.JSONDecodeError as e: # Handle JSON parsing errors st.error(f"An error occurred: {e}")