import plotly.graph_objects as go import numpy as np def list_supported_models(task): if task == "Text Classification": return ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"] elif task == "Text Generation": return ["gpt2", "distilgpt2"] elif task == "Question Answering": return ["deepset/roberta-base-squad2", "distilbert-base-cased-distilled-squad"] return [] def visualize_attention(attentions, tokenizer, inputs): tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) last_layer_attention = attentions[-1][0] # shape: [num_heads, seq_len, seq_len] avg_attention = last_layer_attention.mean(dim=0).detach().numpy() fig = go.Figure(data=go.Heatmap( z=avg_attention, x=tokens, y=tokens, colorscale='Viridis' )) fig.update_layout(title="Average Attention - Last Layer", xaxis_nticks=len(tokens)) return fig