Skip to content

Commit

Permalink
matmul_nbits: Use GPU_WARP_SIZE_HOST for host side code
Browse files Browse the repository at this point in the history
For ROCm device, the host side code needs to call GPU_WARP_SIZE_HOST to
query warpsize of the underlying GPU device.
  • Loading branch information
jagadish-amd authored Sep 10, 2024
1 parent eaa26ae commit d7e1c61
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ bool TryMatMul4Bits(
return false;
}
dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m);
dim3 threads(kWarpSize, kColsPerThreadBlock);
dim3 threads(GPU_WARP_SIZE_HOST, kColsPerThreadBlock);
int blocks_per_K = (k + block_size - 1) / block_size;
int shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock +
(zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0);
Expand Down

0 comments on commit d7e1c61

Please sign in to comment.