import copy from collections import defaultdict from pathlib import Path from typing import Dict, List import matplotlib.pyplot as plt import numpy as np import torch from loguru import logger from torchvision.ops import box_iou class Validator: def __init__( self, gt: List[Dict[str, torch.Tensor]], preds: List[Dict[str, torch.Tensor]], conf_thresh=0.5, iou_thresh=0.5, ) -> None: """ Format example: gt = [{'labels': tensor([0]), 'boxes': tensor([[561.0, 297.0, 661.0, 359.0]])}, ...] len(gt) is the number of images bboxes are in format [x1, y1, x2, y2], absolute values """ self.gt = gt self.preds = preds self.conf_thresh = conf_thresh self.iou_thresh = iou_thresh self.thresholds = np.arange(0.2, 1.0, 0.05) self.conf_matrix = None def compute_metrics(self, extended=False) -> Dict[str, float]: filtered_preds = filter_preds(copy.deepcopy(self.preds), self.conf_thresh) metrics = self._compute_main_metrics(filtered_preds) if not extended: metrics.pop("extended_metrics", None) return metrics def _compute_main_metrics(self, preds): ( self.metrics_per_class, self.conf_matrix, self.class_to_idx, ) = self._compute_metrics_and_confusion_matrix(preds) tps, fps, fns = 0, 0, 0 ious = [] extended_metrics = {} for key, value in self.metrics_per_class.items(): tps += value["TPs"] fps += value["FPs"] fns += value["FNs"] ious.extend(value["IoUs"]) extended_metrics[f"precision_{key}"] = ( value["TPs"] / (value["TPs"] + value["FPs"]) if value["TPs"] + value["FPs"] > 0 else 0 ) extended_metrics[f"recall_{key}"] = ( value["TPs"] / (value["TPs"] + value["FNs"]) if value["TPs"] + value["FNs"] > 0 else 0 ) extended_metrics[f"iou_{key}"] = np.mean(value["IoUs"]) precision = tps / (tps + fps) if (tps + fps) > 0 else 0 recall = tps / (tps + fns) if (tps + fns) > 0 else 0 f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 iou = np.mean(ious).item() if ious else 0 return { "f1": f1, "precision": precision, "recall": recall, "iou": iou, "TPs": tps, "FPs": fps, "FNs": fns, "extended_metrics": extended_metrics, } def _compute_matrix_multi_class(self, preds): metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []}) for pred, gt in zip(preds, self.gt): pred_boxes = pred["boxes"] pred_labels = pred["labels"] gt_boxes = gt["boxes"] gt_labels = gt["labels"] # isolate each class labels = torch.unique(torch.cat([pred_labels, gt_labels])) for label in labels: pred_cl_boxes = pred_boxes[pred_labels == label] # filter by bool mask gt_cl_boxes = gt_boxes[gt_labels == label] n_preds = len(pred_cl_boxes) n_gts = len(gt_cl_boxes) if not (n_preds or n_gts): continue if not n_preds: metrics_per_class[label.item()]["FNs"] += n_gts metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts) continue if not n_gts: metrics_per_class[label.item()]["FPs"] += n_preds metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds) continue ious = box_iou(pred_cl_boxes, gt_cl_boxes) # matrix of all IoUs ious_mask = ious >= self.iou_thresh # indeces of boxes that have IoU >= threshold pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True) if not pred_indices.numel(): # no predicts matched gts metrics_per_class[label.item()]["FNs"] += n_gts metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts) metrics_per_class[label.item()]["FPs"] += n_preds metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds) continue iou_values = ious[pred_indices, gt_indices] # sorting by IoU to match hgihest scores first sorted_indices = torch.argsort(-iou_values) pred_indices = pred_indices[sorted_indices] gt_indices = gt_indices[sorted_indices] iou_values = iou_values[sorted_indices] matched_preds = set() matched_gts = set() for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values): if gt_idx.item() not in matched_gts and pred_idx.item() not in matched_preds: matched_preds.add(pred_idx.item()) matched_gts.add(gt_idx.item()) metrics_per_class[label.item()]["TPs"] += 1 metrics_per_class[label.item()]["IoUs"].append(iou.item()) unmatched_preds = set(range(n_preds)) - matched_preds unmatched_gts = set(range(n_gts)) - matched_gts metrics_per_class[label.item()]["FPs"] += len(unmatched_preds) metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_preds)) metrics_per_class[label.item()]["FNs"] += len(unmatched_gts) metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_gts)) return metrics_per_class def _compute_metrics_and_confusion_matrix(self, preds): # Initialize per-class metrics metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []}) # Collect all class IDs all_classes = set() for pred in preds: all_classes.update(pred["labels"].tolist()) for gt in self.gt: all_classes.update(gt["labels"].tolist()) all_classes = sorted(list(all_classes)) class_to_idx = {cls_id: idx for idx, cls_id in enumerate(all_classes)} n_classes = len(all_classes) conf_matrix = np.zeros((n_classes + 1, n_classes + 1), dtype=int) # +1 for background class for pred, gt in zip(preds, self.gt): pred_boxes = pred["boxes"] pred_labels = pred["labels"] gt_boxes = gt["boxes"] gt_labels = gt["labels"] n_preds = len(pred_boxes) n_gts = len(gt_boxes) if n_preds == 0 and n_gts == 0: continue ious = box_iou(pred_boxes, gt_boxes) if n_preds > 0 and n_gts > 0 else torch.tensor([]) # Assign matches between preds and gts matched_pred_indices = set() matched_gt_indices = set() if ious.numel() > 0: # For each pred box, find the gt box with highest IoU ious_mask = ious >= self.iou_thresh pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True) iou_values = ious[pred_indices, gt_indices] # Sorting by IoU to match highest scores first sorted_indices = torch.argsort(-iou_values) pred_indices = pred_indices[sorted_indices] gt_indices = gt_indices[sorted_indices] iou_values = iou_values[sorted_indices] for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values): if ( pred_idx.item() in matched_pred_indices or gt_idx.item() in matched_gt_indices ): continue matched_pred_indices.add(pred_idx.item()) matched_gt_indices.add(gt_idx.item()) pred_label = pred_labels[pred_idx].item() gt_label = gt_labels[gt_idx].item() pred_cls_idx = class_to_idx[pred_label] gt_cls_idx = class_to_idx[gt_label] # Update confusion matrix conf_matrix[gt_cls_idx, pred_cls_idx] += 1 # Update per-class metrics if pred_label == gt_label: metrics_per_class[gt_label]["TPs"] += 1 metrics_per_class[gt_label]["IoUs"].append(iou.item()) else: # Misclassification metrics_per_class[gt_label]["FNs"] += 1 metrics_per_class[pred_label]["FPs"] += 1 metrics_per_class[gt_label]["IoUs"].append(0) metrics_per_class[pred_label]["IoUs"].append(0) # Unmatched predictions (False Positives) unmatched_pred_indices = set(range(n_preds)) - matched_pred_indices for pred_idx in unmatched_pred_indices: pred_label = pred_labels[pred_idx].item() pred_cls_idx = class_to_idx[pred_label] # Update confusion matrix: background row conf_matrix[n_classes, pred_cls_idx] += 1 # Update per-class metrics metrics_per_class[pred_label]["FPs"] += 1 metrics_per_class[pred_label]["IoUs"].append(0) # Unmatched ground truths (False Negatives) unmatched_gt_indices = set(range(n_gts)) - matched_gt_indices for gt_idx in unmatched_gt_indices: gt_label = gt_labels[gt_idx].item() gt_cls_idx = class_to_idx[gt_label] # Update confusion matrix: background column conf_matrix[gt_cls_idx, n_classes] += 1 # Update per-class metrics metrics_per_class[gt_label]["FNs"] += 1 metrics_per_class[gt_label]["IoUs"].append(0) return metrics_per_class, conf_matrix, class_to_idx def save_plots(self, path_to_save) -> None: path_to_save = Path(path_to_save) path_to_save.mkdir(parents=True, exist_ok=True) if self.conf_matrix is not None: class_labels = [str(cls_id) for cls_id in self.class_to_idx.keys()] + ["background"] plt.figure(figsize=(10, 8)) plt.imshow(self.conf_matrix, interpolation="nearest", cmap=plt.cm.Blues) plt.title("Confusion Matrix") plt.colorbar() tick_marks = np.arange(len(class_labels)) plt.xticks(tick_marks, class_labels, rotation=45) plt.yticks(tick_marks, class_labels) # Add labels to each cell thresh = self.conf_matrix.max() / 2.0 for i in range(self.conf_matrix.shape[0]): for j in range(self.conf_matrix.shape[1]): plt.text( j, i, format(self.conf_matrix[i, j], "d"), horizontalalignment="center", color="white" if self.conf_matrix[i, j] > thresh else "black", ) plt.ylabel("True label") plt.xlabel("Predicted label") plt.tight_layout() plt.savefig(path_to_save / "confusion_matrix.png") plt.close() thresholds = self.thresholds precisions, recalls, f1_scores = [], [], [] # Store the original predictions to reset after each threshold original_preds = copy.deepcopy(self.preds) for threshold in thresholds: # Filter predictions based on the current threshold filtered_preds = filter_preds(copy.deepcopy(original_preds), threshold) # Compute metrics with the filtered predictions metrics = self._compute_main_metrics(filtered_preds) precisions.append(metrics["precision"]) recalls.append(metrics["recall"]) f1_scores.append(metrics["f1"]) # Plot Precision and Recall vs Threshold plt.figure() plt.plot(thresholds, precisions, label="Precision", marker="o") plt.plot(thresholds, recalls, label="Recall", marker="o") plt.xlabel("Threshold") plt.ylabel("Value") plt.title("Precision and Recall vs Threshold") plt.legend() plt.grid(True) plt.savefig(path_to_save / "precision_recall_vs_threshold.png") plt.close() # Plot F1 Score vs Threshold plt.figure() plt.plot(thresholds, f1_scores, label="F1 Score", marker="o") plt.xlabel("Threshold") plt.ylabel("F1 Score") plt.title("F1 Score vs Threshold") plt.grid(True) plt.savefig(path_to_save / "f1_score_vs_threshold.png") plt.close() # Find the best threshold based on F1 Score (last occurence) best_idx = len(f1_scores) - np.argmax(f1_scores[::-1]) - 1 best_threshold = thresholds[best_idx] best_f1 = f1_scores[best_idx] logger.info( f"Best Threshold: {round(best_threshold, 2)} with F1 Score: {round(best_f1, 3)}" ) def filter_preds(preds, conf_thresh): for pred in preds: keep_idxs = pred["scores"] >= conf_thresh pred["scores"] = pred["scores"][keep_idxs] pred["boxes"] = pred["boxes"][keep_idxs] pred["labels"] = pred["labels"][keep_idxs] return preds def scale_boxes(boxes, orig_shape, resized_shape): """ boxes in format: [x1, y1, x2, y2], absolute values orig_shape: [height, width] resized_shape: [height, width] """ scale_x = orig_shape[1] / resized_shape[1] scale_y = orig_shape[0] / resized_shape[0] boxes[:, 0] *= scale_x boxes[:, 2] *= scale_x boxes[:, 1] *= scale_y boxes[:, 3] *= scale_y return boxes