Darwinkel commited on
Commit
2496b46
·
verified ·
1 Parent(s): b5e7ebd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import random
4
+ from datasets import load_dataset
5
+
6
+ title = "Ask Rick a Question"
7
+ description = """
8
+ The bot was trained to answer questions based on Rick and Morty dialogues. Ask Rick anything!
9
+ <img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
10
+ """
11
+
12
+ article = "Check out [the original Rick and Morty Bot](https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot) that this demo is based off of."
13
+
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained("./gemma-2b-sciq-checkpoint")
16
+ model = AutoModelForCausalLM.from_pretrained("./gemma-2b-sciq-checkpoint")
17
+
18
+ dataset = load_dataset("allenai/sciq")
19
+ random_test_samples = dataset["test"].select(random.sample(range(0, len(dataset["test"])), 10))
20
+
21
+ examples = []
22
+ for row in random_test_samples:
23
+ examples.append([row['support'].replace('\n', ' ')])
24
+ examples.append([row['support'].replace('\n', ' '), row['correct_answer'].replace('\n', ' ')])
25
+
26
+
27
+ def predict(context, answer):
28
+ formatted = f"{context.replace('\n', ' ')}\n"
29
+
30
+ if answer is not None:
31
+ formatted = f"{context.replace('\n', ' ')}\n{answer.replace('\n', ' ')}\n"
32
+
33
+
34
+ inputs = tokenizer(formatted, return_tensors="pt")
35
+ outputs = model.generate(**inputs, max_new_tokens=100)
36
+ decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
+ split_outputs = decoded_outputs.split("\n")
38
+
39
+ if len(split_outputs) == 6:
40
+ return {
41
+ "context": split_outputs[0],
42
+ "answer": split_outputs[1],
43
+ "question": split_outputs[2],
44
+ "distractor1": split_outputs[3],
45
+ "distractor2": split_outputs[4],
46
+ "distractor3": split_outputs[5],
47
+ }
48
+
49
+ return None
50
+
51
+
52
+ support_gr = gr.TextArea(
53
+ label="Context",
54
+ info="Make sure you use proper punctuation.",
55
+ value="Bananas are yellow and curved."
56
+ )
57
+
58
+ answer_gr = gr.TextArea(
59
+ label="Answer optional",
60
+ info="Make sure you use proper punctuation.",
61
+ value="yellow"
62
+ )
63
+
64
+ button = gr.Button("Generate", elem_id="send-btn", visible=True)
65
+
66
+ output_gr = gr.TextArea(
67
+ label="Output",
68
+ info="Make sure you use proper punctuation.",
69
+ value=""
70
+ )
71
+
72
+ gr.Interface(
73
+ fn=predict,
74
+ inputs=[support_gr, answer_gr],
75
+ outputs=[output_gr],
76
+ title=title,
77
+ description=description,
78
+ article=article,
79
+ examples=examples,
80
+ ).launch()