from typing import Tuple, Dict, Union from pytorch_lightning import LightningModule import torch import kornia import torch.nn as nn from .utils.data_distr import FeatureScalerZScore class CameraParameterWLensDistDictZScore(LightningModule): """Holds individual camera parameters including lens distortion parameters as nn.Modul""" def __init__(self, cam_distr, dist_distr, device="cpu"): super(CameraParameterWLensDistDictZScore, self).__init__() self.cam_distr = cam_distr self._device = device # phi raw self.param_dict = torch.nn.ParameterDict( { k: torch.nn.parameter.Parameter( torch.zeros( *cam_distr[k]["dimension"], device=device, ), requires_grad=False if ("no_grad" in cam_distr[k]) and (cam_distr[k]["no_grad"] == True) else True, ) for k in cam_distr.keys() } ) # denormalization module to get phi_target self.feature_scaler = torch.nn.ModuleDict( {k: FeatureScalerZScore(*cam_distr[k]["mean_std"]) for k in cam_distr.keys()} ) self.dist_distr = dist_distr if self.dist_distr is not None: self.param_dict_dist = torch.nn.ParameterDict( { k: torch.nn.Parameter(torch.zeros(*dist_distr[k]["dimension"], device=device)) for k in dist_distr.keys() } ) # TODO: modify later to dynamically cunstruct a tensor of shape (k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]]) # self.feature_scaler_dist_coeff = torch.nn.ModuleDict( {k: FeatureScalerZScore(*dist_distr[k]["mean_std"]) for k in dist_distr.keys()} ) def initialize( self, update_dict_cam: Union[Dict[str, Union[float, torch.tensor]], None], update_dict_dist=None, ): """Initializes all camera parameters with zeros and replace specific values with provided values Args: update_dict_cam (Dict[str, Union[float, torch.tensor]]): Parameters to be updated """ for k in self.param_dict.keys(): self.param_dict[k].data = torch.zeros( *self.cam_distr[k]["dimension"], device=self._device ) if self.dist_distr is not None: for k in self.dist_distr.keys(): self.param_dict_dist[k].data = torch.zeros( *self.dist_distr[k]["dimension"], device=self._device ) if update_dict_cam is not None and len(update_dict_cam) > 0: for k, v in update_dict_cam.items(): self.param_dict[k].data = ( torch.zeros(*self.cam_distr[k]["dimension"], device=self._device) + v ) if update_dict_dist is not None: raise NotImplementedError def forward(self): phi_dict = {} for k, param in self.param_dict.items(): phi_dict[k] = self.feature_scaler[k](param) if self.dist_distr is None: return phi_dict, None # This is a vector with 4, 5, 8, 12 or 14 elements with shape :math:`(*, n)` depending on the provided dict of coefficients # assumes dict is ordered according (k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]]) psi = torch.stack( [ torch.clamp( self.feature_scaler_dist_coeff[k](param), min=self.dist_distr[k]["minmax"][0], max=self.dist_distr[k]["minmax"][1], ) for k, param in self.param_dict_dist.items() ], dim=-1, # stack individual features and not arbirary leading dimensions ) return phi_dict, psi class SNProjectiveCamera: def __init__( self, phi_dict: Dict[str, torch.tensor], psi: torch.tensor, principal_point: Tuple[float, float], image_width: int, image_height: int, device: str = "cpu", nan_check=True, ) -> None: """Projective camera defined as K @ R [I|-t] with lens distortion module and batch dimensions B,T. Following Euler angles convention, we use a ZXZ succession of intrinsic rotations in order to describe the orientation of the camera. Starting from the world reference axis system, we first apply a rotation around the Z axis to pan the camera. Then the obtained axis system is rotated around its x axis in order to tilt the camera. Then the last rotation around the z axis of the new axis system alows to roll the camera. Note that this z axis is the principal axis of the camera. As T is not provided for camra location and lens distortion, these parameters are assumed to be fixed accross T. phi_dict is a dict of parameters containing: { 'aov_x, torch.Size([B, T])', 'pan, torch.Size([B, T])', 'tilt, torch.Size([B, T])', 'roll, torch.Size([B, T])', 'c_x, torch.Size([B, 1])', 'c_y, torch.Size([B, 1])', 'c_z, torch.Size([B, 1])', } Internally fuses B and T dimension to pseudo batch dimension. { 'aov_x, torch.Size([B*T])', 'pan, torch.Size([B*T])', 'tilt, torch.Size([B*T])' 'roll, torch.Size([B*T])', 'c_x, torch.Size([B])', 'c_y, torch.Size([B])', 'c_z, torch.Size([B])', } aov_x, pan, tilt, roll are assumed in radian. Note on lens distortion: Lens distortion coefficients are independent from image resolution! We I(dist_points(K_ndc, dist_coeff, points2d_ndc)) == I(dist_points(K_raster, dist_coeff, points2d_raster)) Args: phi_dict (Dict[str, torch.tensor]): See example above psi (Union[None, torch.Tensor]): distortion coefficients as concatinated vector according to https://kornia.readthedocs.io/en/latest/geometry.calibration.html of shape (B, T, {2, 4, 5,8,12, 14}) principal_point (Tuple[float, float]): Principal point assumed to be fixed across all samples (B,T,) image_width (int): assumed to be fixed across all samples (B,T,) image_height (int): assumed to be fixed across all samples (B,T,) """ # fuse B and T dimension phi_dict_flat = {} for k, v in phi_dict.items(): if len(v.shape) == 2: phi_dict_flat[k] = v.view(v.shape[0] * v.shape[1]) elif len(v.shape) == 3: phi_dict_flat[k] = v.view(v.shape[0] * v.shape[1], v.shape[-1]) self.batch_dim, self.temporal_dim = phi_dict["pan"].shape self.pseudo_batch_size = phi_dict_flat["pan"].shape[0] self.phi_dict_flat = phi_dict_flat self.principal_point = principal_point self.image_width = image_width self.image_height = image_height self.device = device self.psi = psi if self.psi is not None: if self.psi.shape[-1] != 2: raise NotImplementedError # :math:`(k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])`. # psi is a vector with 2, 4, 5, 8, 12 or 14 elements with shape :math:`(*, n)`. if self.psi.shape[-1] == 2: # assume zero tangential coefficients psi_ext = torch.zeros(*list(self.psi.shape[:-1]), 4) psi_ext[..., :2] = self.psi self.psi = psi_ext self.lens_dist_coeff = self.psi.view(self.pseudo_batch_size, self.psi.shape[-1]).to( self.device ) self.intrinsics_ndc = self.construct_intrinsics_ndc() self.intrinsics_raster = self.construct_intrinsics_raster() self.rotation = self.rotation_from_euler_angles( *[phi_dict_flat[k] for k in ["pan", "tilt", "roll"]] ) self.position = torch.stack([phi_dict_flat[k] for k in ["c_x", "c_y", "c_z"]], dim=-1) self.position = self.position.repeat_interleave( int(self.pseudo_batch_size / self.batch_dim), dim=0 ) # (B, 3) # TODO: probably needs modification if B > 0? self.P_ndc = self.construct_projection_matrix(self.intrinsics_ndc) self.P_raster = self.construct_projection_matrix(self.intrinsics_raster) self.phi_dict = phi_dict self.nan_check = nan_check super().__init__() def construct_projection_matrix(self, intrinsics): It = torch.eye(4, device=self.device)[:-1].repeat(self.pseudo_batch_size, 1, 1) It[:, :, -1] = -self.position # (B, 3, 4) self.It = It return intrinsics @ self.rotation @ It # # (B, 3, 4) def construct_intrinsics_ndc(self): # assume that the principal point is (0,0) K = torch.eye(3, requires_grad=False, device=self.device) K = K.reshape((1, 3, 3)).repeat(self.pseudo_batch_size, 1, 1) K[:, 0, 0] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=2) K[:, 1, 1] = self.get_fl_from_aov_rad( self.phi_dict_flat["aov"], d=2 * self.image_width / self.image_height ) return K def construct_intrinsics_raster(self): # assume that the principal point is (W/2,H/2) K = torch.eye(3, requires_grad=False, device=self.device) K = K.reshape((1, 3, 3)).repeat(self.pseudo_batch_size, 1, 1) K[:, 0, 0] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=self.image_width) K[:, 1, 1] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=self.image_width) K[:, 0, 2] = self.principal_point[0] K[:, 1, 2] = self.principal_point[1] return K def __str__(self) -> str: return f"aov_deg={torch.rad2deg(self.phi_dict['aov'])}, t={torch.stack([self.phi_dict[k] for k in ['c_x', 'c_y', 'c_z']], dim=-1)}, pan_deg={torch.rad2deg(self.phi_dict['pan'])} tilt_deg={torch.rad2deg(self.phi_dict['tilt'])} roll_deg={torch.rad2deg(self.phi_dict['roll'])}" def str_pan_tilt_roll_fl(self, b, t): r = f"FOV={torch.rad2deg(self.phi_dict['aov'][b, t]):.1f}°, pan={torch.rad2deg(self.phi_dict['pan'][b, t]):.1f}° tilt={torch.rad2deg(self.phi_dict['tilt'][b, t]):.1f}° roll={torch.rad2deg(self.phi_dict['roll'][b, t]):.1f}°" return r def str_lens_distortion_coeff(self, b): # TODO: T! also need indivudual lens_dist_coeff for each t in T # print(self.lens_dist_coeff.shape) return f"lens dist coeff=" + " ".join( [f"{x:.2f}" for x in self.lens_dist_coeff[b, :2]] ) # print only radial lens dist. coeff def __repr__(self) -> str: return f"{self.__class__}:" + self.__str__() def __len__(self): return self.pseudo_batch_size # e.g. self.intrinsics.shape[0] def project_point2pixel(self, points3d: torch.tensor, lens_distortion: bool) -> torch.tensor: """Project world coordinates to pixel coordinates. Args: points3d (torch.tensor): of shape (N, 3) or (1, N, 3) Returns: torch.tensor: projected points of shape (B, T, N, 2) """ position = self.position.view(self.pseudo_batch_size, 1, 3) point = points3d - position rotated_point = self.rotation @ point.transpose(1, 2) # (pseudo_batch_size, 3, N) dist_point2cam = rotated_point[:, 2] # (B, N) distance pixel to world point dist_point2cam = dist_point2cam.view(self.pseudo_batch_size, 1, rotated_point.shape[-1]) rotated_point = rotated_point / dist_point2cam # (B, 3, N) / (B, 1, N) -> (B, 3, N) projected_points = self.intrinsics_raster @ rotated_point # (B, 3, N) # transpose vs view? here projected_points = projected_points.transpose(-1, -2) # cannot use view() projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points) if lens_distortion: if self.psi is None: raise RuntimeError("Lens distortion requested, but deactivated in module") projected_points = self.distort_points(projected_points, self.intrinsics_raster) # reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2) projected_points = projected_points.view( self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2 ) if self.nan_check: if torch.isnan(projected_points).any().item(): print(self.phi_dict_flat) print(projected_points) raise RuntimeWarning("NaN in project_point2pixel") return projected_points def project_point2ndc(self, points3d: torch.tensor, lens_distortion: bool) -> torch.tensor: """Project world coordinates to pixel coordinates. Args: points3d (torch.tensor): of shape (N, 3) or (1, N, 3) Returns: torch.tensor: projected points of shape (B, T, N, 2) """ position = self.position.view(self.pseudo_batch_size, 1, 3) point = points3d - position rotated_point = self.rotation @ point.transpose(1, 2) # (pseudo_batch_size, 3, N) dist_point2cam = rotated_point[:, 2] # (B, N) distance pixel to world point dist_point2cam = dist_point2cam.view(self.pseudo_batch_size, 1, rotated_point.shape[-1]) rotated_point = rotated_point / dist_point2cam # (B, 3, N) / (B, 1, N) -> (B, 3, N) projected_points = self.intrinsics_ndc @ rotated_point # (B, 3, N) # transpose vs view? here projected_points = projected_points.transpose(-1, -2) # cannot use view() projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points) if self.nan_check: if torch.isnan(projected_points).any().item(): print(projected_points) print(self.phi_dict_flat) print("lens distortion", self.lens_dist_coeff) raise RuntimeWarning("NaN in project_point2ndc before distort") if lens_distortion: if self.psi is None: raise RuntimeError("Lens distortion requested, but deactivated in module") projected_points = self.distort_points(projected_points, self.intrinsics_ndc) # reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2) projected_points = projected_points.view( self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2 ) if self.nan_check: if torch.isnan(projected_points).any().item(): print(self.phi_dict_flat) print(projected_points) raise RuntimeWarning("NaN in project_point2ndc after distort") return projected_points def project_point2pixel_from_P( self, points3d: torch.tensor, lens_distortion: bool ) -> torch.tensor: """Project world coordinates to pixel coordinates from the projection matrix. Args: points3d (torch.tensor): of shape (1, N, 3) Returns: torch.tensor: projected points of shape (B, T, N, 2) """ points3d = kornia.geometry.conversions.convert_points_to_homogeneous(points3d).transpose( 1, 2 ) # (B, 4, N) projected_points = torch.bmm(self.P_raster, points3d.repeat(self.pseudo_batch_size, 1, 1)) normalize_by = projected_points[:, -1].view( self.pseudo_batch_size, 1, projected_points.shape[-1] ) projected_points /= normalize_by projected_points = projected_points.transpose(-1, -2) # cannot use view() projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points) if lens_distortion: if self.psi is None: raise RuntimeError("Lens distortion requested, but deactivated in module") projected_points = self.distort_points(projected_points, self.intrinsics_raster) # reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2) projected_points = projected_points.view( self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2 ) return projected_points # (B, T, N, 2) def project_point2ndc_from_P( self, points3d: torch.tensor, lens_distortion: bool ) -> torch.tensor: """Project world coordinates to pixel coordinates from the projection matrix. Args: points3d (torch.tensor): of shape (1, N, 3) Returns: torch.tensor: projected points of shape (B, T, N, 2) """ points3d = kornia.geometry.conversions.convert_points_to_homogeneous(points3d).transpose( 1, 2 ) # (B, 4, N) projected_points = torch.bmm(self.P_ndc, points3d.repeat(self.pseudo_batch_size, 1, 1)) normalize_by = projected_points[:, -1].view( self.pseudo_batch_size, 1, projected_points.shape[-1] ) projected_points /= normalize_by projected_points = projected_points.transpose(-1, -2) # cannot use view() projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points) if lens_distortion: if self.psi is None: raise RuntimeError("Lens distortion requested, but deactivated in module") projected_points = self.distort_points(projected_points, self.intrinsics_ndc) # reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2) projected_points = projected_points.view( self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2 ) return projected_points # (B, T, N, 2) def rotation_from_euler_angles(self, pan, tilt, roll): # rotation matrices from a batch of pan tilt roll [rad] vectors of shape (?, ) mask = ( torch.eye(3, requires_grad=False, device=self.device) .reshape((1, 3, 3)) .repeat(pan.shape[0], 1, 1) ) mask[:, 0, 0] = -torch.sin(pan) * torch.sin(roll) * torch.cos(tilt) + torch.cos( pan ) * torch.cos(roll) mask[:, 0, 1] = torch.sin(pan) * torch.cos(roll) + torch.sin(roll) * torch.cos( pan ) * torch.cos(tilt) mask[:, 0, 2] = torch.sin(roll) * torch.sin(tilt) mask[:, 1, 0] = -torch.sin(pan) * torch.cos(roll) * torch.cos(tilt) - torch.sin( roll ) * torch.cos(pan) mask[:, 1, 1] = -torch.sin(pan) * torch.sin(roll) + torch.cos(pan) * torch.cos( roll ) * torch.cos(tilt) mask[:, 1, 2] = torch.sin(tilt) * torch.cos(roll) mask[:, 2, 0] = torch.sin(pan) * torch.sin(tilt) mask[:, 2, 1] = -torch.sin(tilt) * torch.cos(pan) mask[:, 2, 2] = torch.cos(tilt) return mask def get_homography_raster(self): return self.P_raster[:, :, [0, 1, 3]].inverse() def get_rays_world(self, x): """_summary_ Args: x (_type_): x of shape (B, 3, N) Returns: LineCollection: _description_ """ raise NotImplementedError # TODO: verify # ray_cam_trans = torch.bmm(self.rotation.inverse(), torch.bmm(self.intrinsics.inverse(), x)) # # unnormalized direction vector in euclidean points (x,y,z) based on camera origin (0,0,0) # ray_cam_trans = torch.nn.functional.normalize(ray_cam_trans, p=2, dim=1) # (B, 3, N) # # shift support vector to origin in world space, i.e. the translation vector # support = self.position.unsqueeze(-1).repeat( # ray_cam_trans.shape[0], 1, ray_cam_trans.shape[2] # ) # (B, 3, N) # return LineCollection(support=support, direction_norm=ray_cam_trans) @staticmethod def get_aov_rad(d: float, fl: torch.tensor): # https://en.wikipedia.org/wiki/Angle_of_view#Calculating_a_camera's_angle_of_view return 2 * torch.arctan(d / (2 * fl)) # in range [0.0, PI] @staticmethod def get_fl_from_aov_rad(aov_rad: torch.tensor, d: float): return 0.5 * d * (1 / torch.tan(0.5 * aov_rad)) def undistort_points(self, points_pixel: torch.tensor, intrinsics, num_iters=5) -> torch.tensor: """Compensate for lens distortion a set of 2D image points. Wrapper for kornia.geometry.undistort_points() Args: points_pixel (torch.tensor): tensor of shape (B, N, 2) Returns: torch.tensor: undistorted points of shape (B, N, 2) """ # print(points_pixel.shape, intrinsics.shape, self.lens_dist_coeff.shape) batch_dim, temporal_dim, N, _ = points_pixel.shape points_pixel = points_pixel.view(batch_dim * temporal_dim, N, 2) true_batch_size = batch_dim lens_dist_coeff = self.lens_dist_coeff if true_batch_size < self.batch_dim: intrinsics = intrinsics[:true_batch_size] lens_dist_coeff = lens_dist_coeff[:true_batch_size] return kornia.geometry.undistort_points( points_pixel, intrinsics, dist=lens_dist_coeff, num_iters=num_iters ).view(batch_dim, temporal_dim, N, 2) def distort_points(self, points_pixel: torch.tensor, intrinsics) -> torch.tensor: """Distortion of a set of 2D points based on the lens distortion model. Wrapper for kornia.geometry.distort_points() Args: points_pixel (torch.tensor): tensor of shape (B, N, 2) Returns: torch.tensor: distorted points of shape (B, N, 2) """ return kornia.geometry.distort_points(points_pixel, intrinsics, dist=self.lens_dist_coeff) def undistort_images(self, images): # images of shape (B, T, C, H, W) true_batch_size, T = images.shape[:2] images = images.view(true_batch_size * T, 3, self.image_height, self.image_width).to( self.device ) intrinsics = self.intrinsics_raster lens_dist_coeff = self.lens_dist_coeff if true_batch_size < self.batch_dim: intrinsics = intrinsics[:true_batch_size] lens_dist_coeff = lens_dist_coeff[:true_batch_size] return kornia.geometry.calibration.undistort_image( images, intrinsics, lens_dist_coeff ).view(true_batch_size, self.temporal_dim, 3, self.image_height, self.image_width) def get_parameters(self, true_batch_size=None): """ Get dict of relevant camera parameters and homography matrix :return: The dictionary """ out_dict = { "pan_degrees": torch.rad2deg(self.phi_dict["pan"]), "tilt_degrees": torch.rad2deg(self.phi_dict["tilt"]), "roll_degrees": torch.rad2deg(self.phi_dict["roll"]), "position_meters": torch.stack([self.phi_dict[k] for k in ["c_x", "c_y", "c_z"]], dim=1) .squeeze(-1) .unsqueeze(-2) .repeat(1, self.temporal_dim, 1), "aov_radian": self.phi_dict["aov"], "aov_degrees": torch.rad2deg(self.phi_dict["aov"]), "x_focal_length": self.get_fl_from_aov_rad(self.phi_dict["aov"], d=self.image_width), "y_focal_length": self.get_fl_from_aov_rad(self.phi_dict["aov"], d=self.image_width), "principal_point": torch.tensor( [[self.principal_point] * self.temporal_dim] * self.batch_dim ), } out_dict["homography"] = self.get_homography_raster().unsqueeze(1) # (B, 1, 3, 3) # expected for SN evaluation out_dict["radial_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 6) out_dict["tangential_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 2) out_dict["thin_prism_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 4) if self.psi is not None: # in case only k1 and k2 are provided out_dict["radial_distortion"][..., :2] = self.psi[..., :2] if true_batch_size is None or true_batch_size == self.batch_dim: return out_dict for k in out_dict.keys(): out_dict[k] = out_dict[k][:true_batch_size] return out_dict @staticmethod def static_undistort_points(points, cam): intrinsics = cam.intrinsics_raster lens_dist_coeff = cam.lens_dist_coeff true_batch_size = points.shape[0] if true_batch_size < cam.batch_dim: intrinsics = intrinsics[:true_batch_size] lens_dist_coeff = lens_dist_coeff[:true_batch_size] # points in homogenous coordinates # (B, T, 3, S, N) -> (T, 3, S*N) -> (T, S*N, 3) batch_size, T, _, S, N = points.shape points = points.view(batch_size, T, 3, S * N).transpose(2, 3) points[..., :2] = kornia.geometry.undistort_points( points[..., :2].view(batch_size * T, S * N, 2), intrinsics, dist=lens_dist_coeff, num_iters=1, ).view(batch_size, T, S * N, 2) # (T, S*N, 3) -> (T, 3, S*N) -> (B, T, 3, S, N) points = points.transpose(2, 3).view(batch_size, T, 3, S, N) return points