AEG-SmolLM2 / app.py
IsmaelMousa's picture
Update app.py
5355c87 verified
raw
history blame
2.33 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from json_repair import repair_json
from json import loads
checkpoint = "IsmaelMousa/SmolLM2-135M-Instruct-EngSaf-217K"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)
assistant = pipeline("text-generation", tokenizer=tokenizer, model=model, device=device)
def extract(text):
start = text.find("{")
end = text.find("}", start)
if start == -1 or end == -1: return text
response = text[start:end + 1].strip()
response = repair_json(response)
try : return loads(s=response)
except: return response
def grade(question, reference_answer, student_answer, mark_scheme):
system_content = "You are a grading assistant. Evaluate student answers based on the mark scheme. Respond only in JSON format with keys \"score\" (int) and \"rationale\" (string)."
user_content = ("Provide both a score and a rationale by evaluating the student's answer strictly within the mark scheme range, "
"grading based on how well it meets the question's requirements by comparing the student answer to the reference answer.\n"
f"Question: {question}\n"
f"Reference Answer: {reference_answer}\n"
f"Student Answer: {student_answer}\n"
f"Mark Scheme: {mark_scheme}")
messages = [{"role": "system", "content": system_content}, {"role": "user", "content": user_content}]
inputs = tokenizer.apply_chat_template(messages, tokenize=False)
output = assistant(inputs, max_new_tokens=128, do_sample=False, return_full_text=False)[0]["generated_text"]
parsed = extract(output)
return parsed
demo = gr.Interface(fn =grade,
inputs=[gr.Textbox(label="Question"),
gr.Textbox(label="Reference Answer"),
gr.Textbox(label="Student Answer"),
gr.Textbox(label="Mark Scheme")],
outputs=gr.JSON (label="Evaluation Output"))
if __name__ == "__main__": demo.launch()