Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app.py +42 -11
- retrieval.py +11 -1
app.py
CHANGED
@@ -27,28 +27,59 @@ time_taken_for_response = 'N/A'
|
|
27 |
st.subheader("Hi, What do you want to know today?")
|
28 |
question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# Submit Button
|
31 |
if st.button("Submit"):
|
32 |
start_time = time.time()
|
33 |
-
retrieved_documents = retrieve_documents_hybrid(question, 10)
|
34 |
-
response = generate_response_from_document(question, retrieved_documents)
|
35 |
end_time = time.time()
|
36 |
-
time_taken_for_response = end_time-start_time
|
37 |
-
else:
|
38 |
-
response = ""
|
39 |
|
40 |
-
#
|
41 |
st.subheader("Response")
|
42 |
-
st.text_area("Generated Response:", value=response, height=150, disabled=True)
|
43 |
-
|
44 |
-
# Metrics Section
|
45 |
-
st.subheader("Metrics")
|
46 |
|
47 |
col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
|
48 |
|
|
|
49 |
with col1:
|
50 |
if st.button("Calculate Metrics"):
|
51 |
-
metrics = calculate_metrics(question, response, retrieved_documents, time_taken_for_response)
|
52 |
else:
|
53 |
metrics = ""
|
54 |
|
|
|
27 |
st.subheader("Hi, What do you want to know today?")
|
28 |
question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
|
29 |
|
30 |
+
# # Submit Button
|
31 |
+
# if st.button("Submit"):
|
32 |
+
# start_time = time.time()
|
33 |
+
# retrieved_documents = retrieve_documents_hybrid(question, 10)
|
34 |
+
# response = generate_response_from_document(question, retrieved_documents)
|
35 |
+
# end_time = time.time()
|
36 |
+
# time_taken_for_response = end_time-start_time
|
37 |
+
# else:
|
38 |
+
# response = ""
|
39 |
+
|
40 |
+
# # Response Section
|
41 |
+
# st.subheader("Response")
|
42 |
+
# st.text_area("Generated Response:", value=response, height=150, disabled=True)
|
43 |
+
|
44 |
+
# # Metrics Section
|
45 |
+
# st.subheader("Metrics")
|
46 |
+
|
47 |
+
# col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
|
48 |
+
|
49 |
+
# with col1:
|
50 |
+
# if st.button("Calculate Metrics"):
|
51 |
+
# metrics = calculate_metrics(question, response, retrieved_documents, time_taken_for_response)
|
52 |
+
# else:
|
53 |
+
# metrics = ""
|
54 |
+
|
55 |
+
# with col2:
|
56 |
+
# st.text_area("Metrics:", value=metrics, height=100, disabled=True)
|
57 |
+
|
58 |
+
if "retrieved_documents" not in st.session_state:
|
59 |
+
st.session_state.retrieved_documents = []
|
60 |
+
if "response" not in st.session_state:
|
61 |
+
st.session_state.response = ""
|
62 |
+
if "time_taken_for_response" not in st.session_state:
|
63 |
+
st.session_state.time_taken_for_response = "N/A"
|
64 |
+
|
65 |
# Submit Button
|
66 |
if st.button("Submit"):
|
67 |
start_time = time.time()
|
68 |
+
st.session_state.retrieved_documents = retrieve_documents_hybrid(question, 10)
|
69 |
+
st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
|
70 |
end_time = time.time()
|
71 |
+
st.session_state.time_taken_for_response = end_time - start_time
|
|
|
|
|
72 |
|
73 |
+
# Display stored response
|
74 |
st.subheader("Response")
|
75 |
+
st.text_area("Generated Response:", value=st.session_state.response, height=150, disabled=True)
|
|
|
|
|
|
|
76 |
|
77 |
col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
|
78 |
|
79 |
+
# Calculate Metrics Button
|
80 |
with col1:
|
81 |
if st.button("Calculate Metrics"):
|
82 |
+
metrics = calculate_metrics(question, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
|
83 |
else:
|
84 |
metrics = ""
|
85 |
|
retrieval.py
CHANGED
@@ -4,6 +4,9 @@ from langchain.schema import Document
|
|
4 |
import faiss
|
5 |
from rank_bm25 import BM25Okapi
|
6 |
from data_processing import embedding_model #, index, actual_docs
|
|
|
|
|
|
|
7 |
|
8 |
retrieved_docs = None
|
9 |
|
@@ -36,8 +39,10 @@ def retrieve_documents_hybrid(query, top_k=5):
|
|
36 |
|
37 |
# Merge FAISS + BM25 Results
|
38 |
retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
|
|
|
|
|
39 |
|
40 |
-
return
|
41 |
|
42 |
# Retrieval Function
|
43 |
def retrieve_documents(query, top_k=5):
|
@@ -80,3 +85,8 @@ def find_query_dataset(query):
|
|
80 |
best_dataset = dataset_names[nearest_index[0][0]]
|
81 |
return best_dataset
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import faiss
|
5 |
from rank_bm25 import BM25Okapi
|
6 |
from data_processing import embedding_model #, index, actual_docs
|
7 |
+
from sentence_transformers import CrossEncoder
|
8 |
+
|
9 |
+
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
10 |
|
11 |
retrieved_docs = None
|
12 |
|
|
|
39 |
|
40 |
# Merge FAISS + BM25 Results
|
41 |
retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
|
42 |
+
|
43 |
+
reranked_docs = rerank_documents(query, retrieved_docs)
|
44 |
|
45 |
+
return reranked_docs
|
46 |
|
47 |
# Retrieval Function
|
48 |
def retrieve_documents(query, top_k=5):
|
|
|
85 |
best_dataset = dataset_names[nearest_index[0][0]]
|
86 |
return best_dataset
|
87 |
|
88 |
+
def rerank_documents(query, retrieved_docs):
|
89 |
+
doc_texts = [doc for doc in retrieved_docs]
|
90 |
+
scores = reranker.predict([[query, doc] for doc in doc_texts])
|
91 |
+
ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
|
92 |
+
return ranked_docs[:5] # Return top k most relevant
|