Spaces:
Runtime error
Runtime error
File size: 2,263 Bytes
fe1089d 2492536 fe1089d d2116db fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d d2116db fe1089d 2492536 f5ebee7 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 c28c597 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# interpret module that implements the interpretability method
# external imports
from shap import models, maskers, plots, PartitionExplainer
import torch
# internal imports
from utils import formatting as fmt
from .markup import markup_text
# global variables
TEACHER_FORCING = None
TEXT_MASKER = None
# main explain function that returns a chat with explanations
def chat_explained(model, prompt):
model.set_config({})
# create the shap explainer
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
# get the shap values for the prompt
shap_values = shap_explainer([prompt])
# create the explanation graphic and marked text array
graphic = create_graphic(shap_values)
marked_text = markup_text(
shap_values.data[0], shap_values.values[0], variant="shap"
)
# create the response text
response_text = fmt.format_output_text(shap_values.output_names)
# return response, graphic and marked_text array
return response_text, graphic, marked_text
# function used to wrap the model with a shap model
def wrap_shap(model):
# calling global variants
global TEXT_MASKER, TEACHER_FORCING
# set the device to cuda if gpu is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# updating the model settings
model.set_config()
# (re)initialize the shap models and masker
# creating a shap text_generation model
text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
# wrapping the text generation model in a teacher forcing model
TEACHER_FORCING = models.TeacherForcing(
text_generation,
model.TOKENIZER,
device=str(device),
similarity_model=model.MODEL,
similarity_tokenizer=model.TOKENIZER,
)
# setting the text masker as an empty string
TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
# graphic plotting function that creates a html graphic (as string) for the explanation
def create_graphic(shap_values):
# create the html graphic using shap text plot function
graphic_html = plots.text(shap_values, display=False)
# return the html graphic as string to display in iFrame
return str(graphic_html)
|