Skip to content

Commit

Permalink
Merge pull request #23 from nicolas-chaulet/longtype
Browse files Browse the repository at this point in the history
Longtype
  • Loading branch information
nicolas-chaulet authored Feb 28, 2020
2 parents 6dafff9 + a0dd45b commit 55c3605
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 75 deletions.
8 changes: 5 additions & 3 deletions cpu/include/ball_query.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#pragma once
#include <torch/extension.h>
std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor query, at::Tensor support, float radius,
int max_num, int mode);
int max_num, int mode, bool sorted);

std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query, at::Tensor support,
at::Tensor query_batch, at::Tensor support_batch,
float radius, int max_num, int mode);
float radius, int max_num, int mode,
bool sorted);

std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor query, at::Tensor support,
float radius, int max_num, int mode);
float radius, int max_num, int mode,
bool sorted);
4 changes: 2 additions & 2 deletions cpu/include/neighbors.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ using namespace std;
template <typename scalar_t>
int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
vector<long>& neighbors_indices, vector<float>& dists, float radius,
int max_num, int mode);
int max_num, int mode, bool sorted);

template <typename scalar_t>
int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
vector<long>& q_batches, vector<long>& s_batches,
vector<long>& neighbors_indices, vector<float>& dists, float radius,
int max_num, int mode);
int max_num, int mode, bool sorted);

template <typename scalar_t>
void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
Expand Down
16 changes: 8 additions & 8 deletions cpu/src/ball_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <torch/extension.h>

std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor query, float radius,
int max_num, int mode)
int max_num, int mode, bool sorted)
{
CHECK_CONTIGUOUS(support);
CHECK_CONTIGUOUS(query);
Expand All @@ -31,7 +31,7 @@ std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor quer
std::vector<scalar_t>(data_s, data_s + support.size(0) * support.size(1));

max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl, neighbors_indices,
neighbors_dists, radius, max_num, mode);
neighbors_dists, radius, max_num, mode, sorted);
});
auto neighbors_dists_ptr = neighbors_dists.data();
long* neighbors_indices_ptr = neighbors_indices.data();
Expand Down Expand Up @@ -62,7 +62,7 @@ at::Tensor degree(at::Tensor row, int64_t num_nodes)

std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tensor query,
at::Tensor support_batch, at::Tensor query_batch,
float radius, int max_num, int mode)
float radius, int max_num, int mode, bool sorted)
{
CHECK_CONTIGUOUS(support);
CHECK_CONTIGUOUS(query);
Expand Down Expand Up @@ -97,9 +97,9 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
std::vector<scalar_t> supports_stl(support.DATA_PTR<scalar_t>(),
support.DATA_PTR<scalar_t>() + support.numel());

max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl, supports_stl, query_batch_stl,
support_batch_stl, neighbors_indices,
neighbors_dists, radius, max_num, mode);
max_count = batch_nanoflann_neighbors<scalar_t>(
queries_stl, supports_stl, query_batch_stl, support_batch_stl, neighbors_indices,
neighbors_dists, radius, max_num, mode, sorted);
});
auto neighbors_dists_ptr = neighbors_dists.data();
long* neighbors_indices_ptr = neighbors_indices.data();
Expand All @@ -122,7 +122,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
}

std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor support, at::Tensor query,
float radius, int max_num, int mode)
float radius, int max_num, int mode, bool sorted)
{
CHECK_CONTIGUOUS(support);
CHECK_CONTIGUOUS(query);
Expand All @@ -132,7 +132,7 @@ std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor support, at::Tenso
vector<at::Tensor> batch_dist;
for (int i = 0; i < b; i++)
{
auto out_pair = ball_query(query[i], support[i], radius, max_num, mode);
auto out_pair = ball_query(query[i], support[i], radius, max_num, mode, sorted);
batch_idx.push_back(out_pair.first);
batch_dist.push_back(out_pair.second);
}
Expand Down
6 changes: 3 additions & 3 deletions cpu/src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"maximum number of neighbors found if mode = 0, if mode=1 return a "
"tensor of size Num_edge x 2 and return a tensor containing the "
"squared distance of the neighbors",
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0);
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0, "sorted"_a = false);

m.def("batch_ball_query", &batch_ball_query,
"compute the radius search of a point cloud for each batch using "
Expand All @@ -53,7 +53,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"tensor of size Num_edge x 2 and return a tensor containing the "
"squared distance of the neighbors",
"support"_a, "querry"_a, "query_batch"_a, "support_batch"_a, "radius"_a, "max_num"_a = -1,
"mode"_a = 0);
"mode"_a = 0, "sorted"_a = false);
m.def("dense_ball_query", &dense_ball_query,
"compute the radius search of a batch of point cloud using nanoflann"
"- support : a pytorch tensor of size B x N1 x 3, points where the "
Expand All @@ -69,5 +69,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"maximum number of neighbors found if mode = 0, if mode=1 return a "
"tensor of size Num_edge x 2 and return a tensor containing the "
"squared distance of the neighbors",
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0);
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0, "sorted"_a = false);
}
35 changes: 23 additions & 12 deletions cpu/src/neighbors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
// Taken from https://github.com/HuguesTHOMAS/KPConv

#include "neighbors.h"
#include <random>

template <typename scalar_t>
int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
vector<long>& neighbors_indices, vector<float>& dists, float radius,
int max_num, int mode)
int max_num, int mode, bool sorted)
{
// Initiate variables
// ******************
std::random_device rd;
std::mt19937 g(rd());

// square radius

const float search_radius = static_cast<float>(radius * radius);

// indices
Expand Down Expand Up @@ -47,7 +49,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,

// Search params
nanoflann::SearchParams search_params;
search_params.sorted = true;
search_params.sorted = sorted;
std::vector<std::vector<std::pair<size_t, scalar_t>>> list_matches(pcd_query.pts.size());

for (auto& p0 : pcd_query.pts)
Expand All @@ -62,7 +64,11 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
if (nMatches == 0)
list_matches[i0] = {std::make_pair(0, -1)};
else
{
if (!sorted)
std::shuffle(ret_matches.begin(), ret_matches.end(), g);
list_matches[i0] = ret_matches;
}
max_count = max(max_count, nMatches);
i0++;
}
Expand Down Expand Up @@ -132,10 +138,13 @@ template <typename scalar_t>
int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
vector<long>& q_batches, vector<long>& s_batches,
vector<long>& neighbors_indices, vector<float>& dists, float radius,
int max_num, int mode)
int max_num, int mode, bool sorted)
{
// Initiate variables
// ******************
std::random_device rd;
std::mt19937 g(rd());

// indices
int i0 = 0;

Expand Down Expand Up @@ -173,7 +182,7 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
// ***********************
// Search params
nanoflann::SearchParams search_params;
search_params.sorted = true;
search_params.sorted = sorted;
for (auto& p0 : query_pcd.pts)
{
// Check if we changed batch
Expand All @@ -192,16 +201,18 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
index->buildIndex();
}

// Initial guess of neighbors size

all_inds_dists[i0].reserve(max_count);
// Find neighbors
// std::cerr << p0.x << p0.y << p0.z<<std::endl;
// Find neighboors
std::vector<std::pair<size_t, scalar_t>> ret_matches;
ret_matches.reserve(max_count);
scalar_t query_pt[3] = {p0.x, p0.y, p0.z};
size_t nMatches = index->radiusSearch(query_pt, r2, ret_matches, search_params);

size_t nMatches = index->radiusSearch(query_pt, r2, all_inds_dists[i0], search_params);
// Update max count
// Shuffle if needed
if (!sorted)
std::shuffle(ret_matches.begin(), ret_matches.end(), g);
all_inds_dists[i0] = ret_matches;

// Update max count
if (nMatches > (size_t)max_count)
max_count = nMatches;
// Increment query idx
Expand Down
6 changes: 3 additions & 3 deletions cuda/src/ball_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "utils.h"

void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
const float* new_xyz, const float* xyz, int* idx,
const float* new_xyz, const float* xyz, long* idx,
float* dist_out);

void query_ball_point_kernel_partial_wrapper(long batch_size, int size_x, int size_y, float radius,
Expand All @@ -25,15 +25,15 @@ std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tenso
}

at::Tensor idx = torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
at::device(new_xyz.device()).dtype(at::ScalarType::Int));
at::device(new_xyz.device()).dtype(at::ScalarType::Long));
at::Tensor dist = torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1,
at::device(new_xyz.device()).dtype(at::ScalarType::Float));

if (new_xyz.type().is_cuda())
{
query_ball_point_kernel_dense_wrapper(
xyz.size(0), xyz.size(1), new_xyz.size(1), radius, nsample, new_xyz.DATA_PTR<float>(),
xyz.DATA_PTR<float>(), idx.DATA_PTR<int>(), dist.DATA_PTR<float>());
xyz.DATA_PTR<float>(), idx.DATA_PTR<long>(), dist.DATA_PTR<float>());
}
else
{
Expand Down
4 changes: 2 additions & 2 deletions cuda/src/ball_query_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius, int nsample,
const float* __restrict__ new_xyz,
const float* __restrict__ xyz,
int* __restrict__ idx_out,
long* __restrict__ idx_out,
float* __restrict__ dist_out)
{
int batch_index = blockIdx.x;
Expand Down Expand Up @@ -93,7 +93,7 @@ __global__ void query_ball_point_kernel_partial_dense(
}

void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
const float* new_xyz, const float* xyz, int* idx,float* dist_out)
const float* new_xyz, const float* xyz, long* idx,float* dist_out)
{
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
query_ball_point_kernel_dense<<<b, opt_n_threads(m), 0, stream>>>(b, n, m, radius, nsample,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

setup(
name="torch_points",
version="0.4.0",
version="0.4.1",
author="Nicolas Chaulet",
packages=find_packages(),
install_requires=requirements,
Expand Down
24 changes: 14 additions & 10 deletions test/test_ballquerry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@ def test_simple_gpu(self):
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]], [[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float).cuda()
b = torch.tensor([[[0, 0, 0]], [[3, 0, 0]]]).to(torch.float).cuda()
idx, dist = ball_query(1.01, 2, a, b)
torch.testing.assert_allclose(idx.long().cpu(), torch.tensor([[[0, 1]], [[2, 2]]]))
torch.testing.assert_allclose(idx.cpu(), torch.tensor([[[0, 1]], [[2, 2]]]))
torch.testing.assert_allclose(dist.cpu(), torch.tensor([[[0, 1]], [[1, -1]]]).float())

def test_simple_cpu(self):
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]], [[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float)
b = torch.tensor([[[0, 0, 0]], [[3, 0, 0]]]).to(torch.float)
idx, dist = ball_query(1.01, 2, a, b)
torch.testing.assert_allclose(idx.long(), torch.tensor([[[0, 1]], [[2, 2]]]))
idx, dist = ball_query(1.01, 2, a, b, sort=True)
torch.testing.assert_allclose(idx, torch.tensor([[[0, 1]], [[2, 2]]]))
torch.testing.assert_allclose(dist, torch.tensor([[[0, 1]], [[1, -1]]]).float())

a = torch.tensor([[[0, 0, 0], [1, 0, 0], [1, 1, 0]]]).to(torch.float)
idx, dist = ball_query(1.01, 3, a, a)
torch.testing.assert_allclose(idx.long(),torch.tensor([[[0, 1, 0],[1,0,2],[2,1,2]]]))
idx, dist = ball_query(1.01, 3, a, a, sort=True)
torch.testing.assert_allclose(idx, torch.tensor([[[0, 1, 0], [1, 0, 2], [2, 1, 2]]]))

@run_if_cuda
def test_larger_gpu(self):
a = torch.randn(32, 4096, 3).to(torch.float).cuda()
idx,dist = ball_query(1, 64, a, a)
idx, dist = ball_query(1, 64, a, a)
self.assertGreaterEqual(idx.min(), 0)

@run_if_cuda
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_simple_gpu(self):
dist2 = dist2.detach().cpu().numpy()

idx_answer = np.asarray([[1, -1]])
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
dist2_answer = np.asarray([[0.100, -1.0000]]).astype(np.float32)

npt.assert_array_almost_equal(idx, idx_answer)
npt.assert_array_almost_equal(dist2, dist2_answer)
Expand All @@ -88,7 +88,7 @@ def test_simple_cpu(self):
dist2 = dist2.detach().cpu().numpy()

idx_answer = np.asarray([[1, -1]])
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
dist2_answer = np.asarray([[0.100, -1.0000]]).astype(np.float32)

npt.assert_array_almost_equal(idx, idx_answer)
npt.assert_array_almost_equal(dist2, dist2_answer)
Expand All @@ -100,9 +100,13 @@ def test_random_cpu(self):
batch_b = torch.tensor([0 for i in range(b.shape[0] // 2)] + [1 for i in range(b.shape[0] // 2, b.shape[0])])
R = 1

idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True)
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True)
torch.testing.assert_allclose(idx1, idx)
with self.assertRaises(AssertionError):
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False)
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False)
torch.testing.assert_allclose(idx1, idx)

self.assertEqual(idx.shape[0], b.shape[0])
self.assertEqual(dist.shape[0], b.shape[0])
Expand Down
Loading

0 comments on commit 55c3605

Please sign in to comment.