advances / app.py
OussamaElfila21's picture
Upload 35 files
5d92054 verified
import joblib
import uvicorn
import xgboost as xgb
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.responses import JSONResponse
import asyncio
import json
import pickle
import warnings
import os
import io
import timeit
from PIL import Image
import numpy as np
import cv2
# Add this to your existing imports if not already present
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from models.detr_model import DETR
from models.glpn_model import GLPDepth
from models.lstm_model import LSTM_Model
from models.predict_z_location_single_row_lstm import predict_z_location_single_row_lstm
from utils.processing import PROCESSING
from config import CONFIG
warnings.filterwarnings("ignore")
# Initialize FastAPI app
app = FastAPI(
title="Real-Time WebSocket Image Processing API",
description="API for object detection and depth estimation using WebSocket for real-time image processing.",
)
try:
# Load models and utilities
device = CONFIG['device']
print("Loading models...")
detr = DETR() # Object detection model (DETR)
print("DETR model loaded.")
glpn = GLPDepth() # Depth estimation model (GLPN)
print("GLPDepth model loaded.")
zlocE_LSTM = LSTM_Model() # LSTM model for prediction (e.g., localization)
print("LSTM model loaded.")
lstm_scaler = pickle.load(open(CONFIG['lstm_scaler_path'], 'rb')) # Load pre-trained scaler for LSTM
print("LSTM Scaler loaded.")
processing = PROCESSING() # Utility class for post-processing
print("Processing utilities loaded.")
except Exception as e:
print(f"An unexpected error occurred. Details: {e}")
# Serve HTML documentation for the API
@app.get("/", response_class=HTMLResponse)
async def get_docs():
"""
Serve HTML documentation for the WebSocket-based image processing API.
The HTML file must be available in the same directory.
Returns a 404 error if the documentation file is not found.
"""
html_path = os.path.join(os.path.dirname(__file__), "api_documentation.html")
if not os.path.exists(html_path):
return HTMLResponse(content="api_documentation.html file not found", status_code=404)
with open(html_path, "r") as f:
return HTMLResponse(f.read())
@app.get("/try_page", response_class=HTMLResponse)
async def get_docs():
"""
Serve HTML documentation for the WebSocket-based image processing API.
The HTML file must be available in the same directory.
Returns a 404 error if the documentation file is not found.
"""
html_path = os.path.join(os.path.dirname(__file__), "try_page.html")
if not os.path.exists(html_path):
return HTMLResponse(content="try_page.html file not found", status_code=404)
with open(html_path, "r") as f:
return HTMLResponse(f.read())
# Function to decode the image received via WebSocket
async def decode_image(image_bytes):
"""
Decodes image bytes into a PIL Image and returns the image along with its shape.
Args:
image_bytes (bytes): The image data received from the client.
Returns:
tuple: A tuple containing:
- pil_image (PIL.Image): The decoded image.
- img_shape (tuple): Shape of the image as (height, width).
- decode_time (float): Time taken to decode the image in seconds.
Raises:
ValueError: If image decoding fails.
"""
start = timeit.default_timer()
nparr = np.frombuffer(image_bytes, np.uint8)
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame is None:
raise ValueError("Failed to decode image")
color_converted = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(color_converted)
img_shape = color_converted.shape[0:2] # (height, width)
end = timeit.default_timer()
return pil_image, img_shape, end - start
# Function to run the DETR model for object detection
async def run_detr_model(pil_image):
"""
Runs the DETR (DEtection TRansformer) model to perform object detection on the input image.
Args:
pil_image (PIL.Image): The image to be processed by the DETR model.
Returns:
tuple: A tuple containing:
- detr_result (tuple): The DETR model output consisting of detections' scores and boxes.
- detr_time (float): Time taken to run the DETR model in seconds.
"""
start = timeit.default_timer()
detr_result = await asyncio.to_thread(detr.detect, pil_image)
end = timeit.default_timer()
return detr_result, end - start
# Function to run the GLPN model for depth estimation
async def run_glpn_model(pil_image, img_shape):
"""
Runs the GLPN (Global Local Prediction Network) model to estimate the depth of the objects in the image.
Args:
pil_image (PIL.Image): The image to be processed by the GLPN model.
img_shape (tuple): The shape of the image as (height, width).
Returns:
tuple: A tuple containing:
- depth_map (numpy.ndarray): The depth map for the input image.
- glpn_time (float): Time taken to run the GLPN model in seconds.
"""
start = timeit.default_timer()
depth_map = await asyncio.to_thread(glpn.predict, pil_image, img_shape)
end = timeit.default_timer()
return depth_map, end - start
# Function to process the detections with depth map
async def process_detections(scores, boxes, depth_map):
"""
Processes the DETR model detections and integrates depth information from the GLPN model.
Args:
scores (numpy.ndarray): The detection scores for the detected objects.
boxes (numpy.ndarray): The bounding boxes for the detected objects.
depth_map (numpy.ndarray): The depth map generated by the GLPN model.
Returns:
tuple: A tuple containing:
- pdata (dict): Processed detection data including depth and bounding box information.
- process_time (float): Time taken for processing detections in seconds.
"""
start = timeit.default_timer()
pdata = processing.process_detections(scores, boxes, depth_map, detr)
end = timeit.default_timer()
return pdata, end - start
# Function to generate JSON output for LSTM predictions
async def generate_json_output(data):
"""
Predict Z-location for each object in the data and prepare the JSON output.
Parameters:
- data: DataFrame with bounding box coordinates, depth information, and class type.
- ZlocE: Pre-loaded LSTM model for Z-location prediction.
- scaler: Scaler for normalizing input data.
Returns:
- JSON structure with object class, distance estimated, and relevant features.
"""
output_json = []
start = timeit.default_timer()
# Iterate over each row in the data
for i, row in data.iterrows():
# Predict distance for each object using the single-row prediction function
distance = predict_z_location_single_row_lstm(row, zlocE_LSTM, lstm_scaler)
# Create object info dictionary
object_info = {
"class": row["class"], # Object class (e.g., 'car', 'truck')
"distance_estimated": float(distance), # Convert distance to float (if necessary)
"features": {
"xmin": float(row["xmin"]), # Bounding box xmin
"ymin": float(row["ymin"]), # Bounding box ymin
"xmax": float(row["xmax"]), # Bounding box xmax
"ymax": float(row["ymax"]), # Bounding box ymax
"mean_depth": float(row["depth_mean"]), # Depth mean
"depth_mean_trim": float(row["depth_mean_trim"]), # Depth mean trim
"depth_median": float(row["depth_median"]), # Depth median
"width": float(row["width"]), # Object width
"height": float(row["height"]) # Object height
}
}
# Append each object info to the output JSON list
output_json.append(object_info)
end = timeit.default_timer()
# Return the final JSON output structure, and time
return {"objects": output_json}, end - start
# Function to process a single frame (image) in the WebSocket stream
async def process_frame(frame_id, image_bytes):
"""
Processes a single frame (image) from the WebSocket stream. The process includes:
- Decoding the image.
- Running the DETR and GLPN models concurrently.
- Processing detections and generating the final output JSON.
Args:
frame_id (int): The identifier for the frame being processed.
image_bytes (bytes): The image data received from the WebSocket.
Returns:
dict: A dictionary containing the output JSON and timing information for each processing step.
"""
timings = {}
try:
# Step 1: Decode the image
pil_image, img_shape, decode_time = await decode_image(image_bytes)
timings["decode_time"] = decode_time
# Step 2: Run DETR and GLPN models in parallel
(detr_result, detr_time), (depth_map, glpn_time) = await asyncio.gather(
run_detr_model(pil_image),
run_glpn_model(pil_image, img_shape)
)
models_time = max(detr_time, glpn_time) # Take the longest time of the two models
timings["models_time"] = models_time
# Step 3: Process detections with depth map
scores, boxes = detr_result
pdata, process_time = await process_detections(scores, boxes, depth_map)
timings["process_time"] = process_time
# Step 4: Generate output JSON
print("generate json")
output_json, json_time = await generate_json_output(pdata)
print(output_json)
timings["json_time"] = json_time
timings["total_time"] = decode_time + models_time + process_time + json_time
# Add frame_id and timings to the JSON output
output_json["frame_id"] = frame_id
output_json["timings"] = timings
return output_json
except Exception as e:
return {
"error": str(e),
"frame_id": frame_id,
"timings": timings
}
@app.post("/api/predict", summary="Process a single image for object detection and depth estimation")
async def process_image(file: UploadFile = File(...)):
"""
Process a single image for object detection and depth estimation.
The endpoint performs:
- Object detection using DETR model
- Depth estimation using GLPN model
- Z-location prediction using LSTM model
Parameters:
- file: Image file to process (JPEG, PNG)
Returns:
- JSON response with detected objects, estimated distances, and timing information
"""
try:
# Read image content
image_bytes = await file.read()
if not image_bytes:
raise HTTPException(status_code=400, detail="Empty file")
# Use the same processing pipeline as the WebSocket endpoint
result = await process_frame(0, image_bytes)
# Check if there's an error
if "error" in result:
raise HTTPException(status_code=500, detail=result["error"])
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Add custom OpenAPI documentation
@app.get("/api/docs", include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(
openapi_url="/api/openapi.json",
title="Real-Time Image Processing API Documentation",
swagger_js_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui-bundle.js",
swagger_css_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui.css",
)
@app.get("/api/openapi.json", include_in_schema=False)
async def get_open_api_endpoint():
return get_openapi(
title="Real-Time Image Processing API",
version="1.0.0",
description="API for object detection, depth estimation, and z-location prediction using computer vision models",
routes=app.routes,
)
@app.websocket("/ws/predict")
async def websocket_endpoint(websocket: WebSocket):
"""
WebSocket endpoint for real-time image processing. Clients can send image frames to be processed
and receive JSON output containing object detection, depth estimation, and predictions in real-time.
- Handles the reception of image data over the WebSocket.
- Calls the image processing pipeline and returns the result.
Args:
websocket (WebSocket): The WebSocket connection to the client.
"""
await websocket.accept()
frame_id = 0
try:
while True:
frame_id += 1
# Receive image bytes from the client
image_bytes = await websocket.receive_bytes()
# Process the frame asynchronously
processing_task = asyncio.create_task(process_frame(frame_id, image_bytes))
result = await processing_task
# Send the result back to the client
await websocket.send_text(json.dumps(result))
except WebSocketDisconnect:
print(f"Client disconnected after processing {frame_id} frames.")
except Exception as e:
print(f"Unexpected error: {e}")
finally:
await websocket.close()