Spaces:
Runtime error
Runtime error
Commit
·
b0721f8
1
Parent(s):
226ad46
fix: fixing model config settings
Browse files- explanation/attention.py +1 -1
- main.py +9 -8
- model/godel.py +12 -16
explanation/attention.py
CHANGED
@@ -15,7 +15,7 @@ def chat_explained(model, prompt):
|
|
15 |
).input_ids
|
16 |
# generate output together with attentions of the model
|
17 |
decoder_input_ids = model.MODEL.generate(
|
18 |
-
encoder_input_ids, output_attentions=True,
|
19 |
)
|
20 |
|
21 |
# get input and output text as list of strings
|
|
|
15 |
).input_ids
|
16 |
# generate output together with attentions of the model
|
17 |
decoder_input_ids = model.MODEL.generate(
|
18 |
+
encoder_input_ids, output_attentions=True, generation_config=model.CONFIG
|
19 |
)
|
20 |
|
21 |
# get input and output text as list of strings
|
main.py
CHANGED
@@ -110,9 +110,10 @@ with gr.Blocks(
|
|
110 |
label="System Prompt",
|
111 |
info="Set the models system prompt, dictating how it answers.",
|
112 |
# default system prompt is set to this in the backend
|
113 |
-
placeholder=(
|
114 |
-
|
115 |
-
|
|
|
116 |
),
|
117 |
)
|
118 |
# column that takes up 1/4 of the row
|
@@ -121,7 +122,7 @@ with gr.Blocks(
|
|
121 |
xai_selection = gr.Radio(
|
122 |
["None", "SHAP", "Attention"],
|
123 |
label="Interpretability Settings",
|
124 |
-
info="Select a Interpretability Implementation to use.",
|
125 |
value="None",
|
126 |
interactive=True,
|
127 |
show_label=True,
|
@@ -133,15 +134,15 @@ with gr.Blocks(
|
|
133 |
["GODEL", "Mistral"],
|
134 |
label="Model Settings",
|
135 |
info="Select a Model to use.",
|
136 |
-
value="
|
137 |
interactive=True,
|
138 |
show_label=True,
|
139 |
)
|
140 |
|
141 |
# calling info functions on inputs/submits for different settings
|
142 |
-
system_prompt.
|
143 |
-
xai_selection.
|
144 |
-
model_selection.
|
145 |
|
146 |
# row with chatbot ui displaying "conversation" with the model
|
147 |
with gr.Row(equal_height=True):
|
|
|
110 |
label="System Prompt",
|
111 |
info="Set the models system prompt, dictating how it answers.",
|
112 |
# default system prompt is set to this in the backend
|
113 |
+
placeholder=("""
|
114 |
+
You are a helpful, respectful and honest assistant. Always
|
115 |
+
answer as helpfully as possible, while being safe.
|
116 |
+
"""
|
117 |
),
|
118 |
)
|
119 |
# column that takes up 1/4 of the row
|
|
|
122 |
xai_selection = gr.Radio(
|
123 |
["None", "SHAP", "Attention"],
|
124 |
label="Interpretability Settings",
|
125 |
+
info="Select a Interpretability Approach Implementation to use.",
|
126 |
value="None",
|
127 |
interactive=True,
|
128 |
show_label=True,
|
|
|
134 |
["GODEL", "Mistral"],
|
135 |
label="Model Settings",
|
136 |
info="Select a Model to use.",
|
137 |
+
value="Mistral",
|
138 |
interactive=True,
|
139 |
show_label=True,
|
140 |
)
|
141 |
|
142 |
# calling info functions on inputs/submits for different settings
|
143 |
+
system_prompt.change(system_prompt_info, [system_prompt])
|
144 |
+
xai_selection.change(xai_info, [xai_selection])
|
145 |
+
model_selection.change(model_info, [model_selection])
|
146 |
|
147 |
# row with chatbot ui displaying "conversation" with the model
|
148 |
with gr.Row(equal_height=True):
|
model/godel.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# GODEL model module for chat interaction and model instance control
|
2 |
|
3 |
# external imports
|
4 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
|
6 |
# internal imports
|
7 |
from utils import modelling as mdl
|
@@ -10,24 +10,20 @@ from utils import modelling as mdl
|
|
10 |
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
11 |
MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
# function to (re) set config
|
18 |
-
def set_config(
|
19 |
-
global CONFIG
|
20 |
|
21 |
-
# if config dict is given,
|
22 |
-
if
|
23 |
-
|
24 |
-
|
25 |
-
# hard setting model config to default
|
26 |
-
# needed for shap
|
27 |
-
MODEL.config.max_new_tokens = 50
|
28 |
-
MODEL.config.min_length = 8
|
29 |
-
MODEL.config.top_p = 0.9
|
30 |
-
MODEL.config.do_sample = True
|
31 |
|
32 |
|
33 |
# formatting class to formatting input for the model
|
@@ -67,7 +63,7 @@ def respond(prompt):
|
|
67 |
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
|
68 |
|
69 |
# generating using config and decoding output
|
70 |
-
outputs = MODEL.generate(input_ids,
|
71 |
output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
|
72 |
|
73 |
# returns the model output string
|
|
|
1 |
# GODEL model module for chat interaction and model instance control
|
2 |
|
3 |
# external imports
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig
|
5 |
|
6 |
# internal imports
|
7 |
from utils import modelling as mdl
|
|
|
10 |
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
11 |
MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
12 |
|
13 |
+
|
14 |
+
# model config definition
|
15 |
+
CONFIG = GenerationConfig.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
16 |
+
base_config_dict = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
|
17 |
+
CONFIG.update(**base_config_dict)
|
18 |
|
19 |
|
20 |
# function to (re) set config
|
21 |
+
def set_config(config_dict: dict):
|
|
|
22 |
|
23 |
+
# if config dict is not given, set to default
|
24 |
+
if config_dict == {}:
|
25 |
+
config_dict = base_config_dict
|
26 |
+
CONFIG.update(**config_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
# formatting class to formatting input for the model
|
|
|
63 |
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
|
64 |
|
65 |
# generating using config and decoding output
|
66 |
+
outputs = MODEL.generate(input_ids,generation_config=CONFIG)
|
67 |
output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
|
68 |
|
69 |
# returns the model output string
|