Spaces:
Runtime error
Runtime error
File size: 2,253 Bytes
fe1089d d2116db fe1089d d2116db fe1089d d2116db f5ebee7 c28c597 f5ebee7 d2116db f5ebee7 c28c597 f5ebee7 49066ce f5ebee7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# formatting util module providing formatting functions for the model input and output
# external imports
import re
import numpy as np
from numpy import ndarray
# function to format the model reponse nicely
def format_output_text(output: list):
# remove special tokens from list
formatted_output = format_tokens(output)
# start string with first list item if it is not empty
if formatted_output[0] != "":
output_str = formatted_output[0]
else:
# alternatively start with second list item
output_str = formatted_output[1]
# add all other list items with a space in between
for txt in formatted_output[1:]:
# check if the token is a punctuation mark
if txt in [".", ",", "!", "?"]:
# add punctuation mark without space
output_str += txt
# add token with space if not empty
elif txt != "":
output_str += " " + txt
# return the combined string with multiple spaces removed
return re.sub(" +", " ", output_str)
# format the tokens by removing special tokens and special characters
def format_tokens(tokens: list):
# define special tokens to remove and initialize empty list
special_tokens = [
"[CLS]",
"[SEP]",
"[PAD]",
"[UNK]",
"[MASK]",
"▁",
"Ġ",
"</w>",
"/n",
]
updated_tokens = []
# loop through tokens
for t in tokens:
# remove special token from start of token if found
if t.startswith("▁"):
t = t.lstrip("▁")
# loop through special tokens and remove them if found
for s in special_tokens:
t = t.replace(s, "")
# add token to list
updated_tokens.append(t)
# return the list of tokens
return updated_tokens
# function to flatten values into a 2d list by averaging the explanation values
def flatten_attribution(values: ndarray, axis: int = 0):
return np.sum(values, axis=axis)
def flatten_attention(values: ndarray, axis: int = 0):
return np.mean(values, axis=axis)
def avg_attention(attention_values):
attention = attention_values.decoder_attentions[0][0].detach().numpy()
return np.mean(attention, axis=0)
|