cb1716pics commited on
Commit
f78495c
·
verified ·
1 Parent(s): c053f96

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +42 -11
  2. retrieval.py +11 -1
app.py CHANGED
@@ -27,28 +27,59 @@ time_taken_for_response = 'N/A'
27
  st.subheader("Hi, What do you want to know today?")
28
  question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Submit Button
31
  if st.button("Submit"):
32
  start_time = time.time()
33
- retrieved_documents = retrieve_documents_hybrid(question, 10)
34
- response = generate_response_from_document(question, retrieved_documents)
35
  end_time = time.time()
36
- time_taken_for_response = end_time-start_time
37
- else:
38
- response = ""
39
 
40
- # Response Section
41
  st.subheader("Response")
42
- st.text_area("Generated Response:", value=response, height=150, disabled=True)
43
-
44
- # Metrics Section
45
- st.subheader("Metrics")
46
 
47
  col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
48
 
 
49
  with col1:
50
  if st.button("Calculate Metrics"):
51
- metrics = calculate_metrics(question, response, retrieved_documents, time_taken_for_response)
52
  else:
53
  metrics = ""
54
 
 
27
  st.subheader("Hi, What do you want to know today?")
28
  question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
29
 
30
+ # # Submit Button
31
+ # if st.button("Submit"):
32
+ # start_time = time.time()
33
+ # retrieved_documents = retrieve_documents_hybrid(question, 10)
34
+ # response = generate_response_from_document(question, retrieved_documents)
35
+ # end_time = time.time()
36
+ # time_taken_for_response = end_time-start_time
37
+ # else:
38
+ # response = ""
39
+
40
+ # # Response Section
41
+ # st.subheader("Response")
42
+ # st.text_area("Generated Response:", value=response, height=150, disabled=True)
43
+
44
+ # # Metrics Section
45
+ # st.subheader("Metrics")
46
+
47
+ # col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
48
+
49
+ # with col1:
50
+ # if st.button("Calculate Metrics"):
51
+ # metrics = calculate_metrics(question, response, retrieved_documents, time_taken_for_response)
52
+ # else:
53
+ # metrics = ""
54
+
55
+ # with col2:
56
+ # st.text_area("Metrics:", value=metrics, height=100, disabled=True)
57
+
58
+ if "retrieved_documents" not in st.session_state:
59
+ st.session_state.retrieved_documents = []
60
+ if "response" not in st.session_state:
61
+ st.session_state.response = ""
62
+ if "time_taken_for_response" not in st.session_state:
63
+ st.session_state.time_taken_for_response = "N/A"
64
+
65
  # Submit Button
66
  if st.button("Submit"):
67
  start_time = time.time()
68
+ st.session_state.retrieved_documents = retrieve_documents_hybrid(question, 10)
69
+ st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
70
  end_time = time.time()
71
+ st.session_state.time_taken_for_response = end_time - start_time
 
 
72
 
73
+ # Display stored response
74
  st.subheader("Response")
75
+ st.text_area("Generated Response:", value=st.session_state.response, height=150, disabled=True)
 
 
 
76
 
77
  col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
78
 
79
+ # Calculate Metrics Button
80
  with col1:
81
  if st.button("Calculate Metrics"):
82
+ metrics = calculate_metrics(question, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
83
  else:
84
  metrics = ""
85
 
retrieval.py CHANGED
@@ -4,6 +4,9 @@ from langchain.schema import Document
4
  import faiss
5
  from rank_bm25 import BM25Okapi
6
  from data_processing import embedding_model #, index, actual_docs
 
 
 
7
 
8
  retrieved_docs = None
9
 
@@ -36,8 +39,10 @@ def retrieve_documents_hybrid(query, top_k=5):
36
 
37
  # Merge FAISS + BM25 Results
38
  retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
 
 
39
 
40
- return retrieved_docs
41
 
42
  # Retrieval Function
43
  def retrieve_documents(query, top_k=5):
@@ -80,3 +85,8 @@ def find_query_dataset(query):
80
  best_dataset = dataset_names[nearest_index[0][0]]
81
  return best_dataset
82
 
 
 
 
 
 
 
4
  import faiss
5
  from rank_bm25 import BM25Okapi
6
  from data_processing import embedding_model #, index, actual_docs
7
+ from sentence_transformers import CrossEncoder
8
+
9
+ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
10
 
11
  retrieved_docs = None
12
 
 
39
 
40
  # Merge FAISS + BM25 Results
41
  retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
42
+
43
+ reranked_docs = rerank_documents(query, retrieved_docs)
44
 
45
+ return reranked_docs
46
 
47
  # Retrieval Function
48
  def retrieve_documents(query, top_k=5):
 
85
  best_dataset = dataset_names[nearest_index[0][0]]
86
  return best_dataset
87
 
88
+ def rerank_documents(query, retrieved_docs):
89
+ doc_texts = [doc for doc in retrieved_docs]
90
+ scores = reranker.predict([[query, doc] for doc in doc_texts])
91
+ ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
92
+ return ranked_docs[:5] # Return top k most relevant