File size: 1,953 Bytes
5d92054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import numpy as np


def predict_z_location_single_row_lstm(row, ZlocE, scaler):
    """
    Preprocess bounding box coordinates, depth information, and class type 
    to predict Z-location using the LSTM model for a single row.
    
    Parameters:
    - row: A single row of DataFrame with bounding box coordinates, depth info, and class type.
    - ZlocE: Pre-loaded LSTM model for Z-location prediction.
    - scaler: Scaler for normalizing input data.
    
    Returns:
    - z_loc_prediction: Predicted Z-location for the given row.
    """
    # One-hot encoding of class type
    class_type = row['class']

    if class_type == 'bicycle':
        class_tensor = torch.tensor([[0, 1, 0, 0, 0, 0]], dtype=torch.float32)
    elif class_type == 'car':
        class_tensor = torch.tensor([[0, 0, 1, 0, 0, 0]], dtype=torch.float32)
    elif class_type == 'person':
        class_tensor = torch.tensor([[0, 0, 0, 1, 0, 0]], dtype=torch.float32)
    elif class_type == 'train':
        class_tensor = torch.tensor([[0, 0, 0, 0, 1, 0]], dtype=torch.float32)
    elif class_type == 'truck':
        class_tensor = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.float32)
    else:
        class_tensor = torch.tensor([[1, 0, 0, 0, 0, 0]], dtype=torch.float32)

    # Prepare input data (bounding box + depth info)
    input_data = np.array([row[['xmin', 'ymin', 'xmax', 'ymax', 'width', 'height', 'depth_mean', 'depth_median',
                                'depth_mean_trim']].values], dtype=np.float32)
    input_data = torch.from_numpy(input_data)

    # Concatenate class information
    input_data = torch.cat([input_data, class_tensor], dim=-1)

    # Scale the input data
    scaled_input = torch.tensor(scaler.transform(input_data), dtype=torch.float32).unsqueeze(0).unsqueeze(0)

    # Use the LSTM model to predict the Z-location
    z_loc_prediction = ZlocE.predict(scaled_input).detach().numpy()[0]

    return z_loc_prediction