Skip to content

Commit

Permalink
fix the issue of batch gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiaxingla committed Jul 18, 2024
1 parent 962766b commit 7739df6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
1 change: 0 additions & 1 deletion examples/sycl/pvc/pvc_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,6 @@ void collective_gemm(int M, int K, int N, int L = 1) {
}

int main() {
auto gmem_size = syclcompat::get_current_device().get_global_mem_size();
collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 4096);
collective_gemm<256, 256, 32, 64, 32>(8192, 8192, 8192);
collective_gemm<256, 256, 32, 64, 32>(1, 5120, 13824);
Expand Down
6 changes: 3 additions & 3 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct XE_2D_LD_Unpack {
int W = size<1>(traits.tensor)
* sizeof(typename Copy_Traits::CopyInternalType);
auto [y, x, z] = src.data().coord_;
CopyOp::copy(traits.tensor.data() + z * W * H / sizeof(typename Copy_Traits::CopyInternalType), W, H, W, intel::coord_t {x, y},
CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t {x, y},
&*dst.data());
}

Expand Down Expand Up @@ -99,7 +99,7 @@ struct XE_2D_PF_Unpack {
int H = size<0>(traits.tensor);
int W = size<1>(traits.tensor) * sizeof(T);
auto [y, x, z] = src.data().coord_;
CopyOp::template copy<T>(traits.tensor.data() + z * W * H / sizeof(T), W, H, W,
CopyOp::template copy<T>(traits.tensor.data() + z, W, H, W,
intel::coord_t {static_cast<int>(x), static_cast<int>(y)});
}
};
Expand Down Expand Up @@ -416,7 +416,7 @@ struct XE_2D_ST_Unpack {
* sizeof(typename Copy_Traits::CopyInternalType);
auto [y, x, z] = dst.data().coord_;

CopyOp::copy(traits.tensor.data() + z * W * H / sizeof(typename Copy_Traits::CopyInternalType), W, H, W, intel::coord_t {x, y},
CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t {x, y},
&*src.data());
}

Expand Down
10 changes: 5 additions & 5 deletions include/cutlass/gemm/kernel/intel_pvc_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class GemmUniversal<
const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord);
Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(
make_coord(m_coord, 0, l_coord), make_shape(_1{}, K, _1{}),
make_coord(m_coord, 0, 0), make_shape(_1{}, K, L),
make_stride(Int<FragsM * DpasM>{}, _1{}));
constexpr int version =
is_same_v<typename CollectiveMainloop::GmemTiledCopyB,
Expand All @@ -248,8 +248,8 @@ class GemmUniversal<
: 2;
Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(
make_coord(0, n_coord, l_coord),
make_shape(K, Int<FragsN / version>{}, _1{}),
make_coord(0, n_coord, 0),
make_shape(K, Int<FragsN / version>{}, L),
make_stride(_1{}, Int<version * DpasN>{}));
// Compute tile residues for predication
Expand All @@ -271,8 +271,8 @@ class GemmUniversal<
CollectiveMainloop collective_mma;
collective_mma(
accumulators,
tAi(_,_,_,0),
tBi(_,_,_,0),
tAi(_,_,_,l_coord),
tBi(_,_,_,l_coord),
accumulators,
k_tile_iter, k_tile_count,
residue_mnk,
Expand Down

0 comments on commit 7739df6

Please sign in to comment.