File size: 2,457 Bytes
2496b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
from datasets import load_dataset

title = "Ask Rick a Question"
description = """
The bot was trained to answer questions based on Rick and Morty dialogues. Ask Rick anything!
<img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
"""

article = "Check out [the original Rick and Morty Bot](https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot) that this demo is based off of."


tokenizer = AutoTokenizer.from_pretrained("./gemma-2b-sciq-checkpoint")
model = AutoModelForCausalLM.from_pretrained("./gemma-2b-sciq-checkpoint")

dataset = load_dataset("allenai/sciq")
random_test_samples = dataset["test"].select(random.sample(range(0, len(dataset["test"])), 10))

examples = []
for row in random_test_samples:
    examples.append([row['support'].replace('\n', ' ')])
    examples.append([row['support'].replace('\n', ' '), row['correct_answer'].replace('\n', ' ')])


def predict(context, answer):
    formatted = f"{context.replace('\n', ' ')}\n"
    
    if answer is not None:
        formatted = f"{context.replace('\n', ' ')}\n{answer.replace('\n', ' ')}\n"

        
    inputs = tokenizer(formatted, return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=100)
    decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
    split_outputs = decoded_outputs.split("\n")
    
    if len(split_outputs) == 6:
        return {
            "context": split_outputs[0],
            "answer": split_outputs[1],
            "question": split_outputs[2],
            "distractor1": split_outputs[3],
            "distractor2": split_outputs[4],
            "distractor3": split_outputs[5],
        }
        
    return None
    

support_gr = gr.TextArea(
    label="Context",
    info="Make sure you use proper punctuation.",
    value="Bananas are yellow and curved."
)

answer_gr = gr.TextArea(
    label="Answer optional",
    info="Make sure you use proper punctuation.",
    value="yellow"
)

button = gr.Button("Generate", elem_id="send-btn", visible=True)

output_gr = gr.TextArea(
    label="Output",
    info="Make sure you use proper punctuation.",
    value=""
)

gr.Interface(
    fn=predict,
    inputs=[support_gr, answer_gr],
    outputs=[output_gr],
    title=title,
    description=description,
    article=article,
    examples=examples,
).launch()