From d27022390c4179ca8dacd1b8fdae42e89a44f7f9 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Mon, 15 Jul 2024 17:26:43 +0100 Subject: [PATCH] Add Epilogue Pipeline for PVC using EVT (#80) This PR introduces the Epilogue implementation for PVC using the Epilogue Visitor Tree available for SM90 (and onwards) GPUs for NVIDIA. We only support fusion::LinearCombination operation for PVC i.e. D = alpha * A * B + beta * C through this PR, but it can be extended further to add other fusion operations by partial specialization of the FusionCallBacks struct available in the include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp file. --------- Co-authored-by: Alejandro Acosta --- include/cute/arch/copy_xe.hpp | 2 +- .../cutlass/epilogue/collective/collective_epilogue.hpp | 4 ++++ include/cutlass/epilogue/dispatch_policy.hpp | 6 ++++++ .../fusion/sm90_callbacks_tma_warpspecialized.hpp | 2 +- .../fusion/sm90_visitor_compute_tma_warpspecialized.hpp | 2 +- .../fusion/sm90_visitor_store_tma_warpspecialized.hpp | 4 ++-- include/cutlass/gemm/kernel/intel_pvc_gemm.hpp | 9 ++++++++- 7 files changed, 23 insertions(+), 6 deletions(-) diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 96e83fed41..51200dd08b 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -346,7 +346,7 @@ struct XE_2D_U16x16x16x1x1_V struct XE_2D_U32x8x16x1x1_ST_N { template - CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, const T *src) { static_assert(sizeof(T) == 4, "Expected T to have size 4"); __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index d61f59f729..00ddd37d63 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -56,7 +56,11 @@ class CollectiveEpilogue { #include "default_epilogue.hpp" #include "default_epilogue_array.hpp" #include "epilogue_tensor_broadcast.hpp" +#if defined (SYCL_INTEL_TARGET) +#include "intel_pvc_epilogue.hpp" +#else #include "sm70_epilogue_vectorized.hpp" #include "sm90_epilogue_tma_warpspecialized.hpp" #include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 409ff74dd9..e49f94c023 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -156,6 +156,12 @@ struct Sm90TmaWarpSpecializedBiasElementwise { constexpr static int FragmentSize = FragmentSize_; }; +#if defined (SYCL_INTEL_TARGET) +struct IntelPVCEpilogue { + static constexpr int SubgroupSize = 16; +}; +#endif + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 6c729c10de..7d41952eac 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -910,7 +910,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = Sm90EVT, // activation(Z) Sm90EVT, // Aux = Z // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias, + Sm90ScaledLinCombPerRowBias > > >, diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index e3160fa132..d0a7b2e90d 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -526,7 +526,7 @@ struct Sm90TreeVisitor< if (lane_idx == i) { copy_if(FunctionPredTensor(predicate_fn), tC_rAux, tC_gAux); } - __syncwarp(); + syncwarp(); } } } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index c8d941b62b..51f619fbd8 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -781,7 +781,7 @@ struct Sm90RowReduction { // if constexpr (not IsAtomic && FinalReduction) { // Ensure gmem writes are visible to other threads before incrementing counter - __threadfence(); + threadfence(); sync_fn(); // Collective thread 0 increments atomic tile counter and copies value to smem int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); @@ -1255,7 +1255,7 @@ struct Sm90ColReduction { // if constexpr (not IsAtomic && FinalReduction) { // Ensure gmem writes are visible to other threads before incrementing counter - __threadfence(); + threadfence(); sync_fn(); // Collective thread 0 increments atomic tile counter and copies value to smem int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 2df3d83cda..d63f577162 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -119,6 +119,12 @@ class GemmUniversal< static constexpr int VecC = CollectiveMainloop::VecC; + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + // Device side arguments struct Arguments { GemmUniversalMode mode{}; @@ -197,7 +203,7 @@ class GemmUniversal< CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { - (void)smem_buf; + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); // Preconditions CUTE_STATIC_ASSERT(is_static::value); @@ -231,6 +237,7 @@ class GemmUniversal< BlockIdxX() * CollectiveMainloop::wg_tile_n + (get_sub_group_id() % sg_per_wg_n) * CollectiveMainloop::sg_tile_n; const int l_coord = BlockIdxZ(); + const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( make_coord(m_coord, 0, l_coord), make_shape(_1{}, K, _1{}),