From 848897801a8e448c059a0c7f383dedad46c8f770 Mon Sep 17 00:00:00 2001 From: Haozhe Xie Date: Sun, 20 Dec 2020 12:57:18 +0800 Subject: [PATCH 1/2] Add the implementation of the Gridding layer (arXiv 2006.03761). --- cuda/include/cuda_utils.h | 25 ++- cuda/include/gridding.h | 18 ++ cuda/src/bindings.cpp | 4 + cuda/src/chamfer_dist_gpu.cu | 2 +- cuda/src/gridding.cpp | 26 +++ cuda/src/gridding_gpu.cu | 329 +++++++++++++++++++++++++++++++++++ 6 files changed, 390 insertions(+), 14 deletions(-) create mode 100644 cuda/include/gridding.h create mode 100644 cuda/src/gridding.cpp create mode 100644 cuda/src/gridding_gpu.cu diff --git a/cuda/include/cuda_utils.h b/cuda/include/cuda_utils.h index 9867cf1..a8357b9 100644 --- a/cuda/include/cuda_utils.h +++ b/cuda/include/cuda_utils.h @@ -32,23 +32,22 @@ inline dim3 opt_block_config(int x, int y) // from https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 #else - __device__ double atomicAdd(double* address, double val) - { - unsigned long long int* address_as_ull = - (unsigned long long int*)address; - unsigned long long int old = *address_as_ull, assumed; +__device__ double atomicAdd(double* address, double val) +{ + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val + - __longlong_as_double(assumed))); + do + { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) - } while (assumed != old); + } while (assumed != old); - return __longlong_as_double(old); - } + return __longlong_as_double(old); +} #endif #define CUDA_CHECK_ERRORS() \ diff --git a/cuda/include/gridding.h b/cuda/include/gridding.h new file mode 100644 index 0000000..6e8d640 --- /dev/null +++ b/cuda/include/gridding.h @@ -0,0 +1,18 @@ +#include + +#include +#include + +std::vector gridding_kernel_warpper(float min_x, float max_x, float min_y, + float max_y, float min_z, float max_z, + torch::Tensor ptcloud, cudaStream_t stream); + +torch::Tensor gridding_grad_kernel_warpper(torch::Tensor grid_pt_weights, + torch::Tensor grid_pt_indexes, torch::Tensor grad_grid, + cudaStream_t stream); + +std::vector gridding(float min_x, float max_x, float min_y, float max_y, float min_z, + float max_z, torch::Tensor ptcloud); + +torch::Tensor gridding_grad(torch::Tensor grid_pt_weights, torch::Tensor grid_pt_indexes, + torch::Tensor grad_grid); diff --git a/cuda/src/bindings.cpp b/cuda/src/bindings.cpp index 5a14485..fe84353 100644 --- a/cuda/src/bindings.cpp +++ b/cuda/src/bindings.cpp @@ -1,6 +1,7 @@ #include "ball_query.h" #include "chamfer_dist.h" #include "cubic_feature_sampling.h" +#include "gridding.h" #include "interpolate.h" #include "metrics.h" #include "sampling.h" @@ -23,4 +24,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("cubic_feature_sampling", &cubic_feature_sampling); m.def("cubic_feature_sampling_grad", &cubic_feature_sampling_grad); + + m.def("gridding", &gridding); + m.def("gridding_grad", &gridding_grad); } diff --git a/cuda/src/chamfer_dist_gpu.cu b/cuda/src/chamfer_dist_gpu.cu index 5b7ad14..0b3eee3 100644 --- a/cuda/src/chamfer_dist_gpu.cu +++ b/cuda/src/chamfer_dist_gpu.cu @@ -2,8 +2,8 @@ #include #include -#include #include "cuda_utils.h" +#include template __global__ void chamfer_dist_kernel(int batch_size, int n, const scalar_t* __restrict__ xyz1, int m, diff --git a/cuda/src/gridding.cpp b/cuda/src/gridding.cpp new file mode 100644 index 0000000..d5c7014 --- /dev/null +++ b/cuda/src/gridding.cpp @@ -0,0 +1,26 @@ +#include "gridding.h" +#include "utils.h" + +std::vector gridding(float min_x, float max_x, float min_y, float max_y, float min_z, + float max_z, torch::Tensor ptcloud) +{ + CHECK_CUDA(ptcloud); + CHECK_CONTIGUOUS(ptcloud); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + return gridding_kernel_warpper(min_x, max_x, min_y, max_y, min_z, max_z, ptcloud, stream); +} + +torch::Tensor gridding_grad(torch::Tensor grid_pt_weights, torch::Tensor grid_pt_indexes, + torch::Tensor grad_grid) +{ + CHECK_CUDA(grid_pt_weights); + CHECK_CONTIGUOUS(grid_pt_weights); + CHECK_CUDA(grid_pt_indexes); + CHECK_CONTIGUOUS(grid_pt_indexes); + CHECK_CUDA(grad_grid); + CHECK_CONTIGUOUS(grad_grid); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + return gridding_grad_kernel_warpper(grid_pt_weights, grid_pt_indexes, grad_grid, stream); +} diff --git a/cuda/src/gridding_gpu.cu b/cuda/src/gridding_gpu.cu new file mode 100644 index 0000000..8b02ad5 --- /dev/null +++ b/cuda/src/gridding_gpu.cu @@ -0,0 +1,329 @@ +#include +#include +#include +#include + +#include "cuda_utils.h" + +#define CUDA_NUM_THREADS 512 + +// Computer the number of threads needed in GPU +inline int get_n_threads(int n) +{ + const int pow_2 = std::log(static_cast(n)) / std::log(2.0); + return max(min(1 << pow_2, CUDA_NUM_THREADS), 1); +} + +__device__ int compute_index(int offset_x, int offset_y, int offset_z, int len_y, int len_z) +{ + return offset_x * len_y * len_z + offset_y * len_z + offset_z; +} + +__device__ float compute_weight(float x, float x0) +{ + return 1 - abs(x - x0); +} + +template +__global__ void +gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float min_z, + int len_y, int len_z, const scalar_t* __restrict__ ptcloud, + scalar_t* __restrict__ grid_weights, scalar_t* __restrict__ grid_pt_weights, + int* __restrict__ grid_pt_indexes) +{ + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + + ptcloud += batch_index * n_pts * 3; + grid_weights += batch_index * n_grid_vertices; + grid_pt_weights += batch_index * n_pts * 24; + grid_pt_indexes += batch_index * n_pts * 8; + + for (int j = index; j < n_pts; j += stride) + { + scalar_t pt_x = ptcloud[j * 3 + 0]; + scalar_t pt_y = ptcloud[j * 3 + 1]; + scalar_t pt_z = ptcloud[j * 3 + 2]; + + int lower_x = std::floor(pt_x); + int upper_x = std::ceil(pt_x); + if (lower_x == upper_x) + { + upper_x += 1; + } + int lower_y = std::floor(pt_y); + int upper_y = std::ceil(pt_y); + if (lower_y == upper_y) + { + upper_y += 1; + } + int lower_z = std::floor(pt_z); + int upper_z = std::ceil(pt_z); + if (lower_z == upper_z) + { + upper_z += 1; + } + + int lx_offset = lower_x - min_x, ux_offset = upper_x - min_x; + int ly_offset = lower_y - min_y, uy_offset = upper_y - min_y; + int lz_offset = lower_z - min_z, uz_offset = upper_z - min_z; + + // Compute weights and corresponding positions, a loop for 8 points + // LLL -> Lower X, Lower Y, Lower Z + grid_pt_indexes[j * 8 + 0] = compute_index(lx_offset, ly_offset, lz_offset, len_y, len_z); + grid_pt_weights[j * 24 + 0] = compute_weight(pt_x, lower_x); + grid_pt_weights[j * 24 + 1] = compute_weight(pt_y, lower_y); + grid_pt_weights[j * 24 + 2] = compute_weight(pt_z, lower_z); + + // LLU -> Lower X, Lower Y, Upper Z + grid_pt_indexes[j * 8 + 1] = compute_index(lx_offset, ly_offset, uz_offset, len_y, len_z); + grid_pt_weights[j * 24 + 3] = compute_weight(pt_x, lower_x); + grid_pt_weights[j * 24 + 4] = compute_weight(pt_y, lower_y); + grid_pt_weights[j * 24 + 5] = compute_weight(pt_z, upper_z); + + // LUL -> Lower X, Upper Y, Lower Z + grid_pt_indexes[j * 8 + 2] = compute_index(lx_offset, uy_offset, lz_offset, len_y, len_z); + grid_pt_weights[j * 24 + 6] = compute_weight(pt_x, lower_x); + grid_pt_weights[j * 24 + 7] = compute_weight(pt_y, upper_y); + grid_pt_weights[j * 24 + 8] = compute_weight(pt_z, lower_z); + + // LUU -> Lower X, Upper Y, Upper Z + grid_pt_indexes[j * 8 + 3] = compute_index(lx_offset, uy_offset, uz_offset, len_y, len_z); + grid_pt_weights[j * 24 + 9] = compute_weight(pt_x, lower_x); + grid_pt_weights[j * 24 + 10] = compute_weight(pt_y, upper_y); + grid_pt_weights[j * 24 + 11] = compute_weight(pt_z, upper_z); + + // ULL -> Upper X, Lower Y, Lower Z + grid_pt_indexes[j * 8 + 4] = compute_index(ux_offset, ly_offset, lz_offset, len_y, len_z); + grid_pt_weights[j * 24 + 12] = compute_weight(pt_x, upper_x); + grid_pt_weights[j * 24 + 13] = compute_weight(pt_y, lower_y); + grid_pt_weights[j * 24 + 14] = compute_weight(pt_z, lower_z); + + // ULU -> Upper X, Lower Y, Upper Z + grid_pt_indexes[j * 8 + 5] = compute_index(ux_offset, ly_offset, uz_offset, len_y, len_z); + grid_pt_weights[j * 24 + 15] = compute_weight(pt_x, upper_x); + grid_pt_weights[j * 24 + 16] = compute_weight(pt_y, lower_y); + grid_pt_weights[j * 24 + 17] = compute_weight(pt_z, upper_z); + + // UUL -> Upper X, Upper Y, Lower Z + grid_pt_indexes[j * 8 + 6] = compute_index(ux_offset, uy_offset, lz_offset, len_y, len_z); + grid_pt_weights[j * 24 + 18] = compute_weight(pt_x, upper_x); + grid_pt_weights[j * 24 + 19] = compute_weight(pt_y, upper_y); + grid_pt_weights[j * 24 + 20] = compute_weight(pt_z, lower_z); + + // UUU -> Upper X, Upper Y, Upper Z + grid_pt_indexes[j * 8 + 7] = compute_index(ux_offset, uy_offset, uz_offset, len_y, len_z); + grid_pt_weights[j * 24 + 21] = compute_weight(pt_x, upper_x); + grid_pt_weights[j * 24 + 22] = compute_weight(pt_y, upper_y); + grid_pt_weights[j * 24 + 23] = compute_weight(pt_z, upper_z); + } + + __syncthreads(); + + int gvtx_idx = 0; + for (int j = index; j < n_pts; j += stride) + { + // LLL -> Lower X, Lower Y, Lower Z + gvtx_idx = grid_pt_indexes[j * 8 + 0]; + atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 0] * + grid_pt_weights[j * 24 + 1] * + grid_pt_weights[j * 24 + 2]); + // LLU -> Lower X, Lower Y, Upper Z + gvtx_idx = grid_pt_indexes[j * 8 + 1]; + atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 3] * + grid_pt_weights[j * 24 + 4] * + grid_pt_weights[j * 24 + 5]); + // LUL -> Lower X, Upper Y, Lower Z + gvtx_idx = grid_pt_indexes[j * 8 + 2]; + atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 6] * + grid_pt_weights[j * 24 + 7] * + grid_pt_weights[j * 24 + 8]); + // LUU -> Lower X, Upper Y, Upper Z + gvtx_idx = grid_pt_indexes[j * 8 + 3]; + atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 9] * + grid_pt_weights[j * 24 + 10] * + grid_pt_weights[j * 24 + 11]); + // ULL -> Upper X, Lower Y, Lower Z + gvtx_idx = grid_pt_indexes[j * 8 + 4]; + atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 12] * + grid_pt_weights[j * 24 + 13] * + grid_pt_weights[j * 24 + 14]); + // ULU -> Upper X, Lower Y, Upper Z + gvtx_idx = grid_pt_indexes[j * 8 + 5]; + atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 15] * + grid_pt_weights[j * 24 + 16] * + grid_pt_weights[j * 24 + 17]); + // UUL -> Upper X, Upper Y, Lower Z + gvtx_idx = grid_pt_indexes[j * 8 + 6]; + atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 18] * + grid_pt_weights[j * 24 + 19] * + grid_pt_weights[j * 24 + 20]); + // UUU -> Upper X, Upper Y, Upper Z + gvtx_idx = grid_pt_indexes[j * 8 + 7]; + atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 21] * + grid_pt_weights[j * 24 + 22] * + grid_pt_weights[j * 24 + 23]); + } +} + +std::vector gridding_kernel_warpper(float min_x, float max_x, float min_y, + float max_y, float min_z, float max_z, + torch::Tensor ptcloud, cudaStream_t stream) +{ + int batch_size = ptcloud.size(0); + int n_pts = ptcloud.size(1); + int len_x = max_x - min_x + 1; + int len_y = max_y - min_y + 1; + int len_z = max_z - min_z + 1; + int n_grid_vertices = len_x * len_y * len_z; + + torch::Tensor grid_weights = + torch::zeros({batch_size, n_grid_vertices}, torch::CUDA(torch::kFloat)); + torch::Tensor grid_pt_weights = + torch::zeros({batch_size, n_pts, 8, 3}, torch::CUDA(torch::kFloat)); + torch::Tensor grid_pt_indexes = torch::zeros({batch_size, n_pts, 8}, torch::CUDA(torch::kInt)); + + AT_DISPATCH_FLOATING_TYPES( + ptcloud.scalar_type(), "gridding_cuda", ([&] { + gridding_kernel<<>>( + n_grid_vertices, n_pts, min_x, min_y, min_z, len_y, len_z, + ptcloud.data_ptr(), grid_weights.data_ptr(), + grid_pt_weights.data_ptr(), grid_pt_indexes.data_ptr()); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("Error in gridding_kernel_warpper: %s\n", cudaGetErrorString(err)); + } + return {grid_weights, grid_pt_weights, grid_pt_indexes}; +} + +template +__global__ void +gridding_grad_kernel(int n_grid_vertices, int n_pts, const scalar_t* __restrict__ grid_pt_weights, + const int* __restrict__ grid_pt_indexes, + const scalar_t* __restrict__ grad_grid, scalar_t* __restrict__ grad_ptcloud) +{ + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + + grid_pt_weights += batch_index * n_pts * 24; + grid_pt_indexes += batch_index * n_pts * 8; + grad_grid += batch_index * n_grid_vertices; + grad_ptcloud += batch_index * n_pts * 3; + + int gvtx_idx = 0; + scalar_t grad_vtx = 0, x_weights = 0, y_weights = 0, z_weights = 0; + for (int j = index; j < n_pts; j += stride) + { + // Compute gradient for the corresponding positions, a loop for 8 points + // LLL -> Lower X, Lower Y, Lower Z + gvtx_idx = grid_pt_indexes[j * 8 + 0]; + grad_vtx = grad_grid[gvtx_idx]; + x_weights = grid_pt_weights[j * 24 + 0]; + y_weights = grid_pt_weights[j * 24 + 1]; + z_weights = grid_pt_weights[j * 24 + 2]; + atomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); + + // LLU -> Lower X, Lower Y, Upper Z + gvtx_idx = grid_pt_indexes[j * 8 + 1]; + grad_vtx = grad_grid[gvtx_idx]; + x_weights = grid_pt_weights[j * 24 + 3]; + y_weights = grid_pt_weights[j * 24 + 4]; + z_weights = grid_pt_weights[j * 24 + 5]; + atomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); + + // LUL -> Lower X, Upper Y, Lower Z + gvtx_idx = grid_pt_indexes[j * 8 + 2]; + grad_vtx = grad_grid[gvtx_idx]; + x_weights = grid_pt_weights[j * 24 + 6]; + y_weights = grid_pt_weights[j * 24 + 7]; + z_weights = grid_pt_weights[j * 24 + 8]; + atomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); + + // LUU -> Lower X, Upper Y, Upper Z + gvtx_idx = grid_pt_indexes[j * 8 + 3]; + grad_vtx = grad_grid[gvtx_idx]; + x_weights = grid_pt_weights[j * 24 + 9]; + y_weights = grid_pt_weights[j * 24 + 10]; + z_weights = grid_pt_weights[j * 24 + 11]; + atomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); + + // ULL -> Upper X, Lower Y, Lower Z + gvtx_idx = grid_pt_indexes[j * 8 + 4]; + grad_vtx = grad_grid[gvtx_idx]; + x_weights = grid_pt_weights[j * 24 + 12]; + y_weights = grid_pt_weights[j * 24 + 13]; + z_weights = grid_pt_weights[j * 24 + 14]; + atomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); + + // ULU -> Upper X, Lower Y, Upper Z + gvtx_idx = grid_pt_indexes[j * 8 + 5]; + grad_vtx = grad_grid[gvtx_idx]; + x_weights = grid_pt_weights[j * 24 + 15]; + y_weights = grid_pt_weights[j * 24 + 16]; + z_weights = grid_pt_weights[j * 24 + 17]; + atomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); + + // UUL -> Upper X, Upper Y, Lower Z + gvtx_idx = grid_pt_indexes[j * 8 + 6]; + grad_vtx = grad_grid[gvtx_idx]; + x_weights = grid_pt_weights[j * 24 + 18]; + y_weights = grid_pt_weights[j * 24 + 19]; + z_weights = grid_pt_weights[j * 24 + 20]; + atomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); + + // UUU -> Upper X, Upper Y, Upper Z + gvtx_idx = grid_pt_indexes[j * 8 + 7]; + grad_vtx = grad_grid[gvtx_idx]; + x_weights = grid_pt_weights[j * 24 + 21]; + y_weights = grid_pt_weights[j * 24 + 22]; + z_weights = grid_pt_weights[j * 24 + 23]; + atomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); + atomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); + } +} + +torch::Tensor gridding_grad_kernel_warpper(torch::Tensor grid_pt_weights, + torch::Tensor grid_pt_indexes, torch::Tensor grad_grid, + cudaStream_t stream) +{ + int batch_size = grad_grid.size(0); + int n_grid_vertices = grad_grid.size(1); + int n_pts = grid_pt_indexes.size(1); + + torch::Tensor grad_ptcloud = torch::zeros({batch_size, n_pts, 3}, torch::CUDA(torch::kFloat)); + + AT_DISPATCH_FLOATING_TYPES( + grid_pt_weights.scalar_type(), "gridding_grad_cuda", ([&] { + gridding_grad_kernel<<>>( + n_grid_vertices, n_pts, grid_pt_weights.data_ptr(), + grid_pt_indexes.data_ptr(), grad_grid.data_ptr(), + grad_ptcloud.data_ptr()); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("Error in gridding_grad_kernel_warpper: %s\n", cudaGetErrorString(err)); + } + return grad_ptcloud; +} From e2c315b478f76375f281b834031126b7430894ec Mon Sep 17 00:00:00 2001 From: Haozhe Xie Date: Tue, 22 Dec 2020 19:07:59 +0800 Subject: [PATCH 2/2] Create the unit test for Gridding. --- cuda/src/gridding_gpu.cu | 62 ++++++++++++++-------------- test/test_gridding.py | 30 ++++++++++++++ torch_points_kernels/__init__.py | 1 + torch_points_kernels/chamfer_dist.py | 11 ++--- torch_points_kernels/gridding.py | 60 +++++++++++++++++++++++++++ torch_points_kernels/torchpoints.py | 1 - 6 files changed, 126 insertions(+), 39 deletions(-) create mode 100644 test/test_gridding.py create mode 100644 torch_points_kernels/gridding.py diff --git a/cuda/src/gridding_gpu.cu b/cuda/src/gridding_gpu.cu index 8b02ad5..5fdc68b 100644 --- a/cuda/src/gridding_gpu.cu +++ b/cuda/src/gridding_gpu.cu @@ -19,15 +19,16 @@ __device__ int compute_index(int offset_x, int offset_y, int offset_z, int len_y return offset_x * len_y * len_z + offset_y * len_z + offset_z; } -__device__ float compute_weight(float x, float x0) +template +__device__ scalar_t compute_weight(scalar_t x, scalar_t x0) { return 1 - abs(x - x0); } template __global__ void -gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float min_z, - int len_y, int len_z, const scalar_t* __restrict__ ptcloud, +gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float min_z, int len_y, + int len_z, const scalar_t* __restrict__ ptcloud, scalar_t* __restrict__ grid_weights, scalar_t* __restrict__ grid_pt_weights, int* __restrict__ grid_pt_indexes) { @@ -72,51 +73,51 @@ gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float // Compute weights and corresponding positions, a loop for 8 points // LLL -> Lower X, Lower Y, Lower Z grid_pt_indexes[j * 8 + 0] = compute_index(lx_offset, ly_offset, lz_offset, len_y, len_z); - grid_pt_weights[j * 24 + 0] = compute_weight(pt_x, lower_x); - grid_pt_weights[j * 24 + 1] = compute_weight(pt_y, lower_y); - grid_pt_weights[j * 24 + 2] = compute_weight(pt_z, lower_z); + grid_pt_weights[j * 24 + 0] = compute_weight(pt_x, lower_x); + grid_pt_weights[j * 24 + 1] = compute_weight(pt_y, lower_y); + grid_pt_weights[j * 24 + 2] = compute_weight(pt_z, lower_z); // LLU -> Lower X, Lower Y, Upper Z grid_pt_indexes[j * 8 + 1] = compute_index(lx_offset, ly_offset, uz_offset, len_y, len_z); - grid_pt_weights[j * 24 + 3] = compute_weight(pt_x, lower_x); - grid_pt_weights[j * 24 + 4] = compute_weight(pt_y, lower_y); - grid_pt_weights[j * 24 + 5] = compute_weight(pt_z, upper_z); + grid_pt_weights[j * 24 + 3] = compute_weight(pt_x, lower_x); + grid_pt_weights[j * 24 + 4] = compute_weight(pt_y, lower_y); + grid_pt_weights[j * 24 + 5] = compute_weight(pt_z, upper_z); // LUL -> Lower X, Upper Y, Lower Z grid_pt_indexes[j * 8 + 2] = compute_index(lx_offset, uy_offset, lz_offset, len_y, len_z); - grid_pt_weights[j * 24 + 6] = compute_weight(pt_x, lower_x); - grid_pt_weights[j * 24 + 7] = compute_weight(pt_y, upper_y); - grid_pt_weights[j * 24 + 8] = compute_weight(pt_z, lower_z); + grid_pt_weights[j * 24 + 6] = compute_weight(pt_x, lower_x); + grid_pt_weights[j * 24 + 7] = compute_weight(pt_y, upper_y); + grid_pt_weights[j * 24 + 8] = compute_weight(pt_z, lower_z); // LUU -> Lower X, Upper Y, Upper Z grid_pt_indexes[j * 8 + 3] = compute_index(lx_offset, uy_offset, uz_offset, len_y, len_z); - grid_pt_weights[j * 24 + 9] = compute_weight(pt_x, lower_x); - grid_pt_weights[j * 24 + 10] = compute_weight(pt_y, upper_y); - grid_pt_weights[j * 24 + 11] = compute_weight(pt_z, upper_z); + grid_pt_weights[j * 24 + 9] = compute_weight(pt_x, lower_x); + grid_pt_weights[j * 24 + 10] = compute_weight(pt_y, upper_y); + grid_pt_weights[j * 24 + 11] = compute_weight(pt_z, upper_z); // ULL -> Upper X, Lower Y, Lower Z grid_pt_indexes[j * 8 + 4] = compute_index(ux_offset, ly_offset, lz_offset, len_y, len_z); - grid_pt_weights[j * 24 + 12] = compute_weight(pt_x, upper_x); - grid_pt_weights[j * 24 + 13] = compute_weight(pt_y, lower_y); - grid_pt_weights[j * 24 + 14] = compute_weight(pt_z, lower_z); + grid_pt_weights[j * 24 + 12] = compute_weight(pt_x, upper_x); + grid_pt_weights[j * 24 + 13] = compute_weight(pt_y, lower_y); + grid_pt_weights[j * 24 + 14] = compute_weight(pt_z, lower_z); // ULU -> Upper X, Lower Y, Upper Z grid_pt_indexes[j * 8 + 5] = compute_index(ux_offset, ly_offset, uz_offset, len_y, len_z); - grid_pt_weights[j * 24 + 15] = compute_weight(pt_x, upper_x); - grid_pt_weights[j * 24 + 16] = compute_weight(pt_y, lower_y); - grid_pt_weights[j * 24 + 17] = compute_weight(pt_z, upper_z); + grid_pt_weights[j * 24 + 15] = compute_weight(pt_x, upper_x); + grid_pt_weights[j * 24 + 16] = compute_weight(pt_y, lower_y); + grid_pt_weights[j * 24 + 17] = compute_weight(pt_z, upper_z); // UUL -> Upper X, Upper Y, Lower Z grid_pt_indexes[j * 8 + 6] = compute_index(ux_offset, uy_offset, lz_offset, len_y, len_z); - grid_pt_weights[j * 24 + 18] = compute_weight(pt_x, upper_x); - grid_pt_weights[j * 24 + 19] = compute_weight(pt_y, upper_y); - grid_pt_weights[j * 24 + 20] = compute_weight(pt_z, lower_z); + grid_pt_weights[j * 24 + 18] = compute_weight(pt_x, upper_x); + grid_pt_weights[j * 24 + 19] = compute_weight(pt_y, upper_y); + grid_pt_weights[j * 24 + 20] = compute_weight(pt_z, lower_z); // UUU -> Upper X, Upper Y, Upper Z grid_pt_indexes[j * 8 + 7] = compute_index(ux_offset, uy_offset, uz_offset, len_y, len_z); - grid_pt_weights[j * 24 + 21] = compute_weight(pt_x, upper_x); - grid_pt_weights[j * 24 + 22] = compute_weight(pt_y, upper_y); - grid_pt_weights[j * 24 + 23] = compute_weight(pt_z, upper_z); + grid_pt_weights[j * 24 + 21] = compute_weight(pt_x, upper_x); + grid_pt_weights[j * 24 + 22] = compute_weight(pt_y, upper_y); + grid_pt_weights[j * 24 + 23] = compute_weight(pt_z, upper_z); } __syncthreads(); @@ -179,9 +180,9 @@ std::vector gridding_kernel_warpper(float min_x, float max_x, flo int n_grid_vertices = len_x * len_y * len_z; torch::Tensor grid_weights = - torch::zeros({batch_size, n_grid_vertices}, torch::CUDA(torch::kFloat)); + torch::zeros({batch_size, n_grid_vertices}, torch::CUDA(ptcloud.scalar_type())); torch::Tensor grid_pt_weights = - torch::zeros({batch_size, n_pts, 8, 3}, torch::CUDA(torch::kFloat)); + torch::zeros({batch_size, n_pts, 8, 3}, torch::CUDA(ptcloud.scalar_type())); torch::Tensor grid_pt_indexes = torch::zeros({batch_size, n_pts, 8}, torch::CUDA(torch::kInt)); AT_DISPATCH_FLOATING_TYPES( @@ -310,7 +311,8 @@ torch::Tensor gridding_grad_kernel_warpper(torch::Tensor grid_pt_weights, int n_grid_vertices = grad_grid.size(1); int n_pts = grid_pt_indexes.size(1); - torch::Tensor grad_ptcloud = torch::zeros({batch_size, n_pts, 3}, torch::CUDA(torch::kFloat)); + torch::Tensor grad_ptcloud = + torch::zeros({batch_size, n_pts, 3}, torch::CUDA(grid_pt_weights.scalar_type())); AT_DISPATCH_FLOATING_TYPES( grid_pt_weights.scalar_type(), "gridding_grad_cuda", ([&] { diff --git a/test/test_gridding.py b/test/test_gridding.py new file mode 100644 index 0000000..4fe032b --- /dev/null +++ b/test/test_gridding.py @@ -0,0 +1,30 @@ +import numpy as np +import os +import sys +import torch +import unittest + +from torch.autograd import gradcheck + +from . import run_if_cuda + + +ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") +sys.path.insert(0, ROOT) + +from torch_points_kernels.gridding import GriddingFunction + + +class TestGridding(unittest.TestCase): + @run_if_cuda + def test_gridding_function_32pts(self): + x = torch.rand(1, 32, 3) + x.requires_grad = True + self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda(), 4])) + + @run_if_cuda + def test_gridding_function_64pts(self): + x = torch.rand(1, 64, 3) + x.requires_grad = True + self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda(), 8])) + diff --git a/torch_points_kernels/__init__.py b/torch_points_kernels/__init__.py index d93af03..5bea7f0 100644 --- a/torch_points_kernels/__init__.py +++ b/torch_points_kernels/__init__.py @@ -15,4 +15,5 @@ "instance_iou", "chamfer_dist", "cubic_feature_sampling", + "gridding", ] diff --git a/torch_points_kernels/chamfer_dist.py b/torch_points_kernels/chamfer_dist.py index f7596b1..528216d 100644 --- a/torch_points_kernels/chamfer_dist.py +++ b/torch_points_kernels/chamfer_dist.py @@ -8,9 +8,7 @@ class ChamferFunction(torch.autograd.Function): @staticmethod def forward(ctx, xyz1, xyz2): if not torch.cuda.is_available(): - raise NotImplementedError( - "CPU version is not available for Chamfer Distance" - ) + raise NotImplementedError("CPU version is not available for Chamfer Distance") dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2) ctx.save_for_backward(xyz1, xyz2, idx1, idx2) @@ -20,9 +18,7 @@ def forward(ctx, xyz1, xyz2): @staticmethod def backward(ctx, grad_dist1, grad_dist2): xyz1, xyz2, idx1, idx2 = ctx.saved_tensors - grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad( - xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2 - ) + grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2) return grad_xyz1, grad_xyz2 @@ -45,7 +41,7 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False): (B, ): the distances between B pairs of point clouds """ if len(xyz1.shape) != 3 or xyz1.size(2) != 3 or len(xyz2.shape) != 3 or xyz2.size(2) != 3: - raise ValueError('The input point cloud should be of size (B, n_pts, 3)') + raise ValueError("The input point cloud should be of size (B, n_pts, 3)") batch_size = xyz1.size(0) if batch_size == 1 and ignore_zeros: @@ -56,4 +52,3 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False): dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) return torch.mean(dist1) + torch.mean(dist2) - diff --git a/torch_points_kernels/gridding.py b/torch_points_kernels/gridding.py new file mode 100644 index 0000000..a2b187a --- /dev/null +++ b/torch_points_kernels/gridding.py @@ -0,0 +1,60 @@ +import torch + +if torch.cuda.is_available(): + import torch_points_kernels.points_cuda as tpcuda + + +class GriddingFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, ptcloud, scale): + if not torch.cuda.is_available(): + raise NotImplementedError("CPU version is not available for Chamfer Distance") + + grid, grid_pt_weights, grid_pt_indexes = tpcuda.gridding( + -scale, scale - 1, -scale, scale - 1, -scale, scale - 1, ptcloud + ) + # print(grid.size()) # torch.Size(batch_size, n_grid_vertices) + # print(grid_pt_weights.size()) # torch.Size(batch_size, n_pts, 8, 3) + # print(grid_pt_indexes.size()) # torch.Size(batch_size, n_pts, 8) + ctx.save_for_backward(grid_pt_weights, grid_pt_indexes) + + return grid + + @staticmethod + def backward(ctx, grad_grid): + grid_pt_weights, grid_pt_indexes = ctx.saved_tensors + grad_ptcloud = tpcuda.gridding_grad(grid_pt_weights, grid_pt_indexes, grad_grid) + # print(grad_ptcloud.size()) # torch.Size(batch_size, n_pts, 3) + + return grad_ptcloud, None + + +def gridding(ptcloud, scale): + r""" + Converts the input point clouds into 3D grids by trilinear interpolcation. + Please refer to https://arxiv.org/pdf/2006.03761 for more information + + Parameters + ---------- + ptcloud : torch.Tensor (dtype=torch.float32) + (B, n_pts, 3) B point clouds containing n_pts points + scale : Int + the resolution of the 3D grid + + Returns + ------- + grid: torch.Tensor + (B, scale, scale, scale): the grid of the resolution of scale * scale * scale + """ + if len(ptcloud.shape) != 3 or ptcloud.size(2) != 3: + raise ValueError("The input point cloud should be of size (B, n_pts, 3)") + + ptcloud = ptcloud * scale + _ptcloud = torch.split(ptcloud, 1, dim=0) + grids = [] + for p in _ptcloud: + non_zeros = torch.sum(p, dim=2).ne(0) + p = p[non_zeros].unsqueeze(dim=0) + grids.append(GriddingFunction.apply(p, scale)) + + return torch.cat(grids, dim=0).contiguous() diff --git a/torch_points_kernels/torchpoints.py b/torch_points_kernels/torchpoints.py index 73eba1e..79b1994 100644 --- a/torch_points_kernels/torchpoints.py +++ b/torch_points_kernels/torchpoints.py @@ -216,4 +216,3 @@ def ball_query( return ball_query_dense(radius, nsample, x, y, sort=sort) else: raise Exception("unrecognized mode {}".format(mode)) -