Spaces:
Runtime error
Runtime error
// Copyright (c) Meta Platforms, Inc. and affiliates. | |
// All rights reserved. | |
// | |
// This source code is licensed under the license found in the | |
// LICENSE file in the root directory of this source tree. | |
void compute_raydirs_forward_cuda( | |
int N, int H, int W, | |
float * viewposim, | |
float * viewrotim, | |
float * focalim, | |
float * princptim, | |
float * pixelcoordsim, | |
float volradius, | |
float * raypos, | |
float * raydir, | |
float * tminmax, | |
cudaStream_t stream); | |
void compute_raydirs_backward_cuda( | |
int N, int H, int W, | |
float * viewposim, | |
float * viewrotim, | |
float * focalim, | |
float * princptim, | |
float * pixelcoordsim, | |
float volradius, | |
float * raypos, | |
float * raydir, | |
float * tminmax, | |
float * grad_viewposim, | |
float * grad_viewrotim, | |
float * grad_focalim, | |
float * grad_princptim, | |
cudaStream_t stream); | |
std::vector<torch::Tensor> compute_raydirs_forward( | |
torch::Tensor viewposim, | |
torch::Tensor viewrotim, | |
torch::Tensor focalim, | |
torch::Tensor princptim, | |
torch::optional<torch::Tensor> pixelcoordsim, | |
int W, int H, | |
float volradius, | |
torch::Tensor rayposim, | |
torch::Tensor raydirim, | |
torch::Tensor tminmaxim) { | |
CHECK_INPUT(viewposim); | |
CHECK_INPUT(viewrotim); | |
CHECK_INPUT(focalim); | |
CHECK_INPUT(princptim); | |
if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); } | |
CHECK_INPUT(rayposim); | |
CHECK_INPUT(raydirim); | |
CHECK_INPUT(tminmaxim); | |
int N = viewposim.size(0); | |
assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W)); | |
compute_raydirs_forward_cuda(N, H, W, | |
reinterpret_cast<float *>(viewposim.data_ptr()), | |
reinterpret_cast<float *>(viewrotim.data_ptr()), | |
reinterpret_cast<float *>(focalim.data_ptr()), | |
reinterpret_cast<float *>(princptim.data_ptr()), | |
pixelcoordsim ? reinterpret_cast<float *>(pixelcoordsim->data_ptr()) : nullptr, | |
volradius, | |
reinterpret_cast<float *>(rayposim.data_ptr()), | |
reinterpret_cast<float *>(raydirim.data_ptr()), | |
reinterpret_cast<float *>(tminmaxim.data_ptr()), | |
0); | |
return {}; | |
} | |
std::vector<torch::Tensor> compute_raydirs_backward( | |
torch::Tensor viewposim, | |
torch::Tensor viewrotim, | |
torch::Tensor focalim, | |
torch::Tensor princptim, | |
torch::optional<torch::Tensor> pixelcoordsim, | |
int W, int H, | |
float volradius, | |
torch::Tensor rayposim, | |
torch::Tensor raydirim, | |
torch::Tensor tminmaxim, | |
torch::Tensor grad_viewpos, | |
torch::Tensor grad_viewrot, | |
torch::Tensor grad_focal, | |
torch::Tensor grad_princpt) { | |
CHECK_INPUT(viewposim); | |
CHECK_INPUT(viewrotim); | |
CHECK_INPUT(focalim); | |
CHECK_INPUT(princptim); | |
if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); } | |
CHECK_INPUT(rayposim); | |
CHECK_INPUT(raydirim); | |
CHECK_INPUT(tminmaxim); | |
CHECK_INPUT(grad_viewpos); | |
CHECK_INPUT(grad_viewrot); | |
CHECK_INPUT(grad_focal); | |
CHECK_INPUT(grad_princpt); | |
int N = viewposim.size(0); | |
assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W)); | |
compute_raydirs_backward_cuda(N, H, W, | |
reinterpret_cast<float *>(viewposim.data_ptr()), | |
reinterpret_cast<float *>(viewrotim.data_ptr()), | |
reinterpret_cast<float *>(focalim.data_ptr()), | |
reinterpret_cast<float *>(princptim.data_ptr()), | |
pixelcoordsim ? reinterpret_cast<float *>(pixelcoordsim->data_ptr()) : nullptr, | |
volradius, | |
reinterpret_cast<float *>(rayposim.data_ptr()), | |
reinterpret_cast<float *>(raydirim.data_ptr()), | |
reinterpret_cast<float *>(tminmaxim.data_ptr()), | |
reinterpret_cast<float *>(grad_viewpos.data_ptr()), | |
reinterpret_cast<float *>(grad_viewrot.data_ptr()), | |
reinterpret_cast<float *>(grad_focal.data_ptr()), | |
reinterpret_cast<float *>(grad_princpt.data_ptr()), | |
0); | |
return {}; | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("compute_raydirs_forward", &compute_raydirs_forward, "raydirs forward (CUDA)"); | |
m.def("compute_raydirs_backward", &compute_raydirs_backward, "raydirs backward (CUDA)"); | |
} | |