Skip to content

Commit

Permalink
Add the implementation of the Gridding layer (arXiv 2006.03761).
Browse files Browse the repository at this point in the history
  • Loading branch information
hzxie committed Dec 22, 2020
1 parent 48e3a16 commit 8488978
Show file tree
Hide file tree
Showing 6 changed files with 390 additions and 14 deletions.
25 changes: 12 additions & 13 deletions cuda/include/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,22 @@ inline dim3 opt_block_config(int x, int y)
// from https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600
#else
__device__ double atomicAdd(double* address, double val)
{
unsigned long long int* address_as_ull =
(unsigned long long int*)address;
unsigned long long int old = *address_as_ull, assumed;
__device__ double atomicAdd(double* address, double val)
{
unsigned long long int* address_as_ull = (unsigned long long int*)address;
unsigned long long int old = *address_as_ull, assumed;

do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val +
__longlong_as_double(assumed)));
do
{
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));

// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
} while (assumed != old);

return __longlong_as_double(old);
}
return __longlong_as_double(old);
}
#endif

#define CUDA_CHECK_ERRORS() \
Expand Down
18 changes: 18 additions & 0 deletions cuda/include/gridding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include <vector>

#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

std::vector<torch::Tensor> gridding_kernel_warpper(float min_x, float max_x, float min_y,
float max_y, float min_z, float max_z,
torch::Tensor ptcloud, cudaStream_t stream);

torch::Tensor gridding_grad_kernel_warpper(torch::Tensor grid_pt_weights,
torch::Tensor grid_pt_indexes, torch::Tensor grad_grid,
cudaStream_t stream);

std::vector<torch::Tensor> gridding(float min_x, float max_x, float min_y, float max_y, float min_z,
float max_z, torch::Tensor ptcloud);

torch::Tensor gridding_grad(torch::Tensor grid_pt_weights, torch::Tensor grid_pt_indexes,
torch::Tensor grad_grid);
4 changes: 4 additions & 0 deletions cuda/src/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "ball_query.h"
#include "chamfer_dist.h"
#include "cubic_feature_sampling.h"
#include "gridding.h"
#include "interpolate.h"
#include "metrics.h"
#include "sampling.h"
Expand All @@ -23,4 +24,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)

m.def("cubic_feature_sampling", &cubic_feature_sampling);
m.def("cubic_feature_sampling_grad", &cubic_feature_sampling_grad);

m.def("gridding", &gridding);
m.def("gridding_grad", &gridding_grad);
}
2 changes: 1 addition & 1 deletion cuda/src/chamfer_dist_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#include <cuda_runtime.h>
#include <torch/extension.h>

#include <vector>
#include "cuda_utils.h"
#include <vector>

template <typename scalar_t>
__global__ void chamfer_dist_kernel(int batch_size, int n, const scalar_t* __restrict__ xyz1, int m,
Expand Down
26 changes: 26 additions & 0 deletions cuda/src/gridding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "gridding.h"
#include "utils.h"

std::vector<torch::Tensor> gridding(float min_x, float max_x, float min_y, float max_y, float min_z,
float max_z, torch::Tensor ptcloud)
{
CHECK_CUDA(ptcloud);
CHECK_CONTIGUOUS(ptcloud);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
return gridding_kernel_warpper(min_x, max_x, min_y, max_y, min_z, max_z, ptcloud, stream);
}

torch::Tensor gridding_grad(torch::Tensor grid_pt_weights, torch::Tensor grid_pt_indexes,
torch::Tensor grad_grid)
{
CHECK_CUDA(grid_pt_weights);
CHECK_CONTIGUOUS(grid_pt_weights);
CHECK_CUDA(grid_pt_indexes);
CHECK_CONTIGUOUS(grid_pt_indexes);
CHECK_CUDA(grad_grid);
CHECK_CONTIGUOUS(grad_grid);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
return gridding_grad_kernel_warpper(grid_pt_weights, grid_pt_indexes, grad_grid, stream);
}
Loading

0 comments on commit 8488978

Please sign in to comment.