ellawang9 commited on
Commit
2b3dacc
·
verified ·
1 Parent(s): fee4cb4

Update src/chat.py

Browse files
Files changed (1) hide show
  1. src/chat.py +39 -7
src/chat.py CHANGED
@@ -7,18 +7,16 @@ import re
7
  from difflib import get_close_matches
8
 
9
  class SchoolChatbot:
10
- """
11
- A chatbot that integrates structured school data and language generation to assist with Boston Public School queries.
12
- """
13
 
14
  def __init__(self):
15
  model_id = MY_MODEL if MY_MODEL else BASE_MODEL
16
  self.client = InferenceClient(model=model_id, token=HF_TOKEN)
17
  self.df = pd.read_csv("bps_data.csv")
18
- with open("keyword_to_column_map.json") as f:
19
  self.keyword_map = json.load(f)
20
 
21
- # Create school name map with aliases
22
  self.school_name_map = {}
23
  for _, row in self.df.iterrows():
24
  primary = row.get("BPS_School_Name")
@@ -31,13 +29,12 @@ class SchoolChatbot:
31
  if pd.notna(abbrev):
32
  self.school_name_map[abbrev.lower()] = primary
33
 
34
- # Add custom aliases
35
  self.school_name_map.update({
36
  "acc": "Another Course to College*",
37
  "baldwin": "Baldwin Early Learning Pilot Academy",
38
  "adams elementary": "Adams, Samuel Elementary",
39
  "alighieri montessori": "Alighieri, Dante Montessori School",
40
- "phineas bates": "Bates, Phineas Elementary"
41
  })
42
 
43
  def format_prompt(self, user_input):
@@ -77,7 +74,42 @@ class SchoolChatbot:
77
  context_items.append(f"The school's {kw} is {val.lower()}.")
78
  return context_items
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def get_response(self, user_input):
 
 
 
 
 
 
81
  matched_school = self.match_school_name(user_input)
82
  structured_facts = self.extract_context_with_keywords(user_input, matched_school)
83
 
 
7
  from difflib import get_close_matches
8
 
9
  class SchoolChatbot:
10
+ """Boston School Chatbot integrating structured data, vector context, and model completion."""
 
 
11
 
12
  def __init__(self):
13
  model_id = MY_MODEL if MY_MODEL else BASE_MODEL
14
  self.client = InferenceClient(model=model_id, token=HF_TOKEN)
15
  self.df = pd.read_csv("bps_data.csv")
16
+ with open("cleaned_keyword_to_column_map.json") as f:
17
  self.keyword_map = json.load(f)
18
 
19
+ # Build name variants for school matching
20
  self.school_name_map = {}
21
  for _, row in self.df.iterrows():
22
  primary = row.get("BPS_School_Name")
 
29
  if pd.notna(abbrev):
30
  self.school_name_map[abbrev.lower()] = primary
31
 
 
32
  self.school_name_map.update({
33
  "acc": "Another Course to College*",
34
  "baldwin": "Baldwin Early Learning Pilot Academy",
35
  "adams elementary": "Adams, Samuel Elementary",
36
  "alighieri montessori": "Alighieri, Dante Montessori School",
37
+ "phineas bates": "Bates, Phineas Elementary",
38
  })
39
 
40
  def format_prompt(self, user_input):
 
74
  context_items.append(f"The school's {kw} is {val.lower()}.")
75
  return context_items
76
 
77
+ def query_schools_by_feature(self, query):
78
+ tokens = re.findall(r'\b\w+\b', query.lower())
79
+ matched_keywords = set()
80
+ for token in tokens:
81
+ matched_keywords.update(get_close_matches(token, self.keyword_map.keys(), cutoff=0.85))
82
+
83
+ positive_terms = "yes|accessible|adequate|good|excellent|present"
84
+ negative_terms = "no|not accessible|inadequate|poor|bad|limited"
85
+
86
+ matching_schools = set()
87
+ inverse = any(t in query.lower() for t in ["not", "inaccessible", "bad", "poor", "lacking"])
88
+
89
+ for keyword in matched_keywords:
90
+ col = self.keyword_map.get(keyword)
91
+ if col and col in self.df.columns:
92
+ if inverse:
93
+ subset = self.df[~self.df[col].astype(str).str.lower().str.contains(positive_terms, na=False)]
94
+ else:
95
+ subset = self.df[self.df[col].astype(str).str.lower().str.contains(positive_terms, na=False)]
96
+ schools = subset["BPS_School_Name"].dropna().unique().tolist()
97
+ matching_schools.update(schools)
98
+
99
+ if not matching_schools:
100
+ return None
101
+ return (
102
+ "The following schools match your criteria:\n" +
103
+ "\n".join(f"- {s}" for s in sorted(matching_schools))
104
+ )
105
+
106
  def get_response(self, user_input):
107
+ # School-wide filter query
108
+ school_filter = self.query_schools_by_feature(user_input)
109
+ if school_filter:
110
+ return school_filter
111
+
112
+ # Per-school context query
113
  matched_school = self.match_school_name(user_input)
114
  structured_facts = self.extract_context_with_keywords(user_input, matched_school)
115