rahideer commited on
Commit
b1bec5c
Β·
verified Β·
1 Parent(s): b36e408

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -3,10 +3,10 @@ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassific
3
  import torch
4
  import plotly.express as px
5
  import numpy as np
6
- from utils import visualize_attention, list_supported_models
 
7
 
8
  st.set_page_config(page_title="Transformer Visualizer", layout="wide")
9
-
10
  st.title("🧠 Transformer Visualizer")
11
  st.markdown("Explore how Transformer models process and understand language.")
12
 
@@ -19,27 +19,40 @@ if st.button("Run"):
19
  st.info(f"Loading model: `{model_name}`...")
20
 
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
-
23
  if task == "Text Classification":
24
  model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
25
  else:
26
  model = AutoModel.from_pretrained(model_name, output_attentions=True)
27
-
28
- inputs = tokenizer(text_input, return_tensors="pt")
29
  outputs = model(**inputs)
30
  attentions = outputs.attentions
31
 
32
  st.success("Model inference complete!")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if attentions:
35
- st.subheader("Attention Visualization")
36
  fig = visualize_attention(attentions, tokenizer, inputs)
37
  st.plotly_chart(fig, use_container_width=True)
38
  else:
39
  st.warning("This model does not return attention weights.")
40
 
41
  if task == "Text Classification":
42
- st.subheader("Prediction")
43
  pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
44
  prediction = pipe(text_input)
45
  st.write(prediction)
 
3
  import torch
4
  import plotly.express as px
5
  import numpy as np
6
+ from sklearn.decomposition import PCA
7
+ from utils import visualize_attention, list_supported_models, plot_token_embeddings
8
 
9
  st.set_page_config(page_title="Transformer Visualizer", layout="wide")
 
10
  st.title("🧠 Transformer Visualizer")
11
  st.markdown("Explore how Transformer models process and understand language.")
12
 
 
19
  st.info(f"Loading model: `{model_name}`...")
20
 
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
22
  if task == "Text Classification":
23
  model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
24
  else:
25
  model = AutoModel.from_pretrained(model_name, output_attentions=True)
26
+
27
+ inputs = tokenizer(text_input, return_tensors="pt", return_token_type_ids=False)
28
  outputs = model(**inputs)
29
  attentions = outputs.attentions
30
 
31
  st.success("Model inference complete!")
32
 
33
+ # Tokenization Visualization
34
+ st.subheader("πŸ”  Tokenization")
35
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
36
+ token_ids = inputs["input_ids"][0].tolist()
37
+ st.write(list(zip(tokens, token_ids)))
38
+
39
+ # Token Embeddings Visualization
40
+ st.subheader("🌐 Token Embedding Space (PCA)")
41
+ with torch.no_grad():
42
+ hidden_states = model.base_model.embeddings.word_embeddings(inputs["input_ids"]).squeeze(0)
43
+ fig_embed = plot_token_embeddings(hidden_states, tokens)
44
+ st.plotly_chart(fig_embed, use_container_width=True)
45
+
46
+ # Attention Visualization
47
  if attentions:
48
+ st.subheader("πŸ‘οΈ Attention Visualization")
49
  fig = visualize_attention(attentions, tokenizer, inputs)
50
  st.plotly_chart(fig, use_container_width=True)
51
  else:
52
  st.warning("This model does not return attention weights.")
53
 
54
  if task == "Text Classification":
55
+ st.subheader("βœ… Prediction")
56
  pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
57
  prediction = pipe(text_input)
58
  st.write(prediction)