From 7b5180fe931f55c17aaca2369f76bfde5f7172e2 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Tue, 10 Sep 2024 11:46:40 -0700 Subject: [PATCH] Add GPU_WARP_SIZE_HOST in threads dim constructor. Signed-off-by: Jagadish Krishnamoorthy --- onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index b73e2d7742c30..ce6c07fbed2bc 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -288,9 +288,8 @@ bool TryMatMul4Bits( if (n % kColsPerThreadBlock != 0 || k % 8 != 0 || m > 1) { return false; } - const int kWarpSize = GPU_WARP_SIZE_HOST; dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); - dim3 threads(kWarpSize, kColsPerThreadBlock); + dim3 threads(GPU_WARP_SIZE_HOST, kColsPerThreadBlock); int blocks_per_K = (k + block_size - 1) / block_size; int shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock + (zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0);