import os import streamlit as st from langchain_community.graphs import Neo4jGraph import pandas as pd import json import time from ki_gen.planner import build_planner_graph # Update import path if init_app moved or args changed from ki_gen.utils import init_app, memory, ConfigSchema, State # Import necessary types from ki_gen.prompts import get_initial_prompt from neo4j import GraphDatabase # Set page config st.set_page_config(page_title="Key Issue Generator", layout="wide") # Neo4j Database Configuration NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io" NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = os.getenv("neo4j_password") # API Keys for LLM services OPENAI_API_KEY = os.getenv("openai_api_key") # GROQ_API_KEY is removed as we switch to Gemini # GROQ_API_KEY = os.getenv("groq_api_key") # Ensure Gemini API key is available in the environment GEMINI_API_KEY = os.getenv("gemini_api_key") LANGSMITH_API_KEY = os.getenv("langsmith_api_key") def verify_neo4j_connectivity(): """Verify connection to Neo4j database""" try: # Ensure driver closes properly driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) driver.verify_connectivity() driver.close() # Explicitly close the driver return True # Return simple boolean except Exception as e: return f"Error: {str(e)}" # Update load_config defaults def load_config() -> ConfigSchema: # Add type hint """Load configuration with custom parameters""" # Custom configuration based on provided parameters # Update default models to gemini-2.0-flash custom_config = { "main_llm": "gemini-2.0-flash", "plan_method": "generation", "use_detailed_query": False, "cypher_gen_method": "guided", "validate_cypher": False, "summarize_model": "gemini-2.0-flash", "eval_method": "binary", "eval_threshold": 0.7, "max_docs": 15, "compression_method": "llm_lingua", "compress_rate": 0.33, "force_tokens": ["."], # Converting to list format as expected by the application "eval_model": "gemini-2.0-flash", "thread_id": "3" # Consider making thread_id dynamic or user-specific } # Add Neo4j graph object to config neo_graph = None # Initialize to None try: # Check connectivity before creating graph object potentially? if verify_neo4j_connectivity() is True: neo_graph = Neo4jGraph( url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD ) custom_config["graph"] = neo_graph else: st.error(f"Neo4j connection issue: {verify_neo4j_connectivity()}") # Return None or raise error if graph is essential return None except Exception as e: st.error(f"Error creating Neo4jGraph object: {e}") return None # Return wrapped in 'configurable' key as expected by LangGraph return {"configurable": custom_config} def generate_key_issues(user_query): """Main function to generate key issues from Neo4j data""" # Initialize application with API keys (remove groq_key) init_app( openai_key=OPENAI_API_KEY, # groq_key=GROQ_API_KEY, # Remove Groq key langsmith_key=LANGSMITH_API_KEY ) # Load configuration with custom parameters config = load_config() if not config or "configurable" not in config or not config["configurable"].get("graph"): st.error("Failed to load configuration or connect to Neo4j. Cannot proceed.") return None, [] # Create status containers plan_status = st.empty() plan_display = st.empty() retrieval_status = st.empty() processing_status = st.empty() # Build planner graph plan_status.info("Building planner graph...") # Pass the full config dictionary to build_planner_graph graph = build_planner_graph(memory, config) # Execute initial prompt generation plan_status.info(f"Generating plan for query: {user_query}") messages_content = [] initial_prompt_data = get_initial_prompt(config, user_query) # Stream initial plan generation try: for event in graph.stream(initial_prompt_data, config, stream_mode="values"): if "messages" in event and event["messages"]: event["messages"][-1].pretty_print() messages_content.append(event["messages"][-1].content) # Add checks for specific nodes if needed for status updates # if "__start__" in event: # Example check # plan_status.info("Starting plan generation...") except Exception as e: st.error(f"Error during initial graph stream: {e}") return None, [] # Get the state with the generated plan (after initial stream/interrupt) try: # Ensure thread_id matches what's used internally if applicable state = graph.get_state(config) # Check if 'store_plan' exists and is a list stored_plan = state.values.get('store_plan', []) if isinstance(stored_plan, list) and stored_plan: steps = [i for i in range(1, len(stored_plan)+1)] plan_df = pd.DataFrame({'Plan steps': steps, 'Description': stored_plan}) plan_status.success("Plan generation complete!") plan_display.dataframe(plan_df, use_container_width=True) else: plan_status.warning("Plan not found or empty in graph state after generation.") plan_display.empty() # Clear display if no plan except Exception as e: st.error(f"Error getting graph state or displaying plan: {e}") return None, [] # Continue with plan execution for document retrieval # This part assumes the graph will continue after the first interrupt retrieval_status.info("Retrieving documents...") try: # Stream from the current state (None indicates continue) for event in graph.stream(None, config, stream_mode="values"): if "messages" in event and event["messages"]: event["messages"][-1].pretty_print() messages_content.append(event["messages"][-1].content) # Add checks for nodes like 'human_validation' if needed for status except Exception as e: st.error(f"Error during document retrieval stream: {e}") return None, [] # Get updated state after document retrieval interrupt try: snapshot = graph.get_state(config) valid_docs_retrieved = snapshot.values.get('valid_docs', []) doc_count = len(valid_docs_retrieved) if isinstance(valid_docs_retrieved, list) else 0 retrieval_status.success(f"Retrieved {doc_count} documents") # --- Human Validation / Processing Steps --- # This section needs interaction logic if manual validation is desired. # For now, setting default processing steps and marking as validated. processing_status.info("Processing documents...") process_steps = ["summarize"] # Default: just summarize # Update state to indicate human validation is complete and specify processing steps # This should happen *before* the next stream call that triggers processing graph.update_state(config, {'human_validated': True, 'process_steps': process_steps}) except Exception as e: st.error(f"Error getting state after retrieval or setting up processing: {e}") return None, [] # Continue execution with document processing try: for event in graph.stream(None, config, stream_mode="values"): if "messages" in event and event["messages"]: event["messages"][-1].pretty_print() messages_content.append(event["messages"][-1].content) # Check for the end node or final chatbot node if needed except Exception as e: st.error(f"Error during document processing stream: {e}") return None, [] # Get final state after processing try: final_snapshot = graph.get_state(config) processing_status.success("Document processing complete!") # Extract final result and documents final_result = None valid_docs_final = [] if "messages" in final_snapshot.values and final_snapshot.values["messages"]: # Assume the last message contains the final result final_result = final_snapshot.values["messages"][-1].content # Get the final state of valid_docs (might be processed summaries) valid_docs_final = final_snapshot.values.get('valid_docs', []) if not isinstance(valid_docs_final, list): # Ensure it's a list valid_docs_final = [] return final_result, valid_docs_final except Exception as e: st.error(f"Error getting final state or extracting results: {e}") return None, [] # App header st.title("Key Issue Generator") st.write("Generate key issues from a Neo4j knowledge graph using advanced language models.") # Check database connectivity connectivity_status = verify_neo4j_connectivity() st.sidebar.header("Database Status") # Use boolean check if connectivity_status is True: st.sidebar.success("Connected to Neo4j database") else: # Display the error message returned st.sidebar.error(f"Database connection issue: {connectivity_status}") # User input section st.header("Enter Your Query") user_query = st.text_area("What would you like to explore?", "What are the main challenges in AI adoption for healthcare systems?", height=100) # Process button if st.button("Generate Key Issues", type="primary"): # Update API key check for Gemini if not OPENAI_API_KEY or not GEMINI_API_KEY or not LANGSMITH_API_KEY or not NEO4J_PASSWORD: st.error("Required API keys (OpenAI, Gemini, Langsmith) or database credentials are missing. Please check your environment variables.") elif connectivity_status is not True: # Check DB connection again before starting st.error(f"Cannot start: Neo4j connection issue: {connectivity_status}") else: with st.spinner("Processing your query..."): start_time = time.time() # Call the main generation function final_result, valid_docs = generate_key_issues(user_query) end_time = time.time() if final_result is not None: # Check if result is not None (indicating success) # Display execution time st.sidebar.info(f"Total execution time: {round(end_time - start_time, 2)} seconds") # Display final result st.header("Generated Key Issues") st.markdown(final_result) # Option to download results st.download_button( label="Download Results", data=final_result, # Ensure final_result is string data file_name="key_issues_results.txt", mime="text/plain" ) # Display retrieved/processed documents in expandable section if valid_docs: with st.expander("View Processed Documents"): # Update title for i, doc in enumerate(valid_docs): st.markdown(f"### Document {i+1}") # Handle doc format (could be string summary or original dict) if isinstance(doc, dict): for key in doc: st.markdown(f"**{key}**: {doc[key]}") elif isinstance(doc, str): st.markdown(doc) # Display string directly if it's a summary else: st.markdown(str(doc)) # Fallback for other types st.divider() else: # Error messages are now shown within generate_key_issues # st.error("An error occurred during processing. Please check the logs or console output for details.") # Adding a placeholder here in case specific errors weren't caught if final_result is None: # Check explicit None return st.error("Processing failed. Please check the console/logs for errors.") # Help information in sidebar with st.sidebar: st.header("About") st.info(""" This application uses advanced language models (like Google Gemini) to analyze a Neo4j knowledge graph and generate key issues based on your query. The process involves: 1. Creating a plan based on your query 2. Retrieving relevant documents from the database 3. Processing and summarizing the information 4. Generating a comprehensive response """)