import gradio as gr import torch from torch import nn import torchvision.transforms as T from LINEA.models import build_linea from LINEA.util.slconfig import DictAction, SLConfig from PIL import Image, ImageDraw LINEA_MODELS = { "LINEA-N": './LINEA/configs/linea/linea_hgnetv2_n.py', "LINEA-S": './LINEA/configs/linea/linea_hgnetv2_s.py', "LINEA-M": './LINEA/configs/linea/linea_hgnetv2_m.py', "LINEA-L": './LINEA/configs/linea/linea_hgnetv2_l.py' } transforms = T.Compose( [ T.Resize((640, 640)), T.ToTensor(), T.Normalize(mean=[0.538, 0.494, 0.453], std=[0.257, 0.263, 0.273]), ] ) example_images = [ ["assets/example1.jpg"], ["assets/example2.jpg"], ["assets/example3.jpg"], ["assets/example4.jpg"], ] description = """

LINEA
Fast and accurate line detection using scalable transformers

Sebastian Janampa and Marios Pattichis

GitHub | Colab

## Getting Started LINEA is a family of transformers models that detectes the line segments on an image. Its key component is its new attention mechanism called **line attention**. To get started, upload an image or select one of the examples below. You can choose between different model size, change the confidence threshold and visualize the results. """ def create_model(model_size): cfg = SLConfig.fromfile(LINEA_MODELS[model_size]) cfg.pretrained = False model, postprocessor = build_linea(cfg) letter = model_size[-1].lower() url = f"https://github.com/SebastianJanampa/storage/releases/download/LINEA/linea_hgnetv2_{letter}.pth" state_dict = torch.hub.load_state_dict_from_url( url, map_location="cpu", file_name=f"linea_hgnetv2_{letter}.pth" ) model.load_state_dict(state_dict['model'], strict=True) class Model(nn.Module): def __init__(self): super().__init__() self.model = model.deploy() self.postprocessor = postprocessor.deploy() def forward(self, images, orig_target_sizes): outputs = self.model(images) outputs = self.postprocessor(outputs, orig_target_sizes) return outputs model = Model() model.eval() return model def draw(images, lines, scores, thrh): for i, im in enumerate(images): draw = ImageDraw.Draw(im) scr = scores[i] line = lines[i][scr > thrh] scrs = scr[scr > thrh] for j, l in enumerate(line): draw.line(list(l), fill="red", width=5) draw.text( (l[0], l[1]), text=f"{round(scrs[j].item(), 2)}", fill="blue", ) return images def filter(lines, scores, threshold): filtered_lines, filter_scores = [], [] for line, scr in zip(lines, scores): idx = scr > threshold filtered_lines.append(line[idx]) filter_scores.append(scr[idx]) return filtered_lines, filter_scores def format_output(lines, scores): n = len(lines[0]) txt = f"{n} lines were detected\n" txt += "Detected lines:\n" for line, scr in zip(lines[0], scores[0]): txt += f"\tx1: {line[0].item():.2f}" txt += f"\ty1: {line[0].item():.2f}" txt += f"\tx2: {line[0].item():.2f}" txt += f"\ty2: {line[0].item():.2f}" txt += f"\tscore: {scr.item():.2f}\n" return txt def process_results( image_path, model_size, threshold ): """ Process the image an returns the detected lines """ if image_path is None: raise gr.Error("Please upload an image first.") model = create_model(model_size) im_pil = Image.open(image_path).convert("RGB") w, h = im_pil.size orig_size = torch.tensor([[w, h]]) im_data = transforms(im_pil).unsqueeze(0) output = model(im_data, orig_size) lines, scores = output result_images = draw([im_pil], lines, scores, thrh=threshold) filtered_lines, filtered_scores = filter(lines, scores, threshold) return format_output(filtered_lines, filtered_scores), result_images[0], (lines, scores) def update_threshold( image_path, raw_results, threshold ): lines, scores = raw_results im_pil = Image.open(image_path).convert("RGB") result_images = draw([im_pil], lines, scores, thrh=threshold) filtered_lines, filtered_scores = filter(lines, scores, threshold) return format_output(filtered_lines, filtered_scores), result_images[0] def update_model( image_path, model_size, threshold ): create_model(model_size) if image_path is None: raise gr.Error("Please upload an image first.") return None, None, None return process_results(image_path, model_size, threshold) # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown(description) with gr.Row(): with gr.Column(): gr.Markdown("""## Input Image""") image_path = gr.Image(label="Upload image", type="filepath") model_size = gr.Dropdown( choices=list(LINEA_MODELS.keys()), label="Choose a LINEA model.", value="LINEA-M" ) threshold = gr.Slider( label="Confidence Threshold", minimum=0.0, maximum=1.0, step=0.05, interactive=True, value=0.30, ) submit_btn = gr.Button("Detect Lines") gr.Examples(examples=example_images, inputs=[image_path, model_size]) with gr.Column(): gr.Markdown("""## Results""") image_output = gr.Image(label="Detected Lines") text_output = gr.Textbox(label="Predicted lines", type="text", lines=5) # Define the action when the button is clicked raw_results = gr.State() plot_inputs = [ raw_results, threshold ] submit_btn.click( fn=process_results, inputs=[image_path, model_size] + plot_inputs[1:], outputs=[text_output, image_output, raw_results], ) # Define the action when the plot checkboxes are clicked threshold.change(fn=update_threshold, inputs=[image_path] + plot_inputs, outputs=[text_output, image_output]) demo.launch()