diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 7d4eebcce..86165bb39 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -14,19 +14,20 @@ using Tensor = at::Tensor; namespace nbit { +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor64 + const pta::PackedTensorAccessor64 hash_table, const pta::PackedTensorAccessor32 hash_table_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { // uint32_t capacity = hash_table.size(0); const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; @@ -35,9 +36,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t table_start = hash_table_offsets[t]; const int64_t table_end = hash_table_offsets[t + 1]; @@ -51,6 +52,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru return; } + using hash_t = + std::conditional_t, uint64_t, uint32_t>; + const uint32_t subwarp_id = threadIdx.x / 4; const uint32_t subwarp_tid = threadIdx.x % 4; #ifdef USE_ROCM @@ -58,13 +62,15 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru #else const uint32_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); #endif + for (int32_t l_start = 0; l_start + subwarp_id < L; l_start += kWarpSize / 4) { - const int32_t idx = indices[indices_start + l_start + subwarp_id]; - uint32_t slot_start = - pruned_hash_function(static_cast(idx)) % capacity; + const index_t idx = indices[indices_start + l_start + subwarp_id]; + hash_t slot_start = + pruned_hash_function(static_cast(idx)) % capacity; + while (true) { - const uint32_t slot = (slot_start + subwarp_tid) % capacity; + const hash_t slot = (slot_start + subwarp_tid) % capacity; const int2 val = *reinterpret_cast( &hash_table[table_start + static_cast(slot)][0]); const int32_t slot_sparse_idx = val.x; @@ -78,6 +84,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru found = true; dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx; } + if (__any_sync(subwarp_mask, found)) { break; } else if (__any_sync(subwarp_mask, empty)) { @@ -89,19 +96,20 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } } +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings, const pta::PackedTensorAccessor32 index_remappings_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; const int32_t t = b_t / B; @@ -109,22 +117,22 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t index_remappings_start = index_remappings_offsets[t]; const int64_t index_remappings_end = index_remappings_offsets[t + 1]; const int64_t capacity = index_remappings_end - index_remappings_start; if (capacity > 0) { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { - int32_t idx = indices[indices_start + l]; + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { + index_t idx = indices[indices_start + l]; dense_indices[indices_start + l] = index_remappings[index_remappings_start + idx]; } } else { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { dense_indices[indices_start + l] = indices[indices_start + l]; } } @@ -132,6 +140,8 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } // namespace nbit +using namespace nbit; + Tensor pruned_hashmap_lookup_cuda( Tensor indices, Tensor offsets, @@ -139,6 +149,7 @@ Tensor pruned_hashmap_lookup_cuda( Tensor hash_table_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, hash_table, hash_table_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); CUDA_DEVICE_GUARD(indices); @@ -149,23 +160,25 @@ Tensor pruned_hashmap_lookup_cuda( TORCH_CHECK(hash_table.size(0) < std::numeric_limits::max()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< - nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, hash_table, int32_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); + int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< + nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; @@ -178,6 +191,7 @@ Tensor pruned_array_lookup_cuda( Tensor index_remappings_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, index_remappings, index_remappings_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); CUDA_DEVICE_GUARD(indices); @@ -204,23 +218,26 @@ Tensor pruned_array_lookup_cuda( TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< - nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); + int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< + nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; } diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index e7b908cdd..dea684dd6 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -7,7 +7,7 @@ */ // clang-format off -{% set wdesc = "weighted" if weighted else "unweighted" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor.h" @@ -22,7 +22,7 @@ namespace nbit { `Tensor int_nbit_split_embedding*_codegen_forward_*_cuda(...)` later in the same generated source file. */ -{% for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} +{%- for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( @@ -31,30 +31,30 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no const pta::PackedTensorAccessor32 weights_placements, const pta::PackedTensorAccessor32 weights_offsets, const pta::PackedTensorAccessor32 weights_tys, - {% if not nobag %} + {%- if not nobag %} const pta::PackedTensorAccessor32 D_offsets, - {% else %} + {%- else %} const int64_t D, - {% endif %} + {%- endif %} FixedDivisor fd_B, // FixedDivisor(div_round_up(B, OutputRowsPerThread)) const pta::PackedTensorAccessor32 indices, const pta::PackedTensorAccessor32 offsets, - {% if not nobag %} + {%- if not nobag %} const int64_t pooling_mode, - {% endif %} + {%- endif %} const int64_t row_alignment, - {% if weighted %} + {%- if weighted %} pta::PackedTensorAccessor32 indice_weights, - {% endif %} - {% if type_map[emb_weight_type].enum_name == "FP8" %} + {%- endif %} + {%- if type_map[emb_weight_type].enum_name == "FP8" %} const int fp8_exponent_bits, const int fp8_exponent_bias, - {% endif %} + {%- endif %} pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations ); -{% endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] +{%- endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] } @@ -107,58 +107,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no C10_CUDA_KERNEL_LAUNCH_CHECK(); \ {%- endmacro %} - -Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( - Tensor dev_weights, - Tensor uvm_weights, - Tensor weights_placements, - Tensor weights_offsets, - Tensor weights_tys, - {% if not nobag %} - Tensor D_offsets, - const int64_t total_D, - {% else %} - const int64_t D, - {% endif %} - const int64_t max_int2_D, - const int64_t max_int4_D, - const int64_t max_int8_D, - const int64_t max_float16_D, - const int64_t max_float32_D, - Tensor indices, - Tensor offsets, - {% if not nobag %} - const int64_t pooling_mode, - {% endif %} - const int64_t row_alignment, - {% if weighted %} - Tensor indice_weights, - {% endif %} - const int64_t output_dtype, - Tensor lxu_cache_weights, - Tensor lxu_cache_locations, - const int64_t max_float8_D, - const int64_t fp8_exponent_bits, - const int64_t fp8_exponent_bias -) { - TENSOR_ON_CUDA_GPU(dev_weights); - TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); - {% if not nobag %} - TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); - {% endif %} - TENSORS_ON_SAME_DEVICE(indices, dev_weights); - TENSORS_ON_SAME_DEVICE(offsets, dev_weights); - {% if weighted %} - TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); - {% endif %} - TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); - TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - - CUDA_DEVICE_GUARD(dev_weights); - +{%- macro construct_and_return_output_tensor() %} // kernels assume indices are contiguous. indices = indices.contiguous(); @@ -180,8 +129,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ TORCH_CHECK(D > 0); {%- endif %} + // Construct output tensor Tensor output; const int kINT8QparamsBytes = 8; + SparseType o_dtype = static_cast(output_dtype); TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); @@ -216,11 +167,63 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ if (B == 0 || indices.numel() == 0) { return output; } +{%- endmacro %} - using index_t = int32_t; +template +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + {%- if not nobag %} + Tensor D_offsets, + const int64_t total_D, + {%- else %} + const int64_t D, + {%- endif %} + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, + Tensor indices, + Tensor offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + const int64_t row_alignment, + {%- if weighted %} + Tensor indice_weights, + {%- endif %} + const int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias +) { + TENSOR_ON_CUDA_GPU(dev_weights); + TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); + {%- if not nobag %} + TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); + {%- endif %} + TENSORS_ON_SAME_DEVICE(indices, dev_weights); + TENSORS_ON_SAME_DEVICE(offsets, dev_weights); + {%- if weighted %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); + {%- endif %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - constexpr int32_t kWarpsPerBlock = 4; + CUDA_DEVICE_GUARD(dev_weights); + + {{- construct_and_return_output_tensor() }} + constexpr int32_t kWarpsPerBlock = 4; const auto device_only = lxu_cache_weights.numel() == 0 && uvm_weights.numel() == 0; #define Y(...) \ if (device_only) { \ @@ -397,6 +400,104 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ })); #undef X + return output; +} + +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + {%- if not nobag %} + Tensor D_offsets, + const int64_t total_D, + {%- else %} + const int64_t D, + {%- endif %} + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, + Tensor indices, + Tensor offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + const int64_t row_alignment, + {%- if weighted %} + Tensor indice_weights, + {%- endif %} + const int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias +) { + // All argument tensors need to be on the same CUDA device + TENSOR_ON_CUDA_GPU(dev_weights); + TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); + {%- if not nobag %} + TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); + {%- endif %} + TENSORS_ON_SAME_DEVICE(indices, dev_weights); + TENSORS_ON_SAME_DEVICE(offsets, dev_weights); + {%- if weighted %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); + {%- endif %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); + + // indices and offsets need to have the same scalar type + TENSORS_HAVE_SAME_TYPE(indices, offsets); + // Only int32_t and int64_t indices are supported at the moment + TENSOR_SCALAR_TYPE_IS_ONE_OF(indices, at::ScalarType::Long, at::ScalarType::Int); + + CUDA_DEVICE_GUARD(dev_weights); + + // Create output tensor ref + Tensor output; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ 'int_nbit_split_embedding' + ('_nobag' if nobag else '') + '_codegen_forward_' + wdesc + '_cuda' }}", [&] { + output = int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl( + dev_weights, + uvm_weights, + weights_placements, + weights_offsets, + weights_tys, + {%- if not nobag %} + D_offsets, + total_D, + {%- else %} + D, + {%- endif %} + max_int2_D, + max_int4_D, + max_int8_D, + max_float16_D, + max_float32_D, + indices, + offsets, + {%- if not nobag %} + pooling_mode, + {%- endif %} + row_alignment, + {%- if weighted %} + indice_weights, + {%- endif %} + output_dtype, + lxu_cache_weights, + lxu_cache_locations, + max_float8_D, + fp8_exponent_bits, + fp8_exponent_bias); + }); + return output; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh index 97353e03c..2164afd3e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh @@ -88,6 +88,7 @@ __device__ inline int32_t padded_D( __device__ inline uint32_t pruned_hash_function(uint32_t h) { // MurmorHash3 32-bit mixing function. + // https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp h ^= h >> 16; h *= 0x85ebca6b; h ^= h >> 13; @@ -96,6 +97,17 @@ __device__ inline uint32_t pruned_hash_function(uint32_t h) { return h; } +__device__ inline uint64_t pruned_hash_function(uint64_t k) { + // MurmorHash3 64-bit mixing function. + // https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp + k ^= k >> 33; + k *= (0xff51afd7ed558ccd); + k ^= k >> 33; + k *= (0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return k; +} + // ---------------------- START cp.async helpers, copied from CUTLASS /// CUTLASS helper to get SMEM pointer diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h index b1ab0306c..60cca19ef 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h @@ -299,3 +299,77 @@ inline at::Tensor aligned_grad_output_tensor_for_cuda_backwards( } return aligned_grad_output; } + +template +std::string tensor_scalar_type_is_one_of( + const at::Tensor& ten, + const ScalarTypes&... ttypes) { + auto has_match = false; + + ( + [&](const auto& ttype) { + if (ten.scalar_type() == ttype) { + has_match = true; + } + }(ttypes), + ...); + + if (has_match) { + return ""; + } + + std::string msg = "Tensor's scalar type ("; + msg.append(toString(ten.scalar_type())); + msg.append(") did not match any one of the following types: ["); + ( + [&](const auto& ttype) { + msg.append(toString(ttype)); + msg.append(", "); + }(ttypes), + ...); + + msg.append("]"); + return msg; +} + +#define TENSOR_SCALAR_TYPE_IS_ONE_OF(...) \ + do { \ + const auto has_match = tensor_scalar_type_is_one_of(__VA_ARGS__); \ + TORCH_CHECK(has_match.empty(), has_match); \ + } while (false) + +template +std::string tensors_have_same_scalar_type(const Tensors&... tensors) { + std::optional dtype; + bool have_same_type = true; + + ( + [&](const auto& tensor) { + if (!dtype) { + dtype = tensor.scalar_type(); + } else if (*dtype != tensor.scalar_type()) { + have_same_type = false; + } + }(tensors), + ...); + + if (have_same_type) { + return ""; + } + + std::string msg = "Tensors' scalar types ("; + ( + [&](const auto& tensor) { + msg.append(toString(tensor.scalar_type())); + msg.append(", "); + }(tensors), + ...); + msg.append(") are not one and the same!"); + return msg; +} + +#define TENSORS_HAVE_SAME_SCALAR_TYPE(...) \ + do { \ + const auto have_same_type = tensors_have_same_scalar_type(__VA_ARGS__); \ + TORCH_CHECK(have_same_type.empty(), have_same_type); \ + } while (false)