23RAG7 / app.py
cb1716pics's picture
Upload 2 files
75c991a verified
raw
history blame
4.97 kB
import streamlit as st
from generator import generate_response_from_document
from retrieval import retrieve_documents_hybrid,find_query_dataset
from evaluation import calculate_metrics
from data_processing import load_recent_questions, save_recent_questions
import time
import matplotlib.pyplot as plt
# Page Title
st.title("RAG7 - Real World RAG System")
st.markdown(
"""
<style>
.stTextArea textarea {
background-color: white !important;
font-size: 24px !important;
color: black !important;
}
</style>
""",
unsafe_allow_html=True
)
# Initialize session state
if "recent_questions" not in st.session_state:
st.session_state.recent_questions = [] #load_recent_questions()
if "last_question" not in st.session_state:
st.session_state.last_question = None
if "response_time" not in st.session_state:
st.session_state.response_time = None
if "retrieved_documents" not in st.session_state:
st.session_state.retrieved_documents = None
if "response" not in st.session_state:
st.session_state.response = None
if "metrics" not in st.session_state:
st.session_state.metrics = None
if st.session_state.recent_questions:
recent_qns = list(reversed(st.session_state.recent_questions))
print(recent_qns)
# Display Recent Questions
st.sidebar.title("Overall RMSE")
rmse_values = [q["metrics"]["RMSE"] for q in recent_qns if "metrics" in q and "RMSE" in q["metrics"]]
if any(rmse_values) and len(rmse_values) > 0:
average_rmse = sum(rmse_values) / len(rmse_values) if rmse_values else 0
st.sidebar.write(f"📊 **Average RMSE:** {average_rmse:.4f} for {len(rmse_values)} questions")
st.sidebar.markdown("---")
st.sidebar.title("Analytics")
# Extract response times and labels
response_time = [q.get('metrics').get('response_time') for q in recent_qns]
labels = [f"Q{i+1}" for i in range(len(response_time))]
# Plot graph
if any(response_time):
fig, ax = plt.subplots()
ax.plot(labels, response_time, marker="o", linestyle="-", color="skyblue")
ax.set_xlabel("Recent Questions")
ax.set_ylabel("Time Taken for Response (seconds)")
ax.set_title("Response Time Analysis")
st.sidebar.pyplot(fig)
st.sidebar.markdown("---")
# Display Recent Questions
st.sidebar.title("Recent Questions")
for q in recent_qns: # Show latest first
st.sidebar.write(f"🔹 {q['question']}")
else:
st.sidebar.title("No recent questions")
# Question Section
st.subheader("Hi, What do you want to know today?")
question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
question = question.strip()
if st.button("Submit"):
if question:
st.session_state.last_question = question
start_time = time.time()
st.session_state.metrics = {}
st.session_state.response = ""
st.session_state.query_dataset = find_query_dataset(question)
st.session_state.retrieved_documents = retrieve_documents_hybrid(question, st.session_state.query_dataset, 5)
st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
end_time = time.time()
st.session_state.time_taken_for_response = end_time - start_time
# Check if question already exists
existing_questions = [q["question"] for q in st.session_state.recent_questions]
if question not in existing_questions:
new_entry = {
"question": question,
"metrics": st.session_state.metrics
}
st.session_state.recent_questions.append(new_entry)
#save_recent_questions(st.session_state.recent_questions)
else:
st.error("Please enter a question before submitting.")
# Display stored response
st.subheader("Response")
st.text_area("Generated Response:", value=st.session_state.response, height=150, disabled=True)
col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
# Calculate Metrics Button
with col1:
if st.button("Show Metrics"):
st.session_state.metrics = calculate_metrics(question, st.session_state.query_dataset, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
for q in st.session_state.recent_questions:
if q["question"] == st.session_state.last_question:
q["metrics"] = st.session_state.metrics
# Save updated data to file
#save_recent_questions(st.session_state.recent_questions)
with col2:
if st.session_state.metrics is not None:
metrics_ = st.session_state.metrics
else:
metrics_ ={}
st.json(metrics_)