-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
126 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import numpy as np | ||
import os | ||
import sys | ||
import torch | ||
import unittest | ||
|
||
from torch.autograd import gradcheck | ||
|
||
from . import run_if_cuda | ||
|
||
|
||
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") | ||
sys.path.insert(0, ROOT) | ||
|
||
from torch_points_kernels.gridding import GriddingFunction | ||
|
||
|
||
class TestGridding(unittest.TestCase): | ||
@run_if_cuda | ||
def test_gridding_function_32pts(self): | ||
x = torch.rand(1, 32, 3) | ||
x.requires_grad = True | ||
self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda(), 4])) | ||
|
||
@run_if_cuda | ||
def test_gridding_function_64pts(self): | ||
x = torch.rand(1, 64, 3) | ||
x.requires_grad = True | ||
self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda(), 8])) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,5 @@ | |
"instance_iou", | ||
"chamfer_dist", | ||
"cubic_feature_sampling", | ||
"gridding", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
|
||
if torch.cuda.is_available(): | ||
import torch_points_kernels.points_cuda as tpcuda | ||
|
||
|
||
class GriddingFunction(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, ptcloud, scale): | ||
if not torch.cuda.is_available(): | ||
raise NotImplementedError("CPU version is not available for Chamfer Distance") | ||
|
||
grid, grid_pt_weights, grid_pt_indexes = tpcuda.gridding( | ||
-scale, scale - 1, -scale, scale - 1, -scale, scale - 1, ptcloud | ||
) | ||
# print(grid.size()) # torch.Size(batch_size, n_grid_vertices) | ||
# print(grid_pt_weights.size()) # torch.Size(batch_size, n_pts, 8, 3) | ||
# print(grid_pt_indexes.size()) # torch.Size(batch_size, n_pts, 8) | ||
ctx.save_for_backward(grid_pt_weights, grid_pt_indexes) | ||
|
||
return grid | ||
|
||
@staticmethod | ||
def backward(ctx, grad_grid): | ||
grid_pt_weights, grid_pt_indexes = ctx.saved_tensors | ||
grad_ptcloud = tpcuda.gridding_grad(grid_pt_weights, grid_pt_indexes, grad_grid) | ||
# print(grad_ptcloud.size()) # torch.Size(batch_size, n_pts, 3) | ||
|
||
return grad_ptcloud, None | ||
|
||
|
||
def gridding(ptcloud, scale): | ||
r""" | ||
Converts the input point clouds into 3D grids by trilinear interpolcation. | ||
Please refer to https://arxiv.org/pdf/2006.03761 for more information | ||
Parameters | ||
---------- | ||
ptcloud : torch.Tensor (dtype=torch.float32) | ||
(B, n_pts, 3) B point clouds containing n_pts points | ||
scale : Int | ||
the resolution of the 3D grid | ||
Returns | ||
------- | ||
grid: torch.Tensor | ||
(B, scale, scale, scale): the grid of the resolution of scale * scale * scale | ||
""" | ||
if len(ptcloud.shape) != 3 or ptcloud.size(2) != 3: | ||
raise ValueError("The input point cloud should be of size (B, n_pts, 3)") | ||
|
||
ptcloud = ptcloud * scale | ||
_ptcloud = torch.split(ptcloud, 1, dim=0) | ||
grids = [] | ||
for p in _ptcloud: | ||
non_zeros = torch.sum(p, dim=2).ne(0) | ||
p = p[non_zeros].unsqueeze(dim=0) | ||
grids.append(GriddingFunction.apply(p, scale)) | ||
|
||
return torch.cat(grids, dim=0).contiguous() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters