max-long's picture
Update app.py
a9ff64c verified
raw
history blame
2.56 kB
import pandas as pd
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset
# Load the CSV file
df = pd.read_csv("1921_catalogue_SMG.csv") # Replace with your actual CSV file path
text_column = "Description" # Replace with the actual column name containing the text data
# Load the model
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True)
def get_new_snippet():
# Randomly select a snippet from the CSV file
if len(df) > 0:
sample = df.sample(n=1)[text_column].values[0]
return sample
else:
return "No more snippets available." # Return this if the CSV file is empty
def ner(text: str):
labels = ["Textile Machinery"]
threshold = 0.5
# Predict entities using the fine-tuned GLiNER model
entities = model.predict_entities(text, labels, flat_ner=True, threshold=threshold)
textile_entities = [
{
"entity": ent["label"],
"word": ent["text"],
"start": ent["start"],
"end": ent["end"],
"score": ent.get("score", 0),
}
for ent in entities
if ent["label"] == "Textile Machinery"
]
highlighted_text = text
for ent in sorted(textile_entities, key=lambda x: x['start'], reverse=True):
highlighted_text = (
highlighted_text[:ent['start']] +
f"<span style='background-color: yellow; font-weight: bold;'>{highlighted_text[ent['start']:ent['end']]}</span>" +
highlighted_text[ent['end']:]
)
return highlighted_text, textile_entities
# Gradio Interface
with gr.Blocks(title="Textile Machinery NER Demo") as demo:
gr.Markdown(
"""
# Textile Machinery Entity Recognition Demo
This demo selects a random text snippet from a CSV file and identifies "Textile Machinery" entities using a fine-tuned GLiNER model.
"""
)
input_text = gr.Textbox(
value="Enter or refresh to get text from CSV",
label="Text input",
placeholder="Enter your text here",
lines=5
)
output_highlighted = gr.HTML(label="Predicted Entities")
output_entities = gr.JSON(label="Entities")
submit_btn = gr.Button("Find Textile Machinery!")
refresh_btn = gr.Button("Get New Snippet")
refresh_btn.click(fn=get_new_snippet, outputs=input_text)
submit_btn.click(
fn=ner,
inputs=[input_text],
outputs=[output_highlighted, output_entities]
)
demo.queue()
demo.launch(debug=True)