Add `handler.py` and `requirements.txt` (#23)
Browse files- Add `handler.py` and `requirements.txt` (acd5c25b4e94f3dff29ee1a1e3b3764c75d3e5f7)
- Update handler.py (1ebdf336dc12fef6aba071bb3b5a905272ede53d)
- Update README.md (3655c7c72dc109866aa38022f07735f7cc475b7b)
- Apply suggestions from
@ThomasDh-C
code review (54dd9132c195511be5b6f4dbc46f14b32934ee4f)
Co-authored-by: Alvaro Bartolome <[email protected]>
- README.md +4 -3
- handler.py +645 -0
- requirements.txt +7 -0
README.md
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
---
|
2 |
library_name: transformers
|
3 |
license: mit
|
4 |
-
|
|
|
|
|
5 |
---
|
6 |
📢 [[GitHub Repo](https://github.com/microsoft/OmniParser/tree/master)] [[OmniParser V2 Blog Post](https://www.microsoft.com/en-us/research/articles/omniparser-v2-turning-any-llm-into-a-computer-use-agent/)] [Huggingface demo](https://huggingface.co/spaces/microsoft/OmniParser-v2)
|
7 |
|
@@ -28,5 +30,4 @@ This model hub includes a finetuned version of YOLOv8 and a finetuned Florence-2
|
|
28 |
- While OmniParser only converts screenshot image into texts, it can be used to construct an GUI agent based on LLMs that is actionable. When developing and operating the agent using OmniParser, the developers need to be responsible and follow common safety standard.
|
29 |
|
30 |
# License
|
31 |
-
Please note that icon_detect model is under AGPL license, and icon_caption is under MIT license. Please refer to the LICENSE file in the folder of each model.
|
32 |
-
|
|
|
1 |
---
|
2 |
library_name: transformers
|
3 |
license: mit
|
4 |
+
tags:
|
5 |
+
- endpoint-template
|
6 |
+
- custom_code
|
7 |
---
|
8 |
📢 [[GitHub Repo](https://github.com/microsoft/OmniParser/tree/master)] [[OmniParser V2 Blog Post](https://www.microsoft.com/en-us/research/articles/omniparser-v2-turning-any-llm-into-a-computer-use-agent/)] [Huggingface demo](https://huggingface.co/spaces/microsoft/OmniParser-v2)
|
9 |
|
|
|
30 |
- While OmniParser only converts screenshot image into texts, it can be used to construct an GUI agent based on LLMs that is actionable. When developing and operating the agent using OmniParser, the developers need to be responsible and follow common safety standard.
|
31 |
|
32 |
# License
|
33 |
+
Please note that icon_detect model is under AGPL license, and icon_caption is under MIT license. Please refer to the LICENSE file in the folder of each model.
|
|
handler.py
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import easyocr
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from PIL.Image import Image as ImageType
|
11 |
+
from supervision.detection.core import Detections
|
12 |
+
from supervision.draw.color import Color, ColorPalette
|
13 |
+
from torchvision.ops import box_convert
|
14 |
+
from torchvision.transforms import ToPILImage
|
15 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
16 |
+
from transformers.image_utils import load_image
|
17 |
+
from ultralytics import YOLO
|
18 |
+
|
19 |
+
# NOTE: here so that it's downloaded before hand so that the endpoint it not stuck listening, whilst the required
|
20 |
+
# files are still being downloaded
|
21 |
+
easyocr.Reader(["en"])
|
22 |
+
|
23 |
+
|
24 |
+
class EndpointHandler:
|
25 |
+
def __init__(self, model_dir: str = "/repository") -> None:
|
26 |
+
self.device = (
|
27 |
+
torch.device("cuda") if torch.cuda.is_available()
|
28 |
+
else (torch.device("mps") if torch.backends.mps.is_available()
|
29 |
+
else torch.device("cpu"))
|
30 |
+
)
|
31 |
+
|
32 |
+
# bounding box detection model
|
33 |
+
self.yolo = YOLO(f"{model_dir}/icon_detect/model.pt")
|
34 |
+
|
35 |
+
# captioning model
|
36 |
+
self.processor = AutoProcessor.from_pretrained(
|
37 |
+
"microsoft/Florence-2-base", trust_remote_code=True
|
38 |
+
)
|
39 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
f"{model_dir}/icon_caption",
|
41 |
+
torch_dtype=torch.float16,
|
42 |
+
trust_remote_code=True,
|
43 |
+
).to(self.device)
|
44 |
+
|
45 |
+
# ocr
|
46 |
+
self.ocr = easyocr.Reader(["en"])
|
47 |
+
|
48 |
+
# box annotator
|
49 |
+
self.annotator = BoxAnnotator()
|
50 |
+
|
51 |
+
def __call__(self, data: Dict[str, Any]) -> Any:
|
52 |
+
# data should contain the following:
|
53 |
+
# "inputs": {
|
54 |
+
# "image": url/base64,
|
55 |
+
# (optional) "image_size": {"w": int, "h": int},
|
56 |
+
# (optional) "bbox_threshold": float,
|
57 |
+
# (optional) "iou_threshold": float,
|
58 |
+
# }
|
59 |
+
data = data.pop("inputs")
|
60 |
+
|
61 |
+
# read image from either url or base64 encoding
|
62 |
+
image = load_image(data["image"])
|
63 |
+
|
64 |
+
ocr_texts, ocr_bboxes = self.check_ocr_bboxes(
|
65 |
+
image,
|
66 |
+
out_format="xyxy",
|
67 |
+
ocr_kwargs={"text_threshold": 0.8},
|
68 |
+
)
|
69 |
+
annotated_image, filtered_bboxes_out = self.get_som_labeled_img(
|
70 |
+
image,
|
71 |
+
image_size=data.get("image_size", None),
|
72 |
+
ocr_texts=ocr_texts,
|
73 |
+
ocr_bboxes=ocr_bboxes,
|
74 |
+
bbox_threshold=data.get("bbox_threshold", 0.05),
|
75 |
+
iou_threshold=data.get("iou_threshold", None),
|
76 |
+
)
|
77 |
+
return {
|
78 |
+
"image": annotated_image,
|
79 |
+
"bboxes": filtered_bboxes_out,
|
80 |
+
}
|
81 |
+
|
82 |
+
def check_ocr_bboxes(
|
83 |
+
self,
|
84 |
+
image: ImageType,
|
85 |
+
out_format: Literal["xywh", "xyxy"] = "xywh",
|
86 |
+
ocr_kwargs: Optional[Dict[str, Any]] = {},
|
87 |
+
) -> Tuple[List[str], List[List[int]]]:
|
88 |
+
if image.mode == "RBGA":
|
89 |
+
image = image.convert("RGB")
|
90 |
+
|
91 |
+
result = self.ocr.readtext(np.array(image), **ocr_kwargs) # type: ignore
|
92 |
+
texts = [str(item[1]) for item in result]
|
93 |
+
bboxes = [
|
94 |
+
self.coordinates_to_bbox(item[0], format=out_format) for item in result
|
95 |
+
]
|
96 |
+
return (texts, bboxes)
|
97 |
+
|
98 |
+
@staticmethod
|
99 |
+
def coordinates_to_bbox(
|
100 |
+
coordinates: np.ndarray, format: Literal["xywh", "xyxy"] = "xywh"
|
101 |
+
) -> List[int]:
|
102 |
+
match format:
|
103 |
+
case "xywh":
|
104 |
+
return [
|
105 |
+
int(coordinates[0][0]),
|
106 |
+
int(coordinates[0][1]),
|
107 |
+
int(coordinates[2][0] - coordinates[0][0]),
|
108 |
+
int(coordinates[2][1] - coordinates[0][1]),
|
109 |
+
]
|
110 |
+
case "xyxy":
|
111 |
+
return [
|
112 |
+
int(coordinates[0][0]),
|
113 |
+
int(coordinates[0][1]),
|
114 |
+
int(coordinates[2][0]),
|
115 |
+
int(coordinates[2][1]),
|
116 |
+
]
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def bbox_area(bbox: List[int], w: int, h: int) -> int:
|
120 |
+
bbox = [bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h]
|
121 |
+
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def remove_bbox_overlap(
|
125 |
+
xyxy_bboxes: List[Dict[str, Any]],
|
126 |
+
ocr_bboxes: Optional[List[Dict[str, Any]]] = None,
|
127 |
+
iou_threshold: Optional[float] = 0.7,
|
128 |
+
) -> List[Dict[str, Any]]:
|
129 |
+
filtered_bboxes = []
|
130 |
+
if ocr_bboxes is not None:
|
131 |
+
filtered_bboxes.extend(ocr_bboxes)
|
132 |
+
|
133 |
+
for i, bbox_outter in enumerate(xyxy_bboxes):
|
134 |
+
bbox_left = bbox_outter["bbox"]
|
135 |
+
valid_bbox = True
|
136 |
+
|
137 |
+
for j, bbox_inner in enumerate(xyxy_bboxes):
|
138 |
+
if i == j:
|
139 |
+
continue
|
140 |
+
|
141 |
+
bbox_right = bbox_inner["bbox"]
|
142 |
+
if (
|
143 |
+
intersection_over_union(
|
144 |
+
bbox_left,
|
145 |
+
bbox_right,
|
146 |
+
)
|
147 |
+
> iou_threshold # type: ignore
|
148 |
+
) and (area(bbox_left) > area(bbox_right)):
|
149 |
+
valid_bbox = False
|
150 |
+
break
|
151 |
+
|
152 |
+
if valid_bbox is False:
|
153 |
+
continue
|
154 |
+
|
155 |
+
if ocr_bboxes is None:
|
156 |
+
filtered_bboxes.append(bbox_outter)
|
157 |
+
continue
|
158 |
+
|
159 |
+
box_added = False
|
160 |
+
ocr_labels = []
|
161 |
+
for ocr_bbox in ocr_bboxes:
|
162 |
+
if not box_added:
|
163 |
+
bbox_right = ocr_bbox["bbox"]
|
164 |
+
if overlap(bbox_right, bbox_left):
|
165 |
+
try:
|
166 |
+
ocr_labels.append(ocr_bbox["content"])
|
167 |
+
filtered_bboxes.remove(ocr_bbox)
|
168 |
+
except Exception:
|
169 |
+
continue
|
170 |
+
elif overlap(bbox_left, bbox_right):
|
171 |
+
box_added = True
|
172 |
+
break
|
173 |
+
|
174 |
+
if not box_added:
|
175 |
+
filtered_bboxes.append(
|
176 |
+
{
|
177 |
+
"type": "icon",
|
178 |
+
"bbox": bbox_outter["bbox"],
|
179 |
+
"interactivity": True,
|
180 |
+
"content": " ".join(ocr_labels) if ocr_labels else None,
|
181 |
+
}
|
182 |
+
)
|
183 |
+
|
184 |
+
return filtered_bboxes
|
185 |
+
|
186 |
+
def get_som_labeled_img(
|
187 |
+
self,
|
188 |
+
image: ImageType,
|
189 |
+
image_size: Optional[Dict[Literal["w", "h"], int]] = None,
|
190 |
+
ocr_texts: Optional[List[str]] = None,
|
191 |
+
ocr_bboxes: Optional[List[List[int]]] = None,
|
192 |
+
bbox_threshold: float = 0.01,
|
193 |
+
iou_threshold: Optional[float] = None,
|
194 |
+
caption_prompt: Optional[str] = None,
|
195 |
+
caption_batch_size: int = 64, # ~2GiB of GPU VRAM (can be increased to 128 which is ~4GiB of GPU VRAM)
|
196 |
+
) -> Tuple[str, List[Dict[str, Any]]]:
|
197 |
+
if image.mode == "RBGA":
|
198 |
+
image = image.convert("RGB")
|
199 |
+
|
200 |
+
w, h = image.size
|
201 |
+
if image_size is None:
|
202 |
+
imgsz = {"h": h, "w": w}
|
203 |
+
else:
|
204 |
+
imgsz = [image_size.get("h", h), image_size.get("w", w)]
|
205 |
+
|
206 |
+
out = self.yolo.predict(
|
207 |
+
image,
|
208 |
+
imgsz=imgsz,
|
209 |
+
conf=bbox_threshold,
|
210 |
+
iou=iou_threshold or 0.7,
|
211 |
+
verbose=False,
|
212 |
+
)[0]
|
213 |
+
if out.boxes is None:
|
214 |
+
raise RuntimeError(
|
215 |
+
"YOLO prediction failed to produce the bounding boxes..."
|
216 |
+
)
|
217 |
+
|
218 |
+
xyxy_bboxes = out.boxes.xyxy
|
219 |
+
xyxy_bboxes = xyxy_bboxes / torch.Tensor([w, h, w, h]).to(xyxy_bboxes.device)
|
220 |
+
image_np = np.asarray(image) # type: ignore
|
221 |
+
|
222 |
+
if ocr_bboxes:
|
223 |
+
ocr_bboxes = torch.tensor(ocr_bboxes) / torch.Tensor([w, h, w, h]) # type: ignore
|
224 |
+
ocr_bboxes = ocr_bboxes.tolist() # type: ignore
|
225 |
+
|
226 |
+
ocr_bboxes = [
|
227 |
+
{
|
228 |
+
"type": "text",
|
229 |
+
"bbox": bbox,
|
230 |
+
"interactivity": False,
|
231 |
+
"content": text,
|
232 |
+
"source": "box_ocr_content_ocr",
|
233 |
+
}
|
234 |
+
for bbox, text in zip(ocr_bboxes, ocr_texts) # type: ignore
|
235 |
+
if self.bbox_area(bbox, w, h) > 0
|
236 |
+
]
|
237 |
+
xyxy_bboxes = [
|
238 |
+
{
|
239 |
+
"type": "icon",
|
240 |
+
"bbox": bbox,
|
241 |
+
"interactivity": True,
|
242 |
+
"content": None,
|
243 |
+
"source": "box_yolo_content_yolo",
|
244 |
+
}
|
245 |
+
for bbox in xyxy_bboxes.tolist()
|
246 |
+
if self.bbox_area(bbox, w, h) > 0
|
247 |
+
]
|
248 |
+
|
249 |
+
filtered_bboxes = self.remove_bbox_overlap(
|
250 |
+
xyxy_bboxes=xyxy_bboxes,
|
251 |
+
ocr_bboxes=ocr_bboxes, # type: ignore
|
252 |
+
iou_threshold=iou_threshold or 0.7,
|
253 |
+
)
|
254 |
+
|
255 |
+
filtered_bboxes_out = sorted(
|
256 |
+
filtered_bboxes, key=lambda x: x["content"] is None
|
257 |
+
)
|
258 |
+
starting_idx = next(
|
259 |
+
(
|
260 |
+
idx
|
261 |
+
for idx, bbox in enumerate(filtered_bboxes_out)
|
262 |
+
if bbox["content"] is None
|
263 |
+
),
|
264 |
+
-1,
|
265 |
+
)
|
266 |
+
|
267 |
+
filtered_bboxes = torch.tensor([box["bbox"] for box in filtered_bboxes_out])
|
268 |
+
non_ocr_bboxes = filtered_bboxes[starting_idx:]
|
269 |
+
|
270 |
+
bbox_images = []
|
271 |
+
for _, coordinates in enumerate(non_ocr_bboxes):
|
272 |
+
try:
|
273 |
+
xmin, xmax = (
|
274 |
+
int(coordinates[0] * image_np.shape[1]),
|
275 |
+
int(coordinates[2] * image_np.shape[1]),
|
276 |
+
)
|
277 |
+
ymin, ymax = (
|
278 |
+
int(coordinates[1] * image_np.shape[0]),
|
279 |
+
int(coordinates[3] * image_np.shape[0]),
|
280 |
+
)
|
281 |
+
cropped_image = image_np[ymin:ymax, xmin:xmax, :]
|
282 |
+
cropped_image = cv2.resize(cropped_image, (64, 64))
|
283 |
+
bbox_images.append(ToPILImage()(cropped_image))
|
284 |
+
except Exception:
|
285 |
+
continue
|
286 |
+
|
287 |
+
if caption_prompt is None:
|
288 |
+
caption_prompt = "<CAPTION>"
|
289 |
+
|
290 |
+
captions = []
|
291 |
+
for idx in range(0, len(bbox_images), caption_batch_size): # type: ignore
|
292 |
+
batch = bbox_images[idx : idx + caption_batch_size] # type: ignore
|
293 |
+
inputs = self.processor(
|
294 |
+
images=batch,
|
295 |
+
text=[caption_prompt] * len(batch),
|
296 |
+
return_tensors="pt",
|
297 |
+
do_resize=False,
|
298 |
+
)
|
299 |
+
if self.device.type in {"cuda", "mps"}:
|
300 |
+
inputs = inputs.to(device=self.device, dtype=torch.float16)
|
301 |
+
|
302 |
+
with torch.inference_mode():
|
303 |
+
generated_ids = self.model.generate(
|
304 |
+
input_ids=inputs["input_ids"],
|
305 |
+
pixel_values=inputs["pixel_values"],
|
306 |
+
max_new_tokens=20,
|
307 |
+
num_beams=1,
|
308 |
+
do_sample=False,
|
309 |
+
early_stopping=False,
|
310 |
+
)
|
311 |
+
|
312 |
+
generated_texts = self.processor.batch_decode(
|
313 |
+
generated_ids, skip_special_tokens=True
|
314 |
+
)
|
315 |
+
captions.extend([text.strip() for text in generated_texts])
|
316 |
+
|
317 |
+
ocr_texts = [f"Text Box ID {idx}: {text}" for idx, text in enumerate(ocr_texts)] # type: ignore
|
318 |
+
for _, bbox in enumerate(filtered_bboxes_out):
|
319 |
+
if bbox["content"] is None:
|
320 |
+
bbox["content"] = captions.pop(0)
|
321 |
+
|
322 |
+
filtered_bboxes = box_convert(
|
323 |
+
boxes=filtered_bboxes, in_fmt="xyxy", out_fmt="cxcywh"
|
324 |
+
)
|
325 |
+
|
326 |
+
annotated_image = image_np.copy()
|
327 |
+
bboxes_annotate = filtered_bboxes * torch.Tensor([w, h, w, h])
|
328 |
+
xyxy_annotate = box_convert(
|
329 |
+
bboxes_annotate, in_fmt="cxcywh", out_fmt="xyxy"
|
330 |
+
).numpy()
|
331 |
+
detections = Detections(xyxy=xyxy_annotate)
|
332 |
+
labels = [str(idx) for idx in range(bboxes_annotate.shape[0])]
|
333 |
+
|
334 |
+
annotated_image = self.annotator.annotate(
|
335 |
+
scene=annotated_image,
|
336 |
+
detections=detections,
|
337 |
+
labels=labels,
|
338 |
+
image_size=(w, h),
|
339 |
+
)
|
340 |
+
assert w == annotated_image.shape[1] and h == annotated_image.shape[0]
|
341 |
+
|
342 |
+
out_image = Image.fromarray(annotated_image)
|
343 |
+
out_buffer = io.BytesIO()
|
344 |
+
out_image.save(out_buffer, format="PNG")
|
345 |
+
encoded_image = base64.b64encode(out_buffer.getvalue()).decode("ascii")
|
346 |
+
|
347 |
+
return encoded_image, filtered_bboxes_out
|
348 |
+
|
349 |
+
|
350 |
+
def area(bbox: List[int]) -> int:
|
351 |
+
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
352 |
+
|
353 |
+
|
354 |
+
def intersection_area(bbox_left: List[int], bbox_right: List[int]) -> int:
|
355 |
+
return max(
|
356 |
+
0, min(bbox_left[2], bbox_right[2]) - min(bbox_left[0], bbox_right[0])
|
357 |
+
) * max(0, min(bbox_left[3], bbox_right[3]) - min(bbox_left[1], bbox_right[1]))
|
358 |
+
|
359 |
+
|
360 |
+
def intersection_over_union(bbox_left: List[int], bbox_right: List[int]) -> float:
|
361 |
+
intersection = intersection_area(bbox_left, bbox_right)
|
362 |
+
bbox_left_area = area(bbox_left)
|
363 |
+
bbox_right_area = area(bbox_right)
|
364 |
+
union = bbox_left_area + bbox_right_area - intersection + 1e-6
|
365 |
+
|
366 |
+
ratio_left, ratio_right = 0, 0
|
367 |
+
if bbox_left_area > 0 and bbox_right_area > 0:
|
368 |
+
ratio_left = intersection / bbox_left_area
|
369 |
+
ratio_right = intersection / bbox_right_area
|
370 |
+
return max(intersection / union, ratio_left, ratio_right)
|
371 |
+
|
372 |
+
|
373 |
+
def overlap(bbox_left: List[int], bbox_right: List[int]) -> bool:
|
374 |
+
intersection = intersection_area(bbox_left, bbox_right)
|
375 |
+
ratio_left = intersection / area(bbox_left)
|
376 |
+
return ratio_left > 0.80
|
377 |
+
|
378 |
+
|
379 |
+
class BoxAnnotator:
|
380 |
+
def __init__(
|
381 |
+
self,
|
382 |
+
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT, # type: ignore
|
383 |
+
thickness: int = 3,
|
384 |
+
text_color: Color = Color.BLACK, # type: ignore
|
385 |
+
text_scale: float = 0.5,
|
386 |
+
text_thickness: int = 2,
|
387 |
+
text_padding: int = 10,
|
388 |
+
avoid_overlap: bool = True,
|
389 |
+
):
|
390 |
+
self.color: Union[Color, ColorPalette] = color
|
391 |
+
self.thickness: int = thickness
|
392 |
+
self.text_color: Color = text_color
|
393 |
+
self.text_scale: float = text_scale
|
394 |
+
self.text_thickness: int = text_thickness
|
395 |
+
self.text_padding: int = text_padding
|
396 |
+
self.avoid_overlap: bool = avoid_overlap
|
397 |
+
|
398 |
+
def annotate(
|
399 |
+
self,
|
400 |
+
scene: np.ndarray,
|
401 |
+
detections: Detections,
|
402 |
+
labels: Optional[List[str]] = None,
|
403 |
+
skip_label: bool = False,
|
404 |
+
image_size: Optional[Tuple[int, int]] = None,
|
405 |
+
) -> np.ndarray:
|
406 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
407 |
+
for i in range(len(detections)):
|
408 |
+
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
|
409 |
+
class_id = (
|
410 |
+
detections.class_id[i] if detections.class_id is not None else None
|
411 |
+
)
|
412 |
+
idx = class_id if class_id is not None else i
|
413 |
+
color = (
|
414 |
+
self.color.by_idx(idx)
|
415 |
+
if isinstance(self.color, ColorPalette)
|
416 |
+
else self.color
|
417 |
+
)
|
418 |
+
cv2.rectangle(
|
419 |
+
img=scene,
|
420 |
+
pt1=(x1, y1),
|
421 |
+
pt2=(x2, y2),
|
422 |
+
color=color.as_bgr(),
|
423 |
+
thickness=self.thickness,
|
424 |
+
)
|
425 |
+
if skip_label:
|
426 |
+
continue
|
427 |
+
|
428 |
+
text = (
|
429 |
+
f"{class_id}"
|
430 |
+
if (labels is None or len(detections) != len(labels))
|
431 |
+
else labels[i]
|
432 |
+
)
|
433 |
+
|
434 |
+
text_width, text_height = cv2.getTextSize(
|
435 |
+
text=text,
|
436 |
+
fontFace=font,
|
437 |
+
fontScale=self.text_scale,
|
438 |
+
thickness=self.text_thickness,
|
439 |
+
)[0]
|
440 |
+
|
441 |
+
if not self.avoid_overlap:
|
442 |
+
text_x = x1 + self.text_padding
|
443 |
+
text_y = y1 - self.text_padding
|
444 |
+
|
445 |
+
text_background_x1 = x1
|
446 |
+
text_background_y1 = y1 - 2 * self.text_padding - text_height
|
447 |
+
|
448 |
+
text_background_x2 = x1 + 2 * self.text_padding + text_width
|
449 |
+
text_background_y2 = y1
|
450 |
+
else:
|
451 |
+
(
|
452 |
+
text_x,
|
453 |
+
text_y,
|
454 |
+
text_background_x1,
|
455 |
+
text_background_y1,
|
456 |
+
text_background_x2,
|
457 |
+
text_background_y2,
|
458 |
+
) = self.get_optimal_label_pos(
|
459 |
+
self.text_padding,
|
460 |
+
text_width,
|
461 |
+
text_height,
|
462 |
+
x1,
|
463 |
+
y1,
|
464 |
+
x2,
|
465 |
+
y2,
|
466 |
+
detections,
|
467 |
+
image_size,
|
468 |
+
)
|
469 |
+
|
470 |
+
cv2.rectangle(
|
471 |
+
img=scene,
|
472 |
+
pt1=(text_background_x1, text_background_y1),
|
473 |
+
pt2=(text_background_x2, text_background_y2),
|
474 |
+
color=color.as_bgr(),
|
475 |
+
thickness=cv2.FILLED,
|
476 |
+
)
|
477 |
+
box_color = color.as_rgb()
|
478 |
+
luminance = (
|
479 |
+
0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
|
480 |
+
)
|
481 |
+
text_color = (0, 0, 0) if luminance > 160 else (255, 255, 255)
|
482 |
+
cv2.putText(
|
483 |
+
img=scene,
|
484 |
+
text=text,
|
485 |
+
org=(text_x, text_y),
|
486 |
+
fontFace=font,
|
487 |
+
fontScale=self.text_scale,
|
488 |
+
color=text_color,
|
489 |
+
thickness=self.text_thickness,
|
490 |
+
lineType=cv2.LINE_AA,
|
491 |
+
)
|
492 |
+
return scene
|
493 |
+
|
494 |
+
@staticmethod
|
495 |
+
def get_optimal_label_pos(
|
496 |
+
text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size
|
497 |
+
):
|
498 |
+
def get_is_overlap(
|
499 |
+
detections,
|
500 |
+
text_background_x1,
|
501 |
+
text_background_y1,
|
502 |
+
text_background_x2,
|
503 |
+
text_background_y2,
|
504 |
+
image_size,
|
505 |
+
):
|
506 |
+
is_overlap = False
|
507 |
+
for i in range(len(detections)):
|
508 |
+
detection = detections.xyxy[i].astype(int)
|
509 |
+
if (
|
510 |
+
intersection_over_union(
|
511 |
+
[
|
512 |
+
text_background_x1,
|
513 |
+
text_background_y1,
|
514 |
+
text_background_x2,
|
515 |
+
text_background_y2,
|
516 |
+
],
|
517 |
+
detection,
|
518 |
+
)
|
519 |
+
> 0.3
|
520 |
+
):
|
521 |
+
is_overlap = True
|
522 |
+
break
|
523 |
+
if (
|
524 |
+
text_background_x1 < 0
|
525 |
+
or text_background_x2 > image_size[0]
|
526 |
+
or text_background_y1 < 0
|
527 |
+
or text_background_y2 > image_size[1]
|
528 |
+
):
|
529 |
+
is_overlap = True
|
530 |
+
return is_overlap
|
531 |
+
|
532 |
+
text_x = x1 + text_padding
|
533 |
+
text_y = y1 - text_padding
|
534 |
+
|
535 |
+
text_background_x1 = x1
|
536 |
+
text_background_y1 = y1 - 2 * text_padding - text_height
|
537 |
+
|
538 |
+
text_background_x2 = x1 + 2 * text_padding + text_width
|
539 |
+
text_background_y2 = y1
|
540 |
+
is_overlap = get_is_overlap(
|
541 |
+
detections,
|
542 |
+
text_background_x1,
|
543 |
+
text_background_y1,
|
544 |
+
text_background_x2,
|
545 |
+
text_background_y2,
|
546 |
+
image_size,
|
547 |
+
)
|
548 |
+
if not is_overlap:
|
549 |
+
return (
|
550 |
+
text_x,
|
551 |
+
text_y,
|
552 |
+
text_background_x1,
|
553 |
+
text_background_y1,
|
554 |
+
text_background_x2,
|
555 |
+
text_background_y2,
|
556 |
+
)
|
557 |
+
|
558 |
+
text_x = x1 - text_padding - text_width
|
559 |
+
text_y = y1 + text_padding + text_height
|
560 |
+
|
561 |
+
text_background_x1 = x1 - 2 * text_padding - text_width
|
562 |
+
text_background_y1 = y1
|
563 |
+
|
564 |
+
text_background_x2 = x1
|
565 |
+
text_background_y2 = y1 + 2 * text_padding + text_height
|
566 |
+
is_overlap = get_is_overlap(
|
567 |
+
detections,
|
568 |
+
text_background_x1,
|
569 |
+
text_background_y1,
|
570 |
+
text_background_x2,
|
571 |
+
text_background_y2,
|
572 |
+
image_size,
|
573 |
+
)
|
574 |
+
if not is_overlap:
|
575 |
+
return (
|
576 |
+
text_x,
|
577 |
+
text_y,
|
578 |
+
text_background_x1,
|
579 |
+
text_background_y1,
|
580 |
+
text_background_x2,
|
581 |
+
text_background_y2,
|
582 |
+
)
|
583 |
+
|
584 |
+
text_x = x2 + text_padding
|
585 |
+
text_y = y1 + text_padding + text_height
|
586 |
+
|
587 |
+
text_background_x1 = x2
|
588 |
+
text_background_y1 = y1
|
589 |
+
|
590 |
+
text_background_x2 = x2 + 2 * text_padding + text_width
|
591 |
+
text_background_y2 = y1 + 2 * text_padding + text_height
|
592 |
+
|
593 |
+
is_overlap = get_is_overlap(
|
594 |
+
detections,
|
595 |
+
text_background_x1,
|
596 |
+
text_background_y1,
|
597 |
+
text_background_x2,
|
598 |
+
text_background_y2,
|
599 |
+
image_size,
|
600 |
+
)
|
601 |
+
if not is_overlap:
|
602 |
+
return (
|
603 |
+
text_x,
|
604 |
+
text_y,
|
605 |
+
text_background_x1,
|
606 |
+
text_background_y1,
|
607 |
+
text_background_x2,
|
608 |
+
text_background_y2,
|
609 |
+
)
|
610 |
+
|
611 |
+
text_x = x2 - text_padding - text_width
|
612 |
+
text_y = y1 - text_padding
|
613 |
+
|
614 |
+
text_background_x1 = x2 - 2 * text_padding - text_width
|
615 |
+
text_background_y1 = y1 - 2 * text_padding - text_height
|
616 |
+
|
617 |
+
text_background_x2 = x2
|
618 |
+
text_background_y2 = y1
|
619 |
+
|
620 |
+
is_overlap = get_is_overlap(
|
621 |
+
detections,
|
622 |
+
text_background_x1,
|
623 |
+
text_background_y1,
|
624 |
+
text_background_x2,
|
625 |
+
text_background_y2,
|
626 |
+
image_size,
|
627 |
+
)
|
628 |
+
if not is_overlap:
|
629 |
+
return (
|
630 |
+
text_x,
|
631 |
+
text_y,
|
632 |
+
text_background_x1,
|
633 |
+
text_background_y1,
|
634 |
+
text_background_x2,
|
635 |
+
text_background_y2,
|
636 |
+
)
|
637 |
+
|
638 |
+
return (
|
639 |
+
text_x,
|
640 |
+
text_y,
|
641 |
+
text_background_x1,
|
642 |
+
text_background_y1,
|
643 |
+
text_background_x2,
|
644 |
+
text_background_y2,
|
645 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
easyocr
|
2 |
+
einops==0.8.0
|
3 |
+
opencv-python
|
4 |
+
opencv-python-headless
|
5 |
+
supervision==0.18.0
|
6 |
+
timm
|
7 |
+
ultralytics==8.3.70
|