import torch import torch.nn as nn from functools import partial from pathlib import Path from typing import Any, Dict, Tuple # Imports depuis le package common (supposé être au même niveau que tvcalib) from common.infer.base import * # from common.registry import Registry # Toujours commenté car source inconnue # from common.utils import to_cuda # Toujours commenté car source inconnue # import project as p # Supprimé car probablement lié au projet complet import torchvision.transforms as T # Imports relatifs à l'intérieur de tvcalib (restent relatifs) from ..sn_segmentation.src.custom_extremities import ( generate_class_synthesis, get_line_extremities ) from ..models.segmentation import InferenceSegmentationModel from ..data.dataset import InferenceDatasetCalibration from ..data.utils import custom_list_collate from ..cam_modules import CameraParameterWLensDistDictZScore, SNProjectiveCamera from ..utils.linalg import distance_line_pointcloud_3d, distance_point_pointcloud from ..utils.objects_3d import SoccerPitchLineCircleSegments, SoccerPitchSNCircleCentralSplit from ..cam_distr.tv_main_center import get_cam_distr, get_dist_distr from ..utils.io import detach_dict, tensor2list # Import depuis le package common from common.data.utils import yards from kornia.geometry.conversions import convert_points_to_homogeneous from tqdm.auto import tqdm # Commenté car lié à la méthode 'robust' et peut introduire des dépendances # from methods.robust.loggers.preview import RobustPreviewLogger import numpy as np class TvCalibInferModule(InferModule): def __init__( self, segmentation_checkpoint: Path, image_shape=(720,1280), optim_steps=2000, lens_dist: bool=False, playfield_size=(105, 68), make_images: bool=False ): self.image_shape = image_shape self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.make_images = make_images # We use the logger to draw visualizations # Commenté car la classe RobustPreviewLogger est commentée # self.previewer = RobustPreviewLogger( # None, num_images=1 # ) self.fn_generate_class_synthesis = partial( generate_class_synthesis, radius=4 ) self.fn_get_line_extremities = partial( get_line_extremities, maxdist=30, width=455, height=256, num_points_lines=4, num_points_circles=8 ) # Segmentation model self.model_seg = InferenceSegmentationModel( segmentation_checkpoint, self.device ) self.object3d = SoccerPitchLineCircleSegments( device=self.device, base_field=SoccerPitchSNCircleCentralSplit() ) self.object3dcpu = SoccerPitchLineCircleSegments( device="cpu", base_field=SoccerPitchSNCircleCentralSplit() ) # Calibration module batch_size_calib = 1 self.model_calib = TVCalibModule( self.object3d, get_cam_distr(1.96, batch_size_calib, 1), get_dist_distr(batch_size_calib, 1) if lens_dist else None, (image_shape[0], image_shape[1]), optim_steps, self.device, log_per_step=False, tqdm_kwqargs=None, ) self.resize = T.Compose([ T.Resize(size=(256,455)) ]) self.offset = np.array([ [1, 0, playfield_size[0]/2.0 ], [0, 1, playfield_size[1]/2.0 ], [0, 0, 1] ]) def setup(self, datamodule: InferDataModule): pass def predict(self, x: Any) -> Dict: """ 1. Run segmentation & Pick keypoints 2. Calibrate based on selected points """ # Segment image = x["image"] keypoints = self._segment(x["image"]) # Calibrate homo = self._calibrate(keypoints) # Rescale to 720p image_720p = self.previewer.to_image(image.clone().detach().cpu()) # Draw predicted playing field if (homo is not None): # to yards # Commenté car previewer est commenté # to_yards = np.array([ # [ yards(1.0), 0, 0 ], # [ 0, yards(1.0), 0 ], # [ 0, 0, 1] # ]) #homo = to_yards @ homo # Commenté car previewer est commenté # try: # inv_homo = np.linalg.inv(homo) @ self.previewer.scale # image_720p = self.previewer.draw_playfield( # image_720p, # self.previewer.image_playfield, # inv_homo, # color=(255,0,0), alpha=1.0, # flip=False # ) # except: # # Homography might # pass pass # Placeholder si l'homographie existe mais previewer est commenté result = { "homography": homo } if (self.make_images): # result["image_720p"] = image_720p # Commenté car image_720p n'est pas modifié sans previewer pass # Placeholder si make_images est True return result def _segment(self, image): # Image -> <1;3;256;455> image = self.resize(image) with torch.no_grad(): sem_lines = self.model_seg.inference( image.unsqueeze(0).to(self.device) ) # sem_lines = sem_lines.detach().cpu().numpy().astype(np.uint8) # Point selection skeletons_batch = self.fn_generate_class_synthesis(sem_lines[0]) keypoints_raw_batch = self.fn_get_line_extremities(skeletons_batch) # Return the keypoints return keypoints_raw_batch def _calibrate(self, keypoints): # Just wrap around the keypoints ds = InferenceDatasetCalibration( [keypoints], self.image_shape[1], self.image_shape[0], self.object3d ) # Get the first item and optimize it _batch_size = 1 x_dict = custom_list_collate([ds[0]]) try: # La gestion de previous_params est faite dans self_optim_batch per_sample_loss, cam, _ = self.model_calib.self_optim_batch(x_dict) output_dict = tensor2list( detach_dict({**cam.get_parameters(_batch_size), **per_sample_loss}) ) homo = output_dict["homography"][0] if (len(homo) > 0): homo = np.array(homo[0]) to_yards = np.array([ [ yards(1), 0, 0 ], [ 0, yards(1), 0 ], [ 0, 0, 1] ]) # Shift the homography by half the playing field homo = to_yards @ self.offset @ homo else: homo = None except Exception as e: print(f"Erreur lors de la calibration: {str(e)}") homo = None return homo class TVCalibModule(torch.nn.Module): def __init__( self, model3d, cam_distr, dist_distr, image_dim: Tuple[int, int], optim_steps: int, device="cpu", tqdm_kwqargs=None, log_per_step=False, *args, **kwargs, ) -> None: super().__init__(*args, **kwargs) self.image_height, self.image_width = image_dim self.principal_point = (self.image_width / 2, self.image_height / 2) self.model3d = model3d self.cam_param_dict = CameraParameterWLensDistDictZScore( cam_distr, dist_distr, device=device ) self.lens_distortion_active = False if dist_distr is None else True self.optim_steps = optim_steps self._device = device # Ajouter l'attribut pour stocker les paramètres précédents self.previous_params = None self.optim = torch.optim.AdamW( self.cam_param_dict.param_dict.parameters(), lr=0.1, weight_decay=0.01 ) self.Scheduler = partial( torch.optim.lr_scheduler.OneCycleLR, max_lr=0.05, total_steps=self.optim_steps, pct_start=0.5, ) if self.lens_distortion_active: self.optim_lens_distortion = torch.optim.AdamW( self.cam_param_dict.param_dict_dist.parameters(), lr=1e-3, weight_decay=0.01 ) self.Scheduler_lens_distortion = partial( torch.optim.lr_scheduler.OneCycleLR, max_lr=1e-3, total_steps=self.optim_steps, pct_start=0.33, optimizer=self.optim_lens_distortion, ) self.tqdm_kwqargs = tqdm_kwqargs if tqdm_kwqargs is None: self.tqdm_kwqargs = {} self.hparams = {"optim": str(self.optim), "scheduler": str(self.Scheduler)} self.log_per_step = log_per_step def forward(self, x): # individual camera parameters & distortion parameters phi_hat, psi_hat = self.cam_param_dict() cam = SNProjectiveCamera( phi_hat, psi_hat, self.principal_point, self.image_width, self.image_height, device=self._device, nan_check=False, ) # (batch_size, num_views_per_cam, 3, num_segments, num_points) points_px_lines_true = x["lines__ndc_projected_selection_shuffled"].to(self._device) batch_size, T_l, _, S_l, N_l = points_px_lines_true.shape # project circle points points_px_circles_true = x["circles__ndc_projected_selection_shuffled"].to(self._device) _, T_c, _, S_c, N_c = points_px_circles_true.shape assert T_c == T_l #################### line-to-point distance at pixel space #################### # start and end point (in world coordinates) for each line segment points3d_lines_keypoints = self.model3d.line_segments # (3, S_l, 2) to (S_l * 2, 3) points3d_lines_keypoints = points3d_lines_keypoints.reshape(3, S_l * 2).transpose(0, 1) points_px_lines_keypoints = convert_points_to_homogeneous( cam.project_point2ndc(points3d_lines_keypoints, lens_distortion=False) ) # (batch_size, t_l, S_l*2, 3) if batch_size < cam.batch_dim: # actual batch_size smaller than expected, i.e. last batch points_px_lines_keypoints = points_px_lines_keypoints[:batch_size] points_px_lines_keypoints = points_px_lines_keypoints.view(batch_size, T_l, S_l, 2, 3) lp1 = points_px_lines_keypoints[..., 0, :].unsqueeze(-2) # -> (batch_size, T_l, 1, S_l, 3) lp2 = points_px_lines_keypoints[..., 1, :].unsqueeze(-2) # -> (batch_size, T_l, 1, S_l, 3) # (batch_size, T, 3, S, N) -> (batch_size, T, 3, S*N) -> (batch_size, T, S*N, 3) -> (batch_size, T, S, N, 3) pc = ( points_px_lines_true.view(batch_size, T_l, 3, S_l * N_l) .transpose(2, 3) .view(batch_size, T_l, S_l, N_l, 3) ) if self.lens_distortion_active: # undistort given points pc = pc.view(batch_size, T_l, S_l * N_l, 3) pc = pc.detach().clone() pc[..., :2] = cam.undistort_points( pc[..., :2], cam.intrinsics_ndc, num_iters=1 ) # num_iters=1 might be enough for a good approximation pc = pc.view(batch_size, T_l, S_l, N_l, 3) distances_px_lines_raw = distance_line_pointcloud_3d( e1=lp2 - lp1, r1=lp1, pc=pc, reduce=None ) # (batch_size, T_l, S_l, N_l) distances_px_lines_raw = distances_px_lines_raw.unsqueeze(-3) # (..., 1, S_l, N_l,), i.e. (batch_size, T, 1, S_l, N_l) #################### circle-to-point distance at pixel space #################### # circle segments are approximated as point clouds of size N_c_star points3d_circles_pc = self.model3d.circle_segments _, S_c, N_c_star = points3d_circles_pc.shape points3d_circles_pc = points3d_circles_pc.reshape(3, S_c * N_c_star).transpose(0, 1) points_px_circles_pc = cam.project_point2ndc(points3d_circles_pc, lens_distortion=False) if batch_size < cam.batch_dim: # actual batch_size smaller than expected, i.e. last batch points_px_circles_pc = points_px_circles_pc[:batch_size] if self.lens_distortion_active: # (batch_size, T_c, _, S_c, N_c) points_px_circles_true = points_px_circles_true.view( batch_size, T_c, 3, S_c * N_c ).transpose(2, 3) points_px_circles_true = points_px_circles_true.detach().clone() points_px_circles_true[..., :2] = cam.undistort_points( points_px_circles_true[..., :2], cam.intrinsics_ndc, num_iters=1 ) points_px_circles_true = points_px_circles_true.transpose(2, 3).view( batch_size, T_c, 3, S_c, N_c ) distances_px_circles_raw = distance_point_pointcloud( points_px_circles_true, points_px_circles_pc.view(batch_size, T_c, S_c, N_c_star, 2) ) distances_dict = { "loss_ndc_lines": distances_px_lines_raw, # (batch_size, T_l, 1, S_l, N_l) "loss_ndc_circles": distances_px_circles_raw, # (batch_size, T_c, 1, S_c, N_c) } return distances_dict, cam def self_optim_batch(self, x, *args, **kwargs): scheduler = self.Scheduler(self.optim) # re-initialize lr scheduler for every batch if self.lens_distortion_active: scheduler_lens_distortion = self.Scheduler_lens_distortion() # Initialiser avec les paramètres précédents si disponibles if self.previous_params is not None: print("Utilisation des paramètres précédents pour l'initialisation") update_dict = {} for k, v in self.previous_params.items(): update_dict[k] = v.detach().clone() self.cam_param_dict.initialize(update_dict) else: print("Première frame : initialisation à zéro") self.cam_param_dict.initialize(None) self.optim.zero_grad() if self.lens_distortion_active: self.optim_lens_distortion.zero_grad() keypoint_masks = { "loss_ndc_lines": x["lines__is_keypoint_mask"].to(self._device), "loss_ndc_circles": x["circles__is_keypoint_mask"].to(self._device), } num_actual_points = { "loss_ndc_circles": keypoint_masks["loss_ndc_circles"].sum(dim=(-1, -2)), "loss_ndc_lines": keypoint_masks["loss_ndc_lines"].sum(dim=(-1, -2)), } per_sample_loss = {} per_sample_loss["mask_lines"] = keypoint_masks["loss_ndc_lines"] per_sample_loss["mask_circles"] = keypoint_masks["loss_ndc_circles"] per_step_info = {"loss": [], "lr": []} # Paramètres pour les critères d'arrêt loss_target = 0.001 # Réduit pour une meilleure précision potentielle loss_patience = 10 # Nombre d'itérations pour vérifier la stagnation loss_tolerance = 1e-4 # Tolérance pour la variation relative de loss loss_history = [] # Historique des valeurs de loss best_loss = float('inf') # Meilleure loss obtenue steps_without_improvement = 0 # Compteur d'itérations sans amélioration # with torch.autograd.detect_anomaly(): with tqdm(range(self.optim_steps), **self.tqdm_kwqargs) as pbar: for step in pbar: self.optim.zero_grad() if self.lens_distortion_active: self.optim_lens_distortion.zero_grad() # forward pass distances_dict, cam = self(x) # distance calculate with masked input and output losses = {} for key_dist, distances in distances_dict.items(): distances[~keypoint_masks[key_dist]] = 0.0 per_sample_loss[f"{key_dist}_distances_raw"] = distances distances_reduced = distances.sum(dim=(-1, -2)) distances_reduced = distances_reduced / num_actual_points[key_dist] distances_reduced[num_actual_points[key_dist] == 0] = 0.0 distances_reduced = distances_reduced.squeeze(-1) per_sample_loss[key_dist] = distances_reduced loss = distances_reduced.mean(dim=-1) loss = loss.sum() losses[key_dist] = loss loss_total_dist = losses["loss_ndc_lines"] + losses["loss_ndc_circles"] loss_total = loss_total_dist current_loss = loss_total.item() # Mettre à jour l'historique des loss loss_history.append(current_loss) # Vérifier si on a une meilleure loss if current_loss < best_loss: best_loss = current_loss steps_without_improvement = 0 else: steps_without_improvement += 1 # Critères d'arrêt (commentés pour forcer le nombre total d'étapes) # if len(loss_history) >= loss_patience: # # Calculer la variation relative moyenne sur les dernières itérations # recent_losses = loss_history[-loss_patience:] # # Gérer le cas où toutes les pertes récentes sont nulles ou proches de zéro # max_recent_loss = max(max(recent_losses), 1e-9) # Evite division par zéro # loss_variation = abs(max(recent_losses) - min(recent_losses)) / max_recent_loss # # # Conditions d'arrêt # if (current_loss <= loss_target or # On a atteint la valeur cible # loss_variation < loss_tolerance or # La loss ne varie plus significativement # steps_without_improvement >= loss_patience): # Pas d'amélioration depuis un moment # print(f"\nArrêt anticipé à l'itération {step+1}:") # print(f"Loss finale: {current_loss:.5f}") # print(f"Meilleure loss: {best_loss:.5f}") # print(f"Variation relative: {loss_variation:.6f}") # break if self.log_per_step: per_step_info["lr"].append(scheduler.get_last_lr()) per_step_info["loss"].append(distances_reduced) if step % 50 == 0: pbar.set_postfix( loss=f"{loss_total_dist.detach().cpu().tolist():.5f}", loss_lines=f'{losses["loss_ndc_lines"].detach().cpu().tolist():.3f}', loss_circles=f'{losses["loss_ndc_circles"].detach().cpu().tolist():.3f}', ) loss_total.backward() self.optim.step() scheduler.step() if self.lens_distortion_active: self.optim_lens_distortion.step() scheduler_lens_distortion.step() # Sauvegarder les paramètres optimisés pour la prochaine frame self.previous_params = {} for k, v in self.cam_param_dict.param_dict.items(): self.previous_params[k] = v.detach().clone() per_sample_loss["loss_ndc_total"] = torch.sum( torch.stack([per_sample_loss[key_dist] for key_dist in distances_dict.keys()], dim=0), dim=0, ) if self.log_per_step: per_step_info["loss"] = torch.stack( per_step_info["loss"], dim=-1 ) per_step_info["lr"] = torch.tensor(per_step_info["lr"]) return per_sample_loss, cam, per_step_info