File size: 2,263 Bytes
fe1089d
2492536
fe1089d
 
 
 
 
 
d2116db
fe1089d
 
 
 
 
 
 
 
2492536
fe1089d
 
 
2492536
fe1089d
 
 
2492536
fe1089d
d2116db
 
 
fe1089d
 
 
2492536
 
f5ebee7
fe1089d
 
2492536
fe1089d
2492536
fe1089d
 
 
 
 
2492536
fe1089d
 
 
2492536
fe1089d
2492536
fe1089d
 
 
 
 
 
 
2492536
fe1089d
 
 
 
 
2492536
fe1089d
 
 
2492536
c28c597
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# interpret module that implements the interpretability method

# external imports
from shap import models, maskers, plots, PartitionExplainer
import torch

# internal imports
from utils import formatting as fmt
from .markup import markup_text

# global variables
TEACHER_FORCING = None
TEXT_MASKER = None


# main explain function that returns a chat with explanations
def chat_explained(model, prompt):
    model.set_config({})

    # create the shap explainer
    shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)

    # get the shap values for the prompt
    shap_values = shap_explainer([prompt])

    # create the explanation graphic and marked text array
    graphic = create_graphic(shap_values)
    marked_text = markup_text(
        shap_values.data[0], shap_values.values[0], variant="shap"
    )

    # create the response text
    response_text = fmt.format_output_text(shap_values.output_names)

    # return response, graphic and marked_text array
    return response_text, graphic, marked_text


# function used to wrap the model with a shap model
def wrap_shap(model):
    # calling global variants
    global TEXT_MASKER, TEACHER_FORCING

    # set the device to cuda if gpu is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # updating the model settings
    model.set_config()

    # (re)initialize the shap models and masker
    # creating a shap text_generation model
    text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
    # wrapping the text generation model in a teacher forcing model
    TEACHER_FORCING = models.TeacherForcing(
        text_generation,
        model.TOKENIZER,
        device=str(device),
        similarity_model=model.MODEL,
        similarity_tokenizer=model.TOKENIZER,
    )
    # setting the text masker as an empty string
    TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)


# graphic plotting function that creates a html graphic (as string) for the explanation
def create_graphic(shap_values):

    # create the html graphic using shap text plot function
    graphic_html = plots.text(shap_values, display=False)

    # return the html graphic as string to display in iFrame
    return str(graphic_html)