SebasJanampa commited on
Commit
b74625d
·
1 Parent(s): 73b6c01

first commit

Browse files
LINEA ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 12ac0a326ddb7ec9809bf7080a19c8509dcffe45
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ import torchvision.transforms as T
5
+
6
+ from LINEA.models import build_linea
7
+ from LINEA.util.slconfig import DictAction, SLConfig
8
+
9
+ from PIL import Image, ImageDraw
10
+
11
+ LINEA_MODELS = {
12
+ "LINEA-N": './LINEA/configs/linea/linea_hgnetv2_n.py',
13
+ "LINEA-S": './LINEA/configs/linea/linea_hgnetv2_s.py',
14
+ "LINEA-M": './LINEA/configs/linea/linea_hgnetv2_m.py',
15
+ "LINEA-L": './LINEA/configs/linea/linea_hgnetv2_l.py'
16
+ }
17
+
18
+ transforms = T.Compose(
19
+ [
20
+ T.Resize((640, 640)),
21
+ T.ToTensor(),
22
+ T.Normalize(mean=[0.538, 0.494, 0.453], std=[0.257, 0.263, 0.273]),
23
+ ]
24
+ )
25
+
26
+ example_images = [
27
+ ["assets/example1.jpg"],
28
+ ["assets/example2.jpg"],
29
+ ["assets/example3.jpg"],
30
+ ["assets/example4.jpg"],
31
+ ]
32
+
33
+ description = """
34
+ <h1 align="center">
35
+ <ins>LINEA</ins>
36
+ <br>
37
+ Fast and accurate line detection using scalable transformers
38
+ </h1>
39
+
40
+ <h2 align="center">
41
+ <a href="https://www.linkedin.com/in/sebastianjr/">Sebastian Janampa</a>
42
+ and
43
+ <a href="https://www.linkedin.com/in/marios-pattichis-207b0119/">Marios Pattichis</a>
44
+ </h2>
45
+
46
+ <h2 align="center">
47
+ <a href="https://github.com/SebastianJanampa/LINEA.git">GitHub</a> |
48
+ <a href="https://colab.research.google.com/github/SebastianJanampa/LINEA/blob/master/LINEA_tutorial.ipynb">Colab</a>
49
+ </h2>
50
+
51
+
52
+ ## Getting Started
53
+
54
+ LINEA is a family of transformers models that detectes the line segments on an image.
55
+ Its key component is its new attention mechanism called **line attention**.
56
+
57
+ To get started, upload an image or select one of the examples below.
58
+ You can choose between different model size, change the confidence threshold and visualize the results.
59
+ """
60
+
61
+ def create_model(model_size):
62
+ cfg = SLConfig.fromfile(LINEA_MODELS[model_size])
63
+ cfg.pretrained = False
64
+
65
+ model, postprocessor = build_linea(cfg)
66
+
67
+ letter = model_size[-1].lower()
68
+ url = f"https://github.com/SebastianJanampa/storage/releases/download/LINEA/linea_hgnetv2_{letter}.pth"
69
+ state_dict = torch.hub.load_state_dict_from_url(
70
+ url, map_location="cpu", file_name=f"linea_hgnetv2_{letter}.pth"
71
+ )
72
+
73
+ model.load_state_dict(state_dict['model'], strict=True)
74
+
75
+ class Model(nn.Module):
76
+ def __init__(self):
77
+ super().__init__()
78
+ self.model = model.deploy()
79
+ self.postprocessor = postprocessor.deploy()
80
+
81
+ def forward(self, images, orig_target_sizes):
82
+ outputs = self.model(images)
83
+ outputs = self.postprocessor(outputs, orig_target_sizes)
84
+ return outputs
85
+
86
+ model = Model()
87
+ model.eval()
88
+
89
+ return model
90
+
91
+ def draw(images, lines, scores, thrh):
92
+ for i, im in enumerate(images):
93
+ draw = ImageDraw.Draw(im)
94
+
95
+ scr = scores[i]
96
+ line = lines[i][scr > thrh]
97
+ scrs = scr[scr > thrh]
98
+
99
+ for j, l in enumerate(line):
100
+ draw.line(list(l), fill="red", width=5)
101
+ draw.text(
102
+ (l[0], l[1]),
103
+ text=f"{round(scrs[j].item(), 2)}",
104
+ fill="blue",
105
+ )
106
+
107
+ return images
108
+
109
+ def filter(lines, scores, threshold):
110
+ filtered_lines, filter_scores = [], []
111
+ for line, scr in zip(lines, scores):
112
+ idx = scr > threshold
113
+ filtered_lines.append(line[idx])
114
+ filter_scores.append(scr[idx])
115
+ return filtered_lines, filter_scores
116
+
117
+ def format_output(lines, scores):
118
+ n = len(lines[0])
119
+
120
+ txt = f"{n} lines were detected\n"
121
+ txt += "Detected lines:\n"
122
+ for line, scr in zip(lines[0], scores[0]):
123
+ txt += f"\tx1: {line[0].item():.2f}"
124
+ txt += f"\ty1: {line[0].item():.2f}"
125
+ txt += f"\tx2: {line[0].item():.2f}"
126
+ txt += f"\ty2: {line[0].item():.2f}"
127
+ txt += f"\tscore: {scr.item():.2f}\n"
128
+ return txt
129
+
130
+ def process_results(
131
+ image_path,
132
+ model_size,
133
+ threshold
134
+ ):
135
+ """ Process the image an returns the detected lines """
136
+ if image_path is None:
137
+ raise gr.Error("Please upload an image first.")
138
+
139
+ model = create_model(model_size)
140
+
141
+ im_pil = Image.open(image_path).convert("RGB")
142
+ w, h = im_pil.size
143
+ orig_size = torch.tensor([[w, h]])
144
+
145
+ im_data = transforms(im_pil).unsqueeze(0)
146
+
147
+ output = model(im_data, orig_size)
148
+ lines, scores = output
149
+
150
+ result_images = draw([im_pil], lines, scores, thrh=threshold)
151
+ filtered_lines, filtered_scores = filter(lines, scores, threshold)
152
+
153
+ return format_output(filtered_lines, filtered_scores), result_images[0], (lines, scores)
154
+
155
+ def update_threshold(
156
+ image_path,
157
+ raw_results,
158
+ threshold
159
+ ):
160
+ lines, scores = raw_results
161
+ im_pil = Image.open(image_path).convert("RGB")
162
+
163
+ result_images = draw([im_pil], lines, scores, thrh=threshold)
164
+ filtered_lines, filtered_scores = filter(lines, scores, threshold)
165
+ return format_output(filtered_lines, filtered_scores), result_images[0]
166
+
167
+ def update_model(
168
+ image_path,
169
+ model_size,
170
+ threshold
171
+ ):
172
+ create_model(model_size)
173
+
174
+ if image_path is None:
175
+ raise gr.Error("Please upload an image first.")
176
+ return None, None, None
177
+
178
+ return process_results(image_path, model_size, threshold)
179
+
180
+
181
+ # Create the Gradio interface
182
+ with gr.Blocks() as demo:
183
+ gr.Markdown(description)
184
+ with gr.Row():
185
+ with gr.Column():
186
+ gr.Markdown("""## Input Image""")
187
+ image_path = gr.Image(label="Upload image", type="filepath")
188
+ model_size = gr.Dropdown(
189
+ choices=list(LINEA_MODELS.keys()), label="Choose a LINEA model.", value="LINEA-M"
190
+ )
191
+ threshold = gr.Slider(
192
+ label="Confidence Threshold",
193
+ minimum=0.0,
194
+ maximum=1.0,
195
+ step=0.05,
196
+ interactive=True,
197
+ value=0.30,
198
+ )
199
+
200
+ submit_btn = gr.Button("Detect Lines")
201
+ gr.Examples(examples=example_images, inputs=[image_path, model_size])
202
+
203
+ with gr.Column():
204
+ gr.Markdown("""## Results""")
205
+ image_output = gr.Image(label="Detected Lines")
206
+
207
+ text_output = gr.Textbox(label="Predicted lines", type="text", lines=5)
208
+
209
+ # Define the action when the button is clicked
210
+ raw_results = gr.State()
211
+
212
+ plot_inputs = [
213
+ raw_results,
214
+ threshold
215
+ ]
216
+
217
+ submit_btn.click(
218
+ fn=process_results,
219
+ inputs=[image_path, model_size] + plot_inputs[1:],
220
+ outputs=[text_output, image_output, raw_results],
221
+ )
222
+
223
+ # Define the action when the plot checkboxes are clicked
224
+ threshold.change(fn=update_threshold, inputs=[image_path] + plot_inputs, outputs=[text_output, image_output])
225
+ demo.launch()
assets/example1.jpg ADDED
assets/example2.jpg ADDED
assets/example3.jpg ADDED
assets/example4.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.1
2
+ torchvision>=0.15.2
3
+ transformers
4
+ yapf
5
+ addict
6
+ scipy