rahideer commited on
Commit
7b5fc70
·
verified ·
1 Parent(s): 059922d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
3
+ import torch
4
+ import plotly.express as px
5
+ import numpy as np
6
+ from utils import visualize_attention, list_supported_models
7
+
8
+ st.set_page_config(page_title="Transformer Visualizer", layout="wide")
9
+
10
+ st.title("🧠 Transformer Visualizer")
11
+ st.markdown("Explore how Transformer models process and understand language.")
12
+
13
+ task = st.sidebar.selectbox("Select Task", ["Text Classification", "Text Generation", "Question Answering"])
14
+ model_name = st.sidebar.selectbox("Select Model", list_supported_models(task))
15
+
16
+ text_input = st.text_area("Enter input text", "The quick brown fox jumps over the lazy dog.")
17
+
18
+ if st.button("Run"):
19
+ st.info(f"Loading model: `{model_name}`...")
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+
23
+ if task == "Text Classification":
24
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
25
+ else:
26
+ model = AutoModel.from_pretrained(model_name, output_attentions=True)
27
+
28
+ inputs = tokenizer(text_input, return_tensors="pt")
29
+ outputs = model(**inputs)
30
+ attentions = outputs.attentions
31
+
32
+ st.success("Model inference complete!")
33
+
34
+ if attentions:
35
+ st.subheader("Attention Visualization")
36
+ fig = visualize_attention(attentions, tokenizer, inputs)
37
+ st.plotly_chart(fig, use_container_width=True)
38
+ else:
39
+ st.warning("This model does not return attention weights.")
40
+
41
+ if task == "Text Classification":
42
+ st.subheader("Prediction")
43
+ pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
44
+ prediction = pipe(text_input)
45
+ st.write(prediction)
46
+
47
+ st.sidebar.markdown("---")
48
+ st.sidebar.write("App by Rahiya Esar 💖")