25b3nk commited on
Commit
95f8f1a
·
verified ·
1 Parent(s): b215cc7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # prompt: write a gradio app to infer the labels from the model we previously trained
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ # Load the fine-tuned model and tokenizer
8
+ checkpoint = "25b3nk/ollama-issues-classifier" # Replace with the actual path to your checkpoint directory
9
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
10
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
11
+
12
+ # Move the model to GPU if available
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+
16
+ # Function to perform inference
17
+ def predict(text):
18
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
19
+ outputs = model(**inputs)
20
+ logits = outputs.logits
21
+ probabilities = torch.sigmoid(logits) # Use sigmoid for multi-label classification
22
+
23
+ # Get predicted labels based on a threshold (e.g., 0.5)
24
+ predicted_labels = (probabilities > 0.5).nonzero()[:, 1].tolist()
25
+
26
+ # Map label IDs back to label names
27
+ predicted_labels_names = [model.config.id2label[label_id] for label_id in predicted_labels]
28
+
29
+ return predicted_labels_names
30
+
31
+
32
+ # Create the Gradio interface
33
+ iface = gr.Interface(
34
+ fn=predict,
35
+ inputs=gr.Textbox(lines=5, placeholder="Enter the issue text here..."),
36
+ outputs=gr.Label(num_top_classes=len(model.config.id2label)), # Display predicted labels
37
+ title="Issue Label Prediction",
38
+ description="Enter an issue description to predict its labels.",
39
+ )
40
+
41
+ iface.launch(debug=True)