kig_test / app.py
adrienbrdne's picture
Update app.py
32d54de verified
import os
import streamlit as st
from langchain_community.graphs import Neo4jGraph
import pandas as pd
import json
import time
from ki_gen.planner import build_planner_graph
# Update import path if init_app moved or args changed
from ki_gen.utils import init_app, memory, ConfigSchema, State # Import necessary types
from ki_gen.prompts import get_initial_prompt
from neo4j import GraphDatabase
# Set page config
st.set_page_config(page_title="Key Issue Generator", layout="wide")
# Neo4j Database Configuration
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = os.getenv("neo4j_password")
# API Keys for LLM services
OPENAI_API_KEY = os.getenv("openai_api_key")
# GROQ_API_KEY is removed as we switch to Gemini
# GROQ_API_KEY = os.getenv("groq_api_key")
# Ensure Gemini API key is available in the environment
GEMINI_API_KEY = os.getenv("gemini_api_key")
LANGSMITH_API_KEY = os.getenv("langsmith_api_key")
def verify_neo4j_connectivity():
"""Verify connection to Neo4j database"""
try:
# Ensure driver closes properly
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
driver.verify_connectivity()
driver.close() # Explicitly close the driver
return True # Return simple boolean
except Exception as e:
return f"Error: {str(e)}"
# Update load_config defaults
def load_config() -> ConfigSchema: # Add type hint
"""Load configuration with custom parameters"""
# Custom configuration based on provided parameters
# Update default models to gemini-2.0-flash
custom_config = {
"main_llm": "gemini-2.0-flash",
"plan_method": "generation",
"use_detailed_query": False,
"cypher_gen_method": "guided",
"validate_cypher": False,
"summarize_model": "gemini-2.0-flash",
"eval_method": "binary",
"eval_threshold": 0.7,
"max_docs": 15,
"compression_method": "llm_lingua",
"compress_rate": 0.33,
"force_tokens": ["."], # Converting to list format as expected by the application
"eval_model": "gemini-2.0-flash",
"thread_id": "3" # Consider making thread_id dynamic or user-specific
}
# Add Neo4j graph object to config
neo_graph = None # Initialize to None
try:
# Check connectivity before creating graph object potentially?
if verify_neo4j_connectivity() is True:
neo_graph = Neo4jGraph(
url=NEO4J_URI,
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD
)
custom_config["graph"] = neo_graph
else:
st.error(f"Neo4j connection issue: {verify_neo4j_connectivity()}")
# Return None or raise error if graph is essential
return None
except Exception as e:
st.error(f"Error creating Neo4jGraph object: {e}")
return None
# Return wrapped in 'configurable' key as expected by LangGraph
return {"configurable": custom_config}
def generate_key_issues(user_query):
"""Main function to generate key issues from Neo4j data"""
# Initialize application with API keys (remove groq_key)
init_app(
openai_key=OPENAI_API_KEY,
# groq_key=GROQ_API_KEY, # Remove Groq key
langsmith_key=LANGSMITH_API_KEY
)
# Load configuration with custom parameters
config = load_config()
if not config or "configurable" not in config or not config["configurable"].get("graph"):
st.error("Failed to load configuration or connect to Neo4j. Cannot proceed.")
return None, []
# Create status containers
plan_status = st.empty()
plan_display = st.empty()
retrieval_status = st.empty()
processing_status = st.empty()
# Build planner graph
plan_status.info("Building planner graph...")
# Pass the full config dictionary to build_planner_graph
graph = build_planner_graph(memory, config)
# Execute initial prompt generation
plan_status.info(f"Generating plan for query: {user_query}")
messages_content = []
initial_prompt_data = get_initial_prompt(config, user_query)
# Stream initial plan generation
try:
for event in graph.stream(initial_prompt_data, config, stream_mode="values"):
if "messages" in event and event["messages"]:
event["messages"][-1].pretty_print()
messages_content.append(event["messages"][-1].content)
# Add checks for specific nodes if needed for status updates
# if "__start__" in event: # Example check
# plan_status.info("Starting plan generation...")
except Exception as e:
st.error(f"Error during initial graph stream: {e}")
return None, []
# Get the state with the generated plan (after initial stream/interrupt)
try:
# Ensure thread_id matches what's used internally if applicable
state = graph.get_state(config)
# Check if 'store_plan' exists and is a list
stored_plan = state.values.get('store_plan', [])
if isinstance(stored_plan, list) and stored_plan:
steps = [i for i in range(1, len(stored_plan)+1)]
plan_df = pd.DataFrame({'Plan steps': steps, 'Description': stored_plan})
plan_status.success("Plan generation complete!")
plan_display.dataframe(plan_df, use_container_width=True)
else:
plan_status.warning("Plan not found or empty in graph state after generation.")
plan_display.empty() # Clear display if no plan
except Exception as e:
st.error(f"Error getting graph state or displaying plan: {e}")
return None, []
# Continue with plan execution for document retrieval
# This part assumes the graph will continue after the first interrupt
retrieval_status.info("Retrieving documents...")
try:
# Stream from the current state (None indicates continue)
for event in graph.stream(None, config, stream_mode="values"):
if "messages" in event and event["messages"]:
event["messages"][-1].pretty_print()
messages_content.append(event["messages"][-1].content)
# Add checks for nodes like 'human_validation' if needed for status
except Exception as e:
st.error(f"Error during document retrieval stream: {e}")
return None, []
# Get updated state after document retrieval interrupt
try:
snapshot = graph.get_state(config)
valid_docs_retrieved = snapshot.values.get('valid_docs', [])
doc_count = len(valid_docs_retrieved) if isinstance(valid_docs_retrieved, list) else 0
retrieval_status.success(f"Retrieved {doc_count} documents")
# --- Human Validation / Processing Steps ---
# This section needs interaction logic if manual validation is desired.
# For now, setting default processing steps and marking as validated.
processing_status.info("Processing documents...")
process_steps = ["summarize"] # Default: just summarize
# Update state to indicate human validation is complete and specify processing steps
# This should happen *before* the next stream call that triggers processing
graph.update_state(config, {'human_validated': True, 'process_steps': process_steps})
except Exception as e:
st.error(f"Error getting state after retrieval or setting up processing: {e}")
return None, []
# Continue execution with document processing
try:
for event in graph.stream(None, config, stream_mode="values"):
if "messages" in event and event["messages"]:
event["messages"][-1].pretty_print()
messages_content.append(event["messages"][-1].content)
# Check for the end node or final chatbot node if needed
except Exception as e:
st.error(f"Error during document processing stream: {e}")
return None, []
# Get final state after processing
try:
final_snapshot = graph.get_state(config)
processing_status.success("Document processing complete!")
# Extract final result and documents
final_result = None
valid_docs_final = []
if "messages" in final_snapshot.values and final_snapshot.values["messages"]:
# Assume the last message contains the final result
final_result = final_snapshot.values["messages"][-1].content
# Get the final state of valid_docs (might be processed summaries)
valid_docs_final = final_snapshot.values.get('valid_docs', [])
if not isinstance(valid_docs_final, list): # Ensure it's a list
valid_docs_final = []
return final_result, valid_docs_final
except Exception as e:
st.error(f"Error getting final state or extracting results: {e}")
return None, []
# App header
st.title("Key Issue Generator")
st.write("Generate key issues from a Neo4j knowledge graph using advanced language models.")
# Check database connectivity
connectivity_status = verify_neo4j_connectivity()
st.sidebar.header("Database Status")
# Use boolean check
if connectivity_status is True:
st.sidebar.success("Connected to Neo4j database")
else:
# Display the error message returned
st.sidebar.error(f"Database connection issue: {connectivity_status}")
# User input section
st.header("Enter Your Query")
user_query = st.text_area("What would you like to explore?",
"What are the main challenges in AI adoption for healthcare systems?",
height=100)
# Process button
if st.button("Generate Key Issues", type="primary"):
# Update API key check for Gemini
if not OPENAI_API_KEY or not GEMINI_API_KEY or not LANGSMITH_API_KEY or not NEO4J_PASSWORD:
st.error("Required API keys (OpenAI, Gemini, Langsmith) or database credentials are missing. Please check your environment variables.")
elif connectivity_status is not True: # Check DB connection again before starting
st.error(f"Cannot start: Neo4j connection issue: {connectivity_status}")
else:
with st.spinner("Processing your query..."):
start_time = time.time()
# Call the main generation function
final_result, valid_docs = generate_key_issues(user_query)
end_time = time.time()
if final_result is not None: # Check if result is not None (indicating success)
# Display execution time
st.sidebar.info(f"Total execution time: {round(end_time - start_time, 2)} seconds")
# Display final result
st.header("Generated Key Issues")
st.markdown(final_result)
# Option to download results
st.download_button(
label="Download Results",
data=final_result, # Ensure final_result is string data
file_name="key_issues_results.txt",
mime="text/plain"
)
# Display retrieved/processed documents in expandable section
if valid_docs:
with st.expander("View Processed Documents"): # Update title
for i, doc in enumerate(valid_docs):
st.markdown(f"### Document {i+1}")
# Handle doc format (could be string summary or original dict)
if isinstance(doc, dict):
for key in doc:
st.markdown(f"**{key}**: {doc[key]}")
elif isinstance(doc, str):
st.markdown(doc) # Display string directly if it's a summary
else:
st.markdown(str(doc)) # Fallback for other types
st.divider()
else:
# Error messages are now shown within generate_key_issues
# st.error("An error occurred during processing. Please check the logs or console output for details.")
# Adding a placeholder here in case specific errors weren't caught
if final_result is None: # Check explicit None return
st.error("Processing failed. Please check the console/logs for errors.")
# Help information in sidebar
with st.sidebar:
st.header("About")
st.info("""
This application uses advanced language models (like Google Gemini) to analyze a Neo4j knowledge graph
and generate key issues based on your query. The process involves:
1. Creating a plan based on your query
2. Retrieving relevant documents from the database
3. Processing and summarizing the information
4. Generating a comprehensive response
""")