Spaces:
Sleeping
Sleeping
Upload 35 files
Browse files- .gitattributes +1 -0
- Dockerfile +36 -0
- api_documentation.html +612 -0
- app.py +379 -0
- config.py +12 -0
- data/models/lstm_scaler.pkl +3 -0
- data/models/pretrained_lstm.pth +3 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- models/__pycache__/detr_model.cpython-311.pyc +0 -0
- models/__pycache__/distance_prediction.cpython-311.pyc +0 -0
- models/__pycache__/glpn_model.cpython-311.pyc +0 -0
- models/__pycache__/lstm_model.cpython-311.pyc +0 -0
- models/__pycache__/predict_z_location_single_row_lstm.cpython-311.pyc +0 -0
- models/detr_model.py +117 -0
- models/glpn_model.py +51 -0
- models/lstm_model.py +78 -0
- models/predict_z_location_single_row_lstm.py +48 -0
- requirements.txt +20 -0
- try_page.html +778 -0
- utils/JSON_output.py +43 -0
- utils/__init__.py +0 -0
- utils/__pycache__/JSON_output.cpython-311.pyc +0 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/processing.cpython-311.pyc +0 -0
- utils/build/lib.win-amd64-cpython-311/processing_cy.cp311-win_amd64.pyd +0 -0
- utils/build/temp.win-amd64-cpython-311/Release/processing.obj +3 -0
- utils/build/temp.win-amd64-cpython-311/Release/processing_cy.cp311-win_amd64.exp +0 -0
- utils/build/temp.win-amd64-cpython-311/Release/processing_cy.cp311-win_amd64.lib +0 -0
- utils/processing.c +0 -0
- utils/processing.html +0 -0
- utils/processing.py +59 -0
- utils/processing.pyx +101 -0
- utils/processing_cy.cp311-win_amd64.pyd +0 -0
- utils/setup.py +18 -0
- utils/visualization.py +42 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
utils/build/temp.win-amd64-cpython-311/Release/processing.obj filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use a specific Python version
|
2 |
+
FROM python:3.11.10-slim
|
3 |
+
|
4 |
+
# Set the working directory inside the container
|
5 |
+
WORKDIR /usr/src/app
|
6 |
+
|
7 |
+
# Install system dependencies
|
8 |
+
RUN apt-get update && apt-get install -y \
|
9 |
+
build-essential \
|
10 |
+
libgl1-mesa-glx \
|
11 |
+
libglib2.0-0 \
|
12 |
+
gcc \
|
13 |
+
python3-dev \
|
14 |
+
&& rm -rf /var/lib/apt/lists/*
|
15 |
+
|
16 |
+
# Upgrade pip
|
17 |
+
RUN pip install --upgrade pip
|
18 |
+
|
19 |
+
# Install Python dependencies
|
20 |
+
COPY requirements.txt .
|
21 |
+
RUN pip install -r requirements.txt
|
22 |
+
|
23 |
+
# Copy your application code into the container
|
24 |
+
COPY . .
|
25 |
+
|
26 |
+
# Build Cython extensions
|
27 |
+
RUN cd utils && python setup.py build_ext --inplace
|
28 |
+
|
29 |
+
# Create necessary directories if they don't exist
|
30 |
+
RUN mkdir -p data/models
|
31 |
+
|
32 |
+
# Expose the port for FastAPI application
|
33 |
+
EXPOSE 8000
|
34 |
+
|
35 |
+
# Specify the command to run the application with Gunicorn
|
36 |
+
CMD ["gunicorn", "app:app", "--workers", "4", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000"]
|
api_documentation.html
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Real-Time Image Processing API Documentation</title>
|
7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
8 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui.min.css">
|
9 |
+
<style>
|
10 |
+
body {
|
11 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
|
12 |
+
margin: 0;
|
13 |
+
padding: 0;
|
14 |
+
color: #333;
|
15 |
+
}
|
16 |
+
.container {
|
17 |
+
max-width: 1200px;
|
18 |
+
margin: 0 auto;
|
19 |
+
padding: 20px;
|
20 |
+
}
|
21 |
+
header {
|
22 |
+
background-color: #252f3f;
|
23 |
+
color: white;
|
24 |
+
padding: 20px;
|
25 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
26 |
+
}
|
27 |
+
header h1 {
|
28 |
+
margin: 0;
|
29 |
+
font-size: 2em;
|
30 |
+
}
|
31 |
+
header p {
|
32 |
+
margin-top: 10px;
|
33 |
+
opacity: 0.8;
|
34 |
+
}
|
35 |
+
.endpoint-cards {
|
36 |
+
display: flex;
|
37 |
+
flex-wrap: wrap;
|
38 |
+
gap: 20px;
|
39 |
+
margin-top: 30px;
|
40 |
+
}
|
41 |
+
.endpoint-card {
|
42 |
+
background-color: white;
|
43 |
+
border-radius: 8px;
|
44 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
45 |
+
padding: 20px;
|
46 |
+
flex: 1 1 300px;
|
47 |
+
transition: transform 0.2s, box-shadow 0.2s;
|
48 |
+
}
|
49 |
+
.endpoint-card:hover {
|
50 |
+
transform: translateY(-5px);
|
51 |
+
box-shadow: 0 5px 15px rgba(0,0,0,0.15);
|
52 |
+
}
|
53 |
+
.http-endpoint {
|
54 |
+
border-left: 5px solid #38a169;
|
55 |
+
}
|
56 |
+
.ws-endpoint {
|
57 |
+
border-left: 5px solid #3182ce;
|
58 |
+
}
|
59 |
+
.method {
|
60 |
+
display: inline-block;
|
61 |
+
padding: 5px 10px;
|
62 |
+
border-radius: 4px;
|
63 |
+
font-weight: bold;
|
64 |
+
margin-right: 10px;
|
65 |
+
}
|
66 |
+
.post {
|
67 |
+
background-color: #38a169;
|
68 |
+
color: white;
|
69 |
+
}
|
70 |
+
.ws {
|
71 |
+
background-color: #3182ce;
|
72 |
+
color: white;
|
73 |
+
}
|
74 |
+
.endpoint-title {
|
75 |
+
font-size: 1.2em;
|
76 |
+
margin-bottom: 15px;
|
77 |
+
display: flex;
|
78 |
+
align-items: center;
|
79 |
+
}
|
80 |
+
.section {
|
81 |
+
margin-top: 40px;
|
82 |
+
margin-bottom: 30px;
|
83 |
+
}
|
84 |
+
h2 {
|
85 |
+
border-bottom: 1px solid #eee;
|
86 |
+
padding-bottom: 10px;
|
87 |
+
color: #252f3f;
|
88 |
+
}
|
89 |
+
.code {
|
90 |
+
background-color: #f7fafc;
|
91 |
+
border-radius: 4px;
|
92 |
+
padding: 15px;
|
93 |
+
font-family: monospace;
|
94 |
+
overflow-x: auto;
|
95 |
+
margin: 15px 0;
|
96 |
+
}
|
97 |
+
.parameter-table, .response-table {
|
98 |
+
width: 100%;
|
99 |
+
border-collapse: collapse;
|
100 |
+
margin: 15px 0;
|
101 |
+
}
|
102 |
+
.parameter-table th, .parameter-table td,
|
103 |
+
.response-table th, .response-table td {
|
104 |
+
text-align: left;
|
105 |
+
padding: 10px;
|
106 |
+
border-bottom: 1px solid #eee;
|
107 |
+
}
|
108 |
+
.parameter-table th, .response-table th {
|
109 |
+
background-color: #f7fafc;
|
110 |
+
}
|
111 |
+
.try-it {
|
112 |
+
margin-top: 30px;
|
113 |
+
padding-top: 30px;
|
114 |
+
border-top: 1px solid #eee;
|
115 |
+
}
|
116 |
+
.swagger-ui .wrapper { max-width: 100%; }
|
117 |
+
#swagger-ui {
|
118 |
+
margin-top: 30px;
|
119 |
+
border: 1px solid #eee;
|
120 |
+
border-radius: 8px;
|
121 |
+
overflow: hidden;
|
122 |
+
}
|
123 |
+
.info-box {
|
124 |
+
background-color: #ebf8ff;
|
125 |
+
border-left: 5px solid #3182ce;
|
126 |
+
padding: 15px;
|
127 |
+
margin: 20px 0;
|
128 |
+
border-radius: 4px;
|
129 |
+
}
|
130 |
+
.response-example {
|
131 |
+
background-color: #f0fff4;
|
132 |
+
border-left: 5px solid #38a169;
|
133 |
+
padding: 15px;
|
134 |
+
border-radius: 4px;
|
135 |
+
margin-top: 20px;
|
136 |
+
}
|
137 |
+
|
138 |
+
/* Updated tabs styles */
|
139 |
+
.tabs {
|
140 |
+
display: flex;
|
141 |
+
margin-top: 30px;
|
142 |
+
border-bottom: 1px solid #e2e8f0;
|
143 |
+
align-items: center;
|
144 |
+
}
|
145 |
+
.tab {
|
146 |
+
padding: 10px 20px;
|
147 |
+
cursor: pointer;
|
148 |
+
border-bottom: 2px solid transparent;
|
149 |
+
transition: all 0.2s ease;
|
150 |
+
}
|
151 |
+
.tab.active {
|
152 |
+
border-bottom: 2px solid #3182ce;
|
153 |
+
color: #3182ce;
|
154 |
+
font-weight: bold;
|
155 |
+
}
|
156 |
+
|
157 |
+
/* New try-button styles */
|
158 |
+
.try-button {
|
159 |
+
margin-left: auto;
|
160 |
+
display: flex;
|
161 |
+
align-items: center;
|
162 |
+
padding: 10px 15px;
|
163 |
+
font-size: 14px;
|
164 |
+
font-weight: bold;
|
165 |
+
border-radius: 4px;
|
166 |
+
background-color: #4299e1;
|
167 |
+
color: white;
|
168 |
+
text-decoration: none;
|
169 |
+
border: none;
|
170 |
+
transition: background-color 0.2s;
|
171 |
+
}
|
172 |
+
.try-button:hover {
|
173 |
+
background-color: #3182ce;
|
174 |
+
}
|
175 |
+
.try-button svg {
|
176 |
+
margin-left: 8px;
|
177 |
+
height: 16px;
|
178 |
+
width: 16px;
|
179 |
+
}
|
180 |
+
|
181 |
+
.tab-content {
|
182 |
+
display: none;
|
183 |
+
padding: 20px 0;
|
184 |
+
}
|
185 |
+
.tab-content.active {
|
186 |
+
display: block;
|
187 |
+
}
|
188 |
+
button {
|
189 |
+
background-color: #4299e1;
|
190 |
+
color: white;
|
191 |
+
border: none;
|
192 |
+
padding: 10px 15px;
|
193 |
+
border-radius: 4px;
|
194 |
+
cursor: pointer;
|
195 |
+
font-weight: bold;
|
196 |
+
transition: background-color 0.2s;
|
197 |
+
}
|
198 |
+
button:hover {
|
199 |
+
background-color: #3182ce;
|
200 |
+
}
|
201 |
+
</style>
|
202 |
+
</head>
|
203 |
+
<body>
|
204 |
+
<header>
|
205 |
+
<div class="container">
|
206 |
+
<h1>Advanced Driver Vision Assistance with Near Collision Estimation System</h1>
|
207 |
+
<p>API for object detection, depth estimation, and distance prediction using computer vision models</p>
|
208 |
+
</div>
|
209 |
+
</header>
|
210 |
+
|
211 |
+
<div class="container">
|
212 |
+
<section class="section">
|
213 |
+
<h2>Overview</h2>
|
214 |
+
<p>This API provides access to advanced computer vision models for real-time image processing. It leverages:</p>
|
215 |
+
<ul>
|
216 |
+
<li><strong>DETR (DEtection TRansformer)</strong> - For accurate object detection</li>
|
217 |
+
<li><strong>GLPN (Global-Local Path Networks)</strong> - For depth estimation</li>
|
218 |
+
<li><strong>LSTM Model</strong> - For Z-location prediction</li>
|
219 |
+
</ul>
|
220 |
+
<p>The API supports both HTTP and WebSocket protocols:</p>
|
221 |
+
<div class="endpoint-cards">
|
222 |
+
<div class="endpoint-card http-endpoint">
|
223 |
+
<div class="endpoint-title">
|
224 |
+
<span class="method post">POST</span>
|
225 |
+
<span>/api/predict</span>
|
226 |
+
</div>
|
227 |
+
<p>Process a single image via HTTP request</p>
|
228 |
+
</div>
|
229 |
+
<div class="endpoint-card ws-endpoint">
|
230 |
+
<div class="endpoint-title">
|
231 |
+
<span class="method ws">WS</span>
|
232 |
+
<span>/ws/predict</span>
|
233 |
+
</div>
|
234 |
+
<p>Stream images for real-time processing via WebSocket</p>
|
235 |
+
</div>
|
236 |
+
</div>
|
237 |
+
</section>
|
238 |
+
|
239 |
+
<div class="tabs">
|
240 |
+
<div class="tab active" onclick="switchTab('http-api')">HTTP API</div>
|
241 |
+
<div class="tab" onclick="switchTab('websocket-api')">WebSocket API</div>
|
242 |
+
<div class="tab">
|
243 |
+
<a href="./try_page" class="ml-2 inline-flex items-center px-4 py-2 border border-transparent text-sm font-medium rounded-md shadow-sm text-white bg-blue-600 hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 transition-all duration-200">
|
244 |
+
Try it
|
245 |
+
<svg xmlns="http://www.w3.org/2000/svg" class="ml-1 h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
246 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M14 5l7 7m0 0l-7 7m7-7H3" />
|
247 |
+
</svg>
|
248 |
+
</a>
|
249 |
+
</div>
|
250 |
+
</div>
|
251 |
+
|
252 |
+
<div id="http-api" class="tab-content active">
|
253 |
+
<section class="section">
|
254 |
+
<h2>HTTP API Reference</h2>
|
255 |
+
<h3>POST /api/predict</h3>
|
256 |
+
<p>Process a single image for object detection, depth estimation, and distance prediction.</p>
|
257 |
+
|
258 |
+
<h4>Request</h4>
|
259 |
+
<p><strong>Content-Type:</strong> multipart/form-data</p>
|
260 |
+
|
261 |
+
<table class="parameter-table">
|
262 |
+
<tr>
|
263 |
+
<th>Parameter</th>
|
264 |
+
<th>Type</th>
|
265 |
+
<th>Required</th>
|
266 |
+
<th>Description</th>
|
267 |
+
</tr>
|
268 |
+
<tr>
|
269 |
+
<td>file</td>
|
270 |
+
<td>File</td>
|
271 |
+
<td>Yes</td>
|
272 |
+
<td>The image file to process (JPEG, PNG)</td>
|
273 |
+
</tr>
|
274 |
+
</table>
|
275 |
+
|
276 |
+
<h4>Request Example</h4>
|
277 |
+
<div class="code">
|
278 |
+
# Python example using requests
|
279 |
+
import requests
|
280 |
+
|
281 |
+
url = "http://localhost:8000/api/predict"
|
282 |
+
files = {"file": open("image.jpg", "rb")}
|
283 |
+
|
284 |
+
response = requests.post(url, files=files)
|
285 |
+
data = response.json()
|
286 |
+
print(data)
|
287 |
+
</div>
|
288 |
+
|
289 |
+
<h4>Response</h4>
|
290 |
+
<p>Returns a JSON object containing:</p>
|
291 |
+
<table class="response-table">
|
292 |
+
<tr>
|
293 |
+
<th>Field</th>
|
294 |
+
<th>Type</th>
|
295 |
+
<th>Description</th>
|
296 |
+
</tr>
|
297 |
+
<tr>
|
298 |
+
<td>objects</td>
|
299 |
+
<td>Array</td>
|
300 |
+
<td>Array of detected objects with their properties</td>
|
301 |
+
</tr>
|
302 |
+
<tr>
|
303 |
+
<td>objects[].class</td>
|
304 |
+
<td>String</td>
|
305 |
+
<td>Class of the detected object (e.g., 'car', 'person')</td>
|
306 |
+
</tr>
|
307 |
+
<tr>
|
308 |
+
<td>objects[].distance_estimated</td>
|
309 |
+
<td>Number</td>
|
310 |
+
<td>Estimated distance of the object</td>
|
311 |
+
</tr>
|
312 |
+
<tr>
|
313 |
+
<td>objects[].features</td>
|
314 |
+
<td>Object</td>
|
315 |
+
<td>Features used for prediction (bounding box, depth information)</td>
|
316 |
+
</tr>
|
317 |
+
<tr>
|
318 |
+
<td>frame_id</td>
|
319 |
+
<td>Number</td>
|
320 |
+
<td>ID of the processed frame (0 for HTTP requests)</td>
|
321 |
+
</tr>
|
322 |
+
<tr>
|
323 |
+
<td>timings</td>
|
324 |
+
<td>Object</td>
|
325 |
+
<td>Processing time metrics for each step</td>
|
326 |
+
</tr>
|
327 |
+
</table>
|
328 |
+
|
329 |
+
<div class="response-example">
|
330 |
+
<h4>Response Example</h4>
|
331 |
+
<pre class="code">
|
332 |
+
{
|
333 |
+
"objects": [
|
334 |
+
{
|
335 |
+
"class": "car",
|
336 |
+
"distance_estimated": 15.42,
|
337 |
+
"features": {
|
338 |
+
"xmin": 120.5,
|
339 |
+
"ymin": 230.8,
|
340 |
+
"xmax": 350.2,
|
341 |
+
"ymax": 480.3,
|
342 |
+
"mean_depth": 0.75,
|
343 |
+
"depth_mean_trim": 0.72,
|
344 |
+
"depth_median": 0.71,
|
345 |
+
"width": 229.7,
|
346 |
+
"height": 249.5
|
347 |
+
}
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"class": "person",
|
351 |
+
"distance_estimated": 8.76,
|
352 |
+
"features": {
|
353 |
+
"xmin": 450.1,
|
354 |
+
"ymin": 200.4,
|
355 |
+
"xmax": 510.8,
|
356 |
+
"ymax": 380.2,
|
357 |
+
"mean_depth": 0.58,
|
358 |
+
"depth_mean_trim": 0.56,
|
359 |
+
"depth_median": 0.55,
|
360 |
+
"width": 60.7,
|
361 |
+
"height": 179.8
|
362 |
+
}
|
363 |
+
}
|
364 |
+
],
|
365 |
+
"frame_id": 0,
|
366 |
+
"timings": {
|
367 |
+
"decode_time": 0.015,
|
368 |
+
"models_time": 0.452,
|
369 |
+
"process_time": 0.063,
|
370 |
+
"json_time": 0.021,
|
371 |
+
"total_time": 0.551
|
372 |
+
}
|
373 |
+
}
|
374 |
+
</pre>
|
375 |
+
</div>
|
376 |
+
|
377 |
+
<h4>HTTP Status Codes</h4>
|
378 |
+
<table class="response-table">
|
379 |
+
<tr>
|
380 |
+
<th>Status Code</th>
|
381 |
+
<th>Description</th>
|
382 |
+
</tr>
|
383 |
+
<tr>
|
384 |
+
<td>200</td>
|
385 |
+
<td>OK - Request was successful</td>
|
386 |
+
</tr>
|
387 |
+
<tr>
|
388 |
+
<td>400</td>
|
389 |
+
<td>Bad Request - Empty file or invalid format</td>
|
390 |
+
</tr>
|
391 |
+
<tr>
|
392 |
+
<td>500</td>
|
393 |
+
<td>Internal Server Error - Processing error</td>
|
394 |
+
</tr>
|
395 |
+
</table>
|
396 |
+
</section>
|
397 |
+
</div>
|
398 |
+
|
399 |
+
<div id="websocket-api" class="tab-content">
|
400 |
+
<section class="section">
|
401 |
+
<h2>WebSocket API Reference</h2>
|
402 |
+
<h3>WebSocket /ws/predict</h3>
|
403 |
+
<p>
|
404 |
+
Stream images for real-time processing and get instant results.
|
405 |
+
Ideal for video feeds and applications requiring continuous processing.
|
406 |
+
</p>
|
407 |
+
|
408 |
+
<div class="info-box">
|
409 |
+
<p><strong>Note:</strong> WebSocket offers better performance for real-time applications. Use this endpoint for processing video feeds or when you need to process multiple images in rapid succession.</p>
|
410 |
+
</div>
|
411 |
+
|
412 |
+
<h4>Connection</h4>
|
413 |
+
<div class="code">
|
414 |
+
# JavaScript example
|
415 |
+
const socket = new WebSocket('ws://localhost:8000/ws/predict');
|
416 |
+
|
417 |
+
socket.onopen = function(e) {
|
418 |
+
console.log('Connection established');
|
419 |
+
};
|
420 |
+
|
421 |
+
socket.onmessage = function(event) {
|
422 |
+
const response = JSON.parse(event.data);
|
423 |
+
console.log('Received:', response);
|
424 |
+
};
|
425 |
+
|
426 |
+
socket.onclose = function(event) {
|
427 |
+
console.log('Connection closed');
|
428 |
+
};
|
429 |
+
</div>
|
430 |
+
|
431 |
+
<h4>Sending Images</h4>
|
432 |
+
<p>Send binary image data directly over the WebSocket connection:</p>
|
433 |
+
<div class="code">
|
434 |
+
// JavaScript example: Sending an image from canvas or file
|
435 |
+
function sendImageFromCanvas(canvas) {
|
436 |
+
canvas.toBlob(function(blob) {
|
437 |
+
const reader = new FileReader();
|
438 |
+
reader.onload = function() {
|
439 |
+
socket.send(reader.result);
|
440 |
+
};
|
441 |
+
reader.readAsArrayBuffer(blob);
|
442 |
+
}, 'image/jpeg');
|
443 |
+
}
|
444 |
+
|
445 |
+
// Or from input file
|
446 |
+
fileInput.onchange = function() {
|
447 |
+
const file = this.files[0];
|
448 |
+
const reader = new FileReader();
|
449 |
+
reader.onload = function() {
|
450 |
+
socket.send(reader.result);
|
451 |
+
};
|
452 |
+
reader.readAsArrayBuffer(file);
|
453 |
+
};
|
454 |
+
</div>
|
455 |
+
|
456 |
+
<h4>Response Format</h4>
|
457 |
+
<p>The WebSocket API returns the same JSON structure as the HTTP API, with incrementing frame_id values.</p>
|
458 |
+
<div class="response-example">
|
459 |
+
<h4>Response Example</h4>
|
460 |
+
<pre class="code">
|
461 |
+
{
|
462 |
+
"objects": [
|
463 |
+
{
|
464 |
+
"class": "car",
|
465 |
+
"distance_estimated": 14.86,
|
466 |
+
"features": {
|
467 |
+
"xmin": 125.3,
|
468 |
+
"ymin": 235.1,
|
469 |
+
"xmax": 355.7,
|
470 |
+
"ymax": 485.9,
|
471 |
+
"mean_depth": 0.77,
|
472 |
+
"depth_mean_trim": 0.74,
|
473 |
+
"depth_median": 0.73,
|
474 |
+
"width": 230.4,
|
475 |
+
"height": 250.8
|
476 |
+
}
|
477 |
+
}
|
478 |
+
],
|
479 |
+
"frame_id": 42,
|
480 |
+
"timings": {
|
481 |
+
"decode_time": 0.014,
|
482 |
+
"models_time": 0.445,
|
483 |
+
"process_time": 0.061,
|
484 |
+
"json_time": 0.020,
|
485 |
+
"total_time": 0.540
|
486 |
+
}
|
487 |
+
}
|
488 |
+
</pre>
|
489 |
+
</div>
|
490 |
+
</section>
|
491 |
+
</div>
|
492 |
+
|
493 |
+
<div id="try-it" class="tab-content">
|
494 |
+
<section class="section">
|
495 |
+
<h2>Try The API</h2>
|
496 |
+
<p>You can test the API directly using the interactive Swagger UI below:</p>
|
497 |
+
|
498 |
+
<div id="swagger-ui"></div>
|
499 |
+
|
500 |
+
<h3>Simple WebSocket Client</h3>
|
501 |
+
<p>Upload an image to test the WebSocket endpoint:</p>
|
502 |
+
|
503 |
+
<div>
|
504 |
+
<input type="file" id="wsFileInput" accept="image/*">
|
505 |
+
<button id="connectWsButton">Connect to WebSocket</button>
|
506 |
+
<button id="disconnectWsButton" disabled>Disconnect</button>
|
507 |
+
</div>
|
508 |
+
|
509 |
+
<div>
|
510 |
+
<p><strong>Status: </strong><span id="wsStatus">Disconnected</span></p>
|
511 |
+
<p><strong>Last Response: </strong></p>
|
512 |
+
<pre id="wsResponse" class="code" style="max-height: 300px; overflow-y: auto;"></pre>
|
513 |
+
</div>
|
514 |
+
</section>
|
515 |
+
</div>
|
516 |
+
</div>
|
517 |
+
|
518 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui-bundle.js"></script>
|
519 |
+
<script>
|
520 |
+
// Handle tab switching
|
521 |
+
function switchTab(tabId) {
|
522 |
+
document.querySelectorAll('.tab').forEach(tab => tab.classList.remove('active'));
|
523 |
+
document.querySelectorAll('.tab-content').forEach(content => content.classList.remove('active'));
|
524 |
+
|
525 |
+
document.querySelector(`.tab[onclick="switchTab('${tabId}')"]`).classList.add('active');
|
526 |
+
document.getElementById(tabId).classList.add('active');
|
527 |
+
}
|
528 |
+
|
529 |
+
// Initialize Swagger UI
|
530 |
+
window.onload = function() {
|
531 |
+
SwaggerUIBundle({
|
532 |
+
url: "/api/openapi.json",
|
533 |
+
dom_id: '#swagger-ui',
|
534 |
+
presets: [
|
535 |
+
SwaggerUIBundle.presets.apis,
|
536 |
+
SwaggerUIBundle.SwaggerUIStandalonePreset
|
537 |
+
],
|
538 |
+
layout: "BaseLayout",
|
539 |
+
deepLinking: true
|
540 |
+
});
|
541 |
+
|
542 |
+
// WebSocket client setup
|
543 |
+
let socket = null;
|
544 |
+
|
545 |
+
const connectButton = document.getElementById('connectWsButton');
|
546 |
+
const disconnectButton = document.getElementById('disconnectWsButton');
|
547 |
+
const fileInput = document.getElementById('wsFileInput');
|
548 |
+
const statusElement = document.getElementById('wsStatus');
|
549 |
+
const responseElement = document.getElementById('wsResponse');
|
550 |
+
|
551 |
+
connectButton.addEventListener('click', () => {
|
552 |
+
if (socket) {
|
553 |
+
socket.close();
|
554 |
+
}
|
555 |
+
|
556 |
+
const wsUrl = window.location.protocol === 'https:'
|
557 |
+
? `wss://${window.location.host}/ws/predict`
|
558 |
+
: `ws://${window.location.host}/ws/predict`;
|
559 |
+
|
560 |
+
socket = new WebSocket(wsUrl);
|
561 |
+
|
562 |
+
socket.onopen = function() {
|
563 |
+
statusElement.textContent = 'Connected';
|
564 |
+
statusElement.style.color = 'green';
|
565 |
+
connectButton.disabled = true;
|
566 |
+
disconnectButton.disabled = false;
|
567 |
+
};
|
568 |
+
|
569 |
+
socket.onmessage = function(event) {
|
570 |
+
const response = JSON.parse(event.data);
|
571 |
+
responseElement.textContent = JSON.stringify(response, null, 2);
|
572 |
+
};
|
573 |
+
|
574 |
+
socket.onclose = function() {
|
575 |
+
statusElement.textContent = 'Disconnected';
|
576 |
+
statusElement.style.color = 'red';
|
577 |
+
connectButton.disabled = false;
|
578 |
+
disconnectButton.disabled = true;
|
579 |
+
socket = null;
|
580 |
+
};
|
581 |
+
|
582 |
+
socket.onerror = function(error) {
|
583 |
+
statusElement.textContent = 'Error: ' + error.message;
|
584 |
+
statusElement.style.color = 'red';
|
585 |
+
};
|
586 |
+
});
|
587 |
+
|
588 |
+
disconnectButton.addEventListener('click', () => {
|
589 |
+
if (socket) {
|
590 |
+
socket.close();
|
591 |
+
}
|
592 |
+
});
|
593 |
+
|
594 |
+
fileInput.addEventListener('change', () => {
|
595 |
+
if (!socket || socket.readyState !== WebSocket.OPEN) {
|
596 |
+
alert('Please connect to WebSocket first');
|
597 |
+
return;
|
598 |
+
}
|
599 |
+
|
600 |
+
const file = fileInput.files[0];
|
601 |
+
if (!file) return;
|
602 |
+
|
603 |
+
const reader = new FileReader();
|
604 |
+
reader.onload = function() {
|
605 |
+
socket.send(reader.result);
|
606 |
+
};
|
607 |
+
reader.readAsArrayBuffer(file);
|
608 |
+
});
|
609 |
+
};
|
610 |
+
</script>
|
611 |
+
</body>
|
612 |
+
</html>
|
app.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import joblib
|
2 |
+
import uvicorn
|
3 |
+
import xgboost as xgb
|
4 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, UploadFile, HTTPException
|
5 |
+
|
6 |
+
from fastapi.responses import HTMLResponse
|
7 |
+
from fastapi.responses import JSONResponse
|
8 |
+
import asyncio
|
9 |
+
import json
|
10 |
+
import pickle
|
11 |
+
import warnings
|
12 |
+
import os
|
13 |
+
import io
|
14 |
+
|
15 |
+
import timeit
|
16 |
+
from PIL import Image
|
17 |
+
import numpy as np
|
18 |
+
import cv2
|
19 |
+
|
20 |
+
# Add this to your existing imports if not already present
|
21 |
+
from fastapi.openapi.docs import get_swagger_ui_html
|
22 |
+
from fastapi.openapi.utils import get_openapi
|
23 |
+
|
24 |
+
from models.detr_model import DETR
|
25 |
+
from models.glpn_model import GLPDepth
|
26 |
+
from models.lstm_model import LSTM_Model
|
27 |
+
from models.predict_z_location_single_row_lstm import predict_z_location_single_row_lstm
|
28 |
+
from utils.processing import PROCESSING
|
29 |
+
from config import CONFIG
|
30 |
+
|
31 |
+
warnings.filterwarnings("ignore")
|
32 |
+
|
33 |
+
# Initialize FastAPI app
|
34 |
+
app = FastAPI(
|
35 |
+
title="Real-Time WebSocket Image Processing API",
|
36 |
+
description="API for object detection and depth estimation using WebSocket for real-time image processing.",
|
37 |
+
)
|
38 |
+
|
39 |
+
try:
|
40 |
+
# Load models and utilities
|
41 |
+
device = CONFIG['device']
|
42 |
+
print("Loading models...")
|
43 |
+
|
44 |
+
detr = DETR() # Object detection model (DETR)
|
45 |
+
print("DETR model loaded.")
|
46 |
+
|
47 |
+
glpn = GLPDepth() # Depth estimation model (GLPN)
|
48 |
+
print("GLPDepth model loaded.")
|
49 |
+
|
50 |
+
zlocE_LSTM = LSTM_Model() # LSTM model for prediction (e.g., localization)
|
51 |
+
print("LSTM model loaded.")
|
52 |
+
|
53 |
+
|
54 |
+
lstm_scaler = pickle.load(open(CONFIG['lstm_scaler_path'], 'rb')) # Load pre-trained scaler for LSTM
|
55 |
+
print("LSTM Scaler loaded.")
|
56 |
+
|
57 |
+
|
58 |
+
processing = PROCESSING() # Utility class for post-processing
|
59 |
+
print("Processing utilities loaded.")
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
print(f"An unexpected error occurred. Details: {e}")
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
# Serve HTML documentation for the API
|
67 |
+
@app.get("/", response_class=HTMLResponse)
|
68 |
+
async def get_docs():
|
69 |
+
"""
|
70 |
+
Serve HTML documentation for the WebSocket-based image processing API.
|
71 |
+
The HTML file must be available in the same directory.
|
72 |
+
Returns a 404 error if the documentation file is not found.
|
73 |
+
"""
|
74 |
+
html_path = os.path.join(os.path.dirname(__file__), "api_documentation.html")
|
75 |
+
if not os.path.exists(html_path):
|
76 |
+
return HTMLResponse(content="api_documentation.html file not found", status_code=404)
|
77 |
+
with open(html_path, "r") as f:
|
78 |
+
return HTMLResponse(f.read())
|
79 |
+
|
80 |
+
|
81 |
+
@app.get("/try_page", response_class=HTMLResponse)
|
82 |
+
async def get_docs():
|
83 |
+
"""
|
84 |
+
Serve HTML documentation for the WebSocket-based image processing API.
|
85 |
+
The HTML file must be available in the same directory.
|
86 |
+
Returns a 404 error if the documentation file is not found.
|
87 |
+
"""
|
88 |
+
html_path = os.path.join(os.path.dirname(__file__), "try_page.html")
|
89 |
+
if not os.path.exists(html_path):
|
90 |
+
return HTMLResponse(content="try_page.html file not found", status_code=404)
|
91 |
+
with open(html_path, "r") as f:
|
92 |
+
return HTMLResponse(f.read())
|
93 |
+
|
94 |
+
|
95 |
+
# Function to decode the image received via WebSocket
|
96 |
+
async def decode_image(image_bytes):
|
97 |
+
"""
|
98 |
+
Decodes image bytes into a PIL Image and returns the image along with its shape.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
image_bytes (bytes): The image data received from the client.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
tuple: A tuple containing:
|
105 |
+
- pil_image (PIL.Image): The decoded image.
|
106 |
+
- img_shape (tuple): Shape of the image as (height, width).
|
107 |
+
- decode_time (float): Time taken to decode the image in seconds.
|
108 |
+
|
109 |
+
Raises:
|
110 |
+
ValueError: If image decoding fails.
|
111 |
+
"""
|
112 |
+
start = timeit.default_timer()
|
113 |
+
nparr = np.frombuffer(image_bytes, np.uint8)
|
114 |
+
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
115 |
+
if frame is None:
|
116 |
+
raise ValueError("Failed to decode image")
|
117 |
+
color_converted = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
118 |
+
pil_image = Image.fromarray(color_converted)
|
119 |
+
img_shape = color_converted.shape[0:2] # (height, width)
|
120 |
+
end = timeit.default_timer()
|
121 |
+
return pil_image, img_shape, end - start
|
122 |
+
|
123 |
+
|
124 |
+
# Function to run the DETR model for object detection
|
125 |
+
async def run_detr_model(pil_image):
|
126 |
+
"""
|
127 |
+
Runs the DETR (DEtection TRansformer) model to perform object detection on the input image.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
pil_image (PIL.Image): The image to be processed by the DETR model.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
tuple: A tuple containing:
|
134 |
+
- detr_result (tuple): The DETR model output consisting of detections' scores and boxes.
|
135 |
+
- detr_time (float): Time taken to run the DETR model in seconds.
|
136 |
+
"""
|
137 |
+
start = timeit.default_timer()
|
138 |
+
detr_result = await asyncio.to_thread(detr.detect, pil_image)
|
139 |
+
end = timeit.default_timer()
|
140 |
+
return detr_result, end - start
|
141 |
+
|
142 |
+
|
143 |
+
# Function to run the GLPN model for depth estimation
|
144 |
+
async def run_glpn_model(pil_image, img_shape):
|
145 |
+
"""
|
146 |
+
Runs the GLPN (Global Local Prediction Network) model to estimate the depth of the objects in the image.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
pil_image (PIL.Image): The image to be processed by the GLPN model.
|
150 |
+
img_shape (tuple): The shape of the image as (height, width).
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
tuple: A tuple containing:
|
154 |
+
- depth_map (numpy.ndarray): The depth map for the input image.
|
155 |
+
- glpn_time (float): Time taken to run the GLPN model in seconds.
|
156 |
+
"""
|
157 |
+
start = timeit.default_timer()
|
158 |
+
depth_map = await asyncio.to_thread(glpn.predict, pil_image, img_shape)
|
159 |
+
end = timeit.default_timer()
|
160 |
+
return depth_map, end - start
|
161 |
+
|
162 |
+
|
163 |
+
# Function to process the detections with depth map
|
164 |
+
async def process_detections(scores, boxes, depth_map):
|
165 |
+
"""
|
166 |
+
Processes the DETR model detections and integrates depth information from the GLPN model.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
scores (numpy.ndarray): The detection scores for the detected objects.
|
170 |
+
boxes (numpy.ndarray): The bounding boxes for the detected objects.
|
171 |
+
depth_map (numpy.ndarray): The depth map generated by the GLPN model.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
tuple: A tuple containing:
|
175 |
+
- pdata (dict): Processed detection data including depth and bounding box information.
|
176 |
+
- process_time (float): Time taken for processing detections in seconds.
|
177 |
+
"""
|
178 |
+
start = timeit.default_timer()
|
179 |
+
pdata = processing.process_detections(scores, boxes, depth_map, detr)
|
180 |
+
end = timeit.default_timer()
|
181 |
+
return pdata, end - start
|
182 |
+
|
183 |
+
|
184 |
+
# Function to generate JSON output for LSTM predictions
|
185 |
+
async def generate_json_output(data):
|
186 |
+
"""
|
187 |
+
Predict Z-location for each object in the data and prepare the JSON output.
|
188 |
+
|
189 |
+
Parameters:
|
190 |
+
- data: DataFrame with bounding box coordinates, depth information, and class type.
|
191 |
+
- ZlocE: Pre-loaded LSTM model for Z-location prediction.
|
192 |
+
- scaler: Scaler for normalizing input data.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
- JSON structure with object class, distance estimated, and relevant features.
|
196 |
+
"""
|
197 |
+
output_json = []
|
198 |
+
start = timeit.default_timer()
|
199 |
+
|
200 |
+
# Iterate over each row in the data
|
201 |
+
for i, row in data.iterrows():
|
202 |
+
# Predict distance for each object using the single-row prediction function
|
203 |
+
|
204 |
+
distance = predict_z_location_single_row_lstm(row, zlocE_LSTM, lstm_scaler)
|
205 |
+
|
206 |
+
# Create object info dictionary
|
207 |
+
object_info = {
|
208 |
+
"class": row["class"], # Object class (e.g., 'car', 'truck')
|
209 |
+
"distance_estimated": float(distance), # Convert distance to float (if necessary)
|
210 |
+
"features": {
|
211 |
+
"xmin": float(row["xmin"]), # Bounding box xmin
|
212 |
+
"ymin": float(row["ymin"]), # Bounding box ymin
|
213 |
+
"xmax": float(row["xmax"]), # Bounding box xmax
|
214 |
+
"ymax": float(row["ymax"]), # Bounding box ymax
|
215 |
+
"mean_depth": float(row["depth_mean"]), # Depth mean
|
216 |
+
"depth_mean_trim": float(row["depth_mean_trim"]), # Depth mean trim
|
217 |
+
"depth_median": float(row["depth_median"]), # Depth median
|
218 |
+
"width": float(row["width"]), # Object width
|
219 |
+
"height": float(row["height"]) # Object height
|
220 |
+
}
|
221 |
+
}
|
222 |
+
|
223 |
+
# Append each object info to the output JSON list
|
224 |
+
output_json.append(object_info)
|
225 |
+
|
226 |
+
end = timeit.default_timer()
|
227 |
+
|
228 |
+
# Return the final JSON output structure, and time
|
229 |
+
return {"objects": output_json}, end - start
|
230 |
+
|
231 |
+
|
232 |
+
# Function to process a single frame (image) in the WebSocket stream
|
233 |
+
async def process_frame(frame_id, image_bytes):
|
234 |
+
"""
|
235 |
+
Processes a single frame (image) from the WebSocket stream. The process includes:
|
236 |
+
- Decoding the image.
|
237 |
+
- Running the DETR and GLPN models concurrently.
|
238 |
+
- Processing detections and generating the final output JSON.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
frame_id (int): The identifier for the frame being processed.
|
242 |
+
image_bytes (bytes): The image data received from the WebSocket.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
dict: A dictionary containing the output JSON and timing information for each processing step.
|
246 |
+
"""
|
247 |
+
timings = {}
|
248 |
+
try:
|
249 |
+
# Step 1: Decode the image
|
250 |
+
pil_image, img_shape, decode_time = await decode_image(image_bytes)
|
251 |
+
timings["decode_time"] = decode_time
|
252 |
+
|
253 |
+
# Step 2: Run DETR and GLPN models in parallel
|
254 |
+
(detr_result, detr_time), (depth_map, glpn_time) = await asyncio.gather(
|
255 |
+
run_detr_model(pil_image),
|
256 |
+
run_glpn_model(pil_image, img_shape)
|
257 |
+
)
|
258 |
+
models_time = max(detr_time, glpn_time) # Take the longest time of the two models
|
259 |
+
timings["models_time"] = models_time
|
260 |
+
|
261 |
+
# Step 3: Process detections with depth map
|
262 |
+
scores, boxes = detr_result
|
263 |
+
pdata, process_time = await process_detections(scores, boxes, depth_map)
|
264 |
+
timings["process_time"] = process_time
|
265 |
+
|
266 |
+
# Step 4: Generate output JSON
|
267 |
+
print("generate json")
|
268 |
+
output_json, json_time = await generate_json_output(pdata)
|
269 |
+
print(output_json)
|
270 |
+
timings["json_time"] = json_time
|
271 |
+
|
272 |
+
timings["total_time"] = decode_time + models_time + process_time + json_time
|
273 |
+
|
274 |
+
# Add frame_id and timings to the JSON output
|
275 |
+
output_json["frame_id"] = frame_id
|
276 |
+
output_json["timings"] = timings
|
277 |
+
|
278 |
+
return output_json
|
279 |
+
|
280 |
+
except Exception as e:
|
281 |
+
return {
|
282 |
+
"error": str(e),
|
283 |
+
"frame_id": frame_id,
|
284 |
+
"timings": timings
|
285 |
+
}
|
286 |
+
|
287 |
+
|
288 |
+
@app.post("/api/predict", summary="Process a single image for object detection and depth estimation")
|
289 |
+
async def process_image(file: UploadFile = File(...)):
|
290 |
+
"""
|
291 |
+
Process a single image for object detection and depth estimation.
|
292 |
+
|
293 |
+
The endpoint performs:
|
294 |
+
- Object detection using DETR model
|
295 |
+
- Depth estimation using GLPN model
|
296 |
+
- Z-location prediction using LSTM model
|
297 |
+
|
298 |
+
Parameters:
|
299 |
+
- file: Image file to process (JPEG, PNG)
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
- JSON response with detected objects, estimated distances, and timing information
|
303 |
+
"""
|
304 |
+
try:
|
305 |
+
# Read image content
|
306 |
+
image_bytes = await file.read()
|
307 |
+
if not image_bytes:
|
308 |
+
raise HTTPException(status_code=400, detail="Empty file")
|
309 |
+
|
310 |
+
# Use the same processing pipeline as the WebSocket endpoint
|
311 |
+
result = await process_frame(0, image_bytes)
|
312 |
+
|
313 |
+
# Check if there's an error
|
314 |
+
if "error" in result:
|
315 |
+
raise HTTPException(status_code=500, detail=result["error"])
|
316 |
+
|
317 |
+
return JSONResponse(content=result)
|
318 |
+
|
319 |
+
except Exception as e:
|
320 |
+
raise HTTPException(status_code=500, detail=str(e))
|
321 |
+
|
322 |
+
|
323 |
+
|
324 |
+
# Add custom OpenAPI documentation
|
325 |
+
@app.get("/api/docs", include_in_schema=False)
|
326 |
+
async def custom_swagger_ui_html():
|
327 |
+
return get_swagger_ui_html(
|
328 |
+
openapi_url="/api/openapi.json",
|
329 |
+
title="Real-Time Image Processing API Documentation",
|
330 |
+
swagger_js_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui-bundle.js",
|
331 |
+
swagger_css_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui.css",
|
332 |
+
)
|
333 |
+
|
334 |
+
|
335 |
+
@app.get("/api/openapi.json", include_in_schema=False)
|
336 |
+
async def get_open_api_endpoint():
|
337 |
+
return get_openapi(
|
338 |
+
title="Real-Time Image Processing API",
|
339 |
+
version="1.0.0",
|
340 |
+
description="API for object detection, depth estimation, and z-location prediction using computer vision models",
|
341 |
+
routes=app.routes,
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
@app.websocket("/ws/predict")
|
346 |
+
async def websocket_endpoint(websocket: WebSocket):
|
347 |
+
"""
|
348 |
+
WebSocket endpoint for real-time image processing. Clients can send image frames to be processed
|
349 |
+
and receive JSON output containing object detection, depth estimation, and predictions in real-time.
|
350 |
+
|
351 |
+
- Handles the reception of image data over the WebSocket.
|
352 |
+
- Calls the image processing pipeline and returns the result.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
websocket (WebSocket): The WebSocket connection to the client.
|
356 |
+
"""
|
357 |
+
await websocket.accept()
|
358 |
+
frame_id = 0
|
359 |
+
|
360 |
+
try:
|
361 |
+
while True:
|
362 |
+
frame_id += 1
|
363 |
+
|
364 |
+
# Receive image bytes from the client
|
365 |
+
image_bytes = await websocket.receive_bytes()
|
366 |
+
|
367 |
+
# Process the frame asynchronously
|
368 |
+
processing_task = asyncio.create_task(process_frame(frame_id, image_bytes))
|
369 |
+
result = await processing_task
|
370 |
+
|
371 |
+
# Send the result back to the client
|
372 |
+
await websocket.send_text(json.dumps(result))
|
373 |
+
|
374 |
+
except WebSocketDisconnect:
|
375 |
+
print(f"Client disconnected after processing {frame_id} frames.")
|
376 |
+
except Exception as e:
|
377 |
+
print(f"Unexpected error: {e}")
|
378 |
+
finally:
|
379 |
+
await websocket.close()
|
config.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
CONFIG = {
|
4 |
+
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
|
5 |
+
'detr_model_path': 'facebook/detr-resnet-101',
|
6 |
+
'glpn_model_path': 'vinvino02/glpn-kitti',
|
7 |
+
'lstm_model_path': 'data/models/pretrained_lstm.pth',
|
8 |
+
'lstm_scaler_path': 'data/models/lstm_scaler.pkl',
|
9 |
+
'xgboost_path': 'data/models/xgboost_model.json',
|
10 |
+
'xgboost_scaler_path': 'data/models/scaler.joblib',
|
11 |
+
'numerical_cols_path': 'data/models/numerical_columns.joblib'
|
12 |
+
}
|
data/models/lstm_scaler.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e21e30154403927e4696795a086191bacf12a03aecb904992faa2dc9fb17343
|
3 |
+
size 809
|
data/models/pretrained_lstm.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d9a7728cd4db9ebb4c8c0cb2fef644ae7d159359864002aa94148f9ea127005c
|
3 |
+
size 31160407
|
models/__init__.py
ADDED
File without changes
|
models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (167 Bytes). View file
|
|
models/__pycache__/detr_model.cpython-311.pyc
ADDED
Binary file (7.7 kB). View file
|
|
models/__pycache__/distance_prediction.cpython-311.pyc
ADDED
Binary file (2.92 kB). View file
|
|
models/__pycache__/glpn_model.cpython-311.pyc
ADDED
Binary file (2.95 kB). View file
|
|
models/__pycache__/lstm_model.cpython-311.pyc
ADDED
Binary file (4.22 kB). View file
|
|
models/__pycache__/predict_z_location_single_row_lstm.cpython-311.pyc
ADDED
Binary file (2.89 kB). View file
|
|
models/detr_model.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Created on Sat Apr 9 04:08:02 2022
|
3 |
+
@author: Admin_with ODD Team
|
4 |
+
|
5 |
+
Edited by our team : Sat Oct 5 10:00 2024
|
6 |
+
|
7 |
+
references: https://github.com/vinvino02/GLPDepth
|
8 |
+
"""
|
9 |
+
|
10 |
+
import io
|
11 |
+
import torch
|
12 |
+
import base64
|
13 |
+
from config import CONFIG
|
14 |
+
from torchvision import transforms
|
15 |
+
from matplotlib import pyplot as plt
|
16 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
17 |
+
from transformers import DetrForObjectDetection, DetrImageProcessor
|
18 |
+
|
19 |
+
class DETR:
|
20 |
+
def __init__(self):
|
21 |
+
|
22 |
+
self.CLASSES = [
|
23 |
+
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
24 |
+
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
|
25 |
+
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
|
26 |
+
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
|
27 |
+
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
|
28 |
+
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
29 |
+
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
|
30 |
+
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
|
31 |
+
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
|
32 |
+
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
|
33 |
+
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
|
34 |
+
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
|
35 |
+
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
|
36 |
+
'toothbrush'
|
37 |
+
]
|
38 |
+
|
39 |
+
self.COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098],
|
40 |
+
[0.929, 0.694, 0.125], [0, 0, 1], [0.466, 0.674, 0.188],
|
41 |
+
[0.301, 0.745, 0.933]]
|
42 |
+
|
43 |
+
self.transform = transforms.Compose([
|
44 |
+
transforms.ToTensor(),
|
45 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
46 |
+
])
|
47 |
+
|
48 |
+
self.model = DetrForObjectDetection.from_pretrained(CONFIG['detr_model_path'], revision="no_timm")
|
49 |
+
self.model.to(CONFIG['device'])
|
50 |
+
self.model.eval()
|
51 |
+
|
52 |
+
|
53 |
+
def box_cxcywh_to_xyxy(self, x):
|
54 |
+
x_c, y_c, w, h = x.unbind(1)
|
55 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
56 |
+
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
57 |
+
return torch.stack(b, dim=1).to(CONFIG['device'])
|
58 |
+
|
59 |
+
def rescale_bboxes(self, out_bbox, size):
|
60 |
+
img_w, img_h = size
|
61 |
+
b = self.box_cxcywh_to_xyxy(out_bbox)
|
62 |
+
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(CONFIG['device'])
|
63 |
+
return b
|
64 |
+
|
65 |
+
|
66 |
+
def detect(self, im):
|
67 |
+
img = self.transform(im).unsqueeze(0).to(CONFIG['device'])
|
68 |
+
assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'Image too large'
|
69 |
+
outputs = self.model(img)
|
70 |
+
probas = outputs['logits'].softmax(-1)[0, :, :-1]
|
71 |
+
keep = probas.max(-1).values > 0.7
|
72 |
+
bboxes_scaled = self.rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
|
73 |
+
return probas[keep], bboxes_scaled
|
74 |
+
|
75 |
+
def visualize(self, im, probas, bboxes):
|
76 |
+
"""
|
77 |
+
Visualizes the detected bounding boxes and class probabilities on the image.
|
78 |
+
|
79 |
+
Parameters:
|
80 |
+
im (PIL.Image): The original input image.
|
81 |
+
probas (Tensor): Class probabilities for detected objects.
|
82 |
+
bboxes (Tensor): Bounding boxes for detected objects.
|
83 |
+
"""
|
84 |
+
# Convert image to RGB format for matplotlib
|
85 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
86 |
+
ax.imshow(im)
|
87 |
+
|
88 |
+
# Iterate over detections and draw bounding boxes and labels
|
89 |
+
for p, (xmin, ymin, xmax, ymax), color in zip(probas, bboxes, self.COLORS * 100):
|
90 |
+
# Detach tensors and convert to float
|
91 |
+
xmin, ymin, xmax, ymax = map(lambda x: x.detach().cpu().numpy().item(), (xmin, ymin, xmax, ymax))
|
92 |
+
|
93 |
+
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
|
94 |
+
fill=False, color=color, linewidth=3))
|
95 |
+
cl = p.argmax()
|
96 |
+
text = f'{self.CLASSES[cl]}: {p[cl].detach().cpu().numpy():0.2f}' # Detach probability as well
|
97 |
+
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
|
98 |
+
|
99 |
+
ax.axis('off')
|
100 |
+
|
101 |
+
# Convert the plot to a PIL Image and then to bytes
|
102 |
+
canvas = FigureCanvas(fig)
|
103 |
+
buf = io.BytesIO()
|
104 |
+
canvas.print_png(buf)
|
105 |
+
buf.seek(0)
|
106 |
+
|
107 |
+
# Base64 encode the image
|
108 |
+
img_bytes = buf.getvalue()
|
109 |
+
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
110 |
+
|
111 |
+
# Close the figure to release memory
|
112 |
+
plt.close(fig)
|
113 |
+
|
114 |
+
return img_base64
|
115 |
+
|
116 |
+
|
117 |
+
|
models/glpn_model.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Created on Sat Apr 9 04:08:02 2022
|
3 |
+
@author: Admin_with ODD Team
|
4 |
+
|
5 |
+
Edited by our team : Sat Oct 5 10:00 2024
|
6 |
+
|
7 |
+
references: https://github.com/vinvino02/GLPDepth
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from transformers import GLPNForDepthEstimation, GLPNFeatureExtractor
|
12 |
+
from PIL import Image
|
13 |
+
from config import CONFIG
|
14 |
+
|
15 |
+
|
16 |
+
# GLPDepth Model Class
|
17 |
+
class GLPDepth:
|
18 |
+
def __init__(self):
|
19 |
+
self.feature_extractor = GLPNFeatureExtractor.from_pretrained(CONFIG['glpn_model_path'])
|
20 |
+
self.model = GLPNForDepthEstimation.from_pretrained(CONFIG['glpn_model_path'])
|
21 |
+
self.model.to(CONFIG['device'])
|
22 |
+
self.model.eval()
|
23 |
+
|
24 |
+
|
25 |
+
def predict(self, img: Image.Image, img_shape : tuple):
|
26 |
+
"""Predict the depth map of the input image.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
img (PIL.Image): Input image for depth estimation.
|
30 |
+
img_shape (tuple): Original image size (height, width).
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
np.ndarray: The predicted depth map in numpy array format.
|
34 |
+
"""
|
35 |
+
with torch.no_grad():
|
36 |
+
# Preprocess image and move to the appropriate device
|
37 |
+
pixel_values = self.feature_extractor(img, return_tensors="pt").pixel_values.to(CONFIG['device'])
|
38 |
+
# Get model output
|
39 |
+
outputs = self.model(pixel_values)
|
40 |
+
predicted_depth = outputs.predicted_depth
|
41 |
+
|
42 |
+
# Resize depth prediction to original image size
|
43 |
+
prediction = torch.nn.functional.interpolate(
|
44 |
+
predicted_depth.unsqueeze(1),
|
45 |
+
size=img_shape[:2], # Interpolate to original image size
|
46 |
+
mode="bicubic",
|
47 |
+
align_corners=False,
|
48 |
+
)
|
49 |
+
prediction = prediction.squeeze().cpu().numpy() # Convert to numpy array (shape: (H, W))
|
50 |
+
|
51 |
+
return prediction
|
models/lstm_model.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Z-Location Estimator Model for Deployment
|
4 |
+
Created on Mon May 23 04:55:50 2022
|
5 |
+
@author: ODD_team
|
6 |
+
Edited by our team : Sat Oct 4 11:00 PM 2024
|
7 |
+
@based on LSTM model
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from config import CONFIG
|
13 |
+
|
14 |
+
device = CONFIG['device']
|
15 |
+
|
16 |
+
# Define the LSTM-based Z-location estimator model
|
17 |
+
class Zloc_Estimator(nn.Module):
|
18 |
+
def __init__(self, input_dim, hidden_dim, layer_dim):
|
19 |
+
super(Zloc_Estimator, self).__init__()
|
20 |
+
|
21 |
+
# LSTM layer
|
22 |
+
self.rnn = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True, bidirectional=False)
|
23 |
+
|
24 |
+
# Fully connected layers
|
25 |
+
layersize = [306, 154, 76]
|
26 |
+
layerlist = []
|
27 |
+
n_in = hidden_dim
|
28 |
+
for i in layersize:
|
29 |
+
layerlist.append(nn.Linear(n_in, i))
|
30 |
+
layerlist.append(nn.ReLU())
|
31 |
+
n_in = i
|
32 |
+
layerlist.append(nn.Linear(layersize[-1], 1)) # Final output layer
|
33 |
+
|
34 |
+
self.fc = nn.Sequential(*layerlist)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
out, hn = self.rnn(x)
|
38 |
+
output = self.fc(out[:, -1]) # Get the last output for prediction
|
39 |
+
return output
|
40 |
+
|
41 |
+
# Deployment-ready class for handling the model
|
42 |
+
class LSTM_Model:
|
43 |
+
def __init__(self):
|
44 |
+
"""
|
45 |
+
Initializes the LSTM model for deployment with predefined parameters
|
46 |
+
and loads the pre-trained model weights.
|
47 |
+
|
48 |
+
:param model_path: Path to the pre-trained model weights file (.pth)
|
49 |
+
"""
|
50 |
+
self.input_dim = 15
|
51 |
+
self.hidden_dim = 612
|
52 |
+
self.layer_dim = 3
|
53 |
+
|
54 |
+
# Initialize the Z-location estimator model
|
55 |
+
self.model = Zloc_Estimator(self.input_dim, self.hidden_dim, self.layer_dim)
|
56 |
+
|
57 |
+
# Load the state dictionary from the file, using map_location in torch.load()
|
58 |
+
state_dict = torch.load(CONFIG['lstm_model_path'], map_location=device)
|
59 |
+
|
60 |
+
# Load the model with the state dictionary
|
61 |
+
self.model.load_state_dict(state_dict, strict=False)
|
62 |
+
self.model.to(device) # This line ensures the model is moved to the right device
|
63 |
+
self.model.eval() # Set the model to evaluation mode
|
64 |
+
|
65 |
+
|
66 |
+
def predict(self, data):
|
67 |
+
"""
|
68 |
+
Predicts the z-location based on input data.
|
69 |
+
|
70 |
+
:param data: Input tensor of shape (batch_size, input_dim)
|
71 |
+
:return: Predicted z-location as a tensor
|
72 |
+
"""
|
73 |
+
with torch.no_grad(): # Disable gradient computation for deployment
|
74 |
+
data = data.to(device) # Move data to the appropriate device
|
75 |
+
data = data.reshape(-1, 1, self.input_dim) # Reshape data to (batch_size, sequence_length, input_dim)
|
76 |
+
zloc = self.model(data)
|
77 |
+
return zloc.cpu() # Return the output in CPU memory for further processing
|
78 |
+
|
models/predict_z_location_single_row_lstm.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def predict_z_location_single_row_lstm(row, ZlocE, scaler):
|
6 |
+
"""
|
7 |
+
Preprocess bounding box coordinates, depth information, and class type
|
8 |
+
to predict Z-location using the LSTM model for a single row.
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
- row: A single row of DataFrame with bounding box coordinates, depth info, and class type.
|
12 |
+
- ZlocE: Pre-loaded LSTM model for Z-location prediction.
|
13 |
+
- scaler: Scaler for normalizing input data.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
- z_loc_prediction: Predicted Z-location for the given row.
|
17 |
+
"""
|
18 |
+
# One-hot encoding of class type
|
19 |
+
class_type = row['class']
|
20 |
+
|
21 |
+
if class_type == 'bicycle':
|
22 |
+
class_tensor = torch.tensor([[0, 1, 0, 0, 0, 0]], dtype=torch.float32)
|
23 |
+
elif class_type == 'car':
|
24 |
+
class_tensor = torch.tensor([[0, 0, 1, 0, 0, 0]], dtype=torch.float32)
|
25 |
+
elif class_type == 'person':
|
26 |
+
class_tensor = torch.tensor([[0, 0, 0, 1, 0, 0]], dtype=torch.float32)
|
27 |
+
elif class_type == 'train':
|
28 |
+
class_tensor = torch.tensor([[0, 0, 0, 0, 1, 0]], dtype=torch.float32)
|
29 |
+
elif class_type == 'truck':
|
30 |
+
class_tensor = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.float32)
|
31 |
+
else:
|
32 |
+
class_tensor = torch.tensor([[1, 0, 0, 0, 0, 0]], dtype=torch.float32)
|
33 |
+
|
34 |
+
# Prepare input data (bounding box + depth info)
|
35 |
+
input_data = np.array([row[['xmin', 'ymin', 'xmax', 'ymax', 'width', 'height', 'depth_mean', 'depth_median',
|
36 |
+
'depth_mean_trim']].values], dtype=np.float32)
|
37 |
+
input_data = torch.from_numpy(input_data)
|
38 |
+
|
39 |
+
# Concatenate class information
|
40 |
+
input_data = torch.cat([input_data, class_tensor], dim=-1)
|
41 |
+
|
42 |
+
# Scale the input data
|
43 |
+
scaled_input = torch.tensor(scaler.transform(input_data), dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
44 |
+
|
45 |
+
# Use the LSTM model to predict the Z-location
|
46 |
+
z_loc_prediction = ZlocE.predict(scaled_input).detach().numpy()[0]
|
47 |
+
|
48 |
+
return z_loc_prediction
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
torchvision>=0.15.0
|
3 |
+
transformers~=4.32.1
|
4 |
+
matplotlib>=3.7.0
|
5 |
+
pandas>=2.0.0
|
6 |
+
opencv-python>=4.8.0
|
7 |
+
uvicorn>=0.24.0
|
8 |
+
fastapi>=0.104.1
|
9 |
+
gunicorn==21.2.0
|
10 |
+
setuptools>=68.0.0
|
11 |
+
scipy>=1.11.0
|
12 |
+
scikit-learn>=1.3.0
|
13 |
+
websockets>=11.0.0
|
14 |
+
shapely>=2.0.0
|
15 |
+
cython>=3.0.0
|
16 |
+
numpy>=1.24.0
|
17 |
+
pillow~=9.4.0
|
18 |
+
joblib~=1.2.0
|
19 |
+
xgboost==3.0.0
|
20 |
+
joblib
|
try_page.html
ADDED
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>HTTP Image Upload Demo</title>
|
7 |
+
<style>
|
8 |
+
/* Reset and base styles */
|
9 |
+
* {
|
10 |
+
box-sizing: border-box;
|
11 |
+
margin: 0;
|
12 |
+
padding: 0;
|
13 |
+
}
|
14 |
+
|
15 |
+
body {
|
16 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
17 |
+
color: #333;
|
18 |
+
line-height: 1.6;
|
19 |
+
}
|
20 |
+
|
21 |
+
/* Demo section styling */
|
22 |
+
.execution-section {
|
23 |
+
max-width: 1200px;
|
24 |
+
margin: 0 auto;
|
25 |
+
padding: 2rem;
|
26 |
+
background-color: #f8f9fa;
|
27 |
+
border-radius: 8px;
|
28 |
+
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
|
29 |
+
}
|
30 |
+
|
31 |
+
.section-title {
|
32 |
+
font-size: 2rem;
|
33 |
+
color: #384B70;
|
34 |
+
margin-bottom: 1rem;
|
35 |
+
padding-bottom: 0.5rem;
|
36 |
+
border-bottom: 2px solid #507687;
|
37 |
+
}
|
38 |
+
|
39 |
+
.demo-container {
|
40 |
+
display: flex;
|
41 |
+
flex-wrap: wrap;
|
42 |
+
gap: 2rem;
|
43 |
+
margin-top: 1.5rem;
|
44 |
+
}
|
45 |
+
|
46 |
+
.upload-container, .response-container {
|
47 |
+
flex: 1;
|
48 |
+
min-width: 300px;
|
49 |
+
padding: 1.5rem;
|
50 |
+
background-color: white;
|
51 |
+
border-radius: 8px;
|
52 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
|
53 |
+
}
|
54 |
+
|
55 |
+
.container-title {
|
56 |
+
font-size: 1.5rem;
|
57 |
+
margin-bottom: 1rem;
|
58 |
+
color: #384B70;
|
59 |
+
}
|
60 |
+
|
61 |
+
/* Upload area styling */
|
62 |
+
.file-input-container {
|
63 |
+
border: 2px dashed #ccc;
|
64 |
+
border-radius: 5px;
|
65 |
+
padding: 2rem;
|
66 |
+
text-align: center;
|
67 |
+
margin-bottom: 1rem;
|
68 |
+
transition: all 0.3s ease;
|
69 |
+
}
|
70 |
+
|
71 |
+
.file-input-container:hover {
|
72 |
+
border-color: #507687;
|
73 |
+
background-color: #f8f9fa;
|
74 |
+
}
|
75 |
+
|
76 |
+
#fileInput {
|
77 |
+
display: none;
|
78 |
+
}
|
79 |
+
|
80 |
+
.file-label {
|
81 |
+
cursor: pointer;
|
82 |
+
display: flex;
|
83 |
+
flex-direction: column;
|
84 |
+
align-items: center;
|
85 |
+
gap: 0.5rem;
|
86 |
+
}
|
87 |
+
|
88 |
+
.file-icon {
|
89 |
+
font-size: 2.5rem;
|
90 |
+
color: #507687;
|
91 |
+
width: 64px;
|
92 |
+
height: 64px;
|
93 |
+
}
|
94 |
+
.file-placeholder {
|
95 |
+
max-width: 100%;
|
96 |
+
height: auto;
|
97 |
+
margin-top: 1rem;
|
98 |
+
border-radius: 4px;
|
99 |
+
display: none;
|
100 |
+
}
|
101 |
+
|
102 |
+
#sendButton {
|
103 |
+
background-color: #384B70;
|
104 |
+
color: white;
|
105 |
+
border: none;
|
106 |
+
border-radius: 4px;
|
107 |
+
padding: 0.75rem 1.5rem;
|
108 |
+
font-size: 1rem;
|
109 |
+
cursor: pointer;
|
110 |
+
transition: background-color 0.3s;
|
111 |
+
width: 100%;
|
112 |
+
margin-top: 1rem;
|
113 |
+
}
|
114 |
+
|
115 |
+
#sendButton:disabled {
|
116 |
+
background-color: #cccccc;
|
117 |
+
cursor: not-allowed;
|
118 |
+
}
|
119 |
+
|
120 |
+
#sendButton:hover:not(:disabled) {
|
121 |
+
background-color: #507687;
|
122 |
+
}
|
123 |
+
|
124 |
+
/* Response area styling */
|
125 |
+
.response-output {
|
126 |
+
height: 300px;
|
127 |
+
overflow-y: auto;
|
128 |
+
background-color: #f8f9fa;
|
129 |
+
border: 1px solid #ddd;
|
130 |
+
border-radius: 4px;
|
131 |
+
padding: 1rem;
|
132 |
+
font-family: monospace;
|
133 |
+
white-space: pre-wrap;
|
134 |
+
}
|
135 |
+
|
136 |
+
/* Tabs styling */
|
137 |
+
.tabs {
|
138 |
+
display: flex;
|
139 |
+
border-bottom: 1px solid #ddd;
|
140 |
+
margin-bottom: 1rem;
|
141 |
+
}
|
142 |
+
|
143 |
+
.tab-button {
|
144 |
+
padding: 0.5rem 1rem;
|
145 |
+
background-color: #f1f1f1;
|
146 |
+
border: none;
|
147 |
+
cursor: pointer;
|
148 |
+
transition: background-color 0.3s;
|
149 |
+
font-size: 1rem;
|
150 |
+
}
|
151 |
+
|
152 |
+
.tab-button.active {
|
153 |
+
background-color: #384B70;
|
154 |
+
color: white;
|
155 |
+
}
|
156 |
+
|
157 |
+
.tab-content {
|
158 |
+
display: none;
|
159 |
+
height: 300px;
|
160 |
+
}
|
161 |
+
|
162 |
+
.tab-content.active {
|
163 |
+
display: block;
|
164 |
+
}
|
165 |
+
|
166 |
+
/* Visualization area styling */
|
167 |
+
#visualizationContainer {
|
168 |
+
position: relative;
|
169 |
+
height: 100%;
|
170 |
+
overflow: auto;
|
171 |
+
background-color: #f8f9fa;
|
172 |
+
border: 1px solid #ddd;
|
173 |
+
border-radius: 4px;
|
174 |
+
}
|
175 |
+
|
176 |
+
.detection-canvas {
|
177 |
+
display: block;
|
178 |
+
margin: 0 auto;
|
179 |
+
}
|
180 |
+
|
181 |
+
/* Utilities */
|
182 |
+
#loading {
|
183 |
+
display: none;
|
184 |
+
margin-top: 1rem;
|
185 |
+
color: #384B70;
|
186 |
+
font-weight: bold;
|
187 |
+
text-align: center;
|
188 |
+
}
|
189 |
+
|
190 |
+
#message {
|
191 |
+
margin-top: 1rem;
|
192 |
+
padding: 0.75rem;
|
193 |
+
border-radius: 4px;
|
194 |
+
text-align: center;
|
195 |
+
display: none;
|
196 |
+
}
|
197 |
+
|
198 |
+
.error {
|
199 |
+
background-color: #ffebee;
|
200 |
+
color: #d32f2f;
|
201 |
+
}
|
202 |
+
|
203 |
+
.success {
|
204 |
+
background-color: #e8f5e9;
|
205 |
+
color: #388e3c;
|
206 |
+
}
|
207 |
+
|
208 |
+
.info {
|
209 |
+
font-size: 0.9rem;
|
210 |
+
color: #666;
|
211 |
+
margin-top: 0.5rem;
|
212 |
+
}
|
213 |
+
|
214 |
+
.stats {
|
215 |
+
margin-top: 1rem;
|
216 |
+
font-size: 0.9rem;
|
217 |
+
color: #666;
|
218 |
+
}
|
219 |
+
|
220 |
+
/* Debug output */
|
221 |
+
#debugOutput {
|
222 |
+
margin-top: 0.5rem;
|
223 |
+
font-size: 0.8rem;
|
224 |
+
color: #999;
|
225 |
+
border-top: 1px dashed #ddd;
|
226 |
+
padding-top: 0.5rem;
|
227 |
+
display: none;
|
228 |
+
}
|
229 |
+
</style>
|
230 |
+
</head>
|
231 |
+
<body>
|
232 |
+
<!-- Interactive Demo Section -->
|
233 |
+
<section class="execution-section">
|
234 |
+
<h2 class="section-title">Try It Yourself</h2>
|
235 |
+
<p>Upload an image and see the object detection and depth estimation results in real-time.</p>
|
236 |
+
|
237 |
+
<div class="demo-container">
|
238 |
+
<!-- Upload Container -->
|
239 |
+
<div class="upload-container">
|
240 |
+
<h3 class="container-title">Upload Image</h3>
|
241 |
+
|
242 |
+
<div class="file-input-container">
|
243 |
+
<label for="fileInput" class="file-label">
|
244 |
+
<img src="https://upload.wikimedia.org/wikipedia/commons/a/a1/Icons8_flat_folder.svg" class="file-icon"/>
|
245 |
+
<span>Click to select image</span>
|
246 |
+
<p class="info">PNG or JPEG, max 2MB</p>
|
247 |
+
</label>
|
248 |
+
<input type="file" accept="image/*" id="fileInput" />
|
249 |
+
<img id="imagePreview" class="file-placeholder" alt="Image preview" />
|
250 |
+
</div>
|
251 |
+
|
252 |
+
<button id="sendButton" disabled>Process Image</button>
|
253 |
+
<div id="loading">Processing your image...</div>
|
254 |
+
<div id="message"></div>
|
255 |
+
|
256 |
+
<div class="stats">
|
257 |
+
<div id="imageSize"></div>
|
258 |
+
<div id="processingTime"></div>
|
259 |
+
</div>
|
260 |
+
|
261 |
+
<div id="debugOutput"></div>
|
262 |
+
</div>
|
263 |
+
|
264 |
+
<!-- Response Container with Tabs -->
|
265 |
+
<div class="response-container">
|
266 |
+
<h3 class="container-title">Response</h3>
|
267 |
+
|
268 |
+
<div class="tabs">
|
269 |
+
<button class="tab-button active" data-tab="raw">Raw Output</button>
|
270 |
+
<button class="tab-button" data-tab="visual">Visual Output</button>
|
271 |
+
</div>
|
272 |
+
|
273 |
+
<!-- Raw Output Tab -->
|
274 |
+
<div id="rawTab" class="tab-content active">
|
275 |
+
<pre class="response-output" id="responseOutput">// Response will appear here after processing</pre>
|
276 |
+
</div>
|
277 |
+
|
278 |
+
<!-- Visual Output Tab -->
|
279 |
+
<div id="visualTab" class="tab-content">
|
280 |
+
<div id="visualizationContainer">
|
281 |
+
<canvas id="detectionCanvas" class="detection-canvas"></canvas>
|
282 |
+
</div>
|
283 |
+
</div>
|
284 |
+
</div>
|
285 |
+
</div>
|
286 |
+
</section>
|
287 |
+
|
288 |
+
<script>
|
289 |
+
// DOM Elements
|
290 |
+
const fileInput = document.getElementById('fileInput');
|
291 |
+
const imagePreview = document.getElementById('imagePreview');
|
292 |
+
const sendButton = document.getElementById('sendButton');
|
293 |
+
const loading = document.getElementById('loading');
|
294 |
+
const message = document.getElementById('message');
|
295 |
+
const responseOutput = document.getElementById('responseOutput');
|
296 |
+
const imageSizeInfo = document.getElementById('imageSize');
|
297 |
+
const processingTimeInfo = document.getElementById('processingTime');
|
298 |
+
const tabButtons = document.querySelectorAll('.tab-button');
|
299 |
+
const tabContents = document.querySelectorAll('.tab-content');
|
300 |
+
const detectionCanvas = document.getElementById('detectionCanvas');
|
301 |
+
const ctx = detectionCanvas.getContext('2d');
|
302 |
+
const debugOutput = document.getElementById('debugOutput');
|
303 |
+
|
304 |
+
// Enable debug mode (set to false in production)
|
305 |
+
const DEBUG = true;
|
306 |
+
|
307 |
+
// API endpoint URL
|
308 |
+
const API_URL = '/api/predict';
|
309 |
+
|
310 |
+
let imageFile = null;
|
311 |
+
let startTime = null;
|
312 |
+
let originalImage = null;
|
313 |
+
let processingWidth = 0;
|
314 |
+
let processingHeight = 0;
|
315 |
+
let responseData = null;
|
316 |
+
|
317 |
+
// Tab switching functionality
|
318 |
+
tabButtons.forEach(button => {
|
319 |
+
button.addEventListener('click', () => {
|
320 |
+
const tabName = button.getAttribute('data-tab');
|
321 |
+
|
322 |
+
// Update button states
|
323 |
+
tabButtons.forEach(btn => btn.classList.remove('active'));
|
324 |
+
button.classList.add('active');
|
325 |
+
|
326 |
+
// Update tab content visibility
|
327 |
+
tabContents.forEach(content => content.classList.remove('active'));
|
328 |
+
document.getElementById(tabName + 'Tab').classList.add('active');
|
329 |
+
|
330 |
+
// If switching to visual tab and we have data, ensure visualization is rendered
|
331 |
+
if (tabName === 'visual' && responseData && originalImage) {
|
332 |
+
visualizeResults(originalImage, responseData);
|
333 |
+
}
|
334 |
+
});
|
335 |
+
});
|
336 |
+
|
337 |
+
// Handle file input change
|
338 |
+
fileInput.addEventListener('change', (event) => {
|
339 |
+
const file = event.target.files[0];
|
340 |
+
|
341 |
+
// Clear previous selections
|
342 |
+
imageFile = null;
|
343 |
+
imagePreview.style.display = 'none';
|
344 |
+
sendButton.disabled = true;
|
345 |
+
originalImage = null;
|
346 |
+
responseData = null;
|
347 |
+
|
348 |
+
// Validate file
|
349 |
+
if (!file) return;
|
350 |
+
|
351 |
+
if (file.size > 2 * 1024 * 1024) {
|
352 |
+
showMessage('File size exceeds 2MB limit.', 'error');
|
353 |
+
return;
|
354 |
+
}
|
355 |
+
|
356 |
+
if (!['image/png', 'image/jpeg'].includes(file.type)) {
|
357 |
+
showMessage('Only PNG and JPEG formats are supported.', 'error');
|
358 |
+
return;
|
359 |
+
}
|
360 |
+
|
361 |
+
// Store file for upload
|
362 |
+
imageFile = file;
|
363 |
+
|
364 |
+
// Show image preview
|
365 |
+
const reader = new FileReader();
|
366 |
+
reader.onload = (e) => {
|
367 |
+
const image = new Image();
|
368 |
+
image.src = e.target.result;
|
369 |
+
|
370 |
+
image.onload = () => {
|
371 |
+
// Store original image for visualization
|
372 |
+
originalImage = image;
|
373 |
+
|
374 |
+
// Set preview
|
375 |
+
imagePreview.src = e.target.result;
|
376 |
+
imagePreview.style.display = 'block';
|
377 |
+
|
378 |
+
// Update image info
|
379 |
+
imageSizeInfo.textContent = `Original size: ${image.width}x${image.height} pixels`;
|
380 |
+
|
381 |
+
// Calculate processing dimensions (for visualization)
|
382 |
+
calculateProcessingDimensions(image.width, image.height);
|
383 |
+
|
384 |
+
// Enable send button
|
385 |
+
sendButton.disabled = false;
|
386 |
+
showMessage('Image ready to process.', 'info');
|
387 |
+
};
|
388 |
+
};
|
389 |
+
reader.readAsDataURL(file);
|
390 |
+
});
|
391 |
+
|
392 |
+
// Calculate dimensions for processing visualization
|
393 |
+
function calculateProcessingDimensions(width, height) {
|
394 |
+
const maxWidth = 640;
|
395 |
+
const maxHeight = 320;
|
396 |
+
|
397 |
+
// Calculate dimensions
|
398 |
+
if (width > height) {
|
399 |
+
if (width > maxWidth) {
|
400 |
+
height = Math.round((height * maxWidth) / width);
|
401 |
+
width = maxWidth;
|
402 |
+
}
|
403 |
+
} else {
|
404 |
+
if (height > maxHeight) {
|
405 |
+
width = Math.round((width * maxHeight) / height);
|
406 |
+
height = maxHeight;
|
407 |
+
}
|
408 |
+
}
|
409 |
+
|
410 |
+
// Store processing dimensions for visualization
|
411 |
+
processingWidth = width;
|
412 |
+
processingHeight = height;
|
413 |
+
}
|
414 |
+
|
415 |
+
// Handle send button click
|
416 |
+
sendButton.addEventListener('click', async () => {
|
417 |
+
if (!imageFile) {
|
418 |
+
showMessage('No image selected.', 'error');
|
419 |
+
return;
|
420 |
+
}
|
421 |
+
|
422 |
+
// Clear previous response
|
423 |
+
responseOutput.textContent = "// Processing...";
|
424 |
+
clearCanvas();
|
425 |
+
responseData = null;
|
426 |
+
debugOutput.style.display = 'none';
|
427 |
+
|
428 |
+
// Show loading state
|
429 |
+
loading.style.display = 'block';
|
430 |
+
message.style.display = 'none';
|
431 |
+
|
432 |
+
// Reset processing time
|
433 |
+
processingTimeInfo.textContent = '';
|
434 |
+
|
435 |
+
// Record start time
|
436 |
+
startTime = performance.now();
|
437 |
+
|
438 |
+
// Create form data for HTTP request
|
439 |
+
const formData = new FormData();
|
440 |
+
formData.append('file', imageFile);
|
441 |
+
|
442 |
+
try {
|
443 |
+
// Send HTTP request
|
444 |
+
const response = await fetch(API_URL, {
|
445 |
+
method: 'POST',
|
446 |
+
body: formData
|
447 |
+
});
|
448 |
+
|
449 |
+
// Handle response
|
450 |
+
if (!response.ok) {
|
451 |
+
const errorText = await response.text();
|
452 |
+
throw new Error(`HTTP error ${response.status}: ${errorText}`);
|
453 |
+
}
|
454 |
+
|
455 |
+
// Parse JSON response
|
456 |
+
const data = await response.json();
|
457 |
+
responseData = data;
|
458 |
+
|
459 |
+
// Calculate processing time
|
460 |
+
const endTime = performance.now();
|
461 |
+
const timeTaken = endTime - startTime;
|
462 |
+
|
463 |
+
// Format and display raw response
|
464 |
+
responseOutput.textContent = JSON.stringify(data, null, 2);
|
465 |
+
processingTimeInfo.textContent = `Processing time: ${timeTaken.toFixed(2)} ms`;
|
466 |
+
|
467 |
+
// Visualize the results
|
468 |
+
if (originalImage) {
|
469 |
+
visualizeResults(originalImage, data);
|
470 |
+
}
|
471 |
+
|
472 |
+
// Show success message
|
473 |
+
showMessage('Image processed successfully!', 'success');
|
474 |
+
} catch (error) {
|
475 |
+
console.error('Error processing image:', error);
|
476 |
+
showMessage(`Error: ${error.message}`, 'error');
|
477 |
+
responseOutput.textContent = `// Error: ${error.message}`;
|
478 |
+
|
479 |
+
if (DEBUG) {
|
480 |
+
debugOutput.style.display = 'block';
|
481 |
+
debugOutput.textContent = `Error: ${error.message}\n${error.stack || ''}`;
|
482 |
+
}
|
483 |
+
} finally {
|
484 |
+
loading.style.display = 'none';
|
485 |
+
}
|
486 |
+
});
|
487 |
+
|
488 |
+
// Visualize detection results
|
489 |
+
function visualizeResults(image, data) {
|
490 |
+
try {
|
491 |
+
// Set canvas dimensions
|
492 |
+
detectionCanvas.width = processingWidth;
|
493 |
+
detectionCanvas.height = processingHeight;
|
494 |
+
|
495 |
+
// Draw the original image
|
496 |
+
ctx.drawImage(image, 0, 0, processingWidth, processingHeight);
|
497 |
+
|
498 |
+
// Set styles for bounding boxes
|
499 |
+
ctx.lineWidth = 3;
|
500 |
+
ctx.font = 'bold 14px Arial';
|
501 |
+
|
502 |
+
// Find detections (checking all common formats)
|
503 |
+
let detections = [];
|
504 |
+
let detectionSource = '';
|
505 |
+
|
506 |
+
if (data.detections && Array.isArray(data.detections)) {
|
507 |
+
detections = data.detections;
|
508 |
+
detectionSource = 'detections';
|
509 |
+
} else if (data.predictions && Array.isArray(data.predictions)) {
|
510 |
+
detections = data.predictions;
|
511 |
+
detectionSource = 'predictions';
|
512 |
+
} else if (data.objects && Array.isArray(data.objects)) {
|
513 |
+
detections = data.objects;
|
514 |
+
detectionSource = 'objects';
|
515 |
+
} else if (data.results && Array.isArray(data.results)) {
|
516 |
+
detections = data.results;
|
517 |
+
detectionSource = 'results';
|
518 |
+
} else {
|
519 |
+
// Try to look one level deeper if no detections found
|
520 |
+
for (const key in data) {
|
521 |
+
if (typeof data[key] === 'object' && data[key] !== null) {
|
522 |
+
if (Array.isArray(data[key])) {
|
523 |
+
detections = data[key];
|
524 |
+
detectionSource = key;
|
525 |
+
break;
|
526 |
+
} else {
|
527 |
+
// Look one more level down
|
528 |
+
for (const subKey in data[key]) {
|
529 |
+
if (Array.isArray(data[key][subKey])) {
|
530 |
+
detections = data[key][subKey];
|
531 |
+
detectionSource = `${key}.${subKey}`;
|
532 |
+
break;
|
533 |
+
}
|
534 |
+
}
|
535 |
+
}
|
536 |
+
}
|
537 |
+
}
|
538 |
+
}
|
539 |
+
|
540 |
+
// Process each detection
|
541 |
+
detections.forEach((detection, index) => {
|
542 |
+
// Try to extract bounding box information
|
543 |
+
let bbox = null;
|
544 |
+
let label = null;
|
545 |
+
let confidence = null;
|
546 |
+
let distance = null;
|
547 |
+
|
548 |
+
// Extract label/class
|
549 |
+
if (detection.class !== undefined) {
|
550 |
+
label = detection.class;
|
551 |
+
} else {
|
552 |
+
// Fallback to other common property names
|
553 |
+
for (const key of ['label', 'name', 'category', 'className']) {
|
554 |
+
if (detection[key] !== undefined) {
|
555 |
+
label = detection[key];
|
556 |
+
break;
|
557 |
+
}
|
558 |
+
}
|
559 |
+
}
|
560 |
+
|
561 |
+
// Default label if none found
|
562 |
+
if (!label) label = `Object ${index + 1}`;
|
563 |
+
|
564 |
+
// Extract confidence score if available
|
565 |
+
for (const key of ['confidence', 'score', 'probability', 'conf']) {
|
566 |
+
if (detection[key] !== undefined) {
|
567 |
+
confidence = detection[key];
|
568 |
+
break;
|
569 |
+
}
|
570 |
+
}
|
571 |
+
|
572 |
+
// Extract distance - specifically look for distance_estimated first
|
573 |
+
if (detection.distance_estimated !== undefined) {
|
574 |
+
distance = detection.distance_estimated;
|
575 |
+
} else {
|
576 |
+
// Fallback to other common distance properties
|
577 |
+
for (const key of ['distance', 'depth', 'z', 'dist', 'range']) {
|
578 |
+
if (detection[key] !== undefined) {
|
579 |
+
distance = detection[key];
|
580 |
+
break;
|
581 |
+
}
|
582 |
+
}
|
583 |
+
}
|
584 |
+
|
585 |
+
// Look for bounding box in features
|
586 |
+
if (detection.features &&
|
587 |
+
detection.features.xmin !== undefined &&
|
588 |
+
detection.features.ymin !== undefined &&
|
589 |
+
detection.features.xmax !== undefined &&
|
590 |
+
detection.features.ymax !== undefined) {
|
591 |
+
|
592 |
+
bbox = {
|
593 |
+
xmin: detection.features.xmin,
|
594 |
+
ymin: detection.features.ymin,
|
595 |
+
xmax: detection.features.xmax,
|
596 |
+
ymax: detection.features.ymax
|
597 |
+
};
|
598 |
+
} else {
|
599 |
+
// Recursively search for bbox-like properties
|
600 |
+
function findBBox(obj, path = '') {
|
601 |
+
if (!obj || typeof obj !== 'object') return null;
|
602 |
+
|
603 |
+
// Check if this object looks like a bbox
|
604 |
+
if ((obj.x !== undefined && obj.y !== undefined &&
|
605 |
+
(obj.width !== undefined || obj.w !== undefined ||
|
606 |
+
obj.height !== undefined || obj.h !== undefined)) ||
|
607 |
+
(obj.xmin !== undefined && obj.ymin !== undefined &&
|
608 |
+
obj.xmax !== undefined && obj.ymax !== undefined)) {
|
609 |
+
return obj;
|
610 |
+
}
|
611 |
+
|
612 |
+
// Check if it's an array of 4 numbers (potential bbox)
|
613 |
+
if (Array.isArray(obj) && obj.length === 4 &&
|
614 |
+
obj.every(item => typeof item === 'number')) {
|
615 |
+
return obj;
|
616 |
+
}
|
617 |
+
|
618 |
+
// Check common bbox property names
|
619 |
+
for (const key of ['bbox', 'box', 'bounding_box', 'boundingBox']) {
|
620 |
+
if (obj[key] !== undefined) {
|
621 |
+
return obj[key];
|
622 |
+
}
|
623 |
+
}
|
624 |
+
|
625 |
+
// Search nested properties
|
626 |
+
for (const key in obj) {
|
627 |
+
const result = findBBox(obj[key], path ? `${path}.${key}` : key);
|
628 |
+
if (result) return result;
|
629 |
+
}
|
630 |
+
|
631 |
+
return null;
|
632 |
+
}
|
633 |
+
|
634 |
+
// Find bbox using recursive search as fallback
|
635 |
+
bbox = findBBox(detection);
|
636 |
+
}
|
637 |
+
|
638 |
+
// If we found a bounding box, draw it
|
639 |
+
if (bbox) {
|
640 |
+
// Parse different bbox formats
|
641 |
+
let x, y, width, height;
|
642 |
+
|
643 |
+
if (Array.isArray(bbox)) {
|
644 |
+
// Try to determine array format
|
645 |
+
if (bbox.length === 4) {
|
646 |
+
if (bbox[0] >= 0 && bbox[1] >= 0 && bbox[2] <= 1 && bbox[3] <= 1) {
|
647 |
+
// Likely normalized [x1, y1, x2, y2]
|
648 |
+
x = bbox[0] * processingWidth;
|
649 |
+
y = bbox[1] * processingHeight;
|
650 |
+
width = (bbox[2] - bbox[0]) * processingWidth;
|
651 |
+
height = (bbox[3] - bbox[1]) * processingHeight;
|
652 |
+
} else if (bbox[2] > bbox[0] && bbox[3] > bbox[1]) {
|
653 |
+
// Likely [x1, y1, x2, y2]
|
654 |
+
x = bbox[0];
|
655 |
+
y = bbox[1];
|
656 |
+
width = bbox[2] - bbox[0];
|
657 |
+
height = bbox[3] - bbox[1];
|
658 |
+
} else {
|
659 |
+
// Assume [x, y, width, height]
|
660 |
+
x = bbox[0];
|
661 |
+
y = bbox[1];
|
662 |
+
width = bbox[2];
|
663 |
+
height = bbox[3];
|
664 |
+
}
|
665 |
+
}
|
666 |
+
} else {
|
667 |
+
// Object format with named properties
|
668 |
+
if (bbox.x !== undefined && bbox.y !== undefined) {
|
669 |
+
x = bbox.x;
|
670 |
+
y = bbox.y;
|
671 |
+
width = bbox.width || bbox.w || 0;
|
672 |
+
height = bbox.height || bbox.h || 0;
|
673 |
+
} else if (bbox.xmin !== undefined && bbox.ymin !== undefined) {
|
674 |
+
x = bbox.xmin;
|
675 |
+
y = bbox.ymin;
|
676 |
+
width = (bbox.xmax || 0) - bbox.xmin;
|
677 |
+
height = (bbox.ymax || 0) - bbox.ymin;
|
678 |
+
}
|
679 |
+
}
|
680 |
+
|
681 |
+
// Validate coordinates
|
682 |
+
if (x === undefined || y === undefined || width === undefined || height === undefined) {
|
683 |
+
return;
|
684 |
+
}
|
685 |
+
|
686 |
+
// Check if we need to scale normalized coordinates (0-1)
|
687 |
+
if (x >= 0 && x <= 1 && y >= 0 && y <= 1 && width >= 0 && width <= 1 && height >= 0 && height <= 1) {
|
688 |
+
x = x * processingWidth;
|
689 |
+
y = y * processingHeight;
|
690 |
+
width = width * processingWidth;
|
691 |
+
height = height * processingHeight;
|
692 |
+
}
|
693 |
+
|
694 |
+
// Generate a color based on the class name
|
695 |
+
const hue = stringToHue(label);
|
696 |
+
ctx.strokeStyle = `hsl(${hue}, 100%, 40%)`;
|
697 |
+
ctx.fillStyle = `hsla(${hue}, 100%, 40%, 0.3)`;
|
698 |
+
|
699 |
+
// Draw bounding box
|
700 |
+
ctx.beginPath();
|
701 |
+
ctx.rect(x, y, width, height);
|
702 |
+
ctx.stroke();
|
703 |
+
ctx.fill();
|
704 |
+
|
705 |
+
// Format confidence value
|
706 |
+
let confidenceText = "";
|
707 |
+
if (confidence !== null && confidence !== undefined) {
|
708 |
+
// Convert to percentage if it's a probability (0-1)
|
709 |
+
if (confidence <= 1) {
|
710 |
+
confidence = (confidence * 100).toFixed(0);
|
711 |
+
} else {
|
712 |
+
confidence = confidence.toFixed(0);
|
713 |
+
}
|
714 |
+
confidenceText = ` ${confidence}%`;
|
715 |
+
}
|
716 |
+
|
717 |
+
// Format distance value
|
718 |
+
let distanceText = "";
|
719 |
+
if (distance !== null && distance !== undefined) {
|
720 |
+
distanceText = ` : ${distance.toFixed(2)} m`;
|
721 |
+
}
|
722 |
+
|
723 |
+
// Create label text
|
724 |
+
const labelText = `${label}${confidenceText}${distanceText}`;
|
725 |
+
|
726 |
+
// Measure text width
|
727 |
+
const textWidth = ctx.measureText(labelText).width + 10;
|
728 |
+
|
729 |
+
// Draw label background
|
730 |
+
ctx.fillStyle = `hsl(${hue}, 100%, 40%)`;
|
731 |
+
ctx.fillRect(x, y - 20, textWidth, 20);
|
732 |
+
|
733 |
+
// Draw label text
|
734 |
+
ctx.fillStyle = "white";
|
735 |
+
ctx.fillText(labelText, x + 5, y - 5);
|
736 |
+
}
|
737 |
+
});
|
738 |
+
|
739 |
+
} catch (error) {
|
740 |
+
console.error('Error visualizing results:', error);
|
741 |
+
debugOutput.style.display = 'block';
|
742 |
+
debugOutput.textContent += `VISUALIZATION ERROR: ${error.message}\n`;
|
743 |
+
debugOutput.textContent += `Error stack: ${error.stack}\n`;
|
744 |
+
}
|
745 |
+
}
|
746 |
+
|
747 |
+
// Generate consistent hue for string
|
748 |
+
function stringToHue(str) {
|
749 |
+
let hash = 0;
|
750 |
+
for (let i = 0; i < str.length; i++) {
|
751 |
+
hash = str.charCodeAt(i) + ((hash << 5) - hash);
|
752 |
+
}
|
753 |
+
return hash % 360;
|
754 |
+
}
|
755 |
+
|
756 |
+
// Clear canvas
|
757 |
+
function clearCanvas() {
|
758 |
+
if (detectionCanvas.getContext) {
|
759 |
+
ctx.clearRect(0, 0, detectionCanvas.width, detectionCanvas.height);
|
760 |
+
}
|
761 |
+
}
|
762 |
+
|
763 |
+
// Show message function
|
764 |
+
function showMessage(text, type) {
|
765 |
+
message.textContent = text;
|
766 |
+
message.className = '';
|
767 |
+
message.classList.add(type);
|
768 |
+
message.style.display = 'block';
|
769 |
+
|
770 |
+
if (type === 'info') {
|
771 |
+
setTimeout(() => {
|
772 |
+
message.style.display = 'none';
|
773 |
+
}, 3000);
|
774 |
+
}
|
775 |
+
}
|
776 |
+
</script>
|
777 |
+
</body>
|
778 |
+
</html>
|
utils/JSON_output.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.predict_z_location_single_row_lstm import predict_z_location_single_row
|
2 |
+
|
3 |
+
def generate_output_json(data, ZlocE, scaler):
|
4 |
+
"""
|
5 |
+
Predict Z-location for each object in the data and prepare the JSON output.
|
6 |
+
|
7 |
+
Parameters:
|
8 |
+
- data: DataFrame with bounding box coordinates, depth information, and class type.
|
9 |
+
- ZlocE: Pre-loaded LSTM model for Z-location prediction.
|
10 |
+
- scaler: Scaler for normalizing input data.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
- JSON structure with object class, distance estimated, and relevant features.
|
14 |
+
"""
|
15 |
+
output_json = []
|
16 |
+
|
17 |
+
# Iterate over each row in the data
|
18 |
+
for i, row in data.iterrows():
|
19 |
+
# Predict distance for each object using the single-row prediction function
|
20 |
+
distance = predict_z_location_single_row(row, ZlocE, scaler)
|
21 |
+
|
22 |
+
# Create object info dictionary
|
23 |
+
object_info = {
|
24 |
+
"class": row["class"], # Object class (e.g., 'car', 'truck')
|
25 |
+
"distance_estimated": float(distance), # Convert distance to float (if necessary)
|
26 |
+
"features": {
|
27 |
+
"xmin": float(row["xmin"]), # Bounding box xmin
|
28 |
+
"ymin": float(row["ymin"]), # Bounding box ymin
|
29 |
+
"xmax": float(row["xmax"]), # Bounding box xmax
|
30 |
+
"ymax": float(row["ymax"]), # Bounding box ymax
|
31 |
+
"mean_depth": float(row["depth_mean"]), # Depth mean
|
32 |
+
"depth_mean_trim": float(row["depth_mean_trim"]), # Depth mean trim
|
33 |
+
"depth_median": float(row["depth_median"]), # Depth median
|
34 |
+
"width": float(row["width"]), # Object width
|
35 |
+
"height": float(row["height"]) # Object height
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
# Append each object info to the output JSON list
|
40 |
+
output_json.append(object_info)
|
41 |
+
|
42 |
+
# Return the final JSON output structure
|
43 |
+
return {"objects": output_json}
|
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/JSON_output.cpython-311.pyc
ADDED
Binary file (1.93 kB). View file
|
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (183 Bytes). View file
|
|
utils/__pycache__/processing.cpython-311.pyc
ADDED
Binary file (3.25 kB). View file
|
|
utils/build/lib.win-amd64-cpython-311/processing_cy.cp311-win_amd64.pyd
ADDED
Binary file (71.7 kB). View file
|
|
utils/build/temp.win-amd64-cpython-311/Release/processing.obj
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:651a39060235c15b7fecf63a9e462ae0c379f5a194a745110cf259131fc235a2
|
3 |
+
size 539061
|
utils/build/temp.win-amd64-cpython-311/Release/processing_cy.cp311-win_amd64.exp
ADDED
Binary file (806 Bytes). View file
|
|
utils/build/temp.win-amd64-cpython-311/Release/processing_cy.cp311-win_amd64.lib
ADDED
Binary file (2.12 kB). View file
|
|
utils/processing.c
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils/processing.html
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils/processing.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Created on Sat Apr 9 04:08:02 2022
|
3 |
+
@author: Admin_with ODD Team
|
4 |
+
|
5 |
+
Edited by our team : Sat Oct 12 2024
|
6 |
+
|
7 |
+
references: https://github.com/vinvino02/GLPDepth
|
8 |
+
"""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from utils.processing_cy import process_bbox_depth_cy, handle_overlaps_cy
|
15 |
+
|
16 |
+
class PROCESSING:
|
17 |
+
def process_detections(self, scores, boxes, depth_map, detr):
|
18 |
+
self.data = pd.DataFrame(columns=['xmin','ymin','xmax','ymax','width', 'height',
|
19 |
+
'depth_mean_trim','depth_mean','depth_median',
|
20 |
+
'class', 'rgb'])
|
21 |
+
|
22 |
+
boxes_array = np.array([[int(box[1]), int(box[0]), int(box[3]), int(box[2])]
|
23 |
+
for box in boxes.tolist()], dtype=np.int32)
|
24 |
+
|
25 |
+
# Use Cython-optimized overlap handling
|
26 |
+
valid_indices = handle_overlaps_cy(depth_map, boxes_array)
|
27 |
+
|
28 |
+
for idx in valid_indices:
|
29 |
+
p = scores[idx]
|
30 |
+
box = boxes[idx]
|
31 |
+
xmin, ymin, xmax, ymax = map(int, box)
|
32 |
+
|
33 |
+
detected_class = p.argmax()
|
34 |
+
class_label = detr.CLASSES[detected_class]
|
35 |
+
|
36 |
+
# Map classes
|
37 |
+
if class_label == 'motorcycle':
|
38 |
+
class_label = 'bicycle'
|
39 |
+
elif class_label == 'bus':
|
40 |
+
class_label = 'train'
|
41 |
+
elif class_label not in ['person', 'truck', 'car', 'bicycle', 'train']:
|
42 |
+
class_label = 'Misc'
|
43 |
+
|
44 |
+
if class_label in ['Misc', 'person', 'truck', 'car', 'bicycle', 'train']:
|
45 |
+
# Use Cython-optimized depth calculations
|
46 |
+
depth_mean, depth_median, (depth_trim_low, depth_trim_high) = \
|
47 |
+
process_bbox_depth_cy(depth_map, ymin, ymax, xmin, xmax)
|
48 |
+
|
49 |
+
class_index = ['Misc', 'person', 'truck', 'car', 'bicycle', 'train'].index(class_label)
|
50 |
+
r, g, b = detr.COLORS[class_index]
|
51 |
+
rgb = (r * 255, g * 255, b * 255)
|
52 |
+
|
53 |
+
new_row = pd.DataFrame([[xmin, ymin, xmax, ymax, xmax - xmin, ymax - ymin,
|
54 |
+
(depth_trim_low + depth_trim_high) / 2,
|
55 |
+
depth_mean, depth_median, class_label, rgb]],
|
56 |
+
columns=self.data.columns)
|
57 |
+
self.data = pd.concat([self.data, new_row], ignore_index=True)
|
58 |
+
|
59 |
+
return self.data
|
utils/processing.pyx
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File: processing.pyx
|
2 |
+
# cython: language_level=3, boundscheck=False, wraparound=False, nonecheck=False, cdivision=True
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
cimport numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
cimport pandas as pd
|
8 |
+
from libc.math cimport isnan
|
9 |
+
from cpython.mem cimport PyMem_Malloc, PyMem_Free
|
10 |
+
|
11 |
+
# Define C types for better performance
|
12 |
+
ctypedef np.float32_t DTYPE_t
|
13 |
+
ctypedef np.int32_t ITYPE_t
|
14 |
+
|
15 |
+
def process_bbox_depth_cy(np.ndarray[DTYPE_t, ndim=2] depth_map,
|
16 |
+
int y_min, int y_max, int x_min, int x_max):
|
17 |
+
"""
|
18 |
+
Optimized bbox depth calculations using Cython
|
19 |
+
"""
|
20 |
+
cdef:
|
21 |
+
int i, j, count = 0
|
22 |
+
double sum_val = 0.0
|
23 |
+
double mean_val = 0.0
|
24 |
+
np.ndarray[DTYPE_t, ndim=1] flat_vals
|
25 |
+
int flat_size = 0
|
26 |
+
|
27 |
+
for i in range(y_min, y_max):
|
28 |
+
for j in range(x_min, x_max):
|
29 |
+
if not isnan(depth_map[i, j]):
|
30 |
+
sum_val += depth_map[i, j]
|
31 |
+
count += 1
|
32 |
+
|
33 |
+
if count > 0:
|
34 |
+
mean_val = sum_val / count
|
35 |
+
|
36 |
+
# Create array for trimmed mean calculation
|
37 |
+
flat_vals = np.zeros(count, dtype=np.float32)
|
38 |
+
flat_size = 0
|
39 |
+
|
40 |
+
for i in range(y_min, y_max):
|
41 |
+
for j in range(x_min, x_max):
|
42 |
+
if not isnan(depth_map[i, j]):
|
43 |
+
flat_vals[flat_size] = depth_map[i, j]
|
44 |
+
flat_size += 1
|
45 |
+
|
46 |
+
return mean_val, np.median(flat_vals), np.percentile(flat_vals, [20, 80])
|
47 |
+
|
48 |
+
def handle_overlaps_cy(np.ndarray[DTYPE_t, ndim=2] depth_map,
|
49 |
+
np.ndarray[ITYPE_t, ndim=2] boxes):
|
50 |
+
"""
|
51 |
+
Optimized overlap handling using Cython
|
52 |
+
"""
|
53 |
+
cdef:
|
54 |
+
int n_boxes = boxes.shape[0]
|
55 |
+
int i, j
|
56 |
+
int y_min1, y_max1, x_min1, x_max1
|
57 |
+
int y_min2, y_max2, x_min2, x_max2
|
58 |
+
double area1, area2, area_intersection
|
59 |
+
bint* to_remove = <bint*>PyMem_Malloc(n_boxes * sizeof(bint))
|
60 |
+
|
61 |
+
if not to_remove:
|
62 |
+
raise MemoryError()
|
63 |
+
|
64 |
+
try:
|
65 |
+
for i in range(n_boxes):
|
66 |
+
to_remove[i] = False
|
67 |
+
|
68 |
+
for i in range(n_boxes):
|
69 |
+
if to_remove[i]:
|
70 |
+
continue
|
71 |
+
|
72 |
+
y_min1, x_min1, y_max1, x_max1 = boxes[i]
|
73 |
+
|
74 |
+
for j in range(i + 1, n_boxes):
|
75 |
+
if to_remove[j]:
|
76 |
+
continue
|
77 |
+
|
78 |
+
y_min2, x_min2, y_max2, x_max2 = boxes[j]
|
79 |
+
|
80 |
+
# Calculate intersection
|
81 |
+
y_min_int = max(y_min1, y_min2)
|
82 |
+
y_max_int = min(y_max1, y_max2)
|
83 |
+
x_min_int = max(x_min1, x_min2)
|
84 |
+
x_max_int = min(x_max1, x_max2)
|
85 |
+
|
86 |
+
if y_min_int < y_max_int and x_min_int < x_max_int:
|
87 |
+
area1 = (y_max1 - y_min1) * (x_max1 - x_min1)
|
88 |
+
area2 = (y_max2 - y_min2) * (x_max2 - x_min2)
|
89 |
+
area_intersection = (y_max_int - y_min_int) * (x_max_int - x_min_int)
|
90 |
+
|
91 |
+
if area_intersection / min(area1, area2) >= 0.70:
|
92 |
+
if area1 < area2:
|
93 |
+
to_remove[i] = True
|
94 |
+
break
|
95 |
+
else:
|
96 |
+
to_remove[j] = True
|
97 |
+
|
98 |
+
return np.array([i for i in range(n_boxes) if not to_remove[i]], dtype=np.int32)
|
99 |
+
|
100 |
+
finally:
|
101 |
+
PyMem_Free(to_remove)
|
utils/processing_cy.cp311-win_amd64.pyd
ADDED
Binary file (71.7 kB). View file
|
|
utils/setup.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, Extension
|
2 |
+
from Cython.Build import cythonize
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
extensions = [
|
6 |
+
Extension(
|
7 |
+
"processing_cy",
|
8 |
+
["processing.pyx"],
|
9 |
+
include_dirs=[np.get_include()],
|
10 |
+
extra_compile_args=["-O3", "-march=native", "-fopenmp"],
|
11 |
+
extra_link_args=["-fopenmp"]
|
12 |
+
)
|
13 |
+
]
|
14 |
+
|
15 |
+
setup(
|
16 |
+
ext_modules=cythonize(extensions, annotate=True),
|
17 |
+
include_dirs=[np.get_include()]
|
18 |
+
)
|
utils/visualization.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib.patches as patches
|
5 |
+
|
6 |
+
|
7 |
+
def plot_depth_with_boxes(depth_map, depth_data):
|
8 |
+
"""
|
9 |
+
Plots the depth map with bounding boxes overlayed.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
depth_map (numpy.ndarray): The depth map to visualize.
|
13 |
+
depth_data (pandas.DataFrame): DataFrame containing bounding box coordinates, depth statistics, and class labels.
|
14 |
+
"""
|
15 |
+
# Normalize the depth map for better visualization
|
16 |
+
depth_map_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
|
17 |
+
|
18 |
+
# Create a figure and axis
|
19 |
+
fig, ax = plt.subplots(1, figsize=(12, 6))
|
20 |
+
|
21 |
+
# Display the depth map
|
22 |
+
ax.imshow(depth_map_normalized, cmap='plasma') # You can change the colormap as desired
|
23 |
+
ax.axis('off') # Hide the axes
|
24 |
+
|
25 |
+
# Loop through the DataFrame and add rectangles
|
26 |
+
for index, row in depth_data.iterrows():
|
27 |
+
xmin, ymin, xmax, ymax = row[['xmin', 'ymin', 'xmax', 'ymax']]
|
28 |
+
class_label = row['class']
|
29 |
+
score = row['depth_mean'] # or whichever statistic you prefer to display
|
30 |
+
|
31 |
+
# Create a rectangle patch
|
32 |
+
rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='yellow', facecolor='none')
|
33 |
+
|
34 |
+
# Add the rectangle to the plot
|
35 |
+
ax.add_patch(rect)
|
36 |
+
|
37 |
+
# Add a text label
|
38 |
+
ax.text(xmin, ymin - 5, f'{class_label}: {score:.2f}', color='white', fontsize=12, weight='bold')
|
39 |
+
|
40 |
+
plt.title('Depth Map with Object Detection Bounding Boxes', fontsize=16)
|
41 |
+
plt.show()
|
42 |
+
|