|
from typing import Tuple, Dict, Union |
|
|
|
from pytorch_lightning import LightningModule |
|
import torch |
|
import kornia |
|
import torch.nn as nn |
|
from .utils.data_distr import FeatureScalerZScore |
|
|
|
|
|
class CameraParameterWLensDistDictZScore(LightningModule): |
|
"""Holds individual camera parameters including lens distortion parameters as nn.Modul""" |
|
|
|
def __init__(self, cam_distr, dist_distr, device="cpu"): |
|
super(CameraParameterWLensDistDictZScore, self).__init__() |
|
|
|
self.cam_distr = cam_distr |
|
self._device = device |
|
|
|
|
|
self.param_dict = torch.nn.ParameterDict( |
|
{ |
|
k: torch.nn.parameter.Parameter( |
|
torch.zeros( |
|
*cam_distr[k]["dimension"], |
|
device=device, |
|
), |
|
requires_grad=False |
|
if ("no_grad" in cam_distr[k]) and (cam_distr[k]["no_grad"] == True) |
|
else True, |
|
) |
|
for k in cam_distr.keys() |
|
} |
|
) |
|
|
|
|
|
self.feature_scaler = torch.nn.ModuleDict( |
|
{k: FeatureScalerZScore(*cam_distr[k]["mean_std"]) for k in cam_distr.keys()} |
|
) |
|
|
|
self.dist_distr = dist_distr |
|
if self.dist_distr is not None: |
|
self.param_dict_dist = torch.nn.ParameterDict( |
|
{ |
|
k: torch.nn.Parameter(torch.zeros(*dist_distr[k]["dimension"], device=device)) |
|
for k in dist_distr.keys() |
|
} |
|
) |
|
|
|
|
|
|
|
self.feature_scaler_dist_coeff = torch.nn.ModuleDict( |
|
{k: FeatureScalerZScore(*dist_distr[k]["mean_std"]) for k in dist_distr.keys()} |
|
) |
|
|
|
def initialize( |
|
self, |
|
update_dict_cam: Union[Dict[str, Union[float, torch.tensor]], None], |
|
update_dict_dist=None, |
|
): |
|
"""Initializes all camera parameters with zeros and replace specific values with provided values |
|
|
|
Args: |
|
update_dict_cam (Dict[str, Union[float, torch.tensor]]): Parameters to be updated |
|
""" |
|
|
|
for k in self.param_dict.keys(): |
|
self.param_dict[k].data = torch.zeros( |
|
*self.cam_distr[k]["dimension"], device=self._device |
|
) |
|
if self.dist_distr is not None: |
|
for k in self.dist_distr.keys(): |
|
self.param_dict_dist[k].data = torch.zeros( |
|
*self.dist_distr[k]["dimension"], device=self._device |
|
) |
|
|
|
if update_dict_cam is not None and len(update_dict_cam) > 0: |
|
for k, v in update_dict_cam.items(): |
|
self.param_dict[k].data = ( |
|
torch.zeros(*self.cam_distr[k]["dimension"], device=self._device) + v |
|
) |
|
if update_dict_dist is not None: |
|
raise NotImplementedError |
|
|
|
def forward(self): |
|
phi_dict = {} |
|
for k, param in self.param_dict.items(): |
|
phi_dict[k] = self.feature_scaler[k](param) |
|
|
|
if self.dist_distr is None: |
|
return phi_dict, None |
|
|
|
|
|
|
|
psi = torch.stack( |
|
[ |
|
torch.clamp( |
|
self.feature_scaler_dist_coeff[k](param), |
|
min=self.dist_distr[k]["minmax"][0], |
|
max=self.dist_distr[k]["minmax"][1], |
|
) |
|
for k, param in self.param_dict_dist.items() |
|
], |
|
dim=-1, |
|
) |
|
|
|
return phi_dict, psi |
|
|
|
|
|
class SNProjectiveCamera: |
|
def __init__( |
|
self, |
|
phi_dict: Dict[str, torch.tensor], |
|
psi: torch.tensor, |
|
principal_point: Tuple[float, float], |
|
image_width: int, |
|
image_height: int, |
|
device: str = "cpu", |
|
nan_check=True, |
|
) -> None: |
|
"""Projective camera defined as K @ R [I|-t] with lens distortion module and batch dimensions B,T. |
|
|
|
Following Euler angles convention, we use a ZXZ succession of intrinsic rotations in order to describe |
|
the orientation of the camera. Starting from the world reference axis system, we first apply a rotation |
|
around the Z axis to pan the camera. Then the obtained axis system is rotated around its x axis in order to tilt the camera. |
|
Then the last rotation around the z axis of the new axis system alows to roll the camera. Note that this z axis is the principal axis of the camera. |
|
|
|
As T is not provided for camra location and lens distortion, these parameters are assumed to be fixed accross T. |
|
phi_dict is a dict of parameters containing: |
|
{ |
|
'aov_x, torch.Size([B, T])', |
|
'pan, torch.Size([B, T])', |
|
'tilt, torch.Size([B, T])', |
|
'roll, torch.Size([B, T])', |
|
'c_x, torch.Size([B, 1])', |
|
'c_y, torch.Size([B, 1])', |
|
'c_z, torch.Size([B, 1])', |
|
} |
|
|
|
Internally fuses B and T dimension to pseudo batch dimension. |
|
{ |
|
'aov_x, torch.Size([B*T])', |
|
'pan, torch.Size([B*T])', |
|
'tilt, torch.Size([B*T])' |
|
'roll, torch.Size([B*T])', |
|
'c_x, torch.Size([B])', |
|
'c_y, torch.Size([B])', |
|
'c_z, torch.Size([B])', |
|
} |
|
|
|
aov_x, pan, tilt, roll are assumed in radian. |
|
|
|
Note on lens distortion: |
|
Lens distortion coefficients are independent from image resolution! |
|
We I(dist_points(K_ndc, dist_coeff, points2d_ndc)) == I(dist_points(K_raster, dist_coeff, points2d_raster)) |
|
|
|
Args: |
|
phi_dict (Dict[str, torch.tensor]): See example above |
|
psi (Union[None, torch.Tensor]): distortion coefficients as concatinated vector according to https://kornia.readthedocs.io/en/latest/geometry.calibration.html of shape (B, T, {2, 4, 5,8,12, 14}) |
|
principal_point (Tuple[float, float]): Principal point assumed to be fixed across all samples (B,T,) |
|
image_width (int): assumed to be fixed across all samples (B,T,) |
|
image_height (int): assumed to be fixed across all samples (B,T,) |
|
""" |
|
|
|
|
|
phi_dict_flat = {} |
|
for k, v in phi_dict.items(): |
|
if len(v.shape) == 2: |
|
phi_dict_flat[k] = v.view(v.shape[0] * v.shape[1]) |
|
elif len(v.shape) == 3: |
|
phi_dict_flat[k] = v.view(v.shape[0] * v.shape[1], v.shape[-1]) |
|
|
|
self.batch_dim, self.temporal_dim = phi_dict["pan"].shape |
|
self.pseudo_batch_size = phi_dict_flat["pan"].shape[0] |
|
self.phi_dict_flat = phi_dict_flat |
|
|
|
self.principal_point = principal_point |
|
self.image_width = image_width |
|
self.image_height = image_height |
|
self.device = device |
|
|
|
self.psi = psi |
|
if self.psi is not None: |
|
if self.psi.shape[-1] != 2: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
if self.psi.shape[-1] == 2: |
|
|
|
psi_ext = torch.zeros(*list(self.psi.shape[:-1]), 4) |
|
psi_ext[..., :2] = self.psi |
|
self.psi = psi_ext |
|
self.lens_dist_coeff = self.psi.view(self.pseudo_batch_size, self.psi.shape[-1]).to( |
|
self.device |
|
) |
|
|
|
self.intrinsics_ndc = self.construct_intrinsics_ndc() |
|
self.intrinsics_raster = self.construct_intrinsics_raster() |
|
|
|
self.rotation = self.rotation_from_euler_angles( |
|
*[phi_dict_flat[k] for k in ["pan", "tilt", "roll"]] |
|
) |
|
self.position = torch.stack([phi_dict_flat[k] for k in ["c_x", "c_y", "c_z"]], dim=-1) |
|
self.position = self.position.repeat_interleave( |
|
int(self.pseudo_batch_size / self.batch_dim), dim=0 |
|
) |
|
self.P_ndc = self.construct_projection_matrix(self.intrinsics_ndc) |
|
self.P_raster = self.construct_projection_matrix(self.intrinsics_raster) |
|
self.phi_dict = phi_dict |
|
|
|
self.nan_check = nan_check |
|
super().__init__() |
|
|
|
def construct_projection_matrix(self, intrinsics): |
|
It = torch.eye(4, device=self.device)[:-1].repeat(self.pseudo_batch_size, 1, 1) |
|
It[:, :, -1] = -self.position |
|
self.It = It |
|
return intrinsics @ self.rotation @ It |
|
|
|
def construct_intrinsics_ndc(self): |
|
|
|
K = torch.eye(3, requires_grad=False, device=self.device) |
|
K = K.reshape((1, 3, 3)).repeat(self.pseudo_batch_size, 1, 1) |
|
K[:, 0, 0] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=2) |
|
K[:, 1, 1] = self.get_fl_from_aov_rad( |
|
self.phi_dict_flat["aov"], d=2 * self.image_width / self.image_height |
|
) |
|
return K |
|
|
|
def construct_intrinsics_raster(self): |
|
|
|
K = torch.eye(3, requires_grad=False, device=self.device) |
|
K = K.reshape((1, 3, 3)).repeat(self.pseudo_batch_size, 1, 1) |
|
K[:, 0, 0] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=self.image_width) |
|
K[:, 1, 1] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=self.image_width) |
|
K[:, 0, 2] = self.principal_point[0] |
|
K[:, 1, 2] = self.principal_point[1] |
|
return K |
|
|
|
def __str__(self) -> str: |
|
return f"aov_deg={torch.rad2deg(self.phi_dict['aov'])}, t={torch.stack([self.phi_dict[k] for k in ['c_x', 'c_y', 'c_z']], dim=-1)}, pan_deg={torch.rad2deg(self.phi_dict['pan'])} tilt_deg={torch.rad2deg(self.phi_dict['tilt'])} roll_deg={torch.rad2deg(self.phi_dict['roll'])}" |
|
|
|
def str_pan_tilt_roll_fl(self, b, t): |
|
r = f"FOV={torch.rad2deg(self.phi_dict['aov'][b, t]):.1f}°, pan={torch.rad2deg(self.phi_dict['pan'][b, t]):.1f}° tilt={torch.rad2deg(self.phi_dict['tilt'][b, t]):.1f}° roll={torch.rad2deg(self.phi_dict['roll'][b, t]):.1f}°" |
|
return r |
|
|
|
def str_lens_distortion_coeff(self, b): |
|
|
|
|
|
return f"lens dist coeff=" + " ".join( |
|
[f"{x:.2f}" for x in self.lens_dist_coeff[b, :2]] |
|
) |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__}:" + self.__str__() |
|
|
|
def __len__(self): |
|
return self.pseudo_batch_size |
|
|
|
def project_point2pixel(self, points3d: torch.tensor, lens_distortion: bool) -> torch.tensor: |
|
"""Project world coordinates to pixel coordinates. |
|
|
|
Args: |
|
points3d (torch.tensor): of shape (N, 3) or (1, N, 3) |
|
|
|
Returns: |
|
torch.tensor: projected points of shape (B, T, N, 2) |
|
""" |
|
position = self.position.view(self.pseudo_batch_size, 1, 3) |
|
point = points3d - position |
|
rotated_point = self.rotation @ point.transpose(1, 2) |
|
dist_point2cam = rotated_point[:, 2] |
|
dist_point2cam = dist_point2cam.view(self.pseudo_batch_size, 1, rotated_point.shape[-1]) |
|
rotated_point = rotated_point / dist_point2cam |
|
|
|
projected_points = self.intrinsics_raster @ rotated_point |
|
|
|
projected_points = projected_points.transpose(-1, -2) |
|
projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points) |
|
if lens_distortion: |
|
if self.psi is None: |
|
raise RuntimeError("Lens distortion requested, but deactivated in module") |
|
projected_points = self.distort_points(projected_points, self.intrinsics_raster) |
|
|
|
|
|
projected_points = projected_points.view( |
|
self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2 |
|
) |
|
if self.nan_check: |
|
if torch.isnan(projected_points).any().item(): |
|
print(self.phi_dict_flat) |
|
print(projected_points) |
|
raise RuntimeWarning("NaN in project_point2pixel") |
|
return projected_points |
|
|
|
def project_point2ndc(self, points3d: torch.tensor, lens_distortion: bool) -> torch.tensor: |
|
"""Project world coordinates to pixel coordinates. |
|
|
|
Args: |
|
points3d (torch.tensor): of shape (N, 3) or (1, N, 3) |
|
|
|
Returns: |
|
torch.tensor: projected points of shape (B, T, N, 2) |
|
""" |
|
position = self.position.view(self.pseudo_batch_size, 1, 3) |
|
point = points3d - position |
|
rotated_point = self.rotation @ point.transpose(1, 2) |
|
dist_point2cam = rotated_point[:, 2] |
|
dist_point2cam = dist_point2cam.view(self.pseudo_batch_size, 1, rotated_point.shape[-1]) |
|
rotated_point = rotated_point / dist_point2cam |
|
|
|
projected_points = self.intrinsics_ndc @ rotated_point |
|
|
|
projected_points = projected_points.transpose(-1, -2) |
|
projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points) |
|
if self.nan_check: |
|
if torch.isnan(projected_points).any().item(): |
|
print(projected_points) |
|
print(self.phi_dict_flat) |
|
print("lens distortion", self.lens_dist_coeff) |
|
|
|
raise RuntimeWarning("NaN in project_point2ndc before distort") |
|
if lens_distortion: |
|
if self.psi is None: |
|
raise RuntimeError("Lens distortion requested, but deactivated in module") |
|
projected_points = self.distort_points(projected_points, self.intrinsics_ndc) |
|
|
|
|
|
projected_points = projected_points.view( |
|
self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2 |
|
) |
|
if self.nan_check: |
|
if torch.isnan(projected_points).any().item(): |
|
print(self.phi_dict_flat) |
|
print(projected_points) |
|
raise RuntimeWarning("NaN in project_point2ndc after distort") |
|
return projected_points |
|
|
|
def project_point2pixel_from_P( |
|
self, points3d: torch.tensor, lens_distortion: bool |
|
) -> torch.tensor: |
|
"""Project world coordinates to pixel coordinates from the projection matrix. |
|
|
|
Args: |
|
points3d (torch.tensor): of shape (1, N, 3) |
|
|
|
Returns: |
|
torch.tensor: projected points of shape (B, T, N, 2) |
|
""" |
|
|
|
points3d = kornia.geometry.conversions.convert_points_to_homogeneous(points3d).transpose( |
|
1, 2 |
|
) |
|
projected_points = torch.bmm(self.P_raster, points3d.repeat(self.pseudo_batch_size, 1, 1)) |
|
normalize_by = projected_points[:, -1].view( |
|
self.pseudo_batch_size, 1, projected_points.shape[-1] |
|
) |
|
projected_points /= normalize_by |
|
projected_points = projected_points.transpose(-1, -2) |
|
projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points) |
|
if lens_distortion: |
|
if self.psi is None: |
|
raise RuntimeError("Lens distortion requested, but deactivated in module") |
|
projected_points = self.distort_points(projected_points, self.intrinsics_raster) |
|
|
|
projected_points = projected_points.view( |
|
self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2 |
|
) |
|
return projected_points |
|
|
|
def project_point2ndc_from_P( |
|
self, points3d: torch.tensor, lens_distortion: bool |
|
) -> torch.tensor: |
|
"""Project world coordinates to pixel coordinates from the projection matrix. |
|
|
|
Args: |
|
points3d (torch.tensor): of shape (1, N, 3) |
|
|
|
Returns: |
|
torch.tensor: projected points of shape (B, T, N, 2) |
|
""" |
|
|
|
points3d = kornia.geometry.conversions.convert_points_to_homogeneous(points3d).transpose( |
|
1, 2 |
|
) |
|
projected_points = torch.bmm(self.P_ndc, points3d.repeat(self.pseudo_batch_size, 1, 1)) |
|
normalize_by = projected_points[:, -1].view( |
|
self.pseudo_batch_size, 1, projected_points.shape[-1] |
|
) |
|
projected_points /= normalize_by |
|
projected_points = projected_points.transpose(-1, -2) |
|
projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points) |
|
if lens_distortion: |
|
if self.psi is None: |
|
raise RuntimeError("Lens distortion requested, but deactivated in module") |
|
projected_points = self.distort_points(projected_points, self.intrinsics_ndc) |
|
|
|
projected_points = projected_points.view( |
|
self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2 |
|
) |
|
return projected_points |
|
|
|
def rotation_from_euler_angles(self, pan, tilt, roll): |
|
|
|
|
|
mask = ( |
|
torch.eye(3, requires_grad=False, device=self.device) |
|
.reshape((1, 3, 3)) |
|
.repeat(pan.shape[0], 1, 1) |
|
) |
|
mask[:, 0, 0] = -torch.sin(pan) * torch.sin(roll) * torch.cos(tilt) + torch.cos( |
|
pan |
|
) * torch.cos(roll) |
|
mask[:, 0, 1] = torch.sin(pan) * torch.cos(roll) + torch.sin(roll) * torch.cos( |
|
pan |
|
) * torch.cos(tilt) |
|
mask[:, 0, 2] = torch.sin(roll) * torch.sin(tilt) |
|
|
|
mask[:, 1, 0] = -torch.sin(pan) * torch.cos(roll) * torch.cos(tilt) - torch.sin( |
|
roll |
|
) * torch.cos(pan) |
|
mask[:, 1, 1] = -torch.sin(pan) * torch.sin(roll) + torch.cos(pan) * torch.cos( |
|
roll |
|
) * torch.cos(tilt) |
|
mask[:, 1, 2] = torch.sin(tilt) * torch.cos(roll) |
|
|
|
mask[:, 2, 0] = torch.sin(pan) * torch.sin(tilt) |
|
mask[:, 2, 1] = -torch.sin(tilt) * torch.cos(pan) |
|
mask[:, 2, 2] = torch.cos(tilt) |
|
|
|
return mask |
|
|
|
def get_homography_raster(self): |
|
return self.P_raster[:, :, [0, 1, 3]].inverse() |
|
|
|
def get_rays_world(self, x): |
|
"""_summary_ |
|
|
|
Args: |
|
x (_type_): x of shape (B, 3, N) |
|
|
|
Returns: |
|
LineCollection: _description_ |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def get_aov_rad(d: float, fl: torch.tensor): |
|
|
|
return 2 * torch.arctan(d / (2 * fl)) |
|
|
|
@staticmethod |
|
def get_fl_from_aov_rad(aov_rad: torch.tensor, d: float): |
|
return 0.5 * d * (1 / torch.tan(0.5 * aov_rad)) |
|
|
|
def undistort_points(self, points_pixel: torch.tensor, intrinsics, num_iters=5) -> torch.tensor: |
|
"""Compensate for lens distortion a set of 2D image points. |
|
|
|
Wrapper for kornia.geometry.undistort_points() |
|
|
|
Args: |
|
points_pixel (torch.tensor): tensor of shape (B, N, 2) |
|
|
|
Returns: |
|
torch.tensor: undistorted points of shape (B, N, 2) |
|
""" |
|
|
|
batch_dim, temporal_dim, N, _ = points_pixel.shape |
|
points_pixel = points_pixel.view(batch_dim * temporal_dim, N, 2) |
|
true_batch_size = batch_dim |
|
|
|
lens_dist_coeff = self.lens_dist_coeff |
|
if true_batch_size < self.batch_dim: |
|
intrinsics = intrinsics[:true_batch_size] |
|
lens_dist_coeff = lens_dist_coeff[:true_batch_size] |
|
|
|
return kornia.geometry.undistort_points( |
|
points_pixel, intrinsics, dist=lens_dist_coeff, num_iters=num_iters |
|
).view(batch_dim, temporal_dim, N, 2) |
|
|
|
def distort_points(self, points_pixel: torch.tensor, intrinsics) -> torch.tensor: |
|
"""Distortion of a set of 2D points based on the lens distortion model. |
|
|
|
Wrapper for kornia.geometry.distort_points() |
|
|
|
Args: |
|
points_pixel (torch.tensor): tensor of shape (B, N, 2) |
|
|
|
Returns: |
|
torch.tensor: distorted points of shape (B, N, 2) |
|
""" |
|
return kornia.geometry.distort_points(points_pixel, intrinsics, dist=self.lens_dist_coeff) |
|
|
|
def undistort_images(self, images): |
|
|
|
true_batch_size, T = images.shape[:2] |
|
images = images.view(true_batch_size * T, 3, self.image_height, self.image_width).to( |
|
self.device |
|
) |
|
intrinsics = self.intrinsics_raster |
|
lens_dist_coeff = self.lens_dist_coeff |
|
if true_batch_size < self.batch_dim: |
|
intrinsics = intrinsics[:true_batch_size] |
|
lens_dist_coeff = lens_dist_coeff[:true_batch_size] |
|
|
|
return kornia.geometry.calibration.undistort_image( |
|
images, intrinsics, lens_dist_coeff |
|
).view(true_batch_size, self.temporal_dim, 3, self.image_height, self.image_width) |
|
|
|
def get_parameters(self, true_batch_size=None): |
|
""" |
|
Get dict of relevant camera parameters and homography matrix |
|
:return: The dictionary |
|
""" |
|
out_dict = { |
|
"pan_degrees": torch.rad2deg(self.phi_dict["pan"]), |
|
"tilt_degrees": torch.rad2deg(self.phi_dict["tilt"]), |
|
"roll_degrees": torch.rad2deg(self.phi_dict["roll"]), |
|
"position_meters": torch.stack([self.phi_dict[k] for k in ["c_x", "c_y", "c_z"]], dim=1) |
|
.squeeze(-1) |
|
.unsqueeze(-2) |
|
.repeat(1, self.temporal_dim, 1), |
|
"aov_radian": self.phi_dict["aov"], |
|
"aov_degrees": torch.rad2deg(self.phi_dict["aov"]), |
|
"x_focal_length": self.get_fl_from_aov_rad(self.phi_dict["aov"], d=self.image_width), |
|
"y_focal_length": self.get_fl_from_aov_rad(self.phi_dict["aov"], d=self.image_width), |
|
"principal_point": torch.tensor( |
|
[[self.principal_point] * self.temporal_dim] * self.batch_dim |
|
), |
|
} |
|
out_dict["homography"] = self.get_homography_raster().unsqueeze(1) |
|
|
|
|
|
out_dict["radial_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 6) |
|
out_dict["tangential_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 2) |
|
out_dict["thin_prism_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 4) |
|
|
|
if self.psi is not None: |
|
|
|
out_dict["radial_distortion"][..., :2] = self.psi[..., :2] |
|
|
|
if true_batch_size is None or true_batch_size == self.batch_dim: |
|
return out_dict |
|
|
|
for k in out_dict.keys(): |
|
out_dict[k] = out_dict[k][:true_batch_size] |
|
|
|
return out_dict |
|
|
|
@staticmethod |
|
def static_undistort_points(points, cam): |
|
|
|
intrinsics = cam.intrinsics_raster |
|
lens_dist_coeff = cam.lens_dist_coeff |
|
|
|
true_batch_size = points.shape[0] |
|
if true_batch_size < cam.batch_dim: |
|
intrinsics = intrinsics[:true_batch_size] |
|
lens_dist_coeff = lens_dist_coeff[:true_batch_size] |
|
|
|
|
|
batch_size, T, _, S, N = points.shape |
|
points = points.view(batch_size, T, 3, S * N).transpose(2, 3) |
|
points[..., :2] = kornia.geometry.undistort_points( |
|
points[..., :2].view(batch_size * T, S * N, 2), |
|
intrinsics, |
|
dist=lens_dist_coeff, |
|
num_iters=1, |
|
).view(batch_size, T, S * N, 2) |
|
|
|
|
|
points = points.transpose(2, 3).view(batch_size, T, 3, S, N) |
|
return points |
|
|