Skip to content

Commit

Permalink
Keep NMS index gathering on cuda device
Browse files Browse the repository at this point in the history
  • Loading branch information
Ghelfi committed Nov 29, 2024
1 parent 8f8a195 commit 2db6ab2
Showing 1 changed file with 57 additions and 27 deletions.
84 changes: 57 additions & 27 deletions torchvision/csrc/ops/cuda/nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,48 @@ __global__ void nms_kernel_impl(
}
}

__global__ static void gather_keep_from_mask(bool *keep,
const unsigned long long *dev_mask,
const int n_boxes) {
// Taken and adapted from mmcv https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
const int col_blocks = ceil_div(n_boxes, threadsPerBlock);
const int thread_id = threadIdx.x;

// mark the bboxes which have been removed.
extern __shared__ unsigned long long removed[];

// initialize removed.
for (int i = thread_id; i < col_blocks; i += blockDim.x) {
removed[i] = 0;
}
__syncthreads();

for (int nblock = 0; nblock < col_blocks; nblock++) {
auto removed_val = removed[nblock];
__syncthreads();
const int i_offset = nblock * threadsPerBlock;
#pragma unroll
for (int inblock = 0; inblock < threadsPerBlock; inblock++) {
const int i = i_offset + inblock;
if (i >= n_boxes) break;
// select a candidate, check if it should kept.
if (!(removed_val & (1ULL << inblock))) {
if (thread_id == 0) {
// mark the output.
keep[i] = true;
}
auto p = dev_mask + i * col_blocks;
// remove all bboxes which overlap the candidate.
for (int j = thread_id; j < col_blocks; j += blockDim.x) {
if (j >= nblock) removed[j] |= p[j];
}
__syncthreads();
removed_val = removed[nblock];
}
}
}
}

at::Tensor nms_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
Expand Down Expand Up @@ -133,35 +175,23 @@ at::Tensor nms_kernel(
(unsigned long long*)mask.data_ptr<int64_t>());
});

at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host =
(unsigned long long*)mask_cpu.data_ptr<int64_t>();

std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

at::Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();

int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;

if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
at::Tensor keep = at::zeros(
{dets_num},
dets.options().dtype(at::kBool).device(at::kCUDA)
);

// Unwrap the mask to fill keep with proper values
// Keeping this unwrap on cuda instead of applying iterative for loops on cpu
// prevents the device -> cpu -> device transfer that could be bottleneck for
// large number of boxes.
// See https://github.com/pytorch/vision/issues/8713 for more details
gather_keep_from_mask<<<1, min(col_blocks, threadsPerBlock),
col_blocks * sizeof(unsigned long long), stream>>>(
keep.data_ptr<bool>(), (unsigned long long*)mask.data_ptr<int64_t>(),
dets_num);

AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
return order_t.masked_select(keep);
}

} // namespace
Expand Down

0 comments on commit 2db6ab2

Please sign in to comment.