File size: 2,059 Bytes
95f8f1a
 
 
 
 
 
 
b2d3117
 
 
95f8f1a
 
 
 
 
727746f
95f8f1a
 
b2d3117
95f8f1a
 
 
 
b2d3117
95f8f1a
 
727746f
 
 
b2d3117
 
95f8f1a
 
b2d3117
727746f
b2d3117
727746f
 
95f8f1a
 
 
 
 
 
 
c9aa8f8
95f8f1a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# prompt: write a gradio app to infer the labels from the model we previously trained

import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load the fine-tuned model and tokenizer
checkpoint_dir = "25b3nk/ollama-issues-classifier"  # Replace with the actual path to your checkpoint directory
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# Function to perform inference
def predict(text):
    prob_thresh = 0.3
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
    outputs = model(**inputs)
    logits = outputs.logits
    probabilities = torch.sigmoid(logits)  # Use sigmoid for multi-label classification
    # print(probabilities)

    # Get predicted labels based on a threshold (e.g., 0.5)
    predicted_labels = (probabilities > prob_thresh).nonzero()[:, 1].tolist()
    # positions = (probabilities > 0.5).nonzero(as_tuple=False)
    prob_values = probabilities[probabilities > prob_thresh].tolist()
    # print(predicted_labels)
    # print(prob_values)

    # Map label IDs back to label names
    # predicted_labels_names = [model.config.id2label[label_id] for label_id in predicted_labels]
    labels_dict = {model.config.id2label[label_id]: prob for label_id, prob in zip(predicted_labels, prob_values)}
    # print(labels_dict)
    # labels_dict = {label: 1/len(predicted_labels_names) for label in predicted_labels_names}
    return labels_dict


# Create the Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=5, placeholder="Enter the issue text here..."),
    outputs=gr.Label(num_top_classes=len(model.config.id2label)),  # Display predicted labels
    title="Ollama github issue Label Prediction",
    description="Enter an issue description to predict its labels.",
)

iface.launch(debug=True)