ellawang9's picture
Update src/chat.py
9c5915a verified
raw
history blame
3.77 kB
from huggingface_hub import InferenceClient
from config import BASE_MODEL, MY_MODEL, HF_TOKEN
class SchoolChatbot:
"""
This class is extra scaffolding around a model. Modify this class to specify how the model recieves prompts and generates responses.
Example usage:
chatbot = SchoolChatbot()
response = chatbot.get_response("What schools offer Spanish programs?")
"""
def __init__(self):
"""
Initialize the chatbot with a HF model ID
"""
model_id = MY_MODEL if MY_MODEL else BASE_MODEL # define MY_MODEL in config.py if you create a new model in the HuggingFace Hub
self.client = InferenceClient(model=model_id, token=HF_TOKEN)
self.df = pd.read_csv("bps_data.csv")
with open("keyword_to_column_map.json") as f:
self.keyword_map = json.load(f)
def format_prompt(self, user_input):
"""
TODO: Implement this method to format the user's input into a proper prompt.
This method should:
1. Add any necessary system context or instructions
2. Format the user's input appropriately
3. Add any special tokens or formatting the model expects
Args:
user_input (str): The user's question about Boston schools
Returns:
str: A formatted prompt ready for the model
Example prompt format:
"You are a helpful assistant that specializes in Boston schools...
User: {user_input}
Assistant:"
"""
return (
"<|system|>You are a helpful assistant that specializes in Boston public school enrollment.<|end|>\n"
f"<|user|>{user_input}<|end|>\n"
"<|assistant|>"
)
def lookup_structured_data(self, query, school_name=None):
"""Search the structured BPS dataset for relevant information"""
results = []
df_filtered = self.df
if school_name:
df_filtered = self.df[self.df["BPS_School_Name"].str.contains(school_name, case=False, na=False)]
if df_filtered.empty:
return None
row = df_filtered.iloc[0]
for key, col in self.keyword_map.items():
if key in query.lower():
val = row.get(col, "N/A")
if pd.notna(val):
results.append(f"{key.title()}: {val}")
return "\n".join(results) if results else None
def get_response(self, user_input):
"""
TODO: Implement this method to generate responses to user questions.
This method should:
1. Use format_prompt() to prepare the input
2. Generate a response using the model
3. Clean up and return the response
Args:
user_input (str): The user's question about Boston schools
Returns:
str: The chatbot's response
Implementation tips:
- Use self.format_prompt() to format the user's input
- Use self.client to generate responses
"""
matched_school = None
for name in self.df["BPS_School_Name"].dropna():
if name.lower() in user_input.lower():
matched_school = name
break
structured_response = self.lookup_structured_data(user_input, matched_school)
if structured_response:
return f"Here’s what I found based on school data:\n{structured_response}"
prompt = self.format_prompt(user_input)
response = self.client.text_generation(
prompt,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
stop_sequences=["<|end|>"]
)
return response.strip()