import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import random from datasets import load_dataset title = "Ask Rick a Question" description = """ The bot was trained to answer questions based on Rick and Morty dialogues. Ask Rick anything! """ article = "Check out [the original Rick and Morty Bot](https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot) that this demo is based off of." tokenizer = AutoTokenizer.from_pretrained("./gemma-2b-sciq-checkpoint") model = AutoModelForCausalLM.from_pretrained("./gemma-2b-sciq-checkpoint") dataset = load_dataset("allenai/sciq") random_test_samples = dataset["test"].select(random.sample(range(0, len(dataset["test"])), 10)) examples = [] for row in random_test_samples: examples.append([row['support'].replace('\n', ' ')]) examples.append([row['support'].replace('\n', ' '), row['correct_answer'].replace('\n', ' ')]) def predict(context, answer): formatted = f"{context.replace('\n', ' ')}\n" if answer is not None: formatted = f"{context.replace('\n', ' ')}\n{answer.replace('\n', ' ')}\n" inputs = tokenizer(formatted, return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=100) decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) split_outputs = decoded_outputs.split("\n") if len(split_outputs) == 6: return { "context": split_outputs[0], "answer": split_outputs[1], "question": split_outputs[2], "distractor1": split_outputs[3], "distractor2": split_outputs[4], "distractor3": split_outputs[5], } return None support_gr = gr.TextArea( label="Context", info="Make sure you use proper punctuation.", value="Bananas are yellow and curved." ) answer_gr = gr.TextArea( label="Answer optional", info="Make sure you use proper punctuation.", value="yellow" ) button = gr.Button("Generate", elem_id="send-btn", visible=True) output_gr = gr.TextArea( label="Output", info="Make sure you use proper punctuation.", value="" ) gr.Interface( fn=predict, inputs=[support_gr, answer_gr], outputs=[output_gr], title=title, description=description, article=article, examples=examples, ).launch()