Spaces:
Sleeping
Sleeping
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/xai.ipynb. | |
# %% auto 0 | |
__all__ = ['get_embeddings', 'get_dataset', 'umap_parameters', 'get_prjs', 'plot_projections', 'plot_projections_clusters', | |
'calculate_cluster_stats', 'anomaly_score', 'detector', 'plot_anomaly_scores_distribution', | |
'plot_clusters_with_anomalies', 'update_plot', 'plot_clusters_with_anomalies_interactive_plot', | |
'get_df_selected', 'shift_datetime', 'get_dateformat', 'get_anomalies', 'get_anomaly_styles', | |
'InteractiveAnomalyPlot', 'plot_save', 'plot_initial_config', 'merge_overlapping_windows', | |
'InteractiveTSPlot', 'add_selected_features', 'add_windows', 'setup_style', 'toggle_trace', | |
'set_features_buttons', 'move_left', 'move_right', 'move_down', 'move_up', 'delta_x_bigger', | |
'delta_y_bigger', 'delta_x_lower', 'delta_y_lower', 'add_movement_buttons', 'setup_boxes', 'initial_plot', | |
'show'] | |
# %% ../nbs/xai.ipynb 1 | |
#Weight & Biases | |
import wandb | |
#Yaml | |
from yaml import load, FullLoader | |
#Embeddings | |
from .all import * | |
from tsai.data.preparation import prepare_forecasting_data | |
from tsai.data.validation import get_forecasting_splits | |
from fastcore.all import * | |
#Dimensionality reduction | |
from tsai.imports import * | |
#Clustering | |
import hdbscan | |
import time | |
from .dr import get_PCA_prjs, get_UMAP_prjs, get_TSNE_prjs | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import ipywidgets as widgets | |
from IPython.display import display | |
from functools import partial | |
from IPython.display import display, clear_output, HTML as IPHTML | |
from ipywidgets import Button, Output, VBox, HBox, HTML, Layout, FloatSlider | |
import plotly.graph_objs as go | |
import plotly.offline as py | |
import plotly.io as pio | |
#! pip install kaleido | |
import kaleido | |
# %% ../nbs/xai.ipynb 4 | |
def get_embeddings(config_lrp, run_lrp, api, print_flag = False): | |
artifacts_gettr = run_lrp.use_artifact if config_lrp.use_wandb else api.artifact | |
emb_artifact = artifacts_gettr(config_lrp.emb_artifact, type='embeddings') | |
if print_flag: print(emb_artifact.name) | |
emb_config = emb_artifact.logged_by().config | |
return emb_artifact.to_obj(), emb_artifact, emb_config | |
# %% ../nbs/xai.ipynb 5 | |
def get_dataset( | |
config_lrp, | |
config_emb, | |
config_dr, | |
run_lrp, | |
api, | |
print_flag = False | |
): | |
# Botch to use artifacts offline | |
artifacts_gettr = run_lrp.use_artifact if config_lrp.use_wandb else api.artifact | |
enc_artifact = artifacts_gettr(config_emb['enc_artifact'], type='learner') | |
if print_flag: print (enc_artifact.name) | |
## TODO: This only works when you run it two timeS! WTF? | |
try: | |
enc_learner = enc_artifact.to_obj() | |
except: | |
enc_learner = enc_artifact.to_obj() | |
## Restore artifact | |
enc_logger = enc_artifact.logged_by() | |
enc_artifact_train = artifacts_gettr(enc_logger.config['train_artifact'], type='dataset') | |
#cfg_.show_attrdict(enc_logger.config) | |
if enc_logger.config['valid_artifact'] is not None: | |
enc_artifact_valid = artifacts_gettr(enc_logger.config['valid_artifact'], type='dataset') | |
if print_flag: print("enc_artifact_valid:", enc_artifact_valid.name) | |
if print_flag: print("enc_artifact_train: ", enc_artifact_train.name) | |
if config_dr['dr_artifact'] is not None: | |
print("Is not none") | |
dr_artifact = artifacts_gettr(config_dr['enc_artifact']) | |
else: | |
dr_artifact = enc_artifact_train | |
if print_flag: print("DR artifact train: ", dr_artifact.name) | |
if print_flag: print("--> DR artifact name", dr_artifact.name) | |
dr_artifact | |
df = dr_artifact.to_df() | |
if print_flag: print("--> DR After to df", df.shape) | |
if print_flag: display(df.head()) | |
return df, dr_artifact, enc_artifact, enc_learner | |
# %% ../nbs/xai.ipynb 6 | |
def umap_parameters(config_dr, config): | |
umap_params_cpu = { | |
'n_neighbors' : config_dr.n_neighbors, | |
'min_dist' : config_dr.min_dist, | |
'random_state': np.uint64(822569775), | |
'metric': config_dr.metric, | |
#'a': 1.5769434601962196, | |
#'b': 0.8950608779914887, | |
#'metric_kwds': {'p': 2}, #No deberΓa ser necesario, just in case | |
#'output_metric': 'euclidean', | |
'verbose': 4, | |
#'n_epochs': 200 | |
} | |
umap_params_gpu = { | |
'n_neighbors' : config_dr.n_neighbors, | |
'min_dist' : config_dr.min_dist, | |
'random_state': np.uint64(1234), | |
'metric': config_dr.metric, | |
'a': 1.5769434601962196, | |
'b': 0.8950608779914887, | |
'target_metric': 'euclidean', | |
'target_n_neighbors': config_dr.n_neighbors, | |
'verbose': 4, #6, #CUML_LEVEL_TRACE | |
'n_epochs': 200*3*2, | |
'init': 'random', | |
'hash_input': True | |
} | |
if config_dr.cpu_flag: | |
umap_params = umap_params_cpu | |
else: | |
umap_params = umap_params_gpu | |
return umap_params | |
# %% ../nbs/xai.ipynb 7 | |
def get_prjs(embs_no_nan, config_dr, config, print_flag = False): | |
umap_params = umap_parameters(config_dr, config) | |
prjs_pca = get_PCA_prjs( | |
X = embs_no_nan, | |
cpu = False, | |
print_flag = print_flag, | |
**umap_params | |
) | |
if print_flag: | |
print(prjs_pca.shape) | |
prjs_umap = get_UMAP_prjs( | |
input_data = prjs_pca, | |
cpu = config_dr.cpu_flag, #config_dr.cpu, | |
print_flag = print_flag, | |
**umap_params | |
) | |
if print_flag: prjs_umap.shape | |
return prjs_umap | |
# %% ../nbs/xai.ipynb 9 | |
def plot_projections(prjs, umap_params, fig_size = (25,25)): | |
"Plot 2D projections thorugh a connected scatter plot" | |
df_prjs = pd.DataFrame(prjs, columns = ['x1', 'x2']) | |
fig = plt.figure(figsize=(fig_size[0],fig_size[1])) | |
ax = fig.add_subplot(111) | |
ax.scatter(df_prjs['x1'], df_prjs['x2'], marker='o', facecolors='none', edgecolors='b', alpha=0.1) | |
ax.plot(df_prjs['x1'], df_prjs['x2'], alpha=0.5, picker=1) | |
plt.title('DR params - n_neighbors:{:d} min_dist:{:f}'.format( | |
umap_params['n_neighbors'],umap_params['min_dist'])) | |
return ax | |
# %% ../nbs/xai.ipynb 10 | |
def plot_projections_clusters(prjs, clusters_labels, umap_params, fig_size = (25,25)): | |
"Plot 2D projections thorugh a connected scatter plot" | |
df_prjs = pd.DataFrame(prjs, columns = ['x1', 'x2']) | |
df_prjs['cluster'] = clusters_labels | |
fig = plt.figure(figsize=(fig_size[0],fig_size[1])) | |
ax = fig.add_subplot(111) | |
# Create a scatter plot for each cluster with different colors | |
unique_labels = df_prjs['cluster'].unique() | |
print(unique_labels) | |
for label in unique_labels: | |
cluster_data = df_prjs[df_prjs['cluster'] == label] | |
ax.scatter(cluster_data['x1'], cluster_data['x2'], label=f'Cluster {label}') | |
#ax.scatter(df_prjs['x1'], df_prjs['x2'], marker='o', facecolors='none', edgecolors='b', alpha=0.1) | |
#ax.plot(df_prjs['x1'], df_prjs['x2'], alpha=0.5, picker=1) | |
plt.title('DR params - n_neighbors:{:d} min_dist:{:f}'.format( | |
umap_params['n_neighbors'],umap_params['min_dist'])) | |
return ax | |
# %% ../nbs/xai.ipynb 11 | |
def calculate_cluster_stats(data, labels): | |
"""Computes the media and the standard deviation for every cluster.""" | |
cluster_stats = {} | |
for label in np.unique(labels): | |
#members = data[labels == label] | |
members = data | |
mean = np.mean(members, axis = 0) | |
std = np.std(members, axis = 0) | |
cluster_stats[label] = (mean, std) | |
return cluster_stats | |
# %% ../nbs/xai.ipynb 12 | |
def anomaly_score(point, cluster_stats, label): | |
"""Computes an anomaly score for each point.""" | |
mean, std = cluster_stats[label] | |
return np.linalg.norm((point - mean) / std) | |
# %% ../nbs/xai.ipynb 13 | |
def detector(data, labels): | |
"""Anomaly detection function.""" | |
cluster_stats = calculate_cluster_stats(data, labels) | |
scores = [] | |
for point, label in zip(data, labels): | |
score = anomaly_score(point, cluster_stats, label) | |
scores.append(score) | |
return np.array(scores) | |
# %% ../nbs/xai.ipynb 15 | |
def plot_anomaly_scores_distribution(anomaly_scores): | |
"Plot the distribution of anomaly scores to check for normality" | |
plt.figure(figsize=(10, 6)) | |
sns.histplot(anomaly_scores, kde=True, bins=30) | |
plt.title("DistribuciΓ³n de Anomaly Scores") | |
plt.xlabel("Anomaly Score") | |
plt.ylabel("Frecuencia") | |
plt.show() | |
# %% ../nbs/xai.ipynb 16 | |
def plot_clusters_with_anomalies(prjs, clusters_labels, anomaly_scores, threshold, fig_size=(25, 25)): | |
"Plot 2D projections of clusters and superimpose anomalies" | |
df_prjs = pd.DataFrame(prjs, columns=['x1', 'x2']) | |
df_prjs['cluster'] = clusters_labels | |
df_prjs['anomaly'] = anomaly_scores > threshold | |
fig = plt.figure(figsize=(fig_size[0], fig_size[1])) | |
ax = fig.add_subplot(111) | |
# Plot each cluster with different colors | |
unique_labels = df_prjs['cluster'].unique() | |
for label in unique_labels: | |
cluster_data = df_prjs[df_prjs['cluster'] == label] | |
ax.scatter(cluster_data['x1'], cluster_data['x2'], label=f'Cluster {label}', alpha=0.7) | |
# Superimpose anomalies | |
anomalies = df_prjs[df_prjs['anomaly']] | |
ax.scatter(anomalies['x1'], anomalies['x2'], color='red', label='Anomalies', edgecolor='k', s=50) | |
plt.title('Clusters and anomalies') | |
plt.legend() | |
plt.show() | |
def update_plot(threshold, prjs_umap, clusters_labels, anomaly_scores, fig_size): | |
plot_clusters_with_anomalies(prjs_umap, clusters_labels, anomaly_scores, threshold, fig_size) | |
def plot_clusters_with_anomalies_interactive_plot(threshold, prjs_umap, clusters_labels, anomaly_scores, fig_size): | |
threshold_slider = widgets.FloatSlider(value=threshold, min=0.001, max=3, step=0.001, description='Threshold') | |
interactive_plot = widgets.interactive(update_plot, threshold = threshold_slider, | |
prjs_umap = widgets.fixed(prjs_umap), | |
clusters_labels = widgets.fixed(clusters_labels), | |
anomaly_scores = widgets.fixed(anomaly_scores), | |
fig_size = widgets.fixed((25,25))) | |
display(interactive_plot) | |
# %% ../nbs/xai.ipynb 18 | |
import plotly.express as px | |
from datetime import timedelta | |
# %% ../nbs/xai.ipynb 19 | |
def get_df_selected(df, selected_indices, w, stride = 1): #Cuidado con stride | |
'''Links back the selected points to the original dataframe and returns the associated windows indices''' | |
n_windows = len(selected_indices) | |
window_ranges = [(id*stride, (id*stride)+w) for id in selected_indices] | |
#window_ranges = [(id*w, (id+1)*w+1) for id in selected_indices] | |
#window_ranges = [(id*stride, (id*stride)+w) for id in selected_indices] | |
#print(window_ranges) | |
valores_tramos = [df.iloc[inicio:fin+1] for inicio, fin in window_ranges] | |
df_selected = pd.concat(valores_tramos, ignore_index=False) | |
return window_ranges, n_windows, df_selected | |
# %% ../nbs/xai.ipynb 20 | |
def shift_datetime(dt, seconds, sign, dateformat="%Y-%m-%d %H:%M:%S.%f", print_flag = False): | |
""" | |
This function gets a datetime dt, a number of seconds, | |
a sign and moves the date such number of seconds to the future | |
if sign is '+' and to the past if sing is '-'. | |
""" | |
if print_flag: print(dateformat) | |
dateformat2= "%Y-%m-%d %H:%M:%S.%f" | |
dateformat3 = "%Y-%m-%d" | |
ok = False | |
try: | |
if print_flag: print("dt ", dt, "seconds", seconds, "sign", sign) | |
new_dt = datetime.strptime(dt, dateformat) | |
if print_flag: print("ndt", new_dt) | |
ok = True | |
except ValueError as e: | |
if print_flag: | |
print("Error: ", e) | |
if (not ok): | |
try: | |
if print_flag: print("Parsing alternative dataformat", dt, "seconds", seconds, "sign", sign, dateformat2) | |
new_dt = datetime.strptime(dt, dateformat3) | |
if print_flag: print("2ndt", new_dt) | |
except ValueError as e: | |
print("Error: ", e) | |
if print_flag: print(new_dt) | |
try: | |
if new_dt.hour == 0 and new_dt.minute == 0 and new_dt.second == 0: | |
if print_flag: "Aqui" | |
new_dt = new_dt.replace(hour=0, minute=0, second=0, microsecond=0) | |
if print_flag: print(new_dt) | |
if print_flag: print("ndt", new_dt) | |
if (sign == '+'): | |
if print_flag: print("Aqui") | |
new_dt = new_dt + timedelta(seconds = seconds) | |
if print_flag: print(new_dt) | |
else: | |
if print_flag: print(sign, type(dt)) | |
new_dt = new_dt - timedelta(seconds = seconds) | |
if print_flag: print(new_dt) | |
if new_dt.hour == 0 and new_dt.minute == 0 and new_dt.second == 0: | |
if print_flag: print("replacing") | |
new_dt = new_dt.replace(hour=0, minute=0, second=0, microsecond=0) | |
new_dt_str = new_dt.strftime(dateformat2) | |
if print_flag: print("new dt ", new_dt) | |
except ValueError as e: | |
if print_flag: print("Aqui3") | |
shift_datetime(dt, 0, sign, dateformat = "%Y-%m-%d", print_flag = False) | |
return str(e) | |
return new_dt_str | |
# %% ../nbs/xai.ipynb 21 | |
def get_dateformat(text_date): | |
dateformat1 = "%Y-%m-%d %H:%M:%S" | |
dateformat2 = "%Y-%m-%d %H:%M:%S.%f" | |
dateformat3 = "%Y-%m-%d" | |
dateformat = "" | |
parts = text_date.split() | |
if len(parts) == 2: | |
time_parts = parts[1].split(':') | |
if len(time_parts) == 3: | |
sec_parts = time_parts[2].split('.') | |
if len(sec_parts) == 2: | |
dateformat = dateformat2 | |
else: | |
dateformat = dateformat1 | |
else: | |
dateformat = "unknown format 1" | |
elif len(parts) == 1: | |
dateformat = dateformat3 | |
else: | |
dateformat = "unknown format 2" | |
return dateformat | |
# %% ../nbs/xai.ipynb 23 | |
def get_anomalies(df, threshold, flag): | |
df['anomaly'] = [ (score > threshold) and flag for score in df['anomaly_score']] | |
def get_anomaly_styles(df, threshold, anomaly_scores, flag = False, print_flag = False): | |
if print_flag: print("Threshold: ", threshold) | |
if print_flag: print("Flag", flag) | |
if print_flag: print("df ~", df.shape) | |
df['anomaly'] = [ (score > threshold) and flag for score in df['anomaly_score'] ] | |
if print_flag: print(df) | |
get_anomalies(df, threshold, flag) | |
anomalies = df[df['anomaly']] | |
if flag: | |
df['anomaly'] = [ | |
(score > threshold) and flag | |
for score in anomaly_scores | |
] | |
symbols = [ | |
'x' if is_anomaly else 'circle' | |
for is_anomaly in df['anomaly'] | |
] | |
line_colors = [ | |
'black' | |
if (is_anomaly and flag) else 'rgba(0,0,0,0)' | |
for is_anomaly in df['anomaly'] | |
] | |
else: | |
symbols = ['circle' for _ in df['x1']] | |
line_colors = ['rgba(0,0,0,0)' for _ in df['x1']] | |
if print_flag: print(anomalies) | |
return symbols, line_colors | |
### Example of use | |
#prjs_df = pd.DataFrame(prjs_umap, columns = ['x1', 'x2']) | |
#prjs_df['anomaly_score'] = anomaly_scores | |
#s, l = get_anomaly_styles(prjs_df, 1, True) | |
# %% ../nbs/xai.ipynb 24 | |
class InteractiveAnomalyPlot(): | |
def __init__( | |
self, selected_indices = [], | |
threshold = 0.15, | |
anomaly_flag = False, | |
path = "../imgs", w = 0 | |
): | |
self.selected_indices = selected_indices | |
self.selected_indices_tmp = selected_indices | |
self.threshold = threshold | |
self.threshold_ = threshold | |
self.anomaly_flag = anomaly_flag | |
self.w = w | |
self.name = f"w={self.w}" | |
self.path = f"{path}{self.name}.png" | |
self.interaction_enabled = True | |
def plot_projections_clusters_interactive( | |
self, prjs, cluster_labels, umap_params, anomaly_scores=[], fig_size=(7,7), print_flag = False | |
): | |
self.selected_indices_tmp = self.selected_indices | |
py.init_notebook_mode() | |
prjs_df, cluster_colors = plot_initial_config(prjs, cluster_labels, anomaly_scores) | |
legend_items = [widgets.HTML(f'<b>Cluster {cluster}:</b> <span style="color:{color};">β </span>') | |
for cluster, color in cluster_colors.items()] | |
legend = widgets.VBox(legend_items) | |
marker_colors = prjs_df['cluster'].map(cluster_colors) | |
symbols, line_colors = get_anomaly_styles(prjs_df, self.threshold_, anomaly_scores, self.anomaly_flag, print_flag) | |
fig = go.FigureWidget( | |
[ | |
go.Scatter( | |
x=prjs_df['x1'], y=prjs_df['x2'], | |
mode="markers", | |
marker= { | |
'color': marker_colors, | |
'line': { 'color': line_colors, 'width': 1 }, | |
'symbol': symbols | |
}, | |
text = prjs_df.index | |
) | |
] | |
) | |
line_trace = go.Scatter( | |
x=prjs_df['x1'], | |
y=prjs_df['x2'], | |
mode="lines", | |
line=dict(color='rgba(128, 128, 128, 0.5)', width=1)#, | |
#showlegend=False # Puedes configurar si deseas mostrar esta lΓnea en la leyenda | |
) | |
fig.add_trace(line_trace) | |
sca = fig.data[0] | |
fig.update_layout( | |
dragmode='lasso', | |
width=700, | |
height=500, | |
title={ | |
'text': '<span style="font-weight:bold">DR params - n_neighbors:{:d} min_dist:{:f}</span>'.format( | |
umap_params['n_neighbors'], umap_params['min_dist']), | |
'y':0.98, | |
'x':0.5, | |
'xanchor': 'center', | |
'yanchor': 'top' | |
}, | |
plot_bgcolor='white', | |
paper_bgcolor='#f0f0f0', | |
xaxis=dict(gridcolor='lightgray', zerolinecolor='black', title = 'x'), | |
yaxis=dict(gridcolor='lightgray', zerolinecolor='black', title = 'y'), | |
margin=dict(l=10, r=20, t=30, b=10) | |
) | |
output_tmp = Output() | |
output_button = Output() | |
output_anomaly = Output() | |
output_threshold = Output() | |
output_width = Output() | |
def select_action(trace, points, selector): | |
self.selected_indices_tmp = points.point_inds | |
with output_tmp: | |
output_tmp.clear_output(wait=True) | |
if print_flag: print("Selected indices tmp:", self.selected_indices_tmp) | |
def button_action(b): | |
self.selected_indices = self.selected_indices_tmp | |
with output_button: | |
output_button.clear_output(wait = True) | |
if print_flag: print("Selected indices:", self.selected_indices) | |
def update_anomalies(): | |
if print_flag: print("About to update anomalies") | |
symbols, line_colors = get_anomaly_styles(prjs_df, self.threshold_, anomaly_scores, self.anomaly_flag, print_flag) | |
if print_flag: print("Anomaly styles got") | |
with fig.batch_update(): | |
fig.data[0].marker.symbol = symbols | |
fig.data[0].marker.line.color = line_colors | |
if print_flag: print("Anomalies updated") | |
if print_flag: print("Threshold: ", self.threshold_) | |
if print_flag: print("Scores: ", anomaly_scores) | |
def anomaly_action(b): | |
with output_anomaly: # Cambia output_flag a output_anomaly | |
output_anomaly.clear_output(wait=True) | |
if print_fllag: print("Negate anomaly flag") | |
self.anomaly_flag = not self.anomaly_flag | |
if print_flag: print("Show anomalies:", self.anomaly_flag) | |
update_anomalies() | |
sca.on_selection(select_action) | |
layout = widgets.Layout(width='auto', height='40px') | |
button = Button( | |
description="Update selected_indices", | |
style = {'button_color': 'lightblue'}, | |
display = 'flex', | |
flex_row = 'column', | |
align_items = 'stretch', | |
layout = layout | |
) | |
anomaly_button = Button( | |
description = "Show anomalies", | |
style = {'button_color': 'lightgray'}, | |
display = 'flex', | |
flex_row = 'column', | |
align_items = 'stretch', | |
layout = layout | |
) | |
button.on_click(button_action) | |
anomaly_button.on_click(anomaly_action) | |
##### Reactivity buttons | |
pause_button = Button( | |
description = "Pause interactiveness", | |
style = {'button_color': 'pink'}, | |
display = 'flex', | |
flex_row = 'column', | |
align_items = 'stretch', | |
layout = layout | |
) | |
resume_button = Button( | |
description = "Resume interactiveness", | |
style = {'button_color': 'lightgreen'}, | |
display = 'flex', | |
flex_row = 'column', | |
align_items = 'stretch', | |
layout = layout | |
) | |
threshold_slider = FloatSlider( | |
value=self.threshold_, | |
min=0.0, | |
max=float(np.ceil(self.threshold+5)), | |
step=0.0001, | |
description='Anomaly threshold:', | |
continuous_update=False | |
) | |
def pause_interaction(b): | |
self.interaction_enabled = False | |
fig.update_layout(dragmode='pan') | |
def resume_interaction(b): | |
self.interaction_enabled = True | |
fig.update_layout(dragmode='lasso') | |
def update_threshold(change): | |
with output_threshold: | |
output_threshold.clear_output(wait = True) | |
if print_flag: print("Update threshold") | |
self.threshold_ = change.new | |
if print_flag: print("Update anomalies threshold = ", self.threshold_) | |
update_anomalies() | |
#### Width | |
width_slider = FloatSlider( | |
value = 0.5, | |
min = 0.0, | |
max = 1.0, | |
step = 0.0001, | |
description = 'Line width:', | |
continuous_update = False | |
) | |
def update_width(change): | |
with output_width: | |
try: | |
output_width.clear_output(wait = True) | |
if print_flag: | |
print("Change line width") | |
print("Trace to update:", fig.data[1]) | |
with fig.batch_update(): | |
fig.data[1].line.width = change.new # Actualiza la opacidad de la lΓnea | |
if print_flag: print("ChangeD line width") | |
except Exception as e: | |
print("Error updating line width:", e) | |
pause_button.on_click(pause_interaction) | |
resume_button.on_click(resume_interaction) | |
threshold_slider.observe(update_threshold, 'value') | |
#### | |
width_slider.observe(update_width, names = 'value') | |
##### | |
space = HTML(" ") | |
vbox = VBox((output_tmp, output_button, output_anomaly, output_threshold, fig)) | |
hbox = HBox((space, button, space, pause_button, space, resume_button, anomaly_button)) | |
# Centrar las dos cajas horizontalmente en el VBox | |
box_layout = widgets.Layout(display='flex', | |
flex_flow='column', | |
align_items='center', | |
width='100%') | |
if self.anomaly_flag: | |
box = VBox((hbox,threshold_slider,width_slider, output_width, vbox), layout = box_layout) | |
else: | |
box = VBox((hbox, width_slider, output_width, vbox), layout = box_layout) | |
box.add_class("layout") | |
plot_save(fig, self.w) | |
display(box) | |
# %% ../nbs/xai.ipynb 25 | |
def plot_save(fig, w): | |
image_bytes = pio.to_image(fig, format='png') | |
with open(f"../imgs/w={w}.png", 'wb') as f: | |
f.write(image_bytes) | |
# %% ../nbs/xai.ipynb 26 | |
def plot_initial_config(prjs, cluster_labels, anomaly_scores): | |
prjs_df = pd.DataFrame(prjs, columns = ['x1', 'x2']) | |
prjs_df['cluster'] = cluster_labels | |
prjs_df['anomaly_score'] = anomaly_scores | |
cluster_colors_df = pd.DataFrame({'cluster': cluster_labels}).drop_duplicates() | |
cluster_colors_df['color'] = px.colors.qualitative.Set1[:len(cluster_colors_df)] | |
cluster_colors = dict(zip(cluster_colors_df['cluster'], cluster_colors_df['color'])) | |
return prjs_df, cluster_colors | |
# %% ../nbs/xai.ipynb 27 | |
def merge_overlapping_windows(windows): | |
if not windows: | |
return [] | |
# Order | |
sorted_windows = sorted(windows, key=lambda x: x[0]) | |
merged_windows = [sorted_windows[0]] | |
for window in sorted_windows[1:]: | |
if window[0] <= merged_windows[-1][1]: | |
# Merge! | |
merged_windows[-1] = (merged_windows[-1][0], max(window[1], merged_windows[-1][1])) | |
else: | |
merged_windows.append(window) | |
return merged_windows | |
# %% ../nbs/xai.ipynb 29 | |
class InteractiveTSPlot: | |
def __init__( | |
self, | |
df, | |
selected_indices, | |
meaningful_features_subset_ids, | |
w, | |
stride=1, | |
print_flag=False, | |
num_points=10000, | |
dateformat='%Y-%m-%d %H:%M:%S', | |
delta_x = 10, | |
delta_y = 0.1 | |
): | |
self.df = df | |
self.selected_indices = selected_indices | |
self.meaningful_features_subset_ids = meaningful_features_subset_ids | |
self.w = w | |
self.stride = stride | |
self.print_flag = print_flag | |
self.num_points = num_points | |
self.dateformat = dateformat | |
self.fig = go.FigureWidget() | |
self.buttons = [] | |
self.print_flag = print_flag | |
self.delta_x = delta_x | |
self.delta_y = delta_y | |
self.window_ranges, self.n_windows, self.df_selected = get_df_selected( | |
self.df, self.selected_indices, self.w, self.stride | |
) | |
# Ensure the small possible number of windows to plot (like in R Shiny App) | |
self.window_ranges = merge_overlapping_windows(self.window_ranges) | |
#Num points no va bien... | |
#num_points = min(df_selected.shape[0], num_points) | |
if self.print_flag: | |
print("windows: ", self.n_windows, self.window_ranges) | |
print("selected id: ", self.df_selected.index) | |
print("points: ", self.num_points) | |
self.df.index = self.df.index.astype(str) | |
self.fig = go.FigureWidget() | |
self.colors = [ | |
f'rgb({np.random.randint(0, 256)}, {np.random.randint(0, 256)}, {np.random.randint(0, 256)})' | |
for _ in range(self.n_windows) | |
] | |
############################## | |
# Outputs for debug printing # | |
############################## | |
self.output_windows = Output() | |
self.output_move = Output() | |
self.output_delta_x = Output() | |
self.output_delta_y = Output() | |
# %% ../nbs/xai.ipynb 30 | |
def add_selected_features(self: InteractiveTSPlot): | |
# Add features time series | |
for feature_id in self.df.columns: | |
feature_pos = self.df.columns.get_loc(feature_id) | |
trace = go.Scatter( | |
#x=df.index[:num_points], | |
#y=df[feature_id][:num_points], | |
x = self.df.index, | |
y = self.df[feature_id], | |
mode='lines', | |
name=feature_id, | |
visible=feature_pos in self.meaningful_features_subset_ids, | |
text=self.df.index | |
#text=[f'{i}-{val}' for i, val in enumerate(df.index)] | |
) | |
self.fig.add_trace(trace) | |
InteractiveTSPlot.add_selected_features = add_selected_features | |
# %% ../nbs/xai.ipynb 31 | |
def add_windows(self: InteractiveTSPlot): | |
for i, (start, end) in enumerate(self.window_ranges): | |
self.fig.add_shape( | |
type="rect", | |
x0=self.df.index[start], | |
x1=self.df.index[end], | |
y0= 0, | |
y1= 1, | |
yref = "paper", | |
fillcolor=self.colors[i], #"LightSalmon", | |
opacity=0.25, | |
layer="below", | |
line=dict(color=self.colors[i], width=1), | |
name = f"w_{i}" | |
) | |
with self.output_windows: | |
print("w[" + str( self.selected_indices[i] )+ "]="+str(self.df.index[start])+", "+str(self.df.index[end])+")") | |
InteractiveTSPlot.add_windows = add_windows | |
# %% ../nbs/xai.ipynb 32 | |
def setup_style(self: InteractiveTSPlot): | |
self.fig.update_layout( | |
title='Time Series with time window plot', | |
xaxis_title='Datetime', | |
yaxis_title='Value', | |
legend_title='Variables', | |
margin=dict(l=10, r=10, t=30, b=10), | |
xaxis=dict( | |
tickformat = '%d-' + self.dateformat, | |
#tickvals=list(range(len(df.index))), | |
#ticktext = [f'{i}-{val}' for i, val in enumerate(df.index)] | |
#grid_color = 'lightgray', zerolinecolor='black', title = 'x' | |
), | |
#yaxis = dict(grid_color = 'lightgray', zerolinecolor='black', title = 'y'), | |
#plot_color = 'white', | |
paper_bgcolor='#f0f0f0' | |
) | |
self.fig.update_yaxes(fixedrange=True) | |
InteractiveTSPlot.setup_style = setup_style | |
# %% ../nbs/xai.ipynb 34 | |
def toggle_trace(self : InteractiveTSPlot, button : Button): | |
idx = button.description | |
trace = self.fig.data[self.df.columns.get_loc(idx)] | |
trace.visible = not trace.visible | |
InteractiveTSPlot.toggle_trace = toggle_trace | |
# %% ../nbs/xai.ipynb 35 | |
def set_features_buttons(self): | |
self.buttons = [ | |
Button( | |
description=str(feature_id), | |
button_style='success' if self.df.columns.get_loc(feature_id) in self.meaningful_features_subset_ids else '' | |
) | |
for feature_id in self.df.columns | |
] | |
for button in self.buttons: | |
button.on_click(self.toggle_trace) | |
InteractiveTSPlot.set_features_buttons = set_features_buttons | |
# %% ../nbs/xai.ipynb 36 | |
def move_left(self : InteractiveTSPlot, button : Button): | |
with self.output_move: | |
self.output_move.clear_output(wait=True) | |
start_date, end_date = self.fig.layout.xaxis.range | |
new_start_date = shift_datetime(start_date, self.delta_x, '-', self.dateformat, self.print_flag) | |
new_end_date = shift_datetime(end_date, self.delta_x, '-', self.dateformat, self.print_flag) | |
with self.fig.batch_update(): | |
self.fig.layout.xaxis.range = [new_start_date, new_end_date] | |
def move_right(self : InteractiveTSPlot, button : Button): | |
self.output_move.clear_output(wait=True) | |
with self.output_move: | |
start_date, end_date = self.fig.layout.xaxis.range | |
new_start_date = shift_datetime(start_date, self.delta_x, '+', self.dateformat, self.print_flag) | |
new_end_date = shift_datetime(end_date, self.delta_x, '+', self.dateformat, self.print_flag) | |
with self.fig.batch_update(): | |
self.fig.layout.xaxis.range = [new_start_date, new_end_date] | |
def move_down(self: InteractiveTSPlot, button : Button): | |
with self.output_move: | |
self.output_move.clear_output(wait=True) | |
start_y, end_y = self.fig.layout.yaxis.range | |
with self.fig.batch_update(): | |
self.ig.layout.yaxis.range = [start_y-self.delta_y, end_y-self.delta_y] | |
def move_up(self: InteractiveTSPlot, button : Button): | |
with self.output_move: | |
self.output_move.clear_output(wait=True) | |
start_y, end_y = self.fig.layout.yaxis.range | |
with self.fig.batch_update(): | |
self.fig.layout.yaxis.range = [start_y+self.delta_y, end_y+self.delta_y] | |
InteractiveTSPlot.move_left = move_left | |
InteractiveTSPlot.move_right = move_right | |
InteractiveTSPlot.move_down = move_down | |
InteractiveTSPlot.move_up = move_up | |
# %% ../nbs/xai.ipynb 37 | |
def delta_x_bigger(self: InteractiveTSPlot): | |
with self.output_delta_x: | |
self.output_delta_x.clear_output(wait = True) | |
if self.print_flag: print("Delta before", self.delta_x) | |
self.delta_x *= 10 | |
if self.print_flag: print("delta_x:", self.delta_x) | |
def delta_y_bigger(self: InteractiveTSPlot): | |
with self.output_delta_y: | |
self.output_delta_y.clear_output(wait = True) | |
if self.print_flag: print("Delta before", self.delta_y) | |
self.delta_y *= 10 | |
if self.print_flag: print("delta_y:", self.delta_y) | |
def delta_x_lower(self:InteractiveTSPlot): | |
with self.output_delta_x: | |
self.output_delta_x.clear_output(wait = True) | |
if self.print_flag: print("Delta before", self.delta_x) | |
self.delta_x /= 10 | |
if self.print_flag: print("delta_x:", self.delta_x) | |
def delta_y_lower(self:InteractiveTSPlot): | |
with self.output_delta_y: | |
self.output_delta_y.clear_output(wait = True) | |
print("Delta before", self.delta_y) | |
self.delta_y = self.delta_y * 10 | |
print("delta_y:", self.delta_y) | |
InteractiveTSPlot.delta_x_bigger = delta_x_bigger | |
InteractiveTSPlot.delta_y_bigger = delta_y_bigger | |
InteractiveTSPlot.delta_x_lower = delta_x_lower | |
InteractiveTSPlot.delta_y_lower = delta_y_lower | |
# %% ../nbs/xai.ipynb 38 | |
def add_movement_buttons(self: InteractiveTSPlot): | |
self.button_left = Button(description="β") | |
self.button_right = Button(description="β") | |
self.button_up = Button(description="β") | |
self.button_down = Button(description="β") | |
self.button_step_x_up = Button(description="dx β") | |
self.button_step_x_down = Button(description="dx β") | |
self.button_step_y_up = Button(description="dyβ") | |
self.button_step_y_down = Button(description="dyβ") | |
# TODO: Arreglar que se pueda modificar el paso con el que se avanza. No se ve el output y no se modifica el valor | |
self.button_step_x_up.on_click(self.delta_x_bigger) | |
self.button_step_x_down.on_click(self.delta_x_lower) | |
self.button_step_y_up.on_click(self.delta_y_bigger) | |
self.button_step_y_down.on_click(self.delta_y_lower) | |
self.button_left.on_click(self.move_left) | |
self.button_right.on_click(self.move_right) | |
self.button_up.on_click(self.move_up) | |
self.button_down.on_click(self.move_down) | |
InteractiveTSPlot.add_movement_buttons = add_movement_buttons | |
# %% ../nbs/xai.ipynb 40 | |
def setup_boxes(self: InteractiveTSPlot): | |
self.steps_x = VBox([self.button_step_x_up, self.button_step_x_down]) | |
self.steps_y = VBox([self.button_step_y_up, self.button_step_y_down]) | |
arrow_buttons = HBox([self.button_left, self.button_right, self.button_up, self.button_down, self.steps_x, self.steps_y]) | |
hbox_layout = widgets.Layout(display='flex', flex_flow='row wrap', align_items='flex-start') | |
hbox = HBox(self.buttons, layout=hbox_layout) | |
box_layout = widgets.Layout( | |
display='flex', | |
flex_flow='column', | |
align_items='center', | |
width='100%' | |
) | |
if self.print_flag: | |
self.box = VBox([hbox, arrow_buttons, self.output_move, self.output_delta_x, self.output_delta_y, self.fig, self.output_windows], layout=box_layout) | |
else: | |
self.box = VBox([hbox, arrow_buttons, self.fig, self.output_windows], layout=box_layout) | |
InteractiveTSPlot.setup_boxes = setup_boxes | |
# %% ../nbs/xai.ipynb 41 | |
def initial_plot(self: InteractiveTSPlot): | |
self.add_selected_features() | |
self.add_windows() | |
self.setup_style() | |
self.set_features_buttons() | |
self.add_movement_buttons() | |
self.setup_boxes() | |
InteractiveTSPlot.initial_plot = initial_plot | |
# %% ../nbs/xai.ipynb 42 | |
def show(self : InteractiveTSPlot): | |
self.initial_plot() | |
display(self.box) | |
InteractiveTSPlot.show = show | |