File size: 5,523 Bytes
bd3d7ac
c487b8d
 
8e49ca4
cccb502
bd3d7ac
 
c487b8d
 
2b3dacc
c487b8d
 
bd3d7ac
c487b8d
9c5915a
2b3dacc
9c5915a
c487b8d
2b3dacc
8e036bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b3dacc
8e036bb
 
bd3d7ac
c487b8d
 
 
 
 
9c5915a
8e036bb
 
 
 
 
 
bd3d7ac
 
 
 
 
 
 
9c5915a
bd3d7ac
 
9c5915a
 
 
bd3d7ac
9c5915a
 
bd3d7ac
 
 
 
 
 
 
c487b8d
2b3dacc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd3d7ac
2b3dacc
 
 
 
 
 
8e036bb
bd3d7ac
9c5915a
bd3d7ac
 
 
 
 
 
 
 
 
 
 
 
 
 
9c5915a
c487b8d
 
 
 
 
 
 
bd3d7ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

from huggingface_hub import InferenceClient
from config import BASE_MODEL, MY_MODEL, HF_TOKEN
import pandas as pd
import json
import re
from difflib import get_close_matches

class SchoolChatbot:
    """Boston School Chatbot integrating structured data, vector context, and model completion."""

    def __init__(self):
        model_id = MY_MODEL if MY_MODEL else BASE_MODEL
        self.client = InferenceClient(model=model_id, token=HF_TOKEN)
        self.df = pd.read_csv("bps_data.csv")
        with open("cleaned_keyword_to_column_map.json") as f:
            self.keyword_map = json.load(f)

        # Build name variants for school matching
        self.school_name_map = {}
        for _, row in self.df.iterrows():
            primary = row.get("BPS_School_Name")
            hist = row.get("BPS_Historical_Name")
            abbrev = row.get("SMMA_Abbreviated_Name")
            if pd.notna(primary):
                self.school_name_map[primary.lower()] = primary
            if pd.notna(hist):
                self.school_name_map[hist.lower()] = primary
            if pd.notna(abbrev):
                self.school_name_map[abbrev.lower()] = primary

        self.school_name_map.update({
            "acc": "Another Course to College*",
            "baldwin": "Baldwin Early Learning Pilot Academy",
            "adams elementary": "Adams, Samuel Elementary",
            "alighieri montessori": "Alighieri, Dante Montessori School",
            "phineas bates": "Bates, Phineas Elementary",
        })

    def format_prompt(self, user_input):
        return (
            "<|system|>You are a helpful assistant that specializes in Boston public school enrollment.<|end|>\n"
            f"<|user|>{user_input}<|end|>\n"
            "<|assistant|>"
        )

    def match_school_name(self, query):
        for key in self.school_name_map:
            if key in query.lower():
                return self.school_name_map[key]
        return None

    def extract_context_with_keywords(self, prompt, school_name=None):
        def extract_keywords(text):
            tokens = re.findall(r'\b\w+\b', text.lower())
            matched = set()
            for token in tokens:
                matched.update(get_close_matches(token, self.keyword_map.keys(), cutoff=0.85))
            return matched

        matched_keywords = extract_keywords(prompt)
        df_filtered = self.df
        if school_name:
            df_filtered = self.df[self.df["BPS_School_Name"].str.contains(school_name, case=False, na=False)]
        if df_filtered.empty:
            return []

        row = df_filtered.iloc[0]
        context_items = []
        for kw in matched_keywords:
            col = self.keyword_map.get(kw)
            val = row.get(col) if col else None
            if col and pd.notna(val):
                context_items.append(f"The school's {kw} is {val.lower()}.")
        return context_items

    def query_schools_by_feature(self, query):
        tokens = re.findall(r'\b\w+\b', query.lower())
        matched_keywords = set()
        for token in tokens:
            matched_keywords.update(get_close_matches(token, self.keyword_map.keys(), cutoff=0.85))

        positive_terms = "yes|accessible|adequate|good|excellent|present"
        negative_terms = "no|not accessible|inadequate|poor|bad|limited"

        matching_schools = set()
        inverse = any(t in query.lower() for t in ["not", "inaccessible", "bad", "poor", "lacking"])

        for keyword in matched_keywords:
            col = self.keyword_map.get(keyword)
            if col and col in self.df.columns:
                if inverse:
                    subset = self.df[~self.df[col].astype(str).str.lower().str.contains(positive_terms, na=False)]
                else:
                    subset = self.df[self.df[col].astype(str).str.lower().str.contains(positive_terms, na=False)]
                schools = subset["BPS_School_Name"].dropna().unique().tolist()
                matching_schools.update(schools)

        if not matching_schools:
            return None
        return (
            "The following schools match your criteria:\n" +
            "\n".join(f"- {s}" for s in sorted(matching_schools))
        )

    def get_response(self, user_input):
        # School-wide filter query
        school_filter = self.query_schools_by_feature(user_input)
        if school_filter:
            return school_filter

        # Per-school context query
        matched_school = self.match_school_name(user_input)
        structured_facts = self.extract_context_with_keywords(user_input, matched_school)

        if structured_facts:
            natural_context = (
                f"You know the following facts about {matched_school or 'a Boston public school'}:\n"
                + "\n".join(f"- {fact}" for fact in structured_facts)
            )
            prompt = (
                "<|system|>You are a helpful assistant that specializes in Boston public school enrollment. "
                "Use any known facts about the school to answer helpfully.<|end|>\n"
                f"<|user|>{user_input}<|end|>\n"
                f"<|context|>{natural_context}<|end|>\n"
                "<|assistant|>"
            )
        else:
            prompt = self.format_prompt(user_input)

        response = self.client.text_generation(
            prompt,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            stop_sequences=["<|end|>"]
        )
        return response.strip()