Skip to content

Commit

Permalink
Create the unit test for Gridding.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzxie committed Dec 22, 2020
1 parent 8488978 commit e2c315b
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 39 deletions.
62 changes: 32 additions & 30 deletions cuda/src/gridding_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ __device__ int compute_index(int offset_x, int offset_y, int offset_z, int len_y
return offset_x * len_y * len_z + offset_y * len_z + offset_z;
}

__device__ float compute_weight(float x, float x0)
template <typename scalar_t>
__device__ scalar_t compute_weight(scalar_t x, scalar_t x0)
{
return 1 - abs(x - x0);
}

template <typename scalar_t>
__global__ void
gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float min_z,
int len_y, int len_z, const scalar_t* __restrict__ ptcloud,
gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float min_z, int len_y,
int len_z, const scalar_t* __restrict__ ptcloud,
scalar_t* __restrict__ grid_weights, scalar_t* __restrict__ grid_pt_weights,
int* __restrict__ grid_pt_indexes)
{
Expand Down Expand Up @@ -72,51 +73,51 @@ gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float
// Compute weights and corresponding positions, a loop for 8 points
// LLL -> Lower X, Lower Y, Lower Z
grid_pt_indexes[j * 8 + 0] = compute_index(lx_offset, ly_offset, lz_offset, len_y, len_z);
grid_pt_weights[j * 24 + 0] = compute_weight(pt_x, lower_x);
grid_pt_weights[j * 24 + 1] = compute_weight(pt_y, lower_y);
grid_pt_weights[j * 24 + 2] = compute_weight(pt_z, lower_z);
grid_pt_weights[j * 24 + 0] = compute_weight<scalar_t>(pt_x, lower_x);
grid_pt_weights[j * 24 + 1] = compute_weight<scalar_t>(pt_y, lower_y);
grid_pt_weights[j * 24 + 2] = compute_weight<scalar_t>(pt_z, lower_z);

// LLU -> Lower X, Lower Y, Upper Z
grid_pt_indexes[j * 8 + 1] = compute_index(lx_offset, ly_offset, uz_offset, len_y, len_z);
grid_pt_weights[j * 24 + 3] = compute_weight(pt_x, lower_x);
grid_pt_weights[j * 24 + 4] = compute_weight(pt_y, lower_y);
grid_pt_weights[j * 24 + 5] = compute_weight(pt_z, upper_z);
grid_pt_weights[j * 24 + 3] = compute_weight<scalar_t>(pt_x, lower_x);
grid_pt_weights[j * 24 + 4] = compute_weight<scalar_t>(pt_y, lower_y);
grid_pt_weights[j * 24 + 5] = compute_weight<scalar_t>(pt_z, upper_z);

// LUL -> Lower X, Upper Y, Lower Z
grid_pt_indexes[j * 8 + 2] = compute_index(lx_offset, uy_offset, lz_offset, len_y, len_z);
grid_pt_weights[j * 24 + 6] = compute_weight(pt_x, lower_x);
grid_pt_weights[j * 24 + 7] = compute_weight(pt_y, upper_y);
grid_pt_weights[j * 24 + 8] = compute_weight(pt_z, lower_z);
grid_pt_weights[j * 24 + 6] = compute_weight<scalar_t>(pt_x, lower_x);
grid_pt_weights[j * 24 + 7] = compute_weight<scalar_t>(pt_y, upper_y);
grid_pt_weights[j * 24 + 8] = compute_weight<scalar_t>(pt_z, lower_z);

// LUU -> Lower X, Upper Y, Upper Z
grid_pt_indexes[j * 8 + 3] = compute_index(lx_offset, uy_offset, uz_offset, len_y, len_z);
grid_pt_weights[j * 24 + 9] = compute_weight(pt_x, lower_x);
grid_pt_weights[j * 24 + 10] = compute_weight(pt_y, upper_y);
grid_pt_weights[j * 24 + 11] = compute_weight(pt_z, upper_z);
grid_pt_weights[j * 24 + 9] = compute_weight<scalar_t>(pt_x, lower_x);
grid_pt_weights[j * 24 + 10] = compute_weight<scalar_t>(pt_y, upper_y);
grid_pt_weights[j * 24 + 11] = compute_weight<scalar_t>(pt_z, upper_z);

// ULL -> Upper X, Lower Y, Lower Z
grid_pt_indexes[j * 8 + 4] = compute_index(ux_offset, ly_offset, lz_offset, len_y, len_z);
grid_pt_weights[j * 24 + 12] = compute_weight(pt_x, upper_x);
grid_pt_weights[j * 24 + 13] = compute_weight(pt_y, lower_y);
grid_pt_weights[j * 24 + 14] = compute_weight(pt_z, lower_z);
grid_pt_weights[j * 24 + 12] = compute_weight<scalar_t>(pt_x, upper_x);
grid_pt_weights[j * 24 + 13] = compute_weight<scalar_t>(pt_y, lower_y);
grid_pt_weights[j * 24 + 14] = compute_weight<scalar_t>(pt_z, lower_z);

// ULU -> Upper X, Lower Y, Upper Z
grid_pt_indexes[j * 8 + 5] = compute_index(ux_offset, ly_offset, uz_offset, len_y, len_z);
grid_pt_weights[j * 24 + 15] = compute_weight(pt_x, upper_x);
grid_pt_weights[j * 24 + 16] = compute_weight(pt_y, lower_y);
grid_pt_weights[j * 24 + 17] = compute_weight(pt_z, upper_z);
grid_pt_weights[j * 24 + 15] = compute_weight<scalar_t>(pt_x, upper_x);
grid_pt_weights[j * 24 + 16] = compute_weight<scalar_t>(pt_y, lower_y);
grid_pt_weights[j * 24 + 17] = compute_weight<scalar_t>(pt_z, upper_z);

// UUL -> Upper X, Upper Y, Lower Z
grid_pt_indexes[j * 8 + 6] = compute_index(ux_offset, uy_offset, lz_offset, len_y, len_z);
grid_pt_weights[j * 24 + 18] = compute_weight(pt_x, upper_x);
grid_pt_weights[j * 24 + 19] = compute_weight(pt_y, upper_y);
grid_pt_weights[j * 24 + 20] = compute_weight(pt_z, lower_z);
grid_pt_weights[j * 24 + 18] = compute_weight<scalar_t>(pt_x, upper_x);
grid_pt_weights[j * 24 + 19] = compute_weight<scalar_t>(pt_y, upper_y);
grid_pt_weights[j * 24 + 20] = compute_weight<scalar_t>(pt_z, lower_z);

// UUU -> Upper X, Upper Y, Upper Z
grid_pt_indexes[j * 8 + 7] = compute_index(ux_offset, uy_offset, uz_offset, len_y, len_z);
grid_pt_weights[j * 24 + 21] = compute_weight(pt_x, upper_x);
grid_pt_weights[j * 24 + 22] = compute_weight(pt_y, upper_y);
grid_pt_weights[j * 24 + 23] = compute_weight(pt_z, upper_z);
grid_pt_weights[j * 24 + 21] = compute_weight<scalar_t>(pt_x, upper_x);
grid_pt_weights[j * 24 + 22] = compute_weight<scalar_t>(pt_y, upper_y);
grid_pt_weights[j * 24 + 23] = compute_weight<scalar_t>(pt_z, upper_z);
}

__syncthreads();
Expand Down Expand Up @@ -179,9 +180,9 @@ std::vector<torch::Tensor> gridding_kernel_warpper(float min_x, float max_x, flo
int n_grid_vertices = len_x * len_y * len_z;

torch::Tensor grid_weights =
torch::zeros({batch_size, n_grid_vertices}, torch::CUDA(torch::kFloat));
torch::zeros({batch_size, n_grid_vertices}, torch::CUDA(ptcloud.scalar_type()));
torch::Tensor grid_pt_weights =
torch::zeros({batch_size, n_pts, 8, 3}, torch::CUDA(torch::kFloat));
torch::zeros({batch_size, n_pts, 8, 3}, torch::CUDA(ptcloud.scalar_type()));
torch::Tensor grid_pt_indexes = torch::zeros({batch_size, n_pts, 8}, torch::CUDA(torch::kInt));

AT_DISPATCH_FLOATING_TYPES(
Expand Down Expand Up @@ -310,7 +311,8 @@ torch::Tensor gridding_grad_kernel_warpper(torch::Tensor grid_pt_weights,
int n_grid_vertices = grad_grid.size(1);
int n_pts = grid_pt_indexes.size(1);

torch::Tensor grad_ptcloud = torch::zeros({batch_size, n_pts, 3}, torch::CUDA(torch::kFloat));
torch::Tensor grad_ptcloud =
torch::zeros({batch_size, n_pts, 3}, torch::CUDA(grid_pt_weights.scalar_type()));

AT_DISPATCH_FLOATING_TYPES(
grid_pt_weights.scalar_type(), "gridding_grad_cuda", ([&] {
Expand Down
30 changes: 30 additions & 0 deletions test/test_gridding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import os
import sys
import torch
import unittest

from torch.autograd import gradcheck

from . import run_if_cuda


ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
sys.path.insert(0, ROOT)

from torch_points_kernels.gridding import GriddingFunction


class TestGridding(unittest.TestCase):
@run_if_cuda
def test_gridding_function_32pts(self):
x = torch.rand(1, 32, 3)
x.requires_grad = True
self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda(), 4]))

@run_if_cuda
def test_gridding_function_64pts(self):
x = torch.rand(1, 64, 3)
x.requires_grad = True
self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda(), 8]))

1 change: 1 addition & 0 deletions torch_points_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
"instance_iou",
"chamfer_dist",
"cubic_feature_sampling",
"gridding",
]
11 changes: 3 additions & 8 deletions torch_points_kernels/chamfer_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ class ChamferFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
if not torch.cuda.is_available():
raise NotImplementedError(
"CPU version is not available for Chamfer Distance"
)
raise NotImplementedError("CPU version is not available for Chamfer Distance")

dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
Expand All @@ -20,9 +18,7 @@ def forward(ctx, xyz1, xyz2):
@staticmethod
def backward(ctx, grad_dist1, grad_dist2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(
xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2
)
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
return grad_xyz1, grad_xyz2


Expand All @@ -45,7 +41,7 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
(B, ): the distances between B pairs of point clouds
"""
if len(xyz1.shape) != 3 or xyz1.size(2) != 3 or len(xyz2.shape) != 3 or xyz2.size(2) != 3:
raise ValueError('The input point cloud should be of size (B, n_pts, 3)')
raise ValueError("The input point cloud should be of size (B, n_pts, 3)")

batch_size = xyz1.size(0)
if batch_size == 1 and ignore_zeros:
Expand All @@ -56,4 +52,3 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False):

dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
return torch.mean(dist1) + torch.mean(dist2)

60 changes: 60 additions & 0 deletions torch_points_kernels/gridding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch

if torch.cuda.is_available():
import torch_points_kernels.points_cuda as tpcuda


class GriddingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, ptcloud, scale):
if not torch.cuda.is_available():
raise NotImplementedError("CPU version is not available for Chamfer Distance")

grid, grid_pt_weights, grid_pt_indexes = tpcuda.gridding(
-scale, scale - 1, -scale, scale - 1, -scale, scale - 1, ptcloud
)
# print(grid.size()) # torch.Size(batch_size, n_grid_vertices)
# print(grid_pt_weights.size()) # torch.Size(batch_size, n_pts, 8, 3)
# print(grid_pt_indexes.size()) # torch.Size(batch_size, n_pts, 8)
ctx.save_for_backward(grid_pt_weights, grid_pt_indexes)

return grid

@staticmethod
def backward(ctx, grad_grid):
grid_pt_weights, grid_pt_indexes = ctx.saved_tensors
grad_ptcloud = tpcuda.gridding_grad(grid_pt_weights, grid_pt_indexes, grad_grid)
# print(grad_ptcloud.size()) # torch.Size(batch_size, n_pts, 3)

return grad_ptcloud, None


def gridding(ptcloud, scale):
r"""
Converts the input point clouds into 3D grids by trilinear interpolcation.
Please refer to https://arxiv.org/pdf/2006.03761 for more information
Parameters
----------
ptcloud : torch.Tensor (dtype=torch.float32)
(B, n_pts, 3) B point clouds containing n_pts points
scale : Int
the resolution of the 3D grid
Returns
-------
grid: torch.Tensor
(B, scale, scale, scale): the grid of the resolution of scale * scale * scale
"""
if len(ptcloud.shape) != 3 or ptcloud.size(2) != 3:
raise ValueError("The input point cloud should be of size (B, n_pts, 3)")

ptcloud = ptcloud * scale
_ptcloud = torch.split(ptcloud, 1, dim=0)
grids = []
for p in _ptcloud:
non_zeros = torch.sum(p, dim=2).ne(0)
p = p[non_zeros].unsqueeze(dim=0)
grids.append(GriddingFunction.apply(p, scale))

return torch.cat(grids, dim=0).contiguous()
1 change: 0 additions & 1 deletion torch_points_kernels/torchpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,3 @@ def ball_query(
return ball_query_dense(radius, nsample, x, y, sort=sort)
else:
raise Exception("unrecognized mode {}".format(mode))

0 comments on commit e2c315b

Please sign in to comment.