Spaces:
Build error
Build error
Update src/chat.py
Browse files- src/chat.py +33 -0
src/chat.py
CHANGED
@@ -16,6 +16,9 @@ class SchoolChatbot:
|
|
16 |
"""
|
17 |
model_id = MY_MODEL if MY_MODEL else BASE_MODEL # define MY_MODEL in config.py if you create a new model in the HuggingFace Hub
|
18 |
self.client = InferenceClient(model=model_id, token=HF_TOKEN)
|
|
|
|
|
|
|
19 |
|
20 |
def format_prompt(self, user_input):
|
21 |
"""
|
@@ -42,6 +45,25 @@ class SchoolChatbot:
|
|
42 |
f"<|user|>{user_input}<|end|>\n"
|
43 |
"<|assistant|>"
|
44 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def get_response(self, user_input):
|
47 |
"""
|
@@ -62,6 +84,17 @@ class SchoolChatbot:
|
|
62 |
- Use self.format_prompt() to format the user's input
|
63 |
- Use self.client to generate responses
|
64 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
prompt = self.format_prompt(user_input)
|
66 |
response = self.client.text_generation(
|
67 |
prompt,
|
|
|
16 |
"""
|
17 |
model_id = MY_MODEL if MY_MODEL else BASE_MODEL # define MY_MODEL in config.py if you create a new model in the HuggingFace Hub
|
18 |
self.client = InferenceClient(model=model_id, token=HF_TOKEN)
|
19 |
+
self.df = pd.read_csv("bps_data.csv")
|
20 |
+
with open("keyword_to_column_map.json") as f:
|
21 |
+
self.keyword_map = json.load(f)
|
22 |
|
23 |
def format_prompt(self, user_input):
|
24 |
"""
|
|
|
45 |
f"<|user|>{user_input}<|end|>\n"
|
46 |
"<|assistant|>"
|
47 |
)
|
48 |
+
|
49 |
+
def lookup_structured_data(self, query, school_name=None):
|
50 |
+
"""Search the structured BPS dataset for relevant information"""
|
51 |
+
results = []
|
52 |
+
df_filtered = self.df
|
53 |
+
|
54 |
+
if school_name:
|
55 |
+
df_filtered = self.df[self.df["BPS_School_Name"].str.contains(school_name, case=False, na=False)]
|
56 |
+
|
57 |
+
if df_filtered.empty:
|
58 |
+
return None
|
59 |
+
|
60 |
+
row = df_filtered.iloc[0]
|
61 |
+
for key, col in self.keyword_map.items():
|
62 |
+
if key in query.lower():
|
63 |
+
val = row.get(col, "N/A")
|
64 |
+
if pd.notna(val):
|
65 |
+
results.append(f"{key.title()}: {val}")
|
66 |
+
return "\n".join(results) if results else None
|
67 |
|
68 |
def get_response(self, user_input):
|
69 |
"""
|
|
|
84 |
- Use self.format_prompt() to format the user's input
|
85 |
- Use self.client to generate responses
|
86 |
"""
|
87 |
+
matched_school = None
|
88 |
+
for name in self.df["BPS_School_Name"].dropna():
|
89 |
+
if name.lower() in user_input.lower():
|
90 |
+
matched_school = name
|
91 |
+
break
|
92 |
+
|
93 |
+
structured_response = self.lookup_structured_data(user_input, matched_school)
|
94 |
+
|
95 |
+
if structured_response:
|
96 |
+
return f"Here’s what I found based on school data:\n{structured_response}"
|
97 |
+
|
98 |
prompt = self.format_prompt(user_input)
|
99 |
response = self.client.text_generation(
|
100 |
prompt,
|