Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from langchain_community.graphs import Neo4jGraph | |
import pandas as pd | |
import json | |
import time | |
from ki_gen.planner import build_planner_graph | |
# Update import path if init_app moved or args changed | |
from ki_gen.utils import init_app, memory, ConfigSchema, State # Import necessary types | |
from ki_gen.prompts import get_initial_prompt | |
from neo4j import GraphDatabase | |
# Set page config | |
st.set_page_config(page_title="Key Issue Generator", layout="wide") | |
# Neo4j Database Configuration | |
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io" | |
NEO4J_USERNAME = "neo4j" | |
NEO4J_PASSWORD = os.getenv("neo4j_password") | |
# API Keys for LLM services | |
OPENAI_API_KEY = os.getenv("openai_api_key") | |
# GROQ_API_KEY is removed as we switch to Gemini | |
# GROQ_API_KEY = os.getenv("groq_api_key") | |
# Ensure Gemini API key is available in the environment | |
GEMINI_API_KEY = os.getenv("gemini_api_key") | |
LANGSMITH_API_KEY = os.getenv("langsmith_api_key") | |
def verify_neo4j_connectivity(): | |
"""Verify connection to Neo4j database""" | |
try: | |
# Ensure driver closes properly | |
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) | |
driver.verify_connectivity() | |
driver.close() # Explicitly close the driver | |
return True # Return simple boolean | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Update load_config defaults | |
def load_config() -> ConfigSchema: # Add type hint | |
"""Load configuration with custom parameters""" | |
# Custom configuration based on provided parameters | |
# Update default models to gemini-2.0-flash | |
custom_config = { | |
"main_llm": "gemini-2.0-flash", | |
"plan_method": "generation", | |
"use_detailed_query": False, | |
"cypher_gen_method": "guided", | |
"validate_cypher": False, | |
"summarize_model": "gemini-2.0-flash", | |
"eval_method": "binary", | |
"eval_threshold": 0.7, | |
"max_docs": 15, | |
"compression_method": "llm_lingua", | |
"compress_rate": 0.33, | |
"force_tokens": ["."], # Converting to list format as expected by the application | |
"eval_model": "gemini-2.0-flash", | |
"thread_id": "3" # Consider making thread_id dynamic or user-specific | |
} | |
# Add Neo4j graph object to config | |
neo_graph = None # Initialize to None | |
try: | |
# Check connectivity before creating graph object potentially? | |
if verify_neo4j_connectivity() is True: | |
neo_graph = Neo4jGraph( | |
url=NEO4J_URI, | |
username=NEO4J_USERNAME, | |
password=NEO4J_PASSWORD | |
) | |
custom_config["graph"] = neo_graph | |
else: | |
st.error(f"Neo4j connection issue: {verify_neo4j_connectivity()}") | |
# Return None or raise error if graph is essential | |
return None | |
except Exception as e: | |
st.error(f"Error creating Neo4jGraph object: {e}") | |
return None | |
# Return wrapped in 'configurable' key as expected by LangGraph | |
return {"configurable": custom_config} | |
def generate_key_issues(user_query): | |
"""Main function to generate key issues from Neo4j data""" | |
# Initialize application with API keys (remove groq_key) | |
init_app( | |
openai_key=OPENAI_API_KEY, | |
# groq_key=GROQ_API_KEY, # Remove Groq key | |
langsmith_key=LANGSMITH_API_KEY | |
) | |
# Load configuration with custom parameters | |
config = load_config() | |
if not config or "configurable" not in config or not config["configurable"].get("graph"): | |
st.error("Failed to load configuration or connect to Neo4j. Cannot proceed.") | |
return None, [] | |
# Create status containers | |
plan_status = st.empty() | |
plan_display = st.empty() | |
retrieval_status = st.empty() | |
processing_status = st.empty() | |
# Build planner graph | |
plan_status.info("Building planner graph...") | |
# Pass the full config dictionary to build_planner_graph | |
graph = build_planner_graph(memory, config) | |
# Execute initial prompt generation | |
plan_status.info(f"Generating plan for query: {user_query}") | |
messages_content = [] | |
initial_prompt_data = get_initial_prompt(config, user_query) | |
# Stream initial plan generation | |
try: | |
for event in graph.stream(initial_prompt_data, config, stream_mode="values"): | |
if "messages" in event and event["messages"]: | |
event["messages"][-1].pretty_print() | |
messages_content.append(event["messages"][-1].content) | |
# Add checks for specific nodes if needed for status updates | |
# if "__start__" in event: # Example check | |
# plan_status.info("Starting plan generation...") | |
except Exception as e: | |
st.error(f"Error during initial graph stream: {e}") | |
return None, [] | |
# Get the state with the generated plan (after initial stream/interrupt) | |
try: | |
# Ensure thread_id matches what's used internally if applicable | |
state = graph.get_state(config) | |
# Check if 'store_plan' exists and is a list | |
stored_plan = state.values.get('store_plan', []) | |
if isinstance(stored_plan, list) and stored_plan: | |
steps = [i for i in range(1, len(stored_plan)+1)] | |
plan_df = pd.DataFrame({'Plan steps': steps, 'Description': stored_plan}) | |
plan_status.success("Plan generation complete!") | |
plan_display.dataframe(plan_df, use_container_width=True) | |
else: | |
plan_status.warning("Plan not found or empty in graph state after generation.") | |
plan_display.empty() # Clear display if no plan | |
except Exception as e: | |
st.error(f"Error getting graph state or displaying plan: {e}") | |
return None, [] | |
# Continue with plan execution for document retrieval | |
# This part assumes the graph will continue after the first interrupt | |
retrieval_status.info("Retrieving documents...") | |
try: | |
# Stream from the current state (None indicates continue) | |
for event in graph.stream(None, config, stream_mode="values"): | |
if "messages" in event and event["messages"]: | |
event["messages"][-1].pretty_print() | |
messages_content.append(event["messages"][-1].content) | |
# Add checks for nodes like 'human_validation' if needed for status | |
except Exception as e: | |
st.error(f"Error during document retrieval stream: {e}") | |
return None, [] | |
# Get updated state after document retrieval interrupt | |
try: | |
snapshot = graph.get_state(config) | |
valid_docs_retrieved = snapshot.values.get('valid_docs', []) | |
doc_count = len(valid_docs_retrieved) if isinstance(valid_docs_retrieved, list) else 0 | |
retrieval_status.success(f"Retrieved {doc_count} documents") | |
# --- Human Validation / Processing Steps --- | |
# This section needs interaction logic if manual validation is desired. | |
# For now, setting default processing steps and marking as validated. | |
processing_status.info("Processing documents...") | |
process_steps = ["summarize"] # Default: just summarize | |
# Update state to indicate human validation is complete and specify processing steps | |
# This should happen *before* the next stream call that triggers processing | |
graph.update_state(config, {'human_validated': True, 'process_steps': process_steps}) | |
except Exception as e: | |
st.error(f"Error getting state after retrieval or setting up processing: {e}") | |
return None, [] | |
# Continue execution with document processing | |
try: | |
for event in graph.stream(None, config, stream_mode="values"): | |
if "messages" in event and event["messages"]: | |
event["messages"][-1].pretty_print() | |
messages_content.append(event["messages"][-1].content) | |
# Check for the end node or final chatbot node if needed | |
except Exception as e: | |
st.error(f"Error during document processing stream: {e}") | |
return None, [] | |
# Get final state after processing | |
try: | |
final_snapshot = graph.get_state(config) | |
processing_status.success("Document processing complete!") | |
# Extract final result and documents | |
final_result = None | |
valid_docs_final = [] | |
if "messages" in final_snapshot.values and final_snapshot.values["messages"]: | |
# Assume the last message contains the final result | |
final_result = final_snapshot.values["messages"][-1].content | |
# Get the final state of valid_docs (might be processed summaries) | |
valid_docs_final = final_snapshot.values.get('valid_docs', []) | |
if not isinstance(valid_docs_final, list): # Ensure it's a list | |
valid_docs_final = [] | |
return final_result, valid_docs_final | |
except Exception as e: | |
st.error(f"Error getting final state or extracting results: {e}") | |
return None, [] | |
# App header | |
st.title("Key Issue Generator") | |
st.write("Generate key issues from a Neo4j knowledge graph using advanced language models.") | |
# Check database connectivity | |
connectivity_status = verify_neo4j_connectivity() | |
st.sidebar.header("Database Status") | |
# Use boolean check | |
if connectivity_status is True: | |
st.sidebar.success("Connected to Neo4j database") | |
else: | |
# Display the error message returned | |
st.sidebar.error(f"Database connection issue: {connectivity_status}") | |
# User input section | |
st.header("Enter Your Query") | |
user_query = st.text_area("What would you like to explore?", | |
"What are the main challenges in AI adoption for healthcare systems?", | |
height=100) | |
# Process button | |
if st.button("Generate Key Issues", type="primary"): | |
# Update API key check for Gemini | |
if not OPENAI_API_KEY or not GEMINI_API_KEY or not LANGSMITH_API_KEY or not NEO4J_PASSWORD: | |
st.error("Required API keys (OpenAI, Gemini, Langsmith) or database credentials are missing. Please check your environment variables.") | |
elif connectivity_status is not True: # Check DB connection again before starting | |
st.error(f"Cannot start: Neo4j connection issue: {connectivity_status}") | |
else: | |
with st.spinner("Processing your query..."): | |
start_time = time.time() | |
# Call the main generation function | |
final_result, valid_docs = generate_key_issues(user_query) | |
end_time = time.time() | |
if final_result is not None: # Check if result is not None (indicating success) | |
# Display execution time | |
st.sidebar.info(f"Total execution time: {round(end_time - start_time, 2)} seconds") | |
# Display final result | |
st.header("Generated Key Issues") | |
st.markdown(final_result) | |
# Option to download results | |
st.download_button( | |
label="Download Results", | |
data=final_result, # Ensure final_result is string data | |
file_name="key_issues_results.txt", | |
mime="text/plain" | |
) | |
# Display retrieved/processed documents in expandable section | |
if valid_docs: | |
with st.expander("View Processed Documents"): # Update title | |
for i, doc in enumerate(valid_docs): | |
st.markdown(f"### Document {i+1}") | |
# Handle doc format (could be string summary or original dict) | |
if isinstance(doc, dict): | |
for key in doc: | |
st.markdown(f"**{key}**: {doc[key]}") | |
elif isinstance(doc, str): | |
st.markdown(doc) # Display string directly if it's a summary | |
else: | |
st.markdown(str(doc)) # Fallback for other types | |
st.divider() | |
else: | |
# Error messages are now shown within generate_key_issues | |
# st.error("An error occurred during processing. Please check the logs or console output for details.") | |
# Adding a placeholder here in case specific errors weren't caught | |
if final_result is None: # Check explicit None return | |
st.error("Processing failed. Please check the console/logs for errors.") | |
# Help information in sidebar | |
with st.sidebar: | |
st.header("About") | |
st.info(""" | |
This application uses advanced language models (like Google Gemini) to analyze a Neo4j knowledge graph | |
and generate key issues based on your query. The process involves: | |
1. Creating a plan based on your query | |
2. Retrieving relevant documents from the database | |
3. Processing and summarizing the information | |
4. Generating a comprehensive response | |
""") |