From f362e6137a2adb213b27a239ed8c467b4fd9c8e7 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Thu, 26 Dec 2024 07:44:50 +0000 Subject: [PATCH] weight packed to [n, k//8] --- src/ATen/native/xpu/WeightInt4Pack.cpp | 22 ++----------------- .../native/xpu/sycl/WeightInt4PackKernel.cpp | 1 - 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/src/ATen/native/xpu/WeightInt4Pack.cpp b/src/ATen/native/xpu/WeightInt4Pack.cpp index 963c72113..06353a713 100644 --- a/src/ATen/native/xpu/WeightInt4Pack.cpp +++ b/src/ATen/native/xpu/WeightInt4Pack.cpp @@ -3,7 +3,7 @@ namespace at::native { // input is [n][k / 2] (uint8 dtype) -// output is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] (int32 dtype) +// output is [n][k // 8] Tensor _convert_weight_to_int4pack_xpu(const Tensor& in, int64_t innerKTiles) { TORCH_CHECK(in.dim() == 2, __func__, " : expect weight to be 2D tensor."); TORCH_CHECK( @@ -18,25 +18,7 @@ Tensor _convert_weight_to_int4pack_xpu(const Tensor& in, int64_t innerKTiles) { auto N = weight.size(0); auto K = weight.size(1) * 2; - // Create fake shapes for cpu. The meta registration in dynamo requires - // operator has the same output shape for each device. So creating a fake - // shape {N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2} - constexpr int64_t kNTileSize = 8; - constexpr int64_t kKTileSize = 16; - auto nTiles = (N + kNTileSize - 1) / kNTileSize; - - TORCH_CHECK(N % 16 == 0, __func__, " : expect N to be dividable by 16"); - const int64_t kSuperKTileSize = kKTileSize * innerKTiles; - TORCH_CHECK( - K % kSuperKTileSize == 0, - __func__, - " : epxect K to be dividable by ", - kSuperKTileSize); - auto kSuperTiles = (K + kSuperKTileSize - 1) / kSuperKTileSize; - - auto weight_packed = at::empty( - {nTiles, kSuperTiles, 32, innerKTiles / 2}, - at::TensorOptions().dtype(at::kInt).device(in.device())); + auto weight_packed = at::empty({N, K / 8}, at::TensorOptions().dtype(at::kInt).device(in.device())); xpu::weight_to_int4pack_kernel(weight_packed, weight, N, K); return weight_packed; diff --git a/src/ATen/native/xpu/sycl/WeightInt4PackKernel.cpp b/src/ATen/native/xpu/sycl/WeightInt4PackKernel.cpp index 4b8e59428..6348e5cfb 100644 --- a/src/ATen/native/xpu/sycl/WeightInt4PackKernel.cpp +++ b/src/ATen/native/xpu/sycl/WeightInt4PackKernel.cpp @@ -21,7 +21,6 @@ struct WeightToInt4PackKernelFunctor { vec_t output; #pragma unroll for (int i = 0; i < 4; i++) { - // output[i] = input[3 - i]; output[i] = input[i]; } *reinterpret_cast(&weight_packed_[out_y * K_div_8 + out_x]) = output;