Spaces:
Build error
Build error
Update src/chat.py
Browse files- src/chat.py +39 -7
src/chat.py
CHANGED
@@ -7,18 +7,16 @@ import re
|
|
7 |
from difflib import get_close_matches
|
8 |
|
9 |
class SchoolChatbot:
|
10 |
-
"""
|
11 |
-
A chatbot that integrates structured school data and language generation to assist with Boston Public School queries.
|
12 |
-
"""
|
13 |
|
14 |
def __init__(self):
|
15 |
model_id = MY_MODEL if MY_MODEL else BASE_MODEL
|
16 |
self.client = InferenceClient(model=model_id, token=HF_TOKEN)
|
17 |
self.df = pd.read_csv("bps_data.csv")
|
18 |
-
with open("
|
19 |
self.keyword_map = json.load(f)
|
20 |
|
21 |
-
#
|
22 |
self.school_name_map = {}
|
23 |
for _, row in self.df.iterrows():
|
24 |
primary = row.get("BPS_School_Name")
|
@@ -31,13 +29,12 @@ class SchoolChatbot:
|
|
31 |
if pd.notna(abbrev):
|
32 |
self.school_name_map[abbrev.lower()] = primary
|
33 |
|
34 |
-
# Add custom aliases
|
35 |
self.school_name_map.update({
|
36 |
"acc": "Another Course to College*",
|
37 |
"baldwin": "Baldwin Early Learning Pilot Academy",
|
38 |
"adams elementary": "Adams, Samuel Elementary",
|
39 |
"alighieri montessori": "Alighieri, Dante Montessori School",
|
40 |
-
"phineas bates": "Bates, Phineas Elementary"
|
41 |
})
|
42 |
|
43 |
def format_prompt(self, user_input):
|
@@ -77,7 +74,42 @@ class SchoolChatbot:
|
|
77 |
context_items.append(f"The school's {kw} is {val.lower()}.")
|
78 |
return context_items
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
def get_response(self, user_input):
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
matched_school = self.match_school_name(user_input)
|
82 |
structured_facts = self.extract_context_with_keywords(user_input, matched_school)
|
83 |
|
|
|
7 |
from difflib import get_close_matches
|
8 |
|
9 |
class SchoolChatbot:
|
10 |
+
"""Boston School Chatbot integrating structured data, vector context, and model completion."""
|
|
|
|
|
11 |
|
12 |
def __init__(self):
|
13 |
model_id = MY_MODEL if MY_MODEL else BASE_MODEL
|
14 |
self.client = InferenceClient(model=model_id, token=HF_TOKEN)
|
15 |
self.df = pd.read_csv("bps_data.csv")
|
16 |
+
with open("cleaned_keyword_to_column_map.json") as f:
|
17 |
self.keyword_map = json.load(f)
|
18 |
|
19 |
+
# Build name variants for school matching
|
20 |
self.school_name_map = {}
|
21 |
for _, row in self.df.iterrows():
|
22 |
primary = row.get("BPS_School_Name")
|
|
|
29 |
if pd.notna(abbrev):
|
30 |
self.school_name_map[abbrev.lower()] = primary
|
31 |
|
|
|
32 |
self.school_name_map.update({
|
33 |
"acc": "Another Course to College*",
|
34 |
"baldwin": "Baldwin Early Learning Pilot Academy",
|
35 |
"adams elementary": "Adams, Samuel Elementary",
|
36 |
"alighieri montessori": "Alighieri, Dante Montessori School",
|
37 |
+
"phineas bates": "Bates, Phineas Elementary",
|
38 |
})
|
39 |
|
40 |
def format_prompt(self, user_input):
|
|
|
74 |
context_items.append(f"The school's {kw} is {val.lower()}.")
|
75 |
return context_items
|
76 |
|
77 |
+
def query_schools_by_feature(self, query):
|
78 |
+
tokens = re.findall(r'\b\w+\b', query.lower())
|
79 |
+
matched_keywords = set()
|
80 |
+
for token in tokens:
|
81 |
+
matched_keywords.update(get_close_matches(token, self.keyword_map.keys(), cutoff=0.85))
|
82 |
+
|
83 |
+
positive_terms = "yes|accessible|adequate|good|excellent|present"
|
84 |
+
negative_terms = "no|not accessible|inadequate|poor|bad|limited"
|
85 |
+
|
86 |
+
matching_schools = set()
|
87 |
+
inverse = any(t in query.lower() for t in ["not", "inaccessible", "bad", "poor", "lacking"])
|
88 |
+
|
89 |
+
for keyword in matched_keywords:
|
90 |
+
col = self.keyword_map.get(keyword)
|
91 |
+
if col and col in self.df.columns:
|
92 |
+
if inverse:
|
93 |
+
subset = self.df[~self.df[col].astype(str).str.lower().str.contains(positive_terms, na=False)]
|
94 |
+
else:
|
95 |
+
subset = self.df[self.df[col].astype(str).str.lower().str.contains(positive_terms, na=False)]
|
96 |
+
schools = subset["BPS_School_Name"].dropna().unique().tolist()
|
97 |
+
matching_schools.update(schools)
|
98 |
+
|
99 |
+
if not matching_schools:
|
100 |
+
return None
|
101 |
+
return (
|
102 |
+
"The following schools match your criteria:\n" +
|
103 |
+
"\n".join(f"- {s}" for s in sorted(matching_schools))
|
104 |
+
)
|
105 |
+
|
106 |
def get_response(self, user_input):
|
107 |
+
# School-wide filter query
|
108 |
+
school_filter = self.query_schools_by_feature(user_input)
|
109 |
+
if school_filter:
|
110 |
+
return school_filter
|
111 |
+
|
112 |
+
# Per-school context query
|
113 |
matched_school = self.match_school_name(user_input)
|
114 |
structured_facts = self.extract_context_with_keywords(user_input, matched_school)
|
115 |
|