diff --git a/cpu/include/ball_query.h b/cpu/include/ball_query.h index 87451b4..d7f6ee9 100644 --- a/cpu/include/ball_query.h +++ b/cpu/include/ball_query.h @@ -1,11 +1,13 @@ #pragma once #include std::pair ball_query(at::Tensor query, at::Tensor support, float radius, - int max_num, int mode); + int max_num, int mode, bool sorted); std::pair batch_ball_query(at::Tensor query, at::Tensor support, at::Tensor query_batch, at::Tensor support_batch, - float radius, int max_num, int mode); + float radius, int max_num, int mode, + bool sorted); std::pair dense_ball_query(at::Tensor query, at::Tensor support, - float radius, int max_num, int mode); + float radius, int max_num, int mode, + bool sorted); diff --git a/cpu/include/neighbors.h b/cpu/include/neighbors.h index 1654f77..ce4d27d 100644 --- a/cpu/include/neighbors.h +++ b/cpu/include/neighbors.h @@ -10,13 +10,13 @@ using namespace std; template int nanoflann_neighbors(vector& queries, vector& supports, vector& neighbors_indices, vector& dists, float radius, - int max_num, int mode); + int max_num, int mode, bool sorted); template int batch_nanoflann_neighbors(vector& queries, vector& supports, vector& q_batches, vector& s_batches, vector& neighbors_indices, vector& dists, float radius, - int max_num, int mode); + int max_num, int mode, bool sorted); template void nanoflann_knn_neighbors(vector& queries, vector& supports, diff --git a/cpu/src/ball_query.cpp b/cpu/src/ball_query.cpp index ddd9140..954b67c 100644 --- a/cpu/src/ball_query.cpp +++ b/cpu/src/ball_query.cpp @@ -8,7 +8,7 @@ #include std::pair ball_query(at::Tensor support, at::Tensor query, float radius, - int max_num, int mode) + int max_num, int mode, bool sorted) { CHECK_CONTIGUOUS(support); CHECK_CONTIGUOUS(query); @@ -31,7 +31,7 @@ std::pair ball_query(at::Tensor support, at::Tensor quer std::vector(data_s, data_s + support.size(0) * support.size(1)); max_count = nanoflann_neighbors(queries_stl, supports_stl, neighbors_indices, - neighbors_dists, radius, max_num, mode); + neighbors_dists, radius, max_num, mode, sorted); }); auto neighbors_dists_ptr = neighbors_dists.data(); long* neighbors_indices_ptr = neighbors_indices.data(); @@ -62,7 +62,7 @@ at::Tensor degree(at::Tensor row, int64_t num_nodes) std::pair batch_ball_query(at::Tensor support, at::Tensor query, at::Tensor support_batch, at::Tensor query_batch, - float radius, int max_num, int mode) + float radius, int max_num, int mode, bool sorted) { CHECK_CONTIGUOUS(support); CHECK_CONTIGUOUS(query); @@ -97,9 +97,9 @@ std::pair batch_ball_query(at::Tensor support, at::Tenso std::vector supports_stl(support.DATA_PTR(), support.DATA_PTR() + support.numel()); - max_count = batch_nanoflann_neighbors(queries_stl, supports_stl, query_batch_stl, - support_batch_stl, neighbors_indices, - neighbors_dists, radius, max_num, mode); + max_count = batch_nanoflann_neighbors( + queries_stl, supports_stl, query_batch_stl, support_batch_stl, neighbors_indices, + neighbors_dists, radius, max_num, mode, sorted); }); auto neighbors_dists_ptr = neighbors_dists.data(); long* neighbors_indices_ptr = neighbors_indices.data(); @@ -122,7 +122,7 @@ std::pair batch_ball_query(at::Tensor support, at::Tenso } std::pair dense_ball_query(at::Tensor support, at::Tensor query, - float radius, int max_num, int mode) + float radius, int max_num, int mode, bool sorted) { CHECK_CONTIGUOUS(support); CHECK_CONTIGUOUS(query); @@ -132,7 +132,7 @@ std::pair dense_ball_query(at::Tensor support, at::Tenso vector batch_dist; for (int i = 0; i < b; i++) { - auto out_pair = ball_query(query[i], support[i], radius, max_num, mode); + auto out_pair = ball_query(query[i], support[i], radius, max_num, mode, sorted); batch_idx.push_back(out_pair.first); batch_dist.push_back(out_pair.second); } diff --git a/cpu/src/bindings.cpp b/cpu/src/bindings.cpp index af05529..042b271 100644 --- a/cpu/src/bindings.cpp +++ b/cpu/src/bindings.cpp @@ -28,7 +28,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "maximum number of neighbors found if mode = 0, if mode=1 return a " "tensor of size Num_edge x 2 and return a tensor containing the " "squared distance of the neighbors", - "support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0); + "support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0, "sorted"_a = false); m.def("batch_ball_query", &batch_ball_query, "compute the radius search of a point cloud for each batch using " @@ -53,7 +53,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "tensor of size Num_edge x 2 and return a tensor containing the " "squared distance of the neighbors", "support"_a, "querry"_a, "query_batch"_a, "support_batch"_a, "radius"_a, "max_num"_a = -1, - "mode"_a = 0); + "mode"_a = 0, "sorted"_a = false); m.def("dense_ball_query", &dense_ball_query, "compute the radius search of a batch of point cloud using nanoflann" "- support : a pytorch tensor of size B x N1 x 3, points where the " @@ -69,5 +69,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "maximum number of neighbors found if mode = 0, if mode=1 return a " "tensor of size Num_edge x 2 and return a tensor containing the " "squared distance of the neighbors", - "support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0); + "support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0, "sorted"_a = false); } diff --git a/cpu/src/neighbors.cpp b/cpu/src/neighbors.cpp index 24d3ea4..9d35f44 100644 --- a/cpu/src/neighbors.cpp +++ b/cpu/src/neighbors.cpp @@ -2,17 +2,19 @@ // Taken from https://github.com/HuguesTHOMAS/KPConv #include "neighbors.h" +#include template int nanoflann_neighbors(vector& queries, vector& supports, vector& neighbors_indices, vector& dists, float radius, - int max_num, int mode) + int max_num, int mode, bool sorted) { // Initiate variables // ****************** + std::random_device rd; + std::mt19937 g(rd()); // square radius - const float search_radius = static_cast(radius * radius); // indices @@ -47,7 +49,7 @@ int nanoflann_neighbors(vector& queries, vector& supports, // Search params nanoflann::SearchParams search_params; - search_params.sorted = true; + search_params.sorted = sorted; std::vector>> list_matches(pcd_query.pts.size()); for (auto& p0 : pcd_query.pts) @@ -62,7 +64,11 @@ int nanoflann_neighbors(vector& queries, vector& supports, if (nMatches == 0) list_matches[i0] = {std::make_pair(0, -1)}; else + { + if (!sorted) + std::shuffle(ret_matches.begin(), ret_matches.end(), g); list_matches[i0] = ret_matches; + } max_count = max(max_count, nMatches); i0++; } @@ -132,10 +138,13 @@ template int batch_nanoflann_neighbors(vector& queries, vector& supports, vector& q_batches, vector& s_batches, vector& neighbors_indices, vector& dists, float radius, - int max_num, int mode) + int max_num, int mode, bool sorted) { // Initiate variables // ****************** + std::random_device rd; + std::mt19937 g(rd()); + // indices int i0 = 0; @@ -173,7 +182,7 @@ int batch_nanoflann_neighbors(vector& queries, vector& suppo // *********************** // Search params nanoflann::SearchParams search_params; - search_params.sorted = true; + search_params.sorted = sorted; for (auto& p0 : query_pcd.pts) { // Check if we changed batch @@ -192,16 +201,18 @@ int batch_nanoflann_neighbors(vector& queries, vector& suppo index->buildIndex(); } - // Initial guess of neighbors size - - all_inds_dists[i0].reserve(max_count); - // Find neighbors - // std::cerr << p0.x << p0.y << p0.z<> ret_matches; + ret_matches.reserve(max_count); scalar_t query_pt[3] = {p0.x, p0.y, p0.z}; + size_t nMatches = index->radiusSearch(query_pt, r2, ret_matches, search_params); - size_t nMatches = index->radiusSearch(query_pt, r2, all_inds_dists[i0], search_params); - // Update max count + // Shuffle if needed + if (!sorted) + std::shuffle(ret_matches.begin(), ret_matches.end(), g); + all_inds_dists[i0] = ret_matches; + // Update max count if (nMatches > (size_t)max_count) max_count = nMatches; // Increment query idx diff --git a/cuda/src/ball_query.cpp b/cuda/src/ball_query.cpp index 682426e..f4be045 100644 --- a/cuda/src/ball_query.cpp +++ b/cuda/src/ball_query.cpp @@ -3,7 +3,7 @@ #include "utils.h" void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample, - const float* new_xyz, const float* xyz, int* idx, + const float* new_xyz, const float* xyz, long* idx, float* dist_out); void query_ball_point_kernel_partial_wrapper(long batch_size, int size_x, int size_y, float radius, @@ -25,7 +25,7 @@ std::pair ball_query_dense(at::Tensor new_xyz, at::Tenso } at::Tensor idx = torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, - at::device(new_xyz.device()).dtype(at::ScalarType::Int)); + at::device(new_xyz.device()).dtype(at::ScalarType::Long)); at::Tensor dist = torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1, at::device(new_xyz.device()).dtype(at::ScalarType::Float)); @@ -33,7 +33,7 @@ std::pair ball_query_dense(at::Tensor new_xyz, at::Tenso { query_ball_point_kernel_dense_wrapper( xyz.size(0), xyz.size(1), new_xyz.size(1), radius, nsample, new_xyz.DATA_PTR(), - xyz.DATA_PTR(), idx.DATA_PTR(), dist.DATA_PTR()); + xyz.DATA_PTR(), idx.DATA_PTR(), dist.DATA_PTR()); } else { diff --git a/cuda/src/ball_query_gpu.cu b/cuda/src/ball_query_gpu.cu index 92ae884..6a98f60 100644 --- a/cuda/src/ball_query_gpu.cu +++ b/cuda/src/ball_query_gpu.cu @@ -9,7 +9,7 @@ __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius, int nsample, const float* __restrict__ new_xyz, const float* __restrict__ xyz, - int* __restrict__ idx_out, + long* __restrict__ idx_out, float* __restrict__ dist_out) { int batch_index = blockIdx.x; @@ -93,7 +93,7 @@ __global__ void query_ball_point_kernel_partial_dense( } void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample, - const float* new_xyz, const float* xyz, int* idx,float* dist_out) + const float* new_xyz, const float* xyz, long* idx,float* dist_out) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); query_ball_point_kernel_dense<<>>(b, n, m, radius, nsample, diff --git a/setup.py b/setup.py index 0c66ca0..03fbc4c 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ setup( name="torch_points", - version="0.4.0", + version="0.4.1", author="Nicolas Chaulet", packages=find_packages(), install_requires=requirements, diff --git a/test/test_ballquerry.py b/test/test_ballquerry.py index c1c2c0a..9337e84 100644 --- a/test/test_ballquerry.py +++ b/test/test_ballquerry.py @@ -14,24 +14,24 @@ def test_simple_gpu(self): a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]], [[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float).cuda() b = torch.tensor([[[0, 0, 0]], [[3, 0, 0]]]).to(torch.float).cuda() idx, dist = ball_query(1.01, 2, a, b) - torch.testing.assert_allclose(idx.long().cpu(), torch.tensor([[[0, 1]], [[2, 2]]])) + torch.testing.assert_allclose(idx.cpu(), torch.tensor([[[0, 1]], [[2, 2]]])) torch.testing.assert_allclose(dist.cpu(), torch.tensor([[[0, 1]], [[1, -1]]]).float()) def test_simple_cpu(self): a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]], [[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float) b = torch.tensor([[[0, 0, 0]], [[3, 0, 0]]]).to(torch.float) - idx, dist = ball_query(1.01, 2, a, b) - torch.testing.assert_allclose(idx.long(), torch.tensor([[[0, 1]], [[2, 2]]])) + idx, dist = ball_query(1.01, 2, a, b, sort=True) + torch.testing.assert_allclose(idx, torch.tensor([[[0, 1]], [[2, 2]]])) torch.testing.assert_allclose(dist, torch.tensor([[[0, 1]], [[1, -1]]]).float()) a = torch.tensor([[[0, 0, 0], [1, 0, 0], [1, 1, 0]]]).to(torch.float) - idx, dist = ball_query(1.01, 3, a, a) - torch.testing.assert_allclose(idx.long(),torch.tensor([[[0, 1, 0],[1,0,2],[2,1,2]]])) + idx, dist = ball_query(1.01, 3, a, a, sort=True) + torch.testing.assert_allclose(idx, torch.tensor([[[0, 1, 0], [1, 0, 2], [2, 1, 2]]])) @run_if_cuda def test_larger_gpu(self): a = torch.randn(32, 4096, 3).to(torch.float).cuda() - idx,dist = ball_query(1, 64, a, a) + idx, dist = ball_query(1, 64, a, a) self.assertGreaterEqual(idx.min(), 0) @run_if_cuda @@ -70,7 +70,7 @@ def test_simple_gpu(self): dist2 = dist2.detach().cpu().numpy() idx_answer = np.asarray([[1, -1]]) - dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32) + dist2_answer = np.asarray([[0.100, -1.0000]]).astype(np.float32) npt.assert_array_almost_equal(idx, idx_answer) npt.assert_array_almost_equal(dist2, dist2_answer) @@ -88,7 +88,7 @@ def test_simple_cpu(self): dist2 = dist2.detach().cpu().numpy() idx_answer = np.asarray([[1, -1]]) - dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32) + dist2_answer = np.asarray([[0.100, -1.0000]]).astype(np.float32) npt.assert_array_almost_equal(idx, idx_answer) npt.assert_array_almost_equal(dist2, dist2_answer) @@ -100,9 +100,13 @@ def test_random_cpu(self): batch_b = torch.tensor([0 for i in range(b.shape[0] // 2)] + [1 for i in range(b.shape[0] // 2, b.shape[0])]) R = 1 - idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b) - idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b) + idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True) + idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True) torch.testing.assert_allclose(idx1, idx) + with self.assertRaises(AssertionError): + idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False) + idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False) + torch.testing.assert_allclose(idx1, idx) self.assertEqual(idx.shape[0], b.shape[0]) self.assertEqual(dist.shape[0], b.shape[0]) diff --git a/torch_points/torchpoints.py b/torch_points/torchpoints.py index 3fe662d..546d10e 100644 --- a/torch_points/torchpoints.py +++ b/torch_points/torchpoints.py @@ -146,33 +146,30 @@ def grouping_operation(features, idx): return grouped_features.reshape(idx.shape[0], features.shape[1], idx.shape[1], idx.shape[2]) -class BallQueryDense(Function): - @staticmethod - def forward(ctx, radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None): - # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor - if new_xyz.is_cuda: - return tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample) - else: - return tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0) - - @staticmethod - def backward(ctx, a=None): - return None, None, None, None - - -class BallQueryPartialDense(Function): - @staticmethod - def forward(ctx, radius, nsample, x, y, batch_x, batch_y): - # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor - if x.is_cuda: - return tpcuda.ball_query_partial_dense(x, y, batch_x, batch_y, radius, nsample) - else: - ind, dist = tpcpu.batch_ball_query(x, y, batch_x, batch_y, radius, nsample, mode=0) - return ind, dist - - @staticmethod - def backward(ctx, a=None): - return None, None, None, None +def ball_query_dense(radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None, sort=False): + # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor + if new_xyz.is_cuda: + if sort: + raise NotImplementedError("CUDA version does not sort the neighbors") + ind, dist = tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample) + else: + ind, dist = tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0, sorted=sort) + positive = dist > 0 + dist[positive] = torch.sqrt(dist[positive]) + return ind, dist + + +def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=False): + # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor + if x.is_cuda: + if sort: + raise NotImplementedError("CUDA version does not sort the neighbors") + ind, dist = tpcuda.ball_query_partial_dense(x, y, batch_x, batch_y, radius, nsample) + else: + ind, dist = tpcpu.batch_ball_query(x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort) + positive = dist > 0 + dist[positive] = torch.sqrt(dist[positive]) + return ind, dist def ball_query( @@ -183,6 +180,7 @@ def ball_query( mode: Optional[str] = "dense", batch_x: Optional[torch.tensor] = None, batch_y: Optional[torch.tensor] = None, + sort: Optional[bool] = False, ) -> torch.Tensor: """ Arguments: @@ -197,11 +195,12 @@ def ball_query( Keyword Arguments: batch_x -- (M, ) [partial_dense] or (B, M, 3) [dense] Contains indexes to indicate within batch it belongs to. batch_y -- (N, ) Contains indexes to indicate within batch it belongs to + sort -- bool wether the neighboors are sorted or not (closests first) Returns: idx: (npoint, nsample) or (B, npoint, nsample) [dense] It contains the indexes of the element within x at radius distance to y - dist2: (N, nsample) or (B, npoint, nsample) Default value: -1. - It contains the square distances of the element within x at radius distance to y + dist: (N, nsample) or (B, npoint, nsample) Default value: -1. + It contains the distance of the element within x at radius distance to y """ if mode is None: raise Exception('The mode should be defined within ["partial_dense | dense"]') @@ -212,12 +211,12 @@ def ball_query( assert x.size(0) == batch_x.size(0) assert y.size(0) == batch_y.size(0) assert x.dim() == 2 - return BallQueryPartialDense.apply(radius, nsample, x, y, batch_x, batch_y) + return ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=sort) elif mode.lower() == "dense": if (batch_x is not None) or (batch_y is not None): raise Exception("batch_x and batch_y should not be provided") assert x.dim() == 3 - return BallQueryDense.apply(radius, nsample, x, y) + return ball_query_dense(radius, nsample, x, y, sort=sort) else: raise Exception("unrecognized mode {}".format(mode))