Spaces:
Running
Running
import gradio as gr | |
import torch | |
from torch import nn | |
import torchvision.transforms as T | |
from LINEA.models import build_linea | |
from LINEA.util.slconfig import DictAction, SLConfig | |
from PIL import Image, ImageDraw | |
LINEA_MODELS = { | |
"LINEA-N": './LINEA/configs/linea/linea_hgnetv2_n.py', | |
"LINEA-S": './LINEA/configs/linea/linea_hgnetv2_s.py', | |
"LINEA-M": './LINEA/configs/linea/linea_hgnetv2_m.py', | |
"LINEA-L": './LINEA/configs/linea/linea_hgnetv2_l.py' | |
} | |
transforms = T.Compose( | |
[ | |
T.Resize((640, 640)), | |
T.ToTensor(), | |
T.Normalize(mean=[0.538, 0.494, 0.453], std=[0.257, 0.263, 0.273]), | |
] | |
) | |
example_images = [ | |
["assets/example1.jpg"], | |
["assets/example2.jpg"], | |
["assets/example3.jpg"], | |
["assets/example4.jpg"], | |
] | |
description = """ | |
<h1 align="center"> | |
<ins>LINEA</ins> | |
<br> | |
Fast and accurate line detection using scalable transformers | |
</h1> | |
<h2 align="center"> | |
<a href="https://www.linkedin.com/in/sebastianjr/">Sebastian Janampa</a> | |
and | |
<a href="https://www.linkedin.com/in/marios-pattichis-207b0119/">Marios Pattichis</a> | |
</h2> | |
<h2 align="center"> | |
<a href="https://github.com/SebastianJanampa/LINEA.git">GitHub</a> | | |
<a href="https://colab.research.google.com/github/SebastianJanampa/LINEA/blob/master/LINEA_tutorial.ipynb">Colab</a> | |
</h2> | |
## Getting Started | |
LINEA is a family of transformers models that detectes the line segments on an image. | |
Its key component is its new attention mechanism called **line attention**. | |
To get started, upload an image or select one of the examples below. | |
You can choose between different model size, change the confidence threshold and visualize the results. | |
""" | |
def create_model(model_size): | |
cfg = SLConfig.fromfile(LINEA_MODELS[model_size]) | |
cfg.pretrained = False | |
model, postprocessor = build_linea(cfg) | |
letter = model_size[-1].lower() | |
url = f"https://github.com/SebastianJanampa/storage/releases/download/LINEA/linea_hgnetv2_{letter}.pth" | |
state_dict = torch.hub.load_state_dict_from_url( | |
url, map_location="cpu", file_name=f"linea_hgnetv2_{letter}.pth" | |
) | |
model.load_state_dict(state_dict['model'], strict=True) | |
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = model.deploy() | |
self.postprocessor = postprocessor.deploy() | |
def forward(self, images, orig_target_sizes): | |
outputs = self.model(images) | |
outputs = self.postprocessor(outputs, orig_target_sizes) | |
return outputs | |
model = Model() | |
model.eval() | |
return model | |
def draw(images, lines, scores, thrh): | |
for i, im in enumerate(images): | |
draw = ImageDraw.Draw(im) | |
scr = scores[i] | |
line = lines[i][scr > thrh] | |
scrs = scr[scr > thrh] | |
for j, l in enumerate(line): | |
draw.line(list(l), fill="red", width=5) | |
draw.text( | |
(l[0], l[1]), | |
text=f"{round(scrs[j].item(), 2)}", | |
fill="blue", | |
) | |
return images | |
def filter(lines, scores, threshold): | |
filtered_lines, filter_scores = [], [] | |
for line, scr in zip(lines, scores): | |
idx = scr > threshold | |
filtered_lines.append(line[idx]) | |
filter_scores.append(scr[idx]) | |
return filtered_lines, filter_scores | |
def format_output(lines, scores): | |
n = len(lines[0]) | |
txt = f"{n} lines were detected\n" | |
txt += "Detected lines:\n" | |
for line, scr in zip(lines[0], scores[0]): | |
txt += f"\tx1: {line[0].item():.2f}" | |
txt += f"\ty1: {line[0].item():.2f}" | |
txt += f"\tx2: {line[0].item():.2f}" | |
txt += f"\ty2: {line[0].item():.2f}" | |
txt += f"\tscore: {scr.item():.2f}\n" | |
return txt | |
def process_results( | |
image_path, | |
model_size, | |
threshold | |
): | |
""" Process the image an returns the detected lines """ | |
if image_path is None: | |
raise gr.Error("Please upload an image first.") | |
model = create_model(model_size) | |
im_pil = Image.open(image_path).convert("RGB") | |
w, h = im_pil.size | |
orig_size = torch.tensor([[w, h]]) | |
im_data = transforms(im_pil).unsqueeze(0) | |
output = model(im_data, orig_size) | |
lines, scores = output | |
result_images = draw([im_pil], lines, scores, thrh=threshold) | |
filtered_lines, filtered_scores = filter(lines, scores, threshold) | |
return format_output(filtered_lines, filtered_scores), result_images[0], (lines, scores) | |
def update_threshold( | |
image_path, | |
raw_results, | |
threshold | |
): | |
lines, scores = raw_results | |
im_pil = Image.open(image_path).convert("RGB") | |
result_images = draw([im_pil], lines, scores, thrh=threshold) | |
filtered_lines, filtered_scores = filter(lines, scores, threshold) | |
return format_output(filtered_lines, filtered_scores), result_images[0] | |
def update_model( | |
image_path, | |
model_size, | |
threshold | |
): | |
create_model(model_size) | |
if image_path is None: | |
raise gr.Error("Please upload an image first.") | |
return None, None, None | |
return process_results(image_path, model_size, threshold) | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("""## Input Image""") | |
image_path = gr.Image(label="Upload image", type="filepath") | |
model_size = gr.Dropdown( | |
choices=list(LINEA_MODELS.keys()), label="Choose a LINEA model.", value="LINEA-M" | |
) | |
threshold = gr.Slider( | |
label="Confidence Threshold", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
value=0.30, | |
) | |
submit_btn = gr.Button("Detect Lines") | |
gr.Examples(examples=example_images, inputs=[image_path, model_size]) | |
with gr.Column(): | |
gr.Markdown("""## Results""") | |
image_output = gr.Image(label="Detected Lines") | |
text_output = gr.Textbox(label="Predicted lines", type="text", lines=5) | |
# Define the action when the button is clicked | |
raw_results = gr.State() | |
plot_inputs = [ | |
raw_results, | |
threshold | |
] | |
submit_btn.click( | |
fn=process_results, | |
inputs=[image_path, model_size] + plot_inputs[1:], | |
outputs=[text_output, image_output, raw_results], | |
) | |
# Define the action when the plot checkboxes are clicked | |
threshold.change(fn=update_threshold, inputs=[image_path] + plot_inputs, outputs=[text_output, image_output]) | |
demo.launch() |