Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for int64_t indices and offsets in TBE inference [3/N] #3124

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@ using Tensor = at::Tensor;

namespace nbit {

template <typename index_t>
__global__
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor64<int32_t, 2, at::RestrictPtrTraits>
const pta::PackedTensorAccessor64<index_t, 2, at::RestrictPtrTraits>
hash_table,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
hash_table_offsets,
const int32_t B,
const int32_t T,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
dense_indices) {
// uint32_t capacity = hash_table.size(0);
const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
Expand All @@ -35,9 +36,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
if (b_t >= B * T) {
return;
}
const int32_t indices_start = offsets[t * B + b];
const int32_t indices_end = offsets[t * B + b + 1];
const int32_t L = indices_end - indices_start;
const index_t indices_start = offsets[t * B + b];
const index_t indices_end = offsets[t * B + b + 1];
const index_t L = indices_end - indices_start;

const int64_t table_start = hash_table_offsets[t];
const int64_t table_end = hash_table_offsets[t + 1];
Expand All @@ -51,20 +52,25 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
return;
}

using hash_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const uint32_t subwarp_id = threadIdx.x / 4;
const uint32_t subwarp_tid = threadIdx.x % 4;
#ifdef USE_ROCM
const uint64_t subwarp_mask = static_cast<uint64_t>(0xF) << (4 * subwarp_id);
#else
const uint32_t subwarp_mask = static_cast<uint32_t>(0xF) << (4 * subwarp_id);
#endif

for (int32_t l_start = 0; l_start + subwarp_id < L;
l_start += kWarpSize / 4) {
const int32_t idx = indices[indices_start + l_start + subwarp_id];
uint32_t slot_start =
pruned_hash_function(static_cast<uint32_t>(idx)) % capacity;
const index_t idx = indices[indices_start + l_start + subwarp_id];
hash_t slot_start =
pruned_hash_function(static_cast<hash_t>(idx)) % capacity;

while (true) {
const uint32_t slot = (slot_start + subwarp_tid) % capacity;
const hash_t slot = (slot_start + subwarp_tid) % capacity;
const int2 val = *reinterpret_cast<const int2*>(
&hash_table[table_start + static_cast<int64_t>(slot)][0]);
const int32_t slot_sparse_idx = val.x;
Expand All @@ -78,6 +84,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
found = true;
dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx;
}

if (__any_sync(subwarp_mask, found)) {
break;
} else if (__any_sync(subwarp_mask, empty)) {
Expand All @@ -89,56 +96,60 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
}
}

template <typename index_t>
__global__
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
index_remappings,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
index_remappings_offsets,
const int32_t B,
const int32_t T,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
dense_indices) {
const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
const int32_t t = b_t / B;
const int32_t b = b_t % B;
if (b_t >= B * T) {
return;
}
const int32_t indices_start = offsets[t * B + b];
const int32_t indices_end = offsets[t * B + b + 1];
const int32_t L = indices_end - indices_start;
const index_t indices_start = offsets[t * B + b];
const index_t indices_end = offsets[t * B + b + 1];
const index_t L = indices_end - indices_start;

const int64_t index_remappings_start = index_remappings_offsets[t];
const int64_t index_remappings_end = index_remappings_offsets[t + 1];
const int64_t capacity = index_remappings_end - index_remappings_start;

if (capacity > 0) {
for (int32_t l = threadIdx.x; l < L; l += blockDim.x) {
int32_t idx = indices[indices_start + l];
for (index_t l = threadIdx.x; l < L; l += blockDim.x) {
index_t idx = indices[indices_start + l];
dense_indices[indices_start + l] =
index_remappings[index_remappings_start + idx];
}
} else {
for (int32_t l = threadIdx.x; l < L; l += blockDim.x) {
for (index_t l = threadIdx.x; l < L; l += blockDim.x) {
dense_indices[indices_start + l] = indices[indices_start + l];
}
}
}

} // namespace nbit

using namespace nbit;

Tensor pruned_hashmap_lookup_cuda(
Tensor indices,
Tensor offsets,
Tensor hash_table,
Tensor hash_table_offsets) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
indices, offsets, hash_table, hash_table_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table);

CUDA_DEVICE_GUARD(indices);

Expand All @@ -149,23 +160,25 @@ Tensor pruned_hashmap_lookup_cuda(
TORCH_CHECK(hash_table.size(0) < std::numeric_limits<int32_t>::max());
constexpr size_t kForwardMaxThreads = 256;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
#endif

nbit::int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<<
nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, hash_table, int32_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32));
int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<<
nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32));
});

C10_CUDA_KERNEL_LAUNCH_CHECK();
return dense_indices;
Expand All @@ -178,6 +191,7 @@ Tensor pruned_array_lookup_cuda(
Tensor index_remappings_offsets) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
indices, offsets, index_remappings, index_remappings_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings);

CUDA_DEVICE_GUARD(indices);

Expand All @@ -204,23 +218,26 @@ Tensor pruned_array_lookup_cuda(
TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim());
constexpr size_t kForwardMaxThreads = 256;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel";
#endif

nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<<
nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32));
int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<<
nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32));
});

C10_CUDA_KERNEL_LAUNCH_CHECK();
return dense_indices;
}
Loading
Loading