Spaces:
Sleeping
Sleeping
import joblib | |
import uvicorn | |
import xgboost as xgb | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, UploadFile, HTTPException | |
from fastapi.responses import HTMLResponse | |
from fastapi.responses import JSONResponse | |
import asyncio | |
import json | |
import pickle | |
import warnings | |
import os | |
import io | |
import timeit | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
# Add this to your existing imports if not already present | |
from fastapi.openapi.docs import get_swagger_ui_html | |
from fastapi.openapi.utils import get_openapi | |
from models.detr_model import DETR | |
from models.glpn_model import GLPDepth | |
from models.lstm_model import LSTM_Model | |
from models.predict_z_location_single_row_lstm import predict_z_location_single_row_lstm | |
from utils.processing import PROCESSING | |
from config import CONFIG | |
warnings.filterwarnings("ignore") | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Real-Time WebSocket Image Processing API", | |
description="API for object detection and depth estimation using WebSocket for real-time image processing.", | |
) | |
try: | |
# Load models and utilities | |
device = CONFIG['device'] | |
print("Loading models...") | |
detr = DETR() # Object detection model (DETR) | |
print("DETR model loaded.") | |
glpn = GLPDepth() # Depth estimation model (GLPN) | |
print("GLPDepth model loaded.") | |
zlocE_LSTM = LSTM_Model() # LSTM model for prediction (e.g., localization) | |
print("LSTM model loaded.") | |
lstm_scaler = pickle.load(open(CONFIG['lstm_scaler_path'], 'rb')) # Load pre-trained scaler for LSTM | |
print("LSTM Scaler loaded.") | |
processing = PROCESSING() # Utility class for post-processing | |
print("Processing utilities loaded.") | |
except Exception as e: | |
print(f"An unexpected error occurred. Details: {e}") | |
# Serve HTML documentation for the API | |
async def get_docs(): | |
""" | |
Serve HTML documentation for the WebSocket-based image processing API. | |
The HTML file must be available in the same directory. | |
Returns a 404 error if the documentation file is not found. | |
""" | |
html_path = os.path.join(os.path.dirname(__file__), "api_documentation.html") | |
if not os.path.exists(html_path): | |
return HTMLResponse(content="api_documentation.html file not found", status_code=404) | |
with open(html_path, "r") as f: | |
return HTMLResponse(f.read()) | |
async def get_docs(): | |
""" | |
Serve HTML documentation for the WebSocket-based image processing API. | |
The HTML file must be available in the same directory. | |
Returns a 404 error if the documentation file is not found. | |
""" | |
html_path = os.path.join(os.path.dirname(__file__), "try_page.html") | |
if not os.path.exists(html_path): | |
return HTMLResponse(content="try_page.html file not found", status_code=404) | |
with open(html_path, "r") as f: | |
return HTMLResponse(f.read()) | |
# Function to decode the image received via WebSocket | |
async def decode_image(image_bytes): | |
""" | |
Decodes image bytes into a PIL Image and returns the image along with its shape. | |
Args: | |
image_bytes (bytes): The image data received from the client. | |
Returns: | |
tuple: A tuple containing: | |
- pil_image (PIL.Image): The decoded image. | |
- img_shape (tuple): Shape of the image as (height, width). | |
- decode_time (float): Time taken to decode the image in seconds. | |
Raises: | |
ValueError: If image decoding fails. | |
""" | |
start = timeit.default_timer() | |
nparr = np.frombuffer(image_bytes, np.uint8) | |
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
if frame is None: | |
raise ValueError("Failed to decode image") | |
color_converted = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(color_converted) | |
img_shape = color_converted.shape[0:2] # (height, width) | |
end = timeit.default_timer() | |
return pil_image, img_shape, end - start | |
# Function to run the DETR model for object detection | |
async def run_detr_model(pil_image): | |
""" | |
Runs the DETR (DEtection TRansformer) model to perform object detection on the input image. | |
Args: | |
pil_image (PIL.Image): The image to be processed by the DETR model. | |
Returns: | |
tuple: A tuple containing: | |
- detr_result (tuple): The DETR model output consisting of detections' scores and boxes. | |
- detr_time (float): Time taken to run the DETR model in seconds. | |
""" | |
start = timeit.default_timer() | |
detr_result = await asyncio.to_thread(detr.detect, pil_image) | |
end = timeit.default_timer() | |
return detr_result, end - start | |
# Function to run the GLPN model for depth estimation | |
async def run_glpn_model(pil_image, img_shape): | |
""" | |
Runs the GLPN (Global Local Prediction Network) model to estimate the depth of the objects in the image. | |
Args: | |
pil_image (PIL.Image): The image to be processed by the GLPN model. | |
img_shape (tuple): The shape of the image as (height, width). | |
Returns: | |
tuple: A tuple containing: | |
- depth_map (numpy.ndarray): The depth map for the input image. | |
- glpn_time (float): Time taken to run the GLPN model in seconds. | |
""" | |
start = timeit.default_timer() | |
depth_map = await asyncio.to_thread(glpn.predict, pil_image, img_shape) | |
end = timeit.default_timer() | |
return depth_map, end - start | |
# Function to process the detections with depth map | |
async def process_detections(scores, boxes, depth_map): | |
""" | |
Processes the DETR model detections and integrates depth information from the GLPN model. | |
Args: | |
scores (numpy.ndarray): The detection scores for the detected objects. | |
boxes (numpy.ndarray): The bounding boxes for the detected objects. | |
depth_map (numpy.ndarray): The depth map generated by the GLPN model. | |
Returns: | |
tuple: A tuple containing: | |
- pdata (dict): Processed detection data including depth and bounding box information. | |
- process_time (float): Time taken for processing detections in seconds. | |
""" | |
start = timeit.default_timer() | |
pdata = processing.process_detections(scores, boxes, depth_map, detr) | |
end = timeit.default_timer() | |
return pdata, end - start | |
# Function to generate JSON output for LSTM predictions | |
async def generate_json_output(data): | |
""" | |
Predict Z-location for each object in the data and prepare the JSON output. | |
Parameters: | |
- data: DataFrame with bounding box coordinates, depth information, and class type. | |
- ZlocE: Pre-loaded LSTM model for Z-location prediction. | |
- scaler: Scaler for normalizing input data. | |
Returns: | |
- JSON structure with object class, distance estimated, and relevant features. | |
""" | |
output_json = [] | |
start = timeit.default_timer() | |
# Iterate over each row in the data | |
for i, row in data.iterrows(): | |
# Predict distance for each object using the single-row prediction function | |
distance = predict_z_location_single_row_lstm(row, zlocE_LSTM, lstm_scaler) | |
# Create object info dictionary | |
object_info = { | |
"class": row["class"], # Object class (e.g., 'car', 'truck') | |
"distance_estimated": float(distance), # Convert distance to float (if necessary) | |
"features": { | |
"xmin": float(row["xmin"]), # Bounding box xmin | |
"ymin": float(row["ymin"]), # Bounding box ymin | |
"xmax": float(row["xmax"]), # Bounding box xmax | |
"ymax": float(row["ymax"]), # Bounding box ymax | |
"mean_depth": float(row["depth_mean"]), # Depth mean | |
"depth_mean_trim": float(row["depth_mean_trim"]), # Depth mean trim | |
"depth_median": float(row["depth_median"]), # Depth median | |
"width": float(row["width"]), # Object width | |
"height": float(row["height"]) # Object height | |
} | |
} | |
# Append each object info to the output JSON list | |
output_json.append(object_info) | |
end = timeit.default_timer() | |
# Return the final JSON output structure, and time | |
return {"objects": output_json}, end - start | |
# Function to process a single frame (image) in the WebSocket stream | |
async def process_frame(frame_id, image_bytes): | |
""" | |
Processes a single frame (image) from the WebSocket stream. The process includes: | |
- Decoding the image. | |
- Running the DETR and GLPN models concurrently. | |
- Processing detections and generating the final output JSON. | |
Args: | |
frame_id (int): The identifier for the frame being processed. | |
image_bytes (bytes): The image data received from the WebSocket. | |
Returns: | |
dict: A dictionary containing the output JSON and timing information for each processing step. | |
""" | |
timings = {} | |
try: | |
# Step 1: Decode the image | |
pil_image, img_shape, decode_time = await decode_image(image_bytes) | |
timings["decode_time"] = decode_time | |
# Step 2: Run DETR and GLPN models in parallel | |
(detr_result, detr_time), (depth_map, glpn_time) = await asyncio.gather( | |
run_detr_model(pil_image), | |
run_glpn_model(pil_image, img_shape) | |
) | |
models_time = max(detr_time, glpn_time) # Take the longest time of the two models | |
timings["models_time"] = models_time | |
# Step 3: Process detections with depth map | |
scores, boxes = detr_result | |
pdata, process_time = await process_detections(scores, boxes, depth_map) | |
timings["process_time"] = process_time | |
# Step 4: Generate output JSON | |
print("generate json") | |
output_json, json_time = await generate_json_output(pdata) | |
print(output_json) | |
timings["json_time"] = json_time | |
timings["total_time"] = decode_time + models_time + process_time + json_time | |
# Add frame_id and timings to the JSON output | |
output_json["frame_id"] = frame_id | |
output_json["timings"] = timings | |
return output_json | |
except Exception as e: | |
return { | |
"error": str(e), | |
"frame_id": frame_id, | |
"timings": timings | |
} | |
async def process_image(file: UploadFile = File(...)): | |
""" | |
Process a single image for object detection and depth estimation. | |
The endpoint performs: | |
- Object detection using DETR model | |
- Depth estimation using GLPN model | |
- Z-location prediction using LSTM model | |
Parameters: | |
- file: Image file to process (JPEG, PNG) | |
Returns: | |
- JSON response with detected objects, estimated distances, and timing information | |
""" | |
try: | |
# Read image content | |
image_bytes = await file.read() | |
if not image_bytes: | |
raise HTTPException(status_code=400, detail="Empty file") | |
# Use the same processing pipeline as the WebSocket endpoint | |
result = await process_frame(0, image_bytes) | |
# Check if there's an error | |
if "error" in result: | |
raise HTTPException(status_code=500, detail=result["error"]) | |
return JSONResponse(content=result) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Add custom OpenAPI documentation | |
async def custom_swagger_ui_html(): | |
return get_swagger_ui_html( | |
openapi_url="/api/openapi.json", | |
title="Real-Time Image Processing API Documentation", | |
swagger_js_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui-bundle.js", | |
swagger_css_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui.css", | |
) | |
async def get_open_api_endpoint(): | |
return get_openapi( | |
title="Real-Time Image Processing API", | |
version="1.0.0", | |
description="API for object detection, depth estimation, and z-location prediction using computer vision models", | |
routes=app.routes, | |
) | |
async def websocket_endpoint(websocket: WebSocket): | |
""" | |
WebSocket endpoint for real-time image processing. Clients can send image frames to be processed | |
and receive JSON output containing object detection, depth estimation, and predictions in real-time. | |
- Handles the reception of image data over the WebSocket. | |
- Calls the image processing pipeline and returns the result. | |
Args: | |
websocket (WebSocket): The WebSocket connection to the client. | |
""" | |
await websocket.accept() | |
frame_id = 0 | |
try: | |
while True: | |
frame_id += 1 | |
# Receive image bytes from the client | |
image_bytes = await websocket.receive_bytes() | |
# Process the frame asynchronously | |
processing_task = asyncio.create_task(process_frame(frame_id, image_bytes)) | |
result = await processing_task | |
# Send the result back to the client | |
await websocket.send_text(json.dumps(result)) | |
except WebSocketDisconnect: | |
print(f"Client disconnected after processing {frame_id} frames.") | |
except Exception as e: | |
print(f"Unexpected error: {e}") | |
finally: | |
await websocket.close() | |