Skip to content

Commit

Permalink
Merge pull request #21 from nicolas-chaulet/changetoken
Browse files Browse the repository at this point in the history
Change token to -1 for partial dense
  • Loading branch information
tchaton authored Feb 5, 2020
2 parents c03fe63 + e6190a5 commit 90bae82
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 17 deletions.
14 changes: 3 additions & 11 deletions cpu/src/neighbors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,43 +209,35 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
}
// how many neighbors do we keep
if (max_num > 0)
{
max_count = max_num;
}
// Reserve the memory

const int token = -1;
if (mode == 0)
{
neighbors_indices.resize(query_pcd.pts.size() * max_count);

dists.resize(query_pcd.pts.size() * max_count);
i0 = 0;

b = 0;

for (auto& inds_dists : all_inds_dists)
{ // Check if we changed batch

if (i0 == q_batches[b + 1] && b < (int)s_batches.size() - 1 &&
b < (int)q_batches.size() - 1)
{
b++;
}

for (int j = 0; j < max_count; j++)
{
if ((unsigned int)j < inds_dists.size())
if ((size_t)j < inds_dists.size())
{
neighbors_indices[i0 * max_count + j] = inds_dists[j].first + s_batches[b];
dists[i0 * max_count + j] = (float)inds_dists[j].second;
}
else
{
neighbors_indices[i0 * max_count + j] = supports.size() / 3;
neighbors_indices[i0 * max_count + j] = token;
dists[i0 * max_count + j] = -1;
}
}

i0++;
}
index.reset();
Expand Down
12 changes: 12 additions & 0 deletions cpu/src/torch_nearest_neighbors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
#include "compat.h"
#include "neighbors.cpp"
#include "neighbors.h"
#include "utils.h"
#include <iostream>
#include <torch/extension.h>

std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor query, float radius,
int max_num, int mode)
{
CHECK_CONTIGUOUS(support);
CHECK_CONTIGUOUS(query);

at::Tensor out;
at::Tensor out_dists;
std::vector<long> neighbors_indices(query.size(0), 0);
Expand Down Expand Up @@ -60,6 +64,11 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
at::Tensor support_batch, at::Tensor query_batch,
float radius, int max_num, int mode)
{
CHECK_CONTIGUOUS(support);
CHECK_CONTIGUOUS(query);
CHECK_CONTIGUOUS(support_batch);
CHECK_CONTIGUOUS(query_batch);

at::Tensor idx;

at::Tensor dist;
Expand Down Expand Up @@ -115,6 +124,9 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor support, at::Tensor query,
float radius, int max_num, int mode)
{
CHECK_CONTIGUOUS(support);
CHECK_CONTIGUOUS(query);

int b = query.size(0);
vector<at::Tensor> batch_idx;
vector<at::Tensor> batch_dist;
Expand Down
4 changes: 2 additions & 2 deletions cuda/src/ball_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x, at::Ten
CHECK_CUDA(batch_y);
}

at::Tensor idx = torch::full({y.size(0), nsample}, x.size(0),
at::device(y.device()).dtype(at::ScalarType::Long));
at::Tensor idx =
torch::full({y.size(0), nsample}, -1, at::device(y.device()).dtype(at::ScalarType::Long));

at::Tensor dist =
torch::full({y.size(0), nsample}, -1, at::device(y.device()).dtype(at::ScalarType::Float));
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

setup(
name="torch_points",
version="0.2.3",
version="0.3.0",
author="Nicolas Chaulet",
packages=find_packages(),
install_requires=requirements,
Expand Down
9 changes: 6 additions & 3 deletions test/test_ballquerry.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_simple_gpu(self):
idx = idx.detach().cpu().numpy()
dist2 = dist2.detach().cpu().numpy()

idx_answer = np.asarray([[1, 4]])
idx_answer = np.asarray([[1, -1]])
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)

npt.assert_array_almost_equal(idx, idx_answer)
Expand All @@ -81,7 +81,7 @@ def test_simple_cpu(self):
idx = idx.detach().cpu().numpy()
dist2 = dist2.detach().cpu().numpy()

idx_answer = np.asarray([[1, 4]])
idx_answer = np.asarray([[1, -1]])
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)

npt.assert_array_almost_equal(idx, idx_answer)
Expand All @@ -95,6 +95,9 @@ def test_random_cpu(self):
R = 1

idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
torch.testing.assert_allclose(idx1, idx)

self.assertEqual(idx.shape[0], b.shape[0])
self.assertEqual(dist.shape[0], b.shape[0])
self.assertLessEqual(idx.max().item(), len(batch_a))
Expand All @@ -104,7 +107,7 @@ def test_random_cpu(self):
idx3_sk = tree.query_radius(b.detach().numpy(), r=R)
i = np.random.randint(len(batch_b))
for p in idx[i].detach().numpy():
if p < len(batch_a):
if p >= 0 and p < len(batch_a):
assert p in idx3_sk[i]


Expand Down

0 comments on commit 90bae82

Please sign in to comment.