import gradio as gr
import torch
from torch import nn
import torchvision.transforms as T
from LINEA.models import build_linea
from LINEA.util.slconfig import DictAction, SLConfig
from PIL import Image, ImageDraw
LINEA_MODELS = {
"LINEA-N": './LINEA/configs/linea/linea_hgnetv2_n.py',
"LINEA-S": './LINEA/configs/linea/linea_hgnetv2_s.py',
"LINEA-M": './LINEA/configs/linea/linea_hgnetv2_m.py',
"LINEA-L": './LINEA/configs/linea/linea_hgnetv2_l.py'
}
transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
T.Normalize(mean=[0.538, 0.494, 0.453], std=[0.257, 0.263, 0.273]),
]
)
example_images = [
["assets/example1.jpg"],
["assets/example2.jpg"],
["assets/example3.jpg"],
["assets/example4.jpg"],
]
description = """
LINEA
Fast and accurate line detection using scalable transformers
## Getting Started
LINEA is a family of transformers models that detectes the line segments on an image.
Its key component is its new attention mechanism called **line attention**.
To get started, upload an image or select one of the examples below.
You can choose between different model size, change the confidence threshold and visualize the results.
"""
def create_model(model_size):
cfg = SLConfig.fromfile(LINEA_MODELS[model_size])
cfg.pretrained = False
model, postprocessor = build_linea(cfg)
letter = model_size[-1].lower()
url = f"https://github.com/SebastianJanampa/storage/releases/download/LINEA/linea_hgnetv2_{letter}.pth"
state_dict = torch.hub.load_state_dict_from_url(
url, map_location="cpu", file_name=f"linea_hgnetv2_{letter}.pth"
)
model.load_state_dict(state_dict['model'], strict=True)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.model = model.deploy()
self.postprocessor = postprocessor.deploy()
def forward(self, images, orig_target_sizes):
outputs = self.model(images)
outputs = self.postprocessor(outputs, orig_target_sizes)
return outputs
model = Model()
model.eval()
return model
def draw(images, lines, scores, thrh):
for i, im in enumerate(images):
draw = ImageDraw.Draw(im)
scr = scores[i]
line = lines[i][scr > thrh]
scrs = scr[scr > thrh]
for j, l in enumerate(line):
draw.line(list(l), fill="red", width=5)
draw.text(
(l[0], l[1]),
text=f"{round(scrs[j].item(), 2)}",
fill="blue",
)
return images
def filter(lines, scores, threshold):
filtered_lines, filter_scores = [], []
for line, scr in zip(lines, scores):
idx = scr > threshold
filtered_lines.append(line[idx])
filter_scores.append(scr[idx])
return filtered_lines, filter_scores
def format_output(lines, scores):
n = len(lines[0])
txt = f"{n} lines were detected\n"
txt += "Detected lines:\n"
for line, scr in zip(lines[0], scores[0]):
txt += f"\tx1: {line[0].item():.2f}"
txt += f"\ty1: {line[0].item():.2f}"
txt += f"\tx2: {line[0].item():.2f}"
txt += f"\ty2: {line[0].item():.2f}"
txt += f"\tscore: {scr.item():.2f}\n"
return txt
def process_results(
image_path,
model_size,
threshold
):
""" Process the image an returns the detected lines """
if image_path is None:
raise gr.Error("Please upload an image first.")
model = create_model(model_size)
im_pil = Image.open(image_path).convert("RGB")
w, h = im_pil.size
orig_size = torch.tensor([[w, h]])
im_data = transforms(im_pil).unsqueeze(0)
output = model(im_data, orig_size)
lines, scores = output
result_images = draw([im_pil], lines, scores, thrh=threshold)
filtered_lines, filtered_scores = filter(lines, scores, threshold)
return format_output(filtered_lines, filtered_scores), result_images[0], (lines, scores)
def update_threshold(
image_path,
raw_results,
threshold
):
lines, scores = raw_results
im_pil = Image.open(image_path).convert("RGB")
result_images = draw([im_pil], lines, scores, thrh=threshold)
filtered_lines, filtered_scores = filter(lines, scores, threshold)
return format_output(filtered_lines, filtered_scores), result_images[0]
def update_model(
image_path,
model_size,
threshold
):
create_model(model_size)
if image_path is None:
raise gr.Error("Please upload an image first.")
return None, None, None
return process_results(image_path, model_size, threshold)
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
gr.Markdown("""## Input Image""")
image_path = gr.Image(label="Upload image", type="filepath")
model_size = gr.Dropdown(
choices=list(LINEA_MODELS.keys()), label="Choose a LINEA model.", value="LINEA-M"
)
threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
value=0.30,
)
submit_btn = gr.Button("Detect Lines")
gr.Examples(examples=example_images, inputs=[image_path, model_size])
with gr.Column():
gr.Markdown("""## Results""")
image_output = gr.Image(label="Detected Lines")
text_output = gr.Textbox(label="Predicted lines", type="text", lines=5)
# Define the action when the button is clicked
raw_results = gr.State()
plot_inputs = [
raw_results,
threshold
]
submit_btn.click(
fn=process_results,
inputs=[image_path, model_size] + plot_inputs[1:],
outputs=[text_output, image_output, raw_results],
)
# Define the action when the plot checkboxes are clicked
threshold.change(fn=update_threshold, inputs=[image_path] + plot_inputs, outputs=[text_output, image_output])
demo.launch()