OussamaElfila21 commited on
Commit
5d92054
·
verified ·
1 Parent(s): c300ae3

Upload 35 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ utils/build/temp.win-amd64-cpython-311/Release/processing.obj filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use a specific Python version
2
+ FROM python:3.11.10-slim
3
+
4
+ # Set the working directory inside the container
5
+ WORKDIR /usr/src/app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ libgl1-mesa-glx \
11
+ libglib2.0-0 \
12
+ gcc \
13
+ python3-dev \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Upgrade pip
17
+ RUN pip install --upgrade pip
18
+
19
+ # Install Python dependencies
20
+ COPY requirements.txt .
21
+ RUN pip install -r requirements.txt
22
+
23
+ # Copy your application code into the container
24
+ COPY . .
25
+
26
+ # Build Cython extensions
27
+ RUN cd utils && python setup.py build_ext --inplace
28
+
29
+ # Create necessary directories if they don't exist
30
+ RUN mkdir -p data/models
31
+
32
+ # Expose the port for FastAPI application
33
+ EXPOSE 8000
34
+
35
+ # Specify the command to run the application with Gunicorn
36
+ CMD ["gunicorn", "app:app", "--workers", "4", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000"]
api_documentation.html ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Real-Time Image Processing API Documentation</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui.min.css">
9
+ <style>
10
+ body {
11
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
12
+ margin: 0;
13
+ padding: 0;
14
+ color: #333;
15
+ }
16
+ .container {
17
+ max-width: 1200px;
18
+ margin: 0 auto;
19
+ padding: 20px;
20
+ }
21
+ header {
22
+ background-color: #252f3f;
23
+ color: white;
24
+ padding: 20px;
25
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
26
+ }
27
+ header h1 {
28
+ margin: 0;
29
+ font-size: 2em;
30
+ }
31
+ header p {
32
+ margin-top: 10px;
33
+ opacity: 0.8;
34
+ }
35
+ .endpoint-cards {
36
+ display: flex;
37
+ flex-wrap: wrap;
38
+ gap: 20px;
39
+ margin-top: 30px;
40
+ }
41
+ .endpoint-card {
42
+ background-color: white;
43
+ border-radius: 8px;
44
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
45
+ padding: 20px;
46
+ flex: 1 1 300px;
47
+ transition: transform 0.2s, box-shadow 0.2s;
48
+ }
49
+ .endpoint-card:hover {
50
+ transform: translateY(-5px);
51
+ box-shadow: 0 5px 15px rgba(0,0,0,0.15);
52
+ }
53
+ .http-endpoint {
54
+ border-left: 5px solid #38a169;
55
+ }
56
+ .ws-endpoint {
57
+ border-left: 5px solid #3182ce;
58
+ }
59
+ .method {
60
+ display: inline-block;
61
+ padding: 5px 10px;
62
+ border-radius: 4px;
63
+ font-weight: bold;
64
+ margin-right: 10px;
65
+ }
66
+ .post {
67
+ background-color: #38a169;
68
+ color: white;
69
+ }
70
+ .ws {
71
+ background-color: #3182ce;
72
+ color: white;
73
+ }
74
+ .endpoint-title {
75
+ font-size: 1.2em;
76
+ margin-bottom: 15px;
77
+ display: flex;
78
+ align-items: center;
79
+ }
80
+ .section {
81
+ margin-top: 40px;
82
+ margin-bottom: 30px;
83
+ }
84
+ h2 {
85
+ border-bottom: 1px solid #eee;
86
+ padding-bottom: 10px;
87
+ color: #252f3f;
88
+ }
89
+ .code {
90
+ background-color: #f7fafc;
91
+ border-radius: 4px;
92
+ padding: 15px;
93
+ font-family: monospace;
94
+ overflow-x: auto;
95
+ margin: 15px 0;
96
+ }
97
+ .parameter-table, .response-table {
98
+ width: 100%;
99
+ border-collapse: collapse;
100
+ margin: 15px 0;
101
+ }
102
+ .parameter-table th, .parameter-table td,
103
+ .response-table th, .response-table td {
104
+ text-align: left;
105
+ padding: 10px;
106
+ border-bottom: 1px solid #eee;
107
+ }
108
+ .parameter-table th, .response-table th {
109
+ background-color: #f7fafc;
110
+ }
111
+ .try-it {
112
+ margin-top: 30px;
113
+ padding-top: 30px;
114
+ border-top: 1px solid #eee;
115
+ }
116
+ .swagger-ui .wrapper { max-width: 100%; }
117
+ #swagger-ui {
118
+ margin-top: 30px;
119
+ border: 1px solid #eee;
120
+ border-radius: 8px;
121
+ overflow: hidden;
122
+ }
123
+ .info-box {
124
+ background-color: #ebf8ff;
125
+ border-left: 5px solid #3182ce;
126
+ padding: 15px;
127
+ margin: 20px 0;
128
+ border-radius: 4px;
129
+ }
130
+ .response-example {
131
+ background-color: #f0fff4;
132
+ border-left: 5px solid #38a169;
133
+ padding: 15px;
134
+ border-radius: 4px;
135
+ margin-top: 20px;
136
+ }
137
+
138
+ /* Updated tabs styles */
139
+ .tabs {
140
+ display: flex;
141
+ margin-top: 30px;
142
+ border-bottom: 1px solid #e2e8f0;
143
+ align-items: center;
144
+ }
145
+ .tab {
146
+ padding: 10px 20px;
147
+ cursor: pointer;
148
+ border-bottom: 2px solid transparent;
149
+ transition: all 0.2s ease;
150
+ }
151
+ .tab.active {
152
+ border-bottom: 2px solid #3182ce;
153
+ color: #3182ce;
154
+ font-weight: bold;
155
+ }
156
+
157
+ /* New try-button styles */
158
+ .try-button {
159
+ margin-left: auto;
160
+ display: flex;
161
+ align-items: center;
162
+ padding: 10px 15px;
163
+ font-size: 14px;
164
+ font-weight: bold;
165
+ border-radius: 4px;
166
+ background-color: #4299e1;
167
+ color: white;
168
+ text-decoration: none;
169
+ border: none;
170
+ transition: background-color 0.2s;
171
+ }
172
+ .try-button:hover {
173
+ background-color: #3182ce;
174
+ }
175
+ .try-button svg {
176
+ margin-left: 8px;
177
+ height: 16px;
178
+ width: 16px;
179
+ }
180
+
181
+ .tab-content {
182
+ display: none;
183
+ padding: 20px 0;
184
+ }
185
+ .tab-content.active {
186
+ display: block;
187
+ }
188
+ button {
189
+ background-color: #4299e1;
190
+ color: white;
191
+ border: none;
192
+ padding: 10px 15px;
193
+ border-radius: 4px;
194
+ cursor: pointer;
195
+ font-weight: bold;
196
+ transition: background-color 0.2s;
197
+ }
198
+ button:hover {
199
+ background-color: #3182ce;
200
+ }
201
+ </style>
202
+ </head>
203
+ <body>
204
+ <header>
205
+ <div class="container">
206
+ <h1>Advanced Driver Vision Assistance with Near Collision Estimation System</h1>
207
+ <p>API for object detection, depth estimation, and distance prediction using computer vision models</p>
208
+ </div>
209
+ </header>
210
+
211
+ <div class="container">
212
+ <section class="section">
213
+ <h2>Overview</h2>
214
+ <p>This API provides access to advanced computer vision models for real-time image processing. It leverages:</p>
215
+ <ul>
216
+ <li><strong>DETR (DEtection TRansformer)</strong> - For accurate object detection</li>
217
+ <li><strong>GLPN (Global-Local Path Networks)</strong> - For depth estimation</li>
218
+ <li><strong>LSTM Model</strong> - For Z-location prediction</li>
219
+ </ul>
220
+ <p>The API supports both HTTP and WebSocket protocols:</p>
221
+ <div class="endpoint-cards">
222
+ <div class="endpoint-card http-endpoint">
223
+ <div class="endpoint-title">
224
+ <span class="method post">POST</span>
225
+ <span>/api/predict</span>
226
+ </div>
227
+ <p>Process a single image via HTTP request</p>
228
+ </div>
229
+ <div class="endpoint-card ws-endpoint">
230
+ <div class="endpoint-title">
231
+ <span class="method ws">WS</span>
232
+ <span>/ws/predict</span>
233
+ </div>
234
+ <p>Stream images for real-time processing via WebSocket</p>
235
+ </div>
236
+ </div>
237
+ </section>
238
+
239
+ <div class="tabs">
240
+ <div class="tab active" onclick="switchTab('http-api')">HTTP API</div>
241
+ <div class="tab" onclick="switchTab('websocket-api')">WebSocket API</div>
242
+ <div class="tab">
243
+ <a href="./try_page" class="ml-2 inline-flex items-center px-4 py-2 border border-transparent text-sm font-medium rounded-md shadow-sm text-white bg-blue-600 hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 transition-all duration-200">
244
+ Try it
245
+ <svg xmlns="http://www.w3.org/2000/svg" class="ml-1 h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
246
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M14 5l7 7m0 0l-7 7m7-7H3" />
247
+ </svg>
248
+ </a>
249
+ </div>
250
+ </div>
251
+
252
+ <div id="http-api" class="tab-content active">
253
+ <section class="section">
254
+ <h2>HTTP API Reference</h2>
255
+ <h3>POST /api/predict</h3>
256
+ <p>Process a single image for object detection, depth estimation, and distance prediction.</p>
257
+
258
+ <h4>Request</h4>
259
+ <p><strong>Content-Type:</strong> multipart/form-data</p>
260
+
261
+ <table class="parameter-table">
262
+ <tr>
263
+ <th>Parameter</th>
264
+ <th>Type</th>
265
+ <th>Required</th>
266
+ <th>Description</th>
267
+ </tr>
268
+ <tr>
269
+ <td>file</td>
270
+ <td>File</td>
271
+ <td>Yes</td>
272
+ <td>The image file to process (JPEG, PNG)</td>
273
+ </tr>
274
+ </table>
275
+
276
+ <h4>Request Example</h4>
277
+ <div class="code">
278
+ # Python example using requests
279
+ import requests
280
+
281
+ url = "http://localhost:8000/api/predict"
282
+ files = {"file": open("image.jpg", "rb")}
283
+
284
+ response = requests.post(url, files=files)
285
+ data = response.json()
286
+ print(data)
287
+ </div>
288
+
289
+ <h4>Response</h4>
290
+ <p>Returns a JSON object containing:</p>
291
+ <table class="response-table">
292
+ <tr>
293
+ <th>Field</th>
294
+ <th>Type</th>
295
+ <th>Description</th>
296
+ </tr>
297
+ <tr>
298
+ <td>objects</td>
299
+ <td>Array</td>
300
+ <td>Array of detected objects with their properties</td>
301
+ </tr>
302
+ <tr>
303
+ <td>objects[].class</td>
304
+ <td>String</td>
305
+ <td>Class of the detected object (e.g., 'car', 'person')</td>
306
+ </tr>
307
+ <tr>
308
+ <td>objects[].distance_estimated</td>
309
+ <td>Number</td>
310
+ <td>Estimated distance of the object</td>
311
+ </tr>
312
+ <tr>
313
+ <td>objects[].features</td>
314
+ <td>Object</td>
315
+ <td>Features used for prediction (bounding box, depth information)</td>
316
+ </tr>
317
+ <tr>
318
+ <td>frame_id</td>
319
+ <td>Number</td>
320
+ <td>ID of the processed frame (0 for HTTP requests)</td>
321
+ </tr>
322
+ <tr>
323
+ <td>timings</td>
324
+ <td>Object</td>
325
+ <td>Processing time metrics for each step</td>
326
+ </tr>
327
+ </table>
328
+
329
+ <div class="response-example">
330
+ <h4>Response Example</h4>
331
+ <pre class="code">
332
+ {
333
+ "objects": [
334
+ {
335
+ "class": "car",
336
+ "distance_estimated": 15.42,
337
+ "features": {
338
+ "xmin": 120.5,
339
+ "ymin": 230.8,
340
+ "xmax": 350.2,
341
+ "ymax": 480.3,
342
+ "mean_depth": 0.75,
343
+ "depth_mean_trim": 0.72,
344
+ "depth_median": 0.71,
345
+ "width": 229.7,
346
+ "height": 249.5
347
+ }
348
+ },
349
+ {
350
+ "class": "person",
351
+ "distance_estimated": 8.76,
352
+ "features": {
353
+ "xmin": 450.1,
354
+ "ymin": 200.4,
355
+ "xmax": 510.8,
356
+ "ymax": 380.2,
357
+ "mean_depth": 0.58,
358
+ "depth_mean_trim": 0.56,
359
+ "depth_median": 0.55,
360
+ "width": 60.7,
361
+ "height": 179.8
362
+ }
363
+ }
364
+ ],
365
+ "frame_id": 0,
366
+ "timings": {
367
+ "decode_time": 0.015,
368
+ "models_time": 0.452,
369
+ "process_time": 0.063,
370
+ "json_time": 0.021,
371
+ "total_time": 0.551
372
+ }
373
+ }
374
+ </pre>
375
+ </div>
376
+
377
+ <h4>HTTP Status Codes</h4>
378
+ <table class="response-table">
379
+ <tr>
380
+ <th>Status Code</th>
381
+ <th>Description</th>
382
+ </tr>
383
+ <tr>
384
+ <td>200</td>
385
+ <td>OK - Request was successful</td>
386
+ </tr>
387
+ <tr>
388
+ <td>400</td>
389
+ <td>Bad Request - Empty file or invalid format</td>
390
+ </tr>
391
+ <tr>
392
+ <td>500</td>
393
+ <td>Internal Server Error - Processing error</td>
394
+ </tr>
395
+ </table>
396
+ </section>
397
+ </div>
398
+
399
+ <div id="websocket-api" class="tab-content">
400
+ <section class="section">
401
+ <h2>WebSocket API Reference</h2>
402
+ <h3>WebSocket /ws/predict</h3>
403
+ <p>
404
+ Stream images for real-time processing and get instant results.
405
+ Ideal for video feeds and applications requiring continuous processing.
406
+ </p>
407
+
408
+ <div class="info-box">
409
+ <p><strong>Note:</strong> WebSocket offers better performance for real-time applications. Use this endpoint for processing video feeds or when you need to process multiple images in rapid succession.</p>
410
+ </div>
411
+
412
+ <h4>Connection</h4>
413
+ <div class="code">
414
+ # JavaScript example
415
+ const socket = new WebSocket('ws://localhost:8000/ws/predict');
416
+
417
+ socket.onopen = function(e) {
418
+ console.log('Connection established');
419
+ };
420
+
421
+ socket.onmessage = function(event) {
422
+ const response = JSON.parse(event.data);
423
+ console.log('Received:', response);
424
+ };
425
+
426
+ socket.onclose = function(event) {
427
+ console.log('Connection closed');
428
+ };
429
+ </div>
430
+
431
+ <h4>Sending Images</h4>
432
+ <p>Send binary image data directly over the WebSocket connection:</p>
433
+ <div class="code">
434
+ // JavaScript example: Sending an image from canvas or file
435
+ function sendImageFromCanvas(canvas) {
436
+ canvas.toBlob(function(blob) {
437
+ const reader = new FileReader();
438
+ reader.onload = function() {
439
+ socket.send(reader.result);
440
+ };
441
+ reader.readAsArrayBuffer(blob);
442
+ }, 'image/jpeg');
443
+ }
444
+
445
+ // Or from input file
446
+ fileInput.onchange = function() {
447
+ const file = this.files[0];
448
+ const reader = new FileReader();
449
+ reader.onload = function() {
450
+ socket.send(reader.result);
451
+ };
452
+ reader.readAsArrayBuffer(file);
453
+ };
454
+ </div>
455
+
456
+ <h4>Response Format</h4>
457
+ <p>The WebSocket API returns the same JSON structure as the HTTP API, with incrementing frame_id values.</p>
458
+ <div class="response-example">
459
+ <h4>Response Example</h4>
460
+ <pre class="code">
461
+ {
462
+ "objects": [
463
+ {
464
+ "class": "car",
465
+ "distance_estimated": 14.86,
466
+ "features": {
467
+ "xmin": 125.3,
468
+ "ymin": 235.1,
469
+ "xmax": 355.7,
470
+ "ymax": 485.9,
471
+ "mean_depth": 0.77,
472
+ "depth_mean_trim": 0.74,
473
+ "depth_median": 0.73,
474
+ "width": 230.4,
475
+ "height": 250.8
476
+ }
477
+ }
478
+ ],
479
+ "frame_id": 42,
480
+ "timings": {
481
+ "decode_time": 0.014,
482
+ "models_time": 0.445,
483
+ "process_time": 0.061,
484
+ "json_time": 0.020,
485
+ "total_time": 0.540
486
+ }
487
+ }
488
+ </pre>
489
+ </div>
490
+ </section>
491
+ </div>
492
+
493
+ <div id="try-it" class="tab-content">
494
+ <section class="section">
495
+ <h2>Try The API</h2>
496
+ <p>You can test the API directly using the interactive Swagger UI below:</p>
497
+
498
+ <div id="swagger-ui"></div>
499
+
500
+ <h3>Simple WebSocket Client</h3>
501
+ <p>Upload an image to test the WebSocket endpoint:</p>
502
+
503
+ <div>
504
+ <input type="file" id="wsFileInput" accept="image/*">
505
+ <button id="connectWsButton">Connect to WebSocket</button>
506
+ <button id="disconnectWsButton" disabled>Disconnect</button>
507
+ </div>
508
+
509
+ <div>
510
+ <p><strong>Status: </strong><span id="wsStatus">Disconnected</span></p>
511
+ <p><strong>Last Response: </strong></p>
512
+ <pre id="wsResponse" class="code" style="max-height: 300px; overflow-y: auto;"></pre>
513
+ </div>
514
+ </section>
515
+ </div>
516
+ </div>
517
+
518
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui-bundle.js"></script>
519
+ <script>
520
+ // Handle tab switching
521
+ function switchTab(tabId) {
522
+ document.querySelectorAll('.tab').forEach(tab => tab.classList.remove('active'));
523
+ document.querySelectorAll('.tab-content').forEach(content => content.classList.remove('active'));
524
+
525
+ document.querySelector(`.tab[onclick="switchTab('${tabId}')"]`).classList.add('active');
526
+ document.getElementById(tabId).classList.add('active');
527
+ }
528
+
529
+ // Initialize Swagger UI
530
+ window.onload = function() {
531
+ SwaggerUIBundle({
532
+ url: "/api/openapi.json",
533
+ dom_id: '#swagger-ui',
534
+ presets: [
535
+ SwaggerUIBundle.presets.apis,
536
+ SwaggerUIBundle.SwaggerUIStandalonePreset
537
+ ],
538
+ layout: "BaseLayout",
539
+ deepLinking: true
540
+ });
541
+
542
+ // WebSocket client setup
543
+ let socket = null;
544
+
545
+ const connectButton = document.getElementById('connectWsButton');
546
+ const disconnectButton = document.getElementById('disconnectWsButton');
547
+ const fileInput = document.getElementById('wsFileInput');
548
+ const statusElement = document.getElementById('wsStatus');
549
+ const responseElement = document.getElementById('wsResponse');
550
+
551
+ connectButton.addEventListener('click', () => {
552
+ if (socket) {
553
+ socket.close();
554
+ }
555
+
556
+ const wsUrl = window.location.protocol === 'https:'
557
+ ? `wss://${window.location.host}/ws/predict`
558
+ : `ws://${window.location.host}/ws/predict`;
559
+
560
+ socket = new WebSocket(wsUrl);
561
+
562
+ socket.onopen = function() {
563
+ statusElement.textContent = 'Connected';
564
+ statusElement.style.color = 'green';
565
+ connectButton.disabled = true;
566
+ disconnectButton.disabled = false;
567
+ };
568
+
569
+ socket.onmessage = function(event) {
570
+ const response = JSON.parse(event.data);
571
+ responseElement.textContent = JSON.stringify(response, null, 2);
572
+ };
573
+
574
+ socket.onclose = function() {
575
+ statusElement.textContent = 'Disconnected';
576
+ statusElement.style.color = 'red';
577
+ connectButton.disabled = false;
578
+ disconnectButton.disabled = true;
579
+ socket = null;
580
+ };
581
+
582
+ socket.onerror = function(error) {
583
+ statusElement.textContent = 'Error: ' + error.message;
584
+ statusElement.style.color = 'red';
585
+ };
586
+ });
587
+
588
+ disconnectButton.addEventListener('click', () => {
589
+ if (socket) {
590
+ socket.close();
591
+ }
592
+ });
593
+
594
+ fileInput.addEventListener('change', () => {
595
+ if (!socket || socket.readyState !== WebSocket.OPEN) {
596
+ alert('Please connect to WebSocket first');
597
+ return;
598
+ }
599
+
600
+ const file = fileInput.files[0];
601
+ if (!file) return;
602
+
603
+ const reader = new FileReader();
604
+ reader.onload = function() {
605
+ socket.send(reader.result);
606
+ };
607
+ reader.readAsArrayBuffer(file);
608
+ });
609
+ };
610
+ </script>
611
+ </body>
612
+ </html>
app.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ import uvicorn
3
+ import xgboost as xgb
4
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, UploadFile, HTTPException
5
+
6
+ from fastapi.responses import HTMLResponse
7
+ from fastapi.responses import JSONResponse
8
+ import asyncio
9
+ import json
10
+ import pickle
11
+ import warnings
12
+ import os
13
+ import io
14
+
15
+ import timeit
16
+ from PIL import Image
17
+ import numpy as np
18
+ import cv2
19
+
20
+ # Add this to your existing imports if not already present
21
+ from fastapi.openapi.docs import get_swagger_ui_html
22
+ from fastapi.openapi.utils import get_openapi
23
+
24
+ from models.detr_model import DETR
25
+ from models.glpn_model import GLPDepth
26
+ from models.lstm_model import LSTM_Model
27
+ from models.predict_z_location_single_row_lstm import predict_z_location_single_row_lstm
28
+ from utils.processing import PROCESSING
29
+ from config import CONFIG
30
+
31
+ warnings.filterwarnings("ignore")
32
+
33
+ # Initialize FastAPI app
34
+ app = FastAPI(
35
+ title="Real-Time WebSocket Image Processing API",
36
+ description="API for object detection and depth estimation using WebSocket for real-time image processing.",
37
+ )
38
+
39
+ try:
40
+ # Load models and utilities
41
+ device = CONFIG['device']
42
+ print("Loading models...")
43
+
44
+ detr = DETR() # Object detection model (DETR)
45
+ print("DETR model loaded.")
46
+
47
+ glpn = GLPDepth() # Depth estimation model (GLPN)
48
+ print("GLPDepth model loaded.")
49
+
50
+ zlocE_LSTM = LSTM_Model() # LSTM model for prediction (e.g., localization)
51
+ print("LSTM model loaded.")
52
+
53
+
54
+ lstm_scaler = pickle.load(open(CONFIG['lstm_scaler_path'], 'rb')) # Load pre-trained scaler for LSTM
55
+ print("LSTM Scaler loaded.")
56
+
57
+
58
+ processing = PROCESSING() # Utility class for post-processing
59
+ print("Processing utilities loaded.")
60
+
61
+ except Exception as e:
62
+ print(f"An unexpected error occurred. Details: {e}")
63
+
64
+
65
+
66
+ # Serve HTML documentation for the API
67
+ @app.get("/", response_class=HTMLResponse)
68
+ async def get_docs():
69
+ """
70
+ Serve HTML documentation for the WebSocket-based image processing API.
71
+ The HTML file must be available in the same directory.
72
+ Returns a 404 error if the documentation file is not found.
73
+ """
74
+ html_path = os.path.join(os.path.dirname(__file__), "api_documentation.html")
75
+ if not os.path.exists(html_path):
76
+ return HTMLResponse(content="api_documentation.html file not found", status_code=404)
77
+ with open(html_path, "r") as f:
78
+ return HTMLResponse(f.read())
79
+
80
+
81
+ @app.get("/try_page", response_class=HTMLResponse)
82
+ async def get_docs():
83
+ """
84
+ Serve HTML documentation for the WebSocket-based image processing API.
85
+ The HTML file must be available in the same directory.
86
+ Returns a 404 error if the documentation file is not found.
87
+ """
88
+ html_path = os.path.join(os.path.dirname(__file__), "try_page.html")
89
+ if not os.path.exists(html_path):
90
+ return HTMLResponse(content="try_page.html file not found", status_code=404)
91
+ with open(html_path, "r") as f:
92
+ return HTMLResponse(f.read())
93
+
94
+
95
+ # Function to decode the image received via WebSocket
96
+ async def decode_image(image_bytes):
97
+ """
98
+ Decodes image bytes into a PIL Image and returns the image along with its shape.
99
+
100
+ Args:
101
+ image_bytes (bytes): The image data received from the client.
102
+
103
+ Returns:
104
+ tuple: A tuple containing:
105
+ - pil_image (PIL.Image): The decoded image.
106
+ - img_shape (tuple): Shape of the image as (height, width).
107
+ - decode_time (float): Time taken to decode the image in seconds.
108
+
109
+ Raises:
110
+ ValueError: If image decoding fails.
111
+ """
112
+ start = timeit.default_timer()
113
+ nparr = np.frombuffer(image_bytes, np.uint8)
114
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
115
+ if frame is None:
116
+ raise ValueError("Failed to decode image")
117
+ color_converted = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
118
+ pil_image = Image.fromarray(color_converted)
119
+ img_shape = color_converted.shape[0:2] # (height, width)
120
+ end = timeit.default_timer()
121
+ return pil_image, img_shape, end - start
122
+
123
+
124
+ # Function to run the DETR model for object detection
125
+ async def run_detr_model(pil_image):
126
+ """
127
+ Runs the DETR (DEtection TRansformer) model to perform object detection on the input image.
128
+
129
+ Args:
130
+ pil_image (PIL.Image): The image to be processed by the DETR model.
131
+
132
+ Returns:
133
+ tuple: A tuple containing:
134
+ - detr_result (tuple): The DETR model output consisting of detections' scores and boxes.
135
+ - detr_time (float): Time taken to run the DETR model in seconds.
136
+ """
137
+ start = timeit.default_timer()
138
+ detr_result = await asyncio.to_thread(detr.detect, pil_image)
139
+ end = timeit.default_timer()
140
+ return detr_result, end - start
141
+
142
+
143
+ # Function to run the GLPN model for depth estimation
144
+ async def run_glpn_model(pil_image, img_shape):
145
+ """
146
+ Runs the GLPN (Global Local Prediction Network) model to estimate the depth of the objects in the image.
147
+
148
+ Args:
149
+ pil_image (PIL.Image): The image to be processed by the GLPN model.
150
+ img_shape (tuple): The shape of the image as (height, width).
151
+
152
+ Returns:
153
+ tuple: A tuple containing:
154
+ - depth_map (numpy.ndarray): The depth map for the input image.
155
+ - glpn_time (float): Time taken to run the GLPN model in seconds.
156
+ """
157
+ start = timeit.default_timer()
158
+ depth_map = await asyncio.to_thread(glpn.predict, pil_image, img_shape)
159
+ end = timeit.default_timer()
160
+ return depth_map, end - start
161
+
162
+
163
+ # Function to process the detections with depth map
164
+ async def process_detections(scores, boxes, depth_map):
165
+ """
166
+ Processes the DETR model detections and integrates depth information from the GLPN model.
167
+
168
+ Args:
169
+ scores (numpy.ndarray): The detection scores for the detected objects.
170
+ boxes (numpy.ndarray): The bounding boxes for the detected objects.
171
+ depth_map (numpy.ndarray): The depth map generated by the GLPN model.
172
+
173
+ Returns:
174
+ tuple: A tuple containing:
175
+ - pdata (dict): Processed detection data including depth and bounding box information.
176
+ - process_time (float): Time taken for processing detections in seconds.
177
+ """
178
+ start = timeit.default_timer()
179
+ pdata = processing.process_detections(scores, boxes, depth_map, detr)
180
+ end = timeit.default_timer()
181
+ return pdata, end - start
182
+
183
+
184
+ # Function to generate JSON output for LSTM predictions
185
+ async def generate_json_output(data):
186
+ """
187
+ Predict Z-location for each object in the data and prepare the JSON output.
188
+
189
+ Parameters:
190
+ - data: DataFrame with bounding box coordinates, depth information, and class type.
191
+ - ZlocE: Pre-loaded LSTM model for Z-location prediction.
192
+ - scaler: Scaler for normalizing input data.
193
+
194
+ Returns:
195
+ - JSON structure with object class, distance estimated, and relevant features.
196
+ """
197
+ output_json = []
198
+ start = timeit.default_timer()
199
+
200
+ # Iterate over each row in the data
201
+ for i, row in data.iterrows():
202
+ # Predict distance for each object using the single-row prediction function
203
+
204
+ distance = predict_z_location_single_row_lstm(row, zlocE_LSTM, lstm_scaler)
205
+
206
+ # Create object info dictionary
207
+ object_info = {
208
+ "class": row["class"], # Object class (e.g., 'car', 'truck')
209
+ "distance_estimated": float(distance), # Convert distance to float (if necessary)
210
+ "features": {
211
+ "xmin": float(row["xmin"]), # Bounding box xmin
212
+ "ymin": float(row["ymin"]), # Bounding box ymin
213
+ "xmax": float(row["xmax"]), # Bounding box xmax
214
+ "ymax": float(row["ymax"]), # Bounding box ymax
215
+ "mean_depth": float(row["depth_mean"]), # Depth mean
216
+ "depth_mean_trim": float(row["depth_mean_trim"]), # Depth mean trim
217
+ "depth_median": float(row["depth_median"]), # Depth median
218
+ "width": float(row["width"]), # Object width
219
+ "height": float(row["height"]) # Object height
220
+ }
221
+ }
222
+
223
+ # Append each object info to the output JSON list
224
+ output_json.append(object_info)
225
+
226
+ end = timeit.default_timer()
227
+
228
+ # Return the final JSON output structure, and time
229
+ return {"objects": output_json}, end - start
230
+
231
+
232
+ # Function to process a single frame (image) in the WebSocket stream
233
+ async def process_frame(frame_id, image_bytes):
234
+ """
235
+ Processes a single frame (image) from the WebSocket stream. The process includes:
236
+ - Decoding the image.
237
+ - Running the DETR and GLPN models concurrently.
238
+ - Processing detections and generating the final output JSON.
239
+
240
+ Args:
241
+ frame_id (int): The identifier for the frame being processed.
242
+ image_bytes (bytes): The image data received from the WebSocket.
243
+
244
+ Returns:
245
+ dict: A dictionary containing the output JSON and timing information for each processing step.
246
+ """
247
+ timings = {}
248
+ try:
249
+ # Step 1: Decode the image
250
+ pil_image, img_shape, decode_time = await decode_image(image_bytes)
251
+ timings["decode_time"] = decode_time
252
+
253
+ # Step 2: Run DETR and GLPN models in parallel
254
+ (detr_result, detr_time), (depth_map, glpn_time) = await asyncio.gather(
255
+ run_detr_model(pil_image),
256
+ run_glpn_model(pil_image, img_shape)
257
+ )
258
+ models_time = max(detr_time, glpn_time) # Take the longest time of the two models
259
+ timings["models_time"] = models_time
260
+
261
+ # Step 3: Process detections with depth map
262
+ scores, boxes = detr_result
263
+ pdata, process_time = await process_detections(scores, boxes, depth_map)
264
+ timings["process_time"] = process_time
265
+
266
+ # Step 4: Generate output JSON
267
+ print("generate json")
268
+ output_json, json_time = await generate_json_output(pdata)
269
+ print(output_json)
270
+ timings["json_time"] = json_time
271
+
272
+ timings["total_time"] = decode_time + models_time + process_time + json_time
273
+
274
+ # Add frame_id and timings to the JSON output
275
+ output_json["frame_id"] = frame_id
276
+ output_json["timings"] = timings
277
+
278
+ return output_json
279
+
280
+ except Exception as e:
281
+ return {
282
+ "error": str(e),
283
+ "frame_id": frame_id,
284
+ "timings": timings
285
+ }
286
+
287
+
288
+ @app.post("/api/predict", summary="Process a single image for object detection and depth estimation")
289
+ async def process_image(file: UploadFile = File(...)):
290
+ """
291
+ Process a single image for object detection and depth estimation.
292
+
293
+ The endpoint performs:
294
+ - Object detection using DETR model
295
+ - Depth estimation using GLPN model
296
+ - Z-location prediction using LSTM model
297
+
298
+ Parameters:
299
+ - file: Image file to process (JPEG, PNG)
300
+
301
+ Returns:
302
+ - JSON response with detected objects, estimated distances, and timing information
303
+ """
304
+ try:
305
+ # Read image content
306
+ image_bytes = await file.read()
307
+ if not image_bytes:
308
+ raise HTTPException(status_code=400, detail="Empty file")
309
+
310
+ # Use the same processing pipeline as the WebSocket endpoint
311
+ result = await process_frame(0, image_bytes)
312
+
313
+ # Check if there's an error
314
+ if "error" in result:
315
+ raise HTTPException(status_code=500, detail=result["error"])
316
+
317
+ return JSONResponse(content=result)
318
+
319
+ except Exception as e:
320
+ raise HTTPException(status_code=500, detail=str(e))
321
+
322
+
323
+
324
+ # Add custom OpenAPI documentation
325
+ @app.get("/api/docs", include_in_schema=False)
326
+ async def custom_swagger_ui_html():
327
+ return get_swagger_ui_html(
328
+ openapi_url="/api/openapi.json",
329
+ title="Real-Time Image Processing API Documentation",
330
+ swagger_js_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui-bundle.js",
331
+ swagger_css_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui.css",
332
+ )
333
+
334
+
335
+ @app.get("/api/openapi.json", include_in_schema=False)
336
+ async def get_open_api_endpoint():
337
+ return get_openapi(
338
+ title="Real-Time Image Processing API",
339
+ version="1.0.0",
340
+ description="API for object detection, depth estimation, and z-location prediction using computer vision models",
341
+ routes=app.routes,
342
+ )
343
+
344
+
345
+ @app.websocket("/ws/predict")
346
+ async def websocket_endpoint(websocket: WebSocket):
347
+ """
348
+ WebSocket endpoint for real-time image processing. Clients can send image frames to be processed
349
+ and receive JSON output containing object detection, depth estimation, and predictions in real-time.
350
+
351
+ - Handles the reception of image data over the WebSocket.
352
+ - Calls the image processing pipeline and returns the result.
353
+
354
+ Args:
355
+ websocket (WebSocket): The WebSocket connection to the client.
356
+ """
357
+ await websocket.accept()
358
+ frame_id = 0
359
+
360
+ try:
361
+ while True:
362
+ frame_id += 1
363
+
364
+ # Receive image bytes from the client
365
+ image_bytes = await websocket.receive_bytes()
366
+
367
+ # Process the frame asynchronously
368
+ processing_task = asyncio.create_task(process_frame(frame_id, image_bytes))
369
+ result = await processing_task
370
+
371
+ # Send the result back to the client
372
+ await websocket.send_text(json.dumps(result))
373
+
374
+ except WebSocketDisconnect:
375
+ print(f"Client disconnected after processing {frame_id} frames.")
376
+ except Exception as e:
377
+ print(f"Unexpected error: {e}")
378
+ finally:
379
+ await websocket.close()
config.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ CONFIG = {
4
+ 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
5
+ 'detr_model_path': 'facebook/detr-resnet-101',
6
+ 'glpn_model_path': 'vinvino02/glpn-kitti',
7
+ 'lstm_model_path': 'data/models/pretrained_lstm.pth',
8
+ 'lstm_scaler_path': 'data/models/lstm_scaler.pkl',
9
+ 'xgboost_path': 'data/models/xgboost_model.json',
10
+ 'xgboost_scaler_path': 'data/models/scaler.joblib',
11
+ 'numerical_cols_path': 'data/models/numerical_columns.joblib'
12
+ }
data/models/lstm_scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e21e30154403927e4696795a086191bacf12a03aecb904992faa2dc9fb17343
3
+ size 809
data/models/pretrained_lstm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9a7728cd4db9ebb4c8c0cb2fef644ae7d159359864002aa94148f9ea127005c
3
+ size 31160407
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (167 Bytes). View file
 
models/__pycache__/detr_model.cpython-311.pyc ADDED
Binary file (7.7 kB). View file
 
models/__pycache__/distance_prediction.cpython-311.pyc ADDED
Binary file (2.92 kB). View file
 
models/__pycache__/glpn_model.cpython-311.pyc ADDED
Binary file (2.95 kB). View file
 
models/__pycache__/lstm_model.cpython-311.pyc ADDED
Binary file (4.22 kB). View file
 
models/__pycache__/predict_z_location_single_row_lstm.cpython-311.pyc ADDED
Binary file (2.89 kB). View file
 
models/detr_model.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Created on Sat Apr 9 04:08:02 2022
3
+ @author: Admin_with ODD Team
4
+
5
+ Edited by our team : Sat Oct 5 10:00 2024
6
+
7
+ references: https://github.com/vinvino02/GLPDepth
8
+ """
9
+
10
+ import io
11
+ import torch
12
+ import base64
13
+ from config import CONFIG
14
+ from torchvision import transforms
15
+ from matplotlib import pyplot as plt
16
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
17
+ from transformers import DetrForObjectDetection, DetrImageProcessor
18
+
19
+ class DETR:
20
+ def __init__(self):
21
+
22
+ self.CLASSES = [
23
+ 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
24
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
25
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
26
+ 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
27
+ 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
28
+ 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
29
+ 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
30
+ 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
31
+ 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
32
+ 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
33
+ 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
34
+ 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
35
+ 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
36
+ 'toothbrush'
37
+ ]
38
+
39
+ self.COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098],
40
+ [0.929, 0.694, 0.125], [0, 0, 1], [0.466, 0.674, 0.188],
41
+ [0.301, 0.745, 0.933]]
42
+
43
+ self.transform = transforms.Compose([
44
+ transforms.ToTensor(),
45
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
46
+ ])
47
+
48
+ self.model = DetrForObjectDetection.from_pretrained(CONFIG['detr_model_path'], revision="no_timm")
49
+ self.model.to(CONFIG['device'])
50
+ self.model.eval()
51
+
52
+
53
+ def box_cxcywh_to_xyxy(self, x):
54
+ x_c, y_c, w, h = x.unbind(1)
55
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
56
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
57
+ return torch.stack(b, dim=1).to(CONFIG['device'])
58
+
59
+ def rescale_bboxes(self, out_bbox, size):
60
+ img_w, img_h = size
61
+ b = self.box_cxcywh_to_xyxy(out_bbox)
62
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(CONFIG['device'])
63
+ return b
64
+
65
+
66
+ def detect(self, im):
67
+ img = self.transform(im).unsqueeze(0).to(CONFIG['device'])
68
+ assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'Image too large'
69
+ outputs = self.model(img)
70
+ probas = outputs['logits'].softmax(-1)[0, :, :-1]
71
+ keep = probas.max(-1).values > 0.7
72
+ bboxes_scaled = self.rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
73
+ return probas[keep], bboxes_scaled
74
+
75
+ def visualize(self, im, probas, bboxes):
76
+ """
77
+ Visualizes the detected bounding boxes and class probabilities on the image.
78
+
79
+ Parameters:
80
+ im (PIL.Image): The original input image.
81
+ probas (Tensor): Class probabilities for detected objects.
82
+ bboxes (Tensor): Bounding boxes for detected objects.
83
+ """
84
+ # Convert image to RGB format for matplotlib
85
+ fig, ax = plt.subplots(figsize=(10, 6))
86
+ ax.imshow(im)
87
+
88
+ # Iterate over detections and draw bounding boxes and labels
89
+ for p, (xmin, ymin, xmax, ymax), color in zip(probas, bboxes, self.COLORS * 100):
90
+ # Detach tensors and convert to float
91
+ xmin, ymin, xmax, ymax = map(lambda x: x.detach().cpu().numpy().item(), (xmin, ymin, xmax, ymax))
92
+
93
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
94
+ fill=False, color=color, linewidth=3))
95
+ cl = p.argmax()
96
+ text = f'{self.CLASSES[cl]}: {p[cl].detach().cpu().numpy():0.2f}' # Detach probability as well
97
+ ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
98
+
99
+ ax.axis('off')
100
+
101
+ # Convert the plot to a PIL Image and then to bytes
102
+ canvas = FigureCanvas(fig)
103
+ buf = io.BytesIO()
104
+ canvas.print_png(buf)
105
+ buf.seek(0)
106
+
107
+ # Base64 encode the image
108
+ img_bytes = buf.getvalue()
109
+ img_base64 = base64.b64encode(img_bytes).decode('utf-8')
110
+
111
+ # Close the figure to release memory
112
+ plt.close(fig)
113
+
114
+ return img_base64
115
+
116
+
117
+
models/glpn_model.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Created on Sat Apr 9 04:08:02 2022
3
+ @author: Admin_with ODD Team
4
+
5
+ Edited by our team : Sat Oct 5 10:00 2024
6
+
7
+ references: https://github.com/vinvino02/GLPDepth
8
+ """
9
+
10
+ import torch
11
+ from transformers import GLPNForDepthEstimation, GLPNFeatureExtractor
12
+ from PIL import Image
13
+ from config import CONFIG
14
+
15
+
16
+ # GLPDepth Model Class
17
+ class GLPDepth:
18
+ def __init__(self):
19
+ self.feature_extractor = GLPNFeatureExtractor.from_pretrained(CONFIG['glpn_model_path'])
20
+ self.model = GLPNForDepthEstimation.from_pretrained(CONFIG['glpn_model_path'])
21
+ self.model.to(CONFIG['device'])
22
+ self.model.eval()
23
+
24
+
25
+ def predict(self, img: Image.Image, img_shape : tuple):
26
+ """Predict the depth map of the input image.
27
+
28
+ Args:
29
+ img (PIL.Image): Input image for depth estimation.
30
+ img_shape (tuple): Original image size (height, width).
31
+
32
+ Returns:
33
+ np.ndarray: The predicted depth map in numpy array format.
34
+ """
35
+ with torch.no_grad():
36
+ # Preprocess image and move to the appropriate device
37
+ pixel_values = self.feature_extractor(img, return_tensors="pt").pixel_values.to(CONFIG['device'])
38
+ # Get model output
39
+ outputs = self.model(pixel_values)
40
+ predicted_depth = outputs.predicted_depth
41
+
42
+ # Resize depth prediction to original image size
43
+ prediction = torch.nn.functional.interpolate(
44
+ predicted_depth.unsqueeze(1),
45
+ size=img_shape[:2], # Interpolate to original image size
46
+ mode="bicubic",
47
+ align_corners=False,
48
+ )
49
+ prediction = prediction.squeeze().cpu().numpy() # Convert to numpy array (shape: (H, W))
50
+
51
+ return prediction
models/lstm_model.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Z-Location Estimator Model for Deployment
4
+ Created on Mon May 23 04:55:50 2022
5
+ @author: ODD_team
6
+ Edited by our team : Sat Oct 4 11:00 PM 2024
7
+ @based on LSTM model
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from config import CONFIG
13
+
14
+ device = CONFIG['device']
15
+
16
+ # Define the LSTM-based Z-location estimator model
17
+ class Zloc_Estimator(nn.Module):
18
+ def __init__(self, input_dim, hidden_dim, layer_dim):
19
+ super(Zloc_Estimator, self).__init__()
20
+
21
+ # LSTM layer
22
+ self.rnn = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True, bidirectional=False)
23
+
24
+ # Fully connected layers
25
+ layersize = [306, 154, 76]
26
+ layerlist = []
27
+ n_in = hidden_dim
28
+ for i in layersize:
29
+ layerlist.append(nn.Linear(n_in, i))
30
+ layerlist.append(nn.ReLU())
31
+ n_in = i
32
+ layerlist.append(nn.Linear(layersize[-1], 1)) # Final output layer
33
+
34
+ self.fc = nn.Sequential(*layerlist)
35
+
36
+ def forward(self, x):
37
+ out, hn = self.rnn(x)
38
+ output = self.fc(out[:, -1]) # Get the last output for prediction
39
+ return output
40
+
41
+ # Deployment-ready class for handling the model
42
+ class LSTM_Model:
43
+ def __init__(self):
44
+ """
45
+ Initializes the LSTM model for deployment with predefined parameters
46
+ and loads the pre-trained model weights.
47
+
48
+ :param model_path: Path to the pre-trained model weights file (.pth)
49
+ """
50
+ self.input_dim = 15
51
+ self.hidden_dim = 612
52
+ self.layer_dim = 3
53
+
54
+ # Initialize the Z-location estimator model
55
+ self.model = Zloc_Estimator(self.input_dim, self.hidden_dim, self.layer_dim)
56
+
57
+ # Load the state dictionary from the file, using map_location in torch.load()
58
+ state_dict = torch.load(CONFIG['lstm_model_path'], map_location=device)
59
+
60
+ # Load the model with the state dictionary
61
+ self.model.load_state_dict(state_dict, strict=False)
62
+ self.model.to(device) # This line ensures the model is moved to the right device
63
+ self.model.eval() # Set the model to evaluation mode
64
+
65
+
66
+ def predict(self, data):
67
+ """
68
+ Predicts the z-location based on input data.
69
+
70
+ :param data: Input tensor of shape (batch_size, input_dim)
71
+ :return: Predicted z-location as a tensor
72
+ """
73
+ with torch.no_grad(): # Disable gradient computation for deployment
74
+ data = data.to(device) # Move data to the appropriate device
75
+ data = data.reshape(-1, 1, self.input_dim) # Reshape data to (batch_size, sequence_length, input_dim)
76
+ zloc = self.model(data)
77
+ return zloc.cpu() # Return the output in CPU memory for further processing
78
+
models/predict_z_location_single_row_lstm.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def predict_z_location_single_row_lstm(row, ZlocE, scaler):
6
+ """
7
+ Preprocess bounding box coordinates, depth information, and class type
8
+ to predict Z-location using the LSTM model for a single row.
9
+
10
+ Parameters:
11
+ - row: A single row of DataFrame with bounding box coordinates, depth info, and class type.
12
+ - ZlocE: Pre-loaded LSTM model for Z-location prediction.
13
+ - scaler: Scaler for normalizing input data.
14
+
15
+ Returns:
16
+ - z_loc_prediction: Predicted Z-location for the given row.
17
+ """
18
+ # One-hot encoding of class type
19
+ class_type = row['class']
20
+
21
+ if class_type == 'bicycle':
22
+ class_tensor = torch.tensor([[0, 1, 0, 0, 0, 0]], dtype=torch.float32)
23
+ elif class_type == 'car':
24
+ class_tensor = torch.tensor([[0, 0, 1, 0, 0, 0]], dtype=torch.float32)
25
+ elif class_type == 'person':
26
+ class_tensor = torch.tensor([[0, 0, 0, 1, 0, 0]], dtype=torch.float32)
27
+ elif class_type == 'train':
28
+ class_tensor = torch.tensor([[0, 0, 0, 0, 1, 0]], dtype=torch.float32)
29
+ elif class_type == 'truck':
30
+ class_tensor = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.float32)
31
+ else:
32
+ class_tensor = torch.tensor([[1, 0, 0, 0, 0, 0]], dtype=torch.float32)
33
+
34
+ # Prepare input data (bounding box + depth info)
35
+ input_data = np.array([row[['xmin', 'ymin', 'xmax', 'ymax', 'width', 'height', 'depth_mean', 'depth_median',
36
+ 'depth_mean_trim']].values], dtype=np.float32)
37
+ input_data = torch.from_numpy(input_data)
38
+
39
+ # Concatenate class information
40
+ input_data = torch.cat([input_data, class_tensor], dim=-1)
41
+
42
+ # Scale the input data
43
+ scaled_input = torch.tensor(scaler.transform(input_data), dtype=torch.float32).unsqueeze(0).unsqueeze(0)
44
+
45
+ # Use the LSTM model to predict the Z-location
46
+ z_loc_prediction = ZlocE.predict(scaled_input).detach().numpy()[0]
47
+
48
+ return z_loc_prediction
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ transformers~=4.32.1
4
+ matplotlib>=3.7.0
5
+ pandas>=2.0.0
6
+ opencv-python>=4.8.0
7
+ uvicorn>=0.24.0
8
+ fastapi>=0.104.1
9
+ gunicorn==21.2.0
10
+ setuptools>=68.0.0
11
+ scipy>=1.11.0
12
+ scikit-learn>=1.3.0
13
+ websockets>=11.0.0
14
+ shapely>=2.0.0
15
+ cython>=3.0.0
16
+ numpy>=1.24.0
17
+ pillow~=9.4.0
18
+ joblib~=1.2.0
19
+ xgboost==3.0.0
20
+ joblib
try_page.html ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>HTTP Image Upload Demo</title>
7
+ <style>
8
+ /* Reset and base styles */
9
+ * {
10
+ box-sizing: border-box;
11
+ margin: 0;
12
+ padding: 0;
13
+ }
14
+
15
+ body {
16
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
17
+ color: #333;
18
+ line-height: 1.6;
19
+ }
20
+
21
+ /* Demo section styling */
22
+ .execution-section {
23
+ max-width: 1200px;
24
+ margin: 0 auto;
25
+ padding: 2rem;
26
+ background-color: #f8f9fa;
27
+ border-radius: 8px;
28
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
29
+ }
30
+
31
+ .section-title {
32
+ font-size: 2rem;
33
+ color: #384B70;
34
+ margin-bottom: 1rem;
35
+ padding-bottom: 0.5rem;
36
+ border-bottom: 2px solid #507687;
37
+ }
38
+
39
+ .demo-container {
40
+ display: flex;
41
+ flex-wrap: wrap;
42
+ gap: 2rem;
43
+ margin-top: 1.5rem;
44
+ }
45
+
46
+ .upload-container, .response-container {
47
+ flex: 1;
48
+ min-width: 300px;
49
+ padding: 1.5rem;
50
+ background-color: white;
51
+ border-radius: 8px;
52
+ box-shadow: 0 2px 4px rgba(0,0,0,0.05);
53
+ }
54
+
55
+ .container-title {
56
+ font-size: 1.5rem;
57
+ margin-bottom: 1rem;
58
+ color: #384B70;
59
+ }
60
+
61
+ /* Upload area styling */
62
+ .file-input-container {
63
+ border: 2px dashed #ccc;
64
+ border-radius: 5px;
65
+ padding: 2rem;
66
+ text-align: center;
67
+ margin-bottom: 1rem;
68
+ transition: all 0.3s ease;
69
+ }
70
+
71
+ .file-input-container:hover {
72
+ border-color: #507687;
73
+ background-color: #f8f9fa;
74
+ }
75
+
76
+ #fileInput {
77
+ display: none;
78
+ }
79
+
80
+ .file-label {
81
+ cursor: pointer;
82
+ display: flex;
83
+ flex-direction: column;
84
+ align-items: center;
85
+ gap: 0.5rem;
86
+ }
87
+
88
+ .file-icon {
89
+ font-size: 2.5rem;
90
+ color: #507687;
91
+ width: 64px;
92
+ height: 64px;
93
+ }
94
+ .file-placeholder {
95
+ max-width: 100%;
96
+ height: auto;
97
+ margin-top: 1rem;
98
+ border-radius: 4px;
99
+ display: none;
100
+ }
101
+
102
+ #sendButton {
103
+ background-color: #384B70;
104
+ color: white;
105
+ border: none;
106
+ border-radius: 4px;
107
+ padding: 0.75rem 1.5rem;
108
+ font-size: 1rem;
109
+ cursor: pointer;
110
+ transition: background-color 0.3s;
111
+ width: 100%;
112
+ margin-top: 1rem;
113
+ }
114
+
115
+ #sendButton:disabled {
116
+ background-color: #cccccc;
117
+ cursor: not-allowed;
118
+ }
119
+
120
+ #sendButton:hover:not(:disabled) {
121
+ background-color: #507687;
122
+ }
123
+
124
+ /* Response area styling */
125
+ .response-output {
126
+ height: 300px;
127
+ overflow-y: auto;
128
+ background-color: #f8f9fa;
129
+ border: 1px solid #ddd;
130
+ border-radius: 4px;
131
+ padding: 1rem;
132
+ font-family: monospace;
133
+ white-space: pre-wrap;
134
+ }
135
+
136
+ /* Tabs styling */
137
+ .tabs {
138
+ display: flex;
139
+ border-bottom: 1px solid #ddd;
140
+ margin-bottom: 1rem;
141
+ }
142
+
143
+ .tab-button {
144
+ padding: 0.5rem 1rem;
145
+ background-color: #f1f1f1;
146
+ border: none;
147
+ cursor: pointer;
148
+ transition: background-color 0.3s;
149
+ font-size: 1rem;
150
+ }
151
+
152
+ .tab-button.active {
153
+ background-color: #384B70;
154
+ color: white;
155
+ }
156
+
157
+ .tab-content {
158
+ display: none;
159
+ height: 300px;
160
+ }
161
+
162
+ .tab-content.active {
163
+ display: block;
164
+ }
165
+
166
+ /* Visualization area styling */
167
+ #visualizationContainer {
168
+ position: relative;
169
+ height: 100%;
170
+ overflow: auto;
171
+ background-color: #f8f9fa;
172
+ border: 1px solid #ddd;
173
+ border-radius: 4px;
174
+ }
175
+
176
+ .detection-canvas {
177
+ display: block;
178
+ margin: 0 auto;
179
+ }
180
+
181
+ /* Utilities */
182
+ #loading {
183
+ display: none;
184
+ margin-top: 1rem;
185
+ color: #384B70;
186
+ font-weight: bold;
187
+ text-align: center;
188
+ }
189
+
190
+ #message {
191
+ margin-top: 1rem;
192
+ padding: 0.75rem;
193
+ border-radius: 4px;
194
+ text-align: center;
195
+ display: none;
196
+ }
197
+
198
+ .error {
199
+ background-color: #ffebee;
200
+ color: #d32f2f;
201
+ }
202
+
203
+ .success {
204
+ background-color: #e8f5e9;
205
+ color: #388e3c;
206
+ }
207
+
208
+ .info {
209
+ font-size: 0.9rem;
210
+ color: #666;
211
+ margin-top: 0.5rem;
212
+ }
213
+
214
+ .stats {
215
+ margin-top: 1rem;
216
+ font-size: 0.9rem;
217
+ color: #666;
218
+ }
219
+
220
+ /* Debug output */
221
+ #debugOutput {
222
+ margin-top: 0.5rem;
223
+ font-size: 0.8rem;
224
+ color: #999;
225
+ border-top: 1px dashed #ddd;
226
+ padding-top: 0.5rem;
227
+ display: none;
228
+ }
229
+ </style>
230
+ </head>
231
+ <body>
232
+ <!-- Interactive Demo Section -->
233
+ <section class="execution-section">
234
+ <h2 class="section-title">Try It Yourself</h2>
235
+ <p>Upload an image and see the object detection and depth estimation results in real-time.</p>
236
+
237
+ <div class="demo-container">
238
+ <!-- Upload Container -->
239
+ <div class="upload-container">
240
+ <h3 class="container-title">Upload Image</h3>
241
+
242
+ <div class="file-input-container">
243
+ <label for="fileInput" class="file-label">
244
+ <img src="https://upload.wikimedia.org/wikipedia/commons/a/a1/Icons8_flat_folder.svg" class="file-icon"/>
245
+ <span>Click to select image</span>
246
+ <p class="info">PNG or JPEG, max 2MB</p>
247
+ </label>
248
+ <input type="file" accept="image/*" id="fileInput" />
249
+ <img id="imagePreview" class="file-placeholder" alt="Image preview" />
250
+ </div>
251
+
252
+ <button id="sendButton" disabled>Process Image</button>
253
+ <div id="loading">Processing your image...</div>
254
+ <div id="message"></div>
255
+
256
+ <div class="stats">
257
+ <div id="imageSize"></div>
258
+ <div id="processingTime"></div>
259
+ </div>
260
+
261
+ <div id="debugOutput"></div>
262
+ </div>
263
+
264
+ <!-- Response Container with Tabs -->
265
+ <div class="response-container">
266
+ <h3 class="container-title">Response</h3>
267
+
268
+ <div class="tabs">
269
+ <button class="tab-button active" data-tab="raw">Raw Output</button>
270
+ <button class="tab-button" data-tab="visual">Visual Output</button>
271
+ </div>
272
+
273
+ <!-- Raw Output Tab -->
274
+ <div id="rawTab" class="tab-content active">
275
+ <pre class="response-output" id="responseOutput">// Response will appear here after processing</pre>
276
+ </div>
277
+
278
+ <!-- Visual Output Tab -->
279
+ <div id="visualTab" class="tab-content">
280
+ <div id="visualizationContainer">
281
+ <canvas id="detectionCanvas" class="detection-canvas"></canvas>
282
+ </div>
283
+ </div>
284
+ </div>
285
+ </div>
286
+ </section>
287
+
288
+ <script>
289
+ // DOM Elements
290
+ const fileInput = document.getElementById('fileInput');
291
+ const imagePreview = document.getElementById('imagePreview');
292
+ const sendButton = document.getElementById('sendButton');
293
+ const loading = document.getElementById('loading');
294
+ const message = document.getElementById('message');
295
+ const responseOutput = document.getElementById('responseOutput');
296
+ const imageSizeInfo = document.getElementById('imageSize');
297
+ const processingTimeInfo = document.getElementById('processingTime');
298
+ const tabButtons = document.querySelectorAll('.tab-button');
299
+ const tabContents = document.querySelectorAll('.tab-content');
300
+ const detectionCanvas = document.getElementById('detectionCanvas');
301
+ const ctx = detectionCanvas.getContext('2d');
302
+ const debugOutput = document.getElementById('debugOutput');
303
+
304
+ // Enable debug mode (set to false in production)
305
+ const DEBUG = true;
306
+
307
+ // API endpoint URL
308
+ const API_URL = '/api/predict';
309
+
310
+ let imageFile = null;
311
+ let startTime = null;
312
+ let originalImage = null;
313
+ let processingWidth = 0;
314
+ let processingHeight = 0;
315
+ let responseData = null;
316
+
317
+ // Tab switching functionality
318
+ tabButtons.forEach(button => {
319
+ button.addEventListener('click', () => {
320
+ const tabName = button.getAttribute('data-tab');
321
+
322
+ // Update button states
323
+ tabButtons.forEach(btn => btn.classList.remove('active'));
324
+ button.classList.add('active');
325
+
326
+ // Update tab content visibility
327
+ tabContents.forEach(content => content.classList.remove('active'));
328
+ document.getElementById(tabName + 'Tab').classList.add('active');
329
+
330
+ // If switching to visual tab and we have data, ensure visualization is rendered
331
+ if (tabName === 'visual' && responseData && originalImage) {
332
+ visualizeResults(originalImage, responseData);
333
+ }
334
+ });
335
+ });
336
+
337
+ // Handle file input change
338
+ fileInput.addEventListener('change', (event) => {
339
+ const file = event.target.files[0];
340
+
341
+ // Clear previous selections
342
+ imageFile = null;
343
+ imagePreview.style.display = 'none';
344
+ sendButton.disabled = true;
345
+ originalImage = null;
346
+ responseData = null;
347
+
348
+ // Validate file
349
+ if (!file) return;
350
+
351
+ if (file.size > 2 * 1024 * 1024) {
352
+ showMessage('File size exceeds 2MB limit.', 'error');
353
+ return;
354
+ }
355
+
356
+ if (!['image/png', 'image/jpeg'].includes(file.type)) {
357
+ showMessage('Only PNG and JPEG formats are supported.', 'error');
358
+ return;
359
+ }
360
+
361
+ // Store file for upload
362
+ imageFile = file;
363
+
364
+ // Show image preview
365
+ const reader = new FileReader();
366
+ reader.onload = (e) => {
367
+ const image = new Image();
368
+ image.src = e.target.result;
369
+
370
+ image.onload = () => {
371
+ // Store original image for visualization
372
+ originalImage = image;
373
+
374
+ // Set preview
375
+ imagePreview.src = e.target.result;
376
+ imagePreview.style.display = 'block';
377
+
378
+ // Update image info
379
+ imageSizeInfo.textContent = `Original size: ${image.width}x${image.height} pixels`;
380
+
381
+ // Calculate processing dimensions (for visualization)
382
+ calculateProcessingDimensions(image.width, image.height);
383
+
384
+ // Enable send button
385
+ sendButton.disabled = false;
386
+ showMessage('Image ready to process.', 'info');
387
+ };
388
+ };
389
+ reader.readAsDataURL(file);
390
+ });
391
+
392
+ // Calculate dimensions for processing visualization
393
+ function calculateProcessingDimensions(width, height) {
394
+ const maxWidth = 640;
395
+ const maxHeight = 320;
396
+
397
+ // Calculate dimensions
398
+ if (width > height) {
399
+ if (width > maxWidth) {
400
+ height = Math.round((height * maxWidth) / width);
401
+ width = maxWidth;
402
+ }
403
+ } else {
404
+ if (height > maxHeight) {
405
+ width = Math.round((width * maxHeight) / height);
406
+ height = maxHeight;
407
+ }
408
+ }
409
+
410
+ // Store processing dimensions for visualization
411
+ processingWidth = width;
412
+ processingHeight = height;
413
+ }
414
+
415
+ // Handle send button click
416
+ sendButton.addEventListener('click', async () => {
417
+ if (!imageFile) {
418
+ showMessage('No image selected.', 'error');
419
+ return;
420
+ }
421
+
422
+ // Clear previous response
423
+ responseOutput.textContent = "// Processing...";
424
+ clearCanvas();
425
+ responseData = null;
426
+ debugOutput.style.display = 'none';
427
+
428
+ // Show loading state
429
+ loading.style.display = 'block';
430
+ message.style.display = 'none';
431
+
432
+ // Reset processing time
433
+ processingTimeInfo.textContent = '';
434
+
435
+ // Record start time
436
+ startTime = performance.now();
437
+
438
+ // Create form data for HTTP request
439
+ const formData = new FormData();
440
+ formData.append('file', imageFile);
441
+
442
+ try {
443
+ // Send HTTP request
444
+ const response = await fetch(API_URL, {
445
+ method: 'POST',
446
+ body: formData
447
+ });
448
+
449
+ // Handle response
450
+ if (!response.ok) {
451
+ const errorText = await response.text();
452
+ throw new Error(`HTTP error ${response.status}: ${errorText}`);
453
+ }
454
+
455
+ // Parse JSON response
456
+ const data = await response.json();
457
+ responseData = data;
458
+
459
+ // Calculate processing time
460
+ const endTime = performance.now();
461
+ const timeTaken = endTime - startTime;
462
+
463
+ // Format and display raw response
464
+ responseOutput.textContent = JSON.stringify(data, null, 2);
465
+ processingTimeInfo.textContent = `Processing time: ${timeTaken.toFixed(2)} ms`;
466
+
467
+ // Visualize the results
468
+ if (originalImage) {
469
+ visualizeResults(originalImage, data);
470
+ }
471
+
472
+ // Show success message
473
+ showMessage('Image processed successfully!', 'success');
474
+ } catch (error) {
475
+ console.error('Error processing image:', error);
476
+ showMessage(`Error: ${error.message}`, 'error');
477
+ responseOutput.textContent = `// Error: ${error.message}`;
478
+
479
+ if (DEBUG) {
480
+ debugOutput.style.display = 'block';
481
+ debugOutput.textContent = `Error: ${error.message}\n${error.stack || ''}`;
482
+ }
483
+ } finally {
484
+ loading.style.display = 'none';
485
+ }
486
+ });
487
+
488
+ // Visualize detection results
489
+ function visualizeResults(image, data) {
490
+ try {
491
+ // Set canvas dimensions
492
+ detectionCanvas.width = processingWidth;
493
+ detectionCanvas.height = processingHeight;
494
+
495
+ // Draw the original image
496
+ ctx.drawImage(image, 0, 0, processingWidth, processingHeight);
497
+
498
+ // Set styles for bounding boxes
499
+ ctx.lineWidth = 3;
500
+ ctx.font = 'bold 14px Arial';
501
+
502
+ // Find detections (checking all common formats)
503
+ let detections = [];
504
+ let detectionSource = '';
505
+
506
+ if (data.detections && Array.isArray(data.detections)) {
507
+ detections = data.detections;
508
+ detectionSource = 'detections';
509
+ } else if (data.predictions && Array.isArray(data.predictions)) {
510
+ detections = data.predictions;
511
+ detectionSource = 'predictions';
512
+ } else if (data.objects && Array.isArray(data.objects)) {
513
+ detections = data.objects;
514
+ detectionSource = 'objects';
515
+ } else if (data.results && Array.isArray(data.results)) {
516
+ detections = data.results;
517
+ detectionSource = 'results';
518
+ } else {
519
+ // Try to look one level deeper if no detections found
520
+ for (const key in data) {
521
+ if (typeof data[key] === 'object' && data[key] !== null) {
522
+ if (Array.isArray(data[key])) {
523
+ detections = data[key];
524
+ detectionSource = key;
525
+ break;
526
+ } else {
527
+ // Look one more level down
528
+ for (const subKey in data[key]) {
529
+ if (Array.isArray(data[key][subKey])) {
530
+ detections = data[key][subKey];
531
+ detectionSource = `${key}.${subKey}`;
532
+ break;
533
+ }
534
+ }
535
+ }
536
+ }
537
+ }
538
+ }
539
+
540
+ // Process each detection
541
+ detections.forEach((detection, index) => {
542
+ // Try to extract bounding box information
543
+ let bbox = null;
544
+ let label = null;
545
+ let confidence = null;
546
+ let distance = null;
547
+
548
+ // Extract label/class
549
+ if (detection.class !== undefined) {
550
+ label = detection.class;
551
+ } else {
552
+ // Fallback to other common property names
553
+ for (const key of ['label', 'name', 'category', 'className']) {
554
+ if (detection[key] !== undefined) {
555
+ label = detection[key];
556
+ break;
557
+ }
558
+ }
559
+ }
560
+
561
+ // Default label if none found
562
+ if (!label) label = `Object ${index + 1}`;
563
+
564
+ // Extract confidence score if available
565
+ for (const key of ['confidence', 'score', 'probability', 'conf']) {
566
+ if (detection[key] !== undefined) {
567
+ confidence = detection[key];
568
+ break;
569
+ }
570
+ }
571
+
572
+ // Extract distance - specifically look for distance_estimated first
573
+ if (detection.distance_estimated !== undefined) {
574
+ distance = detection.distance_estimated;
575
+ } else {
576
+ // Fallback to other common distance properties
577
+ for (const key of ['distance', 'depth', 'z', 'dist', 'range']) {
578
+ if (detection[key] !== undefined) {
579
+ distance = detection[key];
580
+ break;
581
+ }
582
+ }
583
+ }
584
+
585
+ // Look for bounding box in features
586
+ if (detection.features &&
587
+ detection.features.xmin !== undefined &&
588
+ detection.features.ymin !== undefined &&
589
+ detection.features.xmax !== undefined &&
590
+ detection.features.ymax !== undefined) {
591
+
592
+ bbox = {
593
+ xmin: detection.features.xmin,
594
+ ymin: detection.features.ymin,
595
+ xmax: detection.features.xmax,
596
+ ymax: detection.features.ymax
597
+ };
598
+ } else {
599
+ // Recursively search for bbox-like properties
600
+ function findBBox(obj, path = '') {
601
+ if (!obj || typeof obj !== 'object') return null;
602
+
603
+ // Check if this object looks like a bbox
604
+ if ((obj.x !== undefined && obj.y !== undefined &&
605
+ (obj.width !== undefined || obj.w !== undefined ||
606
+ obj.height !== undefined || obj.h !== undefined)) ||
607
+ (obj.xmin !== undefined && obj.ymin !== undefined &&
608
+ obj.xmax !== undefined && obj.ymax !== undefined)) {
609
+ return obj;
610
+ }
611
+
612
+ // Check if it's an array of 4 numbers (potential bbox)
613
+ if (Array.isArray(obj) && obj.length === 4 &&
614
+ obj.every(item => typeof item === 'number')) {
615
+ return obj;
616
+ }
617
+
618
+ // Check common bbox property names
619
+ for (const key of ['bbox', 'box', 'bounding_box', 'boundingBox']) {
620
+ if (obj[key] !== undefined) {
621
+ return obj[key];
622
+ }
623
+ }
624
+
625
+ // Search nested properties
626
+ for (const key in obj) {
627
+ const result = findBBox(obj[key], path ? `${path}.${key}` : key);
628
+ if (result) return result;
629
+ }
630
+
631
+ return null;
632
+ }
633
+
634
+ // Find bbox using recursive search as fallback
635
+ bbox = findBBox(detection);
636
+ }
637
+
638
+ // If we found a bounding box, draw it
639
+ if (bbox) {
640
+ // Parse different bbox formats
641
+ let x, y, width, height;
642
+
643
+ if (Array.isArray(bbox)) {
644
+ // Try to determine array format
645
+ if (bbox.length === 4) {
646
+ if (bbox[0] >= 0 && bbox[1] >= 0 && bbox[2] <= 1 && bbox[3] <= 1) {
647
+ // Likely normalized [x1, y1, x2, y2]
648
+ x = bbox[0] * processingWidth;
649
+ y = bbox[1] * processingHeight;
650
+ width = (bbox[2] - bbox[0]) * processingWidth;
651
+ height = (bbox[3] - bbox[1]) * processingHeight;
652
+ } else if (bbox[2] > bbox[0] && bbox[3] > bbox[1]) {
653
+ // Likely [x1, y1, x2, y2]
654
+ x = bbox[0];
655
+ y = bbox[1];
656
+ width = bbox[2] - bbox[0];
657
+ height = bbox[3] - bbox[1];
658
+ } else {
659
+ // Assume [x, y, width, height]
660
+ x = bbox[0];
661
+ y = bbox[1];
662
+ width = bbox[2];
663
+ height = bbox[3];
664
+ }
665
+ }
666
+ } else {
667
+ // Object format with named properties
668
+ if (bbox.x !== undefined && bbox.y !== undefined) {
669
+ x = bbox.x;
670
+ y = bbox.y;
671
+ width = bbox.width || bbox.w || 0;
672
+ height = bbox.height || bbox.h || 0;
673
+ } else if (bbox.xmin !== undefined && bbox.ymin !== undefined) {
674
+ x = bbox.xmin;
675
+ y = bbox.ymin;
676
+ width = (bbox.xmax || 0) - bbox.xmin;
677
+ height = (bbox.ymax || 0) - bbox.ymin;
678
+ }
679
+ }
680
+
681
+ // Validate coordinates
682
+ if (x === undefined || y === undefined || width === undefined || height === undefined) {
683
+ return;
684
+ }
685
+
686
+ // Check if we need to scale normalized coordinates (0-1)
687
+ if (x >= 0 && x <= 1 && y >= 0 && y <= 1 && width >= 0 && width <= 1 && height >= 0 && height <= 1) {
688
+ x = x * processingWidth;
689
+ y = y * processingHeight;
690
+ width = width * processingWidth;
691
+ height = height * processingHeight;
692
+ }
693
+
694
+ // Generate a color based on the class name
695
+ const hue = stringToHue(label);
696
+ ctx.strokeStyle = `hsl(${hue}, 100%, 40%)`;
697
+ ctx.fillStyle = `hsla(${hue}, 100%, 40%, 0.3)`;
698
+
699
+ // Draw bounding box
700
+ ctx.beginPath();
701
+ ctx.rect(x, y, width, height);
702
+ ctx.stroke();
703
+ ctx.fill();
704
+
705
+ // Format confidence value
706
+ let confidenceText = "";
707
+ if (confidence !== null && confidence !== undefined) {
708
+ // Convert to percentage if it's a probability (0-1)
709
+ if (confidence <= 1) {
710
+ confidence = (confidence * 100).toFixed(0);
711
+ } else {
712
+ confidence = confidence.toFixed(0);
713
+ }
714
+ confidenceText = ` ${confidence}%`;
715
+ }
716
+
717
+ // Format distance value
718
+ let distanceText = "";
719
+ if (distance !== null && distance !== undefined) {
720
+ distanceText = ` : ${distance.toFixed(2)} m`;
721
+ }
722
+
723
+ // Create label text
724
+ const labelText = `${label}${confidenceText}${distanceText}`;
725
+
726
+ // Measure text width
727
+ const textWidth = ctx.measureText(labelText).width + 10;
728
+
729
+ // Draw label background
730
+ ctx.fillStyle = `hsl(${hue}, 100%, 40%)`;
731
+ ctx.fillRect(x, y - 20, textWidth, 20);
732
+
733
+ // Draw label text
734
+ ctx.fillStyle = "white";
735
+ ctx.fillText(labelText, x + 5, y - 5);
736
+ }
737
+ });
738
+
739
+ } catch (error) {
740
+ console.error('Error visualizing results:', error);
741
+ debugOutput.style.display = 'block';
742
+ debugOutput.textContent += `VISUALIZATION ERROR: ${error.message}\n`;
743
+ debugOutput.textContent += `Error stack: ${error.stack}\n`;
744
+ }
745
+ }
746
+
747
+ // Generate consistent hue for string
748
+ function stringToHue(str) {
749
+ let hash = 0;
750
+ for (let i = 0; i < str.length; i++) {
751
+ hash = str.charCodeAt(i) + ((hash << 5) - hash);
752
+ }
753
+ return hash % 360;
754
+ }
755
+
756
+ // Clear canvas
757
+ function clearCanvas() {
758
+ if (detectionCanvas.getContext) {
759
+ ctx.clearRect(0, 0, detectionCanvas.width, detectionCanvas.height);
760
+ }
761
+ }
762
+
763
+ // Show message function
764
+ function showMessage(text, type) {
765
+ message.textContent = text;
766
+ message.className = '';
767
+ message.classList.add(type);
768
+ message.style.display = 'block';
769
+
770
+ if (type === 'info') {
771
+ setTimeout(() => {
772
+ message.style.display = 'none';
773
+ }, 3000);
774
+ }
775
+ }
776
+ </script>
777
+ </body>
778
+ </html>
utils/JSON_output.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.predict_z_location_single_row_lstm import predict_z_location_single_row
2
+
3
+ def generate_output_json(data, ZlocE, scaler):
4
+ """
5
+ Predict Z-location for each object in the data and prepare the JSON output.
6
+
7
+ Parameters:
8
+ - data: DataFrame with bounding box coordinates, depth information, and class type.
9
+ - ZlocE: Pre-loaded LSTM model for Z-location prediction.
10
+ - scaler: Scaler for normalizing input data.
11
+
12
+ Returns:
13
+ - JSON structure with object class, distance estimated, and relevant features.
14
+ """
15
+ output_json = []
16
+
17
+ # Iterate over each row in the data
18
+ for i, row in data.iterrows():
19
+ # Predict distance for each object using the single-row prediction function
20
+ distance = predict_z_location_single_row(row, ZlocE, scaler)
21
+
22
+ # Create object info dictionary
23
+ object_info = {
24
+ "class": row["class"], # Object class (e.g., 'car', 'truck')
25
+ "distance_estimated": float(distance), # Convert distance to float (if necessary)
26
+ "features": {
27
+ "xmin": float(row["xmin"]), # Bounding box xmin
28
+ "ymin": float(row["ymin"]), # Bounding box ymin
29
+ "xmax": float(row["xmax"]), # Bounding box xmax
30
+ "ymax": float(row["ymax"]), # Bounding box ymax
31
+ "mean_depth": float(row["depth_mean"]), # Depth mean
32
+ "depth_mean_trim": float(row["depth_mean_trim"]), # Depth mean trim
33
+ "depth_median": float(row["depth_median"]), # Depth median
34
+ "width": float(row["width"]), # Object width
35
+ "height": float(row["height"]) # Object height
36
+ }
37
+ }
38
+
39
+ # Append each object info to the output JSON list
40
+ output_json.append(object_info)
41
+
42
+ # Return the final JSON output structure
43
+ return {"objects": output_json}
utils/__init__.py ADDED
File without changes
utils/__pycache__/JSON_output.cpython-311.pyc ADDED
Binary file (1.93 kB). View file
 
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (183 Bytes). View file
 
utils/__pycache__/processing.cpython-311.pyc ADDED
Binary file (3.25 kB). View file
 
utils/build/lib.win-amd64-cpython-311/processing_cy.cp311-win_amd64.pyd ADDED
Binary file (71.7 kB). View file
 
utils/build/temp.win-amd64-cpython-311/Release/processing.obj ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:651a39060235c15b7fecf63a9e462ae0c379f5a194a745110cf259131fc235a2
3
+ size 539061
utils/build/temp.win-amd64-cpython-311/Release/processing_cy.cp311-win_amd64.exp ADDED
Binary file (806 Bytes). View file
 
utils/build/temp.win-amd64-cpython-311/Release/processing_cy.cp311-win_amd64.lib ADDED
Binary file (2.12 kB). View file
 
utils/processing.c ADDED
The diff for this file is too large to render. See raw diff
 
utils/processing.html ADDED
The diff for this file is too large to render. See raw diff
 
utils/processing.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Created on Sat Apr 9 04:08:02 2022
3
+ @author: Admin_with ODD Team
4
+
5
+ Edited by our team : Sat Oct 12 2024
6
+
7
+ references: https://github.com/vinvino02/GLPDepth
8
+ """
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+
14
+ from utils.processing_cy import process_bbox_depth_cy, handle_overlaps_cy
15
+
16
+ class PROCESSING:
17
+ def process_detections(self, scores, boxes, depth_map, detr):
18
+ self.data = pd.DataFrame(columns=['xmin','ymin','xmax','ymax','width', 'height',
19
+ 'depth_mean_trim','depth_mean','depth_median',
20
+ 'class', 'rgb'])
21
+
22
+ boxes_array = np.array([[int(box[1]), int(box[0]), int(box[3]), int(box[2])]
23
+ for box in boxes.tolist()], dtype=np.int32)
24
+
25
+ # Use Cython-optimized overlap handling
26
+ valid_indices = handle_overlaps_cy(depth_map, boxes_array)
27
+
28
+ for idx in valid_indices:
29
+ p = scores[idx]
30
+ box = boxes[idx]
31
+ xmin, ymin, xmax, ymax = map(int, box)
32
+
33
+ detected_class = p.argmax()
34
+ class_label = detr.CLASSES[detected_class]
35
+
36
+ # Map classes
37
+ if class_label == 'motorcycle':
38
+ class_label = 'bicycle'
39
+ elif class_label == 'bus':
40
+ class_label = 'train'
41
+ elif class_label not in ['person', 'truck', 'car', 'bicycle', 'train']:
42
+ class_label = 'Misc'
43
+
44
+ if class_label in ['Misc', 'person', 'truck', 'car', 'bicycle', 'train']:
45
+ # Use Cython-optimized depth calculations
46
+ depth_mean, depth_median, (depth_trim_low, depth_trim_high) = \
47
+ process_bbox_depth_cy(depth_map, ymin, ymax, xmin, xmax)
48
+
49
+ class_index = ['Misc', 'person', 'truck', 'car', 'bicycle', 'train'].index(class_label)
50
+ r, g, b = detr.COLORS[class_index]
51
+ rgb = (r * 255, g * 255, b * 255)
52
+
53
+ new_row = pd.DataFrame([[xmin, ymin, xmax, ymax, xmax - xmin, ymax - ymin,
54
+ (depth_trim_low + depth_trim_high) / 2,
55
+ depth_mean, depth_median, class_label, rgb]],
56
+ columns=self.data.columns)
57
+ self.data = pd.concat([self.data, new_row], ignore_index=True)
58
+
59
+ return self.data
utils/processing.pyx ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File: processing.pyx
2
+ # cython: language_level=3, boundscheck=False, wraparound=False, nonecheck=False, cdivision=True
3
+
4
+ import numpy as np
5
+ cimport numpy as np
6
+ import pandas as pd
7
+ cimport pandas as pd
8
+ from libc.math cimport isnan
9
+ from cpython.mem cimport PyMem_Malloc, PyMem_Free
10
+
11
+ # Define C types for better performance
12
+ ctypedef np.float32_t DTYPE_t
13
+ ctypedef np.int32_t ITYPE_t
14
+
15
+ def process_bbox_depth_cy(np.ndarray[DTYPE_t, ndim=2] depth_map,
16
+ int y_min, int y_max, int x_min, int x_max):
17
+ """
18
+ Optimized bbox depth calculations using Cython
19
+ """
20
+ cdef:
21
+ int i, j, count = 0
22
+ double sum_val = 0.0
23
+ double mean_val = 0.0
24
+ np.ndarray[DTYPE_t, ndim=1] flat_vals
25
+ int flat_size = 0
26
+
27
+ for i in range(y_min, y_max):
28
+ for j in range(x_min, x_max):
29
+ if not isnan(depth_map[i, j]):
30
+ sum_val += depth_map[i, j]
31
+ count += 1
32
+
33
+ if count > 0:
34
+ mean_val = sum_val / count
35
+
36
+ # Create array for trimmed mean calculation
37
+ flat_vals = np.zeros(count, dtype=np.float32)
38
+ flat_size = 0
39
+
40
+ for i in range(y_min, y_max):
41
+ for j in range(x_min, x_max):
42
+ if not isnan(depth_map[i, j]):
43
+ flat_vals[flat_size] = depth_map[i, j]
44
+ flat_size += 1
45
+
46
+ return mean_val, np.median(flat_vals), np.percentile(flat_vals, [20, 80])
47
+
48
+ def handle_overlaps_cy(np.ndarray[DTYPE_t, ndim=2] depth_map,
49
+ np.ndarray[ITYPE_t, ndim=2] boxes):
50
+ """
51
+ Optimized overlap handling using Cython
52
+ """
53
+ cdef:
54
+ int n_boxes = boxes.shape[0]
55
+ int i, j
56
+ int y_min1, y_max1, x_min1, x_max1
57
+ int y_min2, y_max2, x_min2, x_max2
58
+ double area1, area2, area_intersection
59
+ bint* to_remove = <bint*>PyMem_Malloc(n_boxes * sizeof(bint))
60
+
61
+ if not to_remove:
62
+ raise MemoryError()
63
+
64
+ try:
65
+ for i in range(n_boxes):
66
+ to_remove[i] = False
67
+
68
+ for i in range(n_boxes):
69
+ if to_remove[i]:
70
+ continue
71
+
72
+ y_min1, x_min1, y_max1, x_max1 = boxes[i]
73
+
74
+ for j in range(i + 1, n_boxes):
75
+ if to_remove[j]:
76
+ continue
77
+
78
+ y_min2, x_min2, y_max2, x_max2 = boxes[j]
79
+
80
+ # Calculate intersection
81
+ y_min_int = max(y_min1, y_min2)
82
+ y_max_int = min(y_max1, y_max2)
83
+ x_min_int = max(x_min1, x_min2)
84
+ x_max_int = min(x_max1, x_max2)
85
+
86
+ if y_min_int < y_max_int and x_min_int < x_max_int:
87
+ area1 = (y_max1 - y_min1) * (x_max1 - x_min1)
88
+ area2 = (y_max2 - y_min2) * (x_max2 - x_min2)
89
+ area_intersection = (y_max_int - y_min_int) * (x_max_int - x_min_int)
90
+
91
+ if area_intersection / min(area1, area2) >= 0.70:
92
+ if area1 < area2:
93
+ to_remove[i] = True
94
+ break
95
+ else:
96
+ to_remove[j] = True
97
+
98
+ return np.array([i for i in range(n_boxes) if not to_remove[i]], dtype=np.int32)
99
+
100
+ finally:
101
+ PyMem_Free(to_remove)
utils/processing_cy.cp311-win_amd64.pyd ADDED
Binary file (71.7 kB). View file
 
utils/setup.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, Extension
2
+ from Cython.Build import cythonize
3
+ import numpy as np
4
+
5
+ extensions = [
6
+ Extension(
7
+ "processing_cy",
8
+ ["processing.pyx"],
9
+ include_dirs=[np.get_include()],
10
+ extra_compile_args=["-O3", "-march=native", "-fopenmp"],
11
+ extra_link_args=["-fopenmp"]
12
+ )
13
+ ]
14
+
15
+ setup(
16
+ ext_modules=cythonize(extensions, annotate=True),
17
+ include_dirs=[np.get_include()]
18
+ )
utils/visualization.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.patches as patches
5
+
6
+
7
+ def plot_depth_with_boxes(depth_map, depth_data):
8
+ """
9
+ Plots the depth map with bounding boxes overlayed.
10
+
11
+ Args:
12
+ depth_map (numpy.ndarray): The depth map to visualize.
13
+ depth_data (pandas.DataFrame): DataFrame containing bounding box coordinates, depth statistics, and class labels.
14
+ """
15
+ # Normalize the depth map for better visualization
16
+ depth_map_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
17
+
18
+ # Create a figure and axis
19
+ fig, ax = plt.subplots(1, figsize=(12, 6))
20
+
21
+ # Display the depth map
22
+ ax.imshow(depth_map_normalized, cmap='plasma') # You can change the colormap as desired
23
+ ax.axis('off') # Hide the axes
24
+
25
+ # Loop through the DataFrame and add rectangles
26
+ for index, row in depth_data.iterrows():
27
+ xmin, ymin, xmax, ymax = row[['xmin', 'ymin', 'xmax', 'ymax']]
28
+ class_label = row['class']
29
+ score = row['depth_mean'] # or whichever statistic you prefer to display
30
+
31
+ # Create a rectangle patch
32
+ rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='yellow', facecolor='none')
33
+
34
+ # Add the rectangle to the plot
35
+ ax.add_patch(rect)
36
+
37
+ # Add a text label
38
+ ax.text(xmin, ymin - 5, f'{class_label}: {score:.2f}', color='white', fontsize=12, weight='bold')
39
+
40
+ plt.title('Depth Map with Object Detection Bounding Boxes', fontsize=16)
41
+ plt.show()
42
+