From ecb57471d6a405cfb8c2c3dd7abd1f73cfeb80bf Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 16 May 2024 15:50:54 +0100 Subject: [PATCH 01/19] Epilogue compiling with the pipeline * Next need to integrate EVT --- .../sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 31 +- .../collective/collective_epilogue.hpp | 4 + .../collective/intel_pvc_epilogue.hpp | 286 ++++++++++++++++++ include/cutlass/epilogue/dispatch_policy.hpp | 4 + .../epilogue/fusion/intel_pvc_epilogue.hpp | 131 ++++++++ .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 19 +- 6 files changed, 456 insertions(+), 19 deletions(-) create mode 100644 include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp create mode 100644 include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index 731edfa15f..5a169c48e6 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -31,6 +31,8 @@ #include "cutlass/gemm/device/gemm.h" #include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp" +#include "cutlass/epilogue/fusion/intel_pvc_epilogue.hpp" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/collective/collective_mma.hpp" @@ -360,26 +362,29 @@ int main(int argc, const char** argv) Layout>, Tile<_32,_64,_32>>; // Subgroup level-tile - using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + EpilogueShape, + ElementAccumulator, cutlass::gemm::TagToStrideC_t, + ElementOutput, cutlass::gemm::TagToStrideC_t, - EpilogueOp, - cutlass::gemm::EpilogueDefault>; + FusionCallBacks, + XE_2D_U16x16x16x2x1_LD_N, + void, void, + XE_2D_U32x8x16x1x1_ST_N, + void, void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - DispatchPolicy, + GEMMDispatchPolicy, TileShape, ElementInputA, cutlass::gemm::TagToStrideA_t, diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index d61f59f729..00ddd37d63 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -56,7 +56,11 @@ class CollectiveEpilogue { #include "default_epilogue.hpp" #include "default_epilogue_array.hpp" #include "epilogue_tensor_broadcast.hpp" +#if defined (SYCL_INTEL_TARGET) +#include "intel_pvc_epilogue.hpp" +#else #include "sm70_epilogue_vectorized.hpp" #include "sm90_epilogue_tma_warpspecialized.hpp" #include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp new file mode 100644 index 0000000000..f81cbf07ae --- /dev/null +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -0,0 +1,286 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +// #include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +// #include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +// #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" +// #include "cutlass/trace.h" + +#include "cute/tensor.hpp" +// #include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2R_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpR2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class CollectiveEpilogue< + IntelPVCEpilogue, + CtaTileMNK_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2R_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpR2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = IntelPVCEpilogue; + using CtaTileMNK = CtaTileMNK_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using ElementAccumulator = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpG2R = CopyOpG2R_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = void; + using CopyOpR2G = CopyOpR2G_; + using SmemLayoutAtomD = void; // SmemLayoutAtomD_; + using CopyOpR2S = void; + + using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2R; + using GmemTiledCopyD = CopyOpR2G; + using ElementOutput = typename FusionCallbacks::ElementOutput; + using ElementCompute = typename FusionCallbacks::ElementCompute; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + +public: + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C; + StrideC dC; + ElementD const* ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + using XE_Copy_C = decltype(make_xe_2d_copy( + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideC{}, int32_t(0)), StrideC{}))); + using XE_Copy_D = decltype(make_xe_2d_copy( + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideD{}, int32_t(0)), StrideD{}))); + + typename FusionCallbacks::Params thread{}; + XE_Copy_C xe_load_c; + XE_Copy_D xe_store_d; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + typename Params::XE_Copy_C xe_load_c = {}; + if constexpr (is_source_supported) { + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); + xe_load_c = make_xe_2d_copy(tensor_c); + } + + typename Params::XE_Copy_D xe_store_d = {}; + if constexpr (is_destination_supported) { + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); + xe_store_d = make_xe_2d_copy(tensor_d); + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + xe_load_c, + xe_store_d + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + // constexpr int tma_alignment_bits = 128; + // auto problem_shape_MNKL = append<4>(problem_shape, 1); + // auto [M,N,K,L] = problem_shape_MNKL; + + // bool implementable = true; + // if constexpr (is_destination_supported) { + // constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits::value; + // implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideD{}); + // } + + // if constexpr (not cute::is_void_v) { + // constexpr int min_tma_aligned_elements_C = tma_alignment_bits / cutlass::sizeof_bits::value; + // implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideC{}); + // } + + // if (!implementable) { + // CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + // } + + // return implementable; + return true; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_) + : params(params_) {} + + // CUTLASS_DEVICE + // bool + // is_producer_load_needed() const { + // return fusion_callbacks.is_producer_load_needed(); + // } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class Accumulator, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + Accumulator accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem) { + + (void) tiled_mma; + (void) residue_mnk; + (void) thread_idx; + (void) smem; + using namespace cute; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + printf("PVC Epilogue\n"); + + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 409ff74dd9..f5f88a073d 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -156,6 +156,10 @@ struct Sm90TmaWarpSpecializedBiasElementwise { constexpr static int FragmentSize = FragmentSize_; }; +#if defined (SYCL_INTEL_TARGET) +struct IntelPVCEpilogue {}; +#endif + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue diff --git a/include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp new file mode 100644 index 0000000000..04859f52a3 --- /dev/null +++ b/include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp @@ -0,0 +1,131 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the Intel PVC epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +// #include "cutlass/epilogue/fusion/callbacks.hpp" +// #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +// #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +// #include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +// #include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" +// #include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// template +// using Sm90EVT = Sm90TreeVisitor; + +template < + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelPVCEpilogue, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_, + void, void +> {//: Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle_> { + + // using Impl = Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle_>; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + // operator typename Impl::Arguments() const { + // return + // { // ternary op : beta * C + (alpha * acc) + // {{beta}, {beta_ptr}}, // leaf args : beta + // {}, // leaf args : C + // { // binary op : alpha * acc + // {{alpha}, {alpha_ptr}}, // leaf args : alpha + // {}, // leaf args : acc + // {} // binary args : multiplies + // }, // end binary op + // {} // ternary args : multiply_add + // }; // end ternary op + // } + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const&, Arguments const& args, void*) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const&, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + // Ctor inheritance +// using Impl::Impl; +}; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 5e8fce8b4e..f0d4912c3e 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -214,6 +214,7 @@ class GemmUniversal< const int m_coord = BlockIdxX() * get<0>(subgroup_shape); const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); + const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( make_coord(m_coord, 0, 0), @@ -253,13 +254,19 @@ class GemmUniversal< smem_buf, params.mainloop ); - auto gmem_tiled_copy_c = make_xe_2d_copy(make_tensor(params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); - Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, 0), - make_shape(Int{}, Int{}, L), - make_stride(get<0>(MmaAtomShape()), get<1>(MmaAtomShape()))); - - copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord)); + // copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord)); + CollectiveEpilogue epilogue{params.epilogue}; + epilogue( + problem_shape_MNKL, + subgroup_shape, + tile_coord, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + smem_buf + ); } }; From a3ada688bd053820eac7fe89842988e50069c5cb Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Fri, 17 May 2024 12:08:23 +0100 Subject: [PATCH 02/19] Linking callbacks to EVT --- .../sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 2 +- .../collective/intel_pvc_epilogue.hpp | 10 ++-- .../epilogue/fusion/intel_pvc_epilogue.hpp | 50 +++++++++---------- .../sm90_callbacks_tma_warpspecialized.hpp | 2 +- ...90_visitor_compute_tma_warpspecialized.hpp | 2 +- ...sm90_visitor_store_tma_warpspecialized.hpp | 4 +- 6 files changed, 33 insertions(+), 37 deletions(-) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index 5a169c48e6..8549648dc5 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -367,7 +367,7 @@ int main(int argc, const char** argv) using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index f81cbf07ae..cccd40848b 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -232,11 +232,11 @@ class CollectiveEpilogue< CollectiveEpilogue(Params const& params_) : params(params_) {} - // CUTLASS_DEVICE - // bool - // is_producer_load_needed() const { - // return fusion_callbacks.is_producer_load_needed(); - // } + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } template< class ProblemShapeMNKL, diff --git a/include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp index 04859f52a3..6ea803e8f6 100644 --- a/include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp @@ -40,12 +40,12 @@ #include "cute/tensor.hpp" #include "cutlass/epilogue/dispatch_policy.hpp" -// #include "cutlass/epilogue/fusion/callbacks.hpp" -// #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" -// #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" -// #include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" -// #include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" -// #include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -53,9 +53,6 @@ namespace cutlass::epilogue::fusion { ///////////////////////////////////////////////////////////////////////////////////////////////// -// template -// using Sm90EVT = Sm90TreeVisitor; - template < class ElementOutput_, class ElementCompute_, @@ -69,11 +66,10 @@ struct FusionCallbacks< epilogue::IntelPVCEpilogue, fusion::LinearCombination, CtaTileShapeMNK_, - EpilogueTile_, - void, void -> {//: Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle_> { + EpilogueTile_ +> : Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> { - // using Impl = Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle_>; + using Impl = Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>; using ElementOutput = ElementOutput_; using ElementCompute = ElementCompute_; using ElementSource = ElementSource_; @@ -86,19 +82,19 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - // operator typename Impl::Arguments() const { - // return - // { // ternary op : beta * C + (alpha * acc) - // {{beta}, {beta_ptr}}, // leaf args : beta - // {}, // leaf args : C - // { // binary op : alpha * acc - // {{alpha}, {alpha_ptr}}, // leaf args : alpha - // {}, // leaf args : acc - // {} // binary args : multiplies - // }, // end binary op - // {} // ternary args : multiply_add - // }; // end ternary op - // } + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } }; using Params = Arguments; @@ -123,7 +119,7 @@ struct FusionCallbacks< } // Ctor inheritance -// using Impl::Impl; + using Impl::Impl; }; } // namespace cutlass::epilogue::fusion diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 6c729c10de..7d41952eac 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -910,7 +910,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = Sm90EVT, // activation(Z) Sm90EVT, // Aux = Z // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias, + Sm90ScaledLinCombPerRowBias > > >, diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index e3160fa132..d0a7b2e90d 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -526,7 +526,7 @@ struct Sm90TreeVisitor< if (lane_idx == i) { copy_if(FunctionPredTensor(predicate_fn), tC_rAux, tC_gAux); } - __syncwarp(); + syncwarp(); } } } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index c8d941b62b..51f619fbd8 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -781,7 +781,7 @@ struct Sm90RowReduction { // if constexpr (not IsAtomic && FinalReduction) { // Ensure gmem writes are visible to other threads before incrementing counter - __threadfence(); + threadfence(); sync_fn(); // Collective thread 0 increments atomic tile counter and copies value to smem int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); @@ -1255,7 +1255,7 @@ struct Sm90ColReduction { // if constexpr (not IsAtomic && FinalReduction) { // Ensure gmem writes are visible to other threads before incrementing counter - __threadfence(); + threadfence(); sync_fn(); // Collective thread 0 increments atomic tile counter and copies value to smem int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); From b6203f416aed9633203a2d0bd4b0906257c2c4ff Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Mon, 27 May 2024 07:45:30 +0100 Subject: [PATCH 03/19] Implemented Epiloge callbacks * Alpha scaling working, need to fix copy atom for C --- .../sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 2 +- include/cute/arch/copy_xe.hpp | 2 +- .../collective/intel_pvc_epilogue.hpp | 124 +++++++++++++++--- include/cutlass/epilogue/dispatch_policy.hpp | 4 +- ...c_epilogue.hpp => intel_pvc_callbacks.hpp} | 21 --- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 15 ++- 6 files changed, 126 insertions(+), 42 deletions(-) rename include/cutlass/epilogue/fusion/{intel_pvc_epilogue.hpp => intel_pvc_callbacks.hpp} (88%) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index 8549648dc5..d403e99b0d 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -32,7 +32,7 @@ #include "cutlass/gemm/device/gemm.h" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp" -#include "cutlass/epilogue/fusion/intel_pvc_epilogue.hpp" +#include "cutlass/epilogue/fusion/intel_pvc_callbacks.hpp" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/collective/collective_mma.hpp" diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 2646dcae1e..3bfc5c8535 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -287,7 +287,7 @@ struct XE_2D_U16x16x16x1x1_V struct XE_2D_U32x8x16x1x1_ST_N { template - CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, const T *src) { #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 4, "Expected T to have size 4"); diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index cccd40848b..ef7f449d7d 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -36,18 +36,15 @@ #include #include "cutlass/cutlass.h" -// #include "cutlass/arch/barrier.h" #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_epilogue.hpp" #include "cutlass/epilogue/collective/detail.hpp" -// #include "cutlass/epilogue/thread/scale_type.h" #include "cutlass/epilogue/fusion/callbacks.hpp" -// #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" #include "cutlass/detail/layout.hpp" -// #include "cutlass/trace.h" + #include "cute/tensor.hpp" -// #include "cutlass/cuda_host_adapter.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -114,6 +111,8 @@ class CollectiveEpilogue< using ElementOutput = typename FusionCallbacks::ElementOutput; using ElementCompute = typename FusionCallbacks::ElementCompute; + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); @@ -131,6 +130,38 @@ class CollectiveEpilogue< public: + using EmptyType = cute::tuple<>; + using SmemCStorage = cute::conditional_t, + EmptyType>; + using SmemDStorage = cute::conditional_t, + EmptyType>; + + struct TensorStorageImpl: cute::tuple { + using Base = cute::tuple; + + constexpr decltype(auto) + smem_C() { + return cute::get<0>(static_cast(*this)); + } + + constexpr decltype(auto) + smem_D() { + return cute::get<1>(static_cast(*this)); + } + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + // Host side epilogue arguments struct Arguments { typename FusionCallbacks::Arguments thread{}; @@ -143,10 +174,10 @@ class CollectiveEpilogue< // Device side epilogue params struct Params { using XE_Copy_C = decltype(make_xe_2d_copy( - make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_tensor(static_cast(nullptr), repeat_like(StrideC{}, int32_t(0)), StrideC{}))); using XE_Copy_D = decltype(make_xe_2d_copy( - make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_tensor(static_cast(nullptr), repeat_like(StrideD{}, int32_t(0)), StrideD{}))); typename FusionCallbacks::Params thread{}; @@ -170,13 +201,13 @@ class CollectiveEpilogue< typename Params::XE_Copy_C xe_load_c = {}; if constexpr (is_source_supported) { - Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); + Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC)); xe_load_c = make_xe_2d_copy(tensor_c); } typename Params::XE_Copy_D xe_store_d = {}; if constexpr (is_destination_supported) { - Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); + Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD)); xe_store_d = make_xe_2d_copy(tensor_d); } @@ -190,14 +221,14 @@ class CollectiveEpilogue< template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + return 0; } template static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { - return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + return Status::kSuccess; } template @@ -229,8 +260,8 @@ class CollectiveEpilogue< } CUTLASS_HOST_DEVICE - CollectiveEpilogue(Params const& params_) - : params(params_) {} + CollectiveEpilogue(Params const& params_, TensorStorage const& shared_storage_) + : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} CUTLASS_DEVICE bool @@ -259,15 +290,78 @@ class CollectiveEpilogue< (void) tiled_mma; (void) residue_mnk; - (void) thread_idx; (void) smem; using namespace cute; + static constexpr int DpasM = get<0>(shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per sub_group for Matrix A + static constexpr int DpasN = get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per sub_group for Matrix B + + static constexpr int FragsM = get<0>(EpilogueTile{}) / DpasM; // A frags per sub_group + static constexpr int FragsN = get<1>(EpilogueTile{}) / DpasN; // B frags per sub_group + + static constexpr int FragmentSize = (DpasN * DpasM) / SubgroupSize; + // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - printf("PVC Epilogue\n"); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + Tensor trC = make_tensor(Shape, Int, Int>{}); + Tensor trD = make_tensor(Shape, Int, Int>{}); + Tensor tOuti = params.xe_store_d.get_pvc_tensor(make_coord(m_coord, n_coord, 0), + make_shape(Int{}, Int{}, L), + make_stride(Int{}, Int{})); + + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); + Tensor cD = local_tile(mD_crd, take<0,2>(TileShapeMNK{}), make_coord(m_coord, n_coord)); + // Get the fusion callbacks + constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout + auto residue_mn = make_coord(M, N); + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + TileShapeMNK{}, + tile_coord_mnkl, + residue_mn, + EpilogueTile{}, + params.xe_load_c, + thread_idx, + cD, + cD/*tRS_cD*/, + trC + }; + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + + cst_callbacks.begin(); + + // if (is_C_load_needed) { + // copy(params.xe_load_c, tOuti(_,_,_,l_coord), trC); + // } + + auto acc_frag = recast>(accumulators); + auto c_frag = recast>(trC); + auto trD_frag = recast>(trD); + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; epi_n++) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + + auto acc_frag_mn = acc_frag(_, epi_m, epi_n); + auto trD_frag_mn = trD_frag(_, epi_m, epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < FragmentSize; ++epi_v) { + trD_frag_mn(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } + } + } + + cst_callbacks.end(); + + copy(params.xe_store_d, trD, tOuti(_,_,_,l_coord)); } diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index f5f88a073d..e49f94c023 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -157,7 +157,9 @@ struct Sm90TmaWarpSpecializedBiasElementwise { }; #if defined (SYCL_INTEL_TARGET) -struct IntelPVCEpilogue {}; +struct IntelPVCEpilogue { + static constexpr int SubgroupSize = 16; +}; #endif ////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp similarity index 88% rename from include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp rename to include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp index 6ea803e8f6..c0e662b778 100644 --- a/include/cutlass/epilogue/fusion/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp @@ -97,27 +97,6 @@ struct FusionCallbacks< } }; - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const&, Arguments const& args, void*) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const&, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - // Ctor inheritance using Impl::Impl; }; diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index f0d4912c3e..24cae2dfef 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -111,6 +111,16 @@ class GemmUniversal< static constexpr int VecC = CollectiveMainloop::VecC; + // Kernel level shared memory storage + struct SharedStorage { + // Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union + union TensorStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + EpilogueTensorStorage epilogue; + } tensors; + }; + // Device side arguments struct Arguments { GemmUniversalMode mode{}; @@ -188,7 +198,7 @@ class GemmUniversal< void operator()(Params const& params, char* smem_buf) { - (void)smem_buf; + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); // Preconditions CUTE_STATIC_ASSERT(is_static::value); @@ -255,8 +265,7 @@ class GemmUniversal< params.mainloop ); - // copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord)); - CollectiveEpilogue epilogue{params.epilogue}; + CollectiveEpilogue epilogue{params.epilogue, shared_storage.tensors.epilogue}; epilogue( problem_shape_MNKL, subgroup_shape, From 6fcf0880710e90db4d0971c9ba4c32cd304863d3 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Tue, 28 May 2024 16:48:48 +0100 Subject: [PATCH 04/19] Fixed load for C, need to address precision issues --- include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index ef7f449d7d..6bea1e0a2f 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -334,9 +334,10 @@ class CollectiveEpilogue< cst_callbacks.begin(); - // if (is_C_load_needed) { - // copy(params.xe_load_c, tOuti(_,_,_,l_coord), trC); - // } + if (is_C_load_needed) { + Tensor trC_recast = recast(trC); + copy(params.xe_load_c, tOuti(_,_,_,l_coord), trC_recast); + } auto acc_frag = recast>(accumulators); auto c_frag = recast>(trC); From 8766578c895e34c4294e98140e38591ce7c56b4d Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 30 May 2024 10:19:23 +0100 Subject: [PATCH 05/19] Epilogue working with EVT * Need to remove register spill * Need to reduce error margin --- examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 8 ++++---- .../cutlass/epilogue/collective/intel_pvc_epilogue.hpp | 4 +--- .../fusion/sm90_visitor_load_tma_warpspecialized.hpp | 4 ++++ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index d403e99b0d..597ad64a4e 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -51,7 +51,7 @@ template static void fill_matrix(std::vector &vector) { std::generate(std::begin(vector), std::end(vector), [&] { - return static_cast( (rand() / double(RAND_MAX)) ); + return static_cast(10*rand() / double(RAND_MAX) -1 ); }); } @@ -208,8 +208,8 @@ struct ExampleRunner { // Check if output from CUTLASS kernel and reference kernel are relatively equal or not // need to set a larger error margin for comparison to succeed - auto epsilon = static_cast(0.1f); - auto nonzero_floor = static_cast(0.1f); + auto epsilon = static_cast(0.5f); + auto nonzero_floor = static_cast(0.5f); bool passed = cutlass::reference::device::BlockCompareRelativelyEqual( block_ref_D.get(), block_D.get(), block_D.size(), @@ -377,7 +377,7 @@ int main(int argc, const char** argv) ElementOutput, cutlass::gemm::TagToStrideC_t, FusionCallBacks, - XE_2D_U16x16x16x2x1_LD_N, + XE_2D_U32x8x16x1x1_LD_N, void, void, XE_2D_U32x8x16x1x1_ST_N, void, void>; diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index 6bea1e0a2f..a75f317ed5 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -335,12 +335,10 @@ class CollectiveEpilogue< cst_callbacks.begin(); if (is_C_load_needed) { - Tensor trC_recast = recast(trC); - copy(params.xe_load_c, tOuti(_,_,_,l_coord), trC_recast); + copy(params.xe_load_c, tOuti(_,_,_,l_coord), trC); } auto acc_frag = recast>(accumulators); - auto c_frag = recast>(trC); auto trD_frag = recast>(trD); CUTLASS_PRAGMA_UNROLL diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index 1ea663f6f0..320472e3c5 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -116,7 +116,11 @@ struct Sm90SrcFetch : Sm90VisitorImpl<> { template CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { +#ifdef SYCL_INTEL_TARGET + return recast>(tCrC)(epi_v, epi_m, epi_n); +#else return recast>(tCrC)(epi_v); +#endif } }; From 5d18844c3e276598221f1d924cea336b55683a86 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 30 May 2024 14:46:04 +0100 Subject: [PATCH 06/19] Set SmemType to EmptyType --- .../cutlass/epilogue/collective/intel_pvc_epilogue.hpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index a75f317ed5..89ec3f43f9 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -131,12 +131,8 @@ class CollectiveEpilogue< public: using EmptyType = cute::tuple<>; - using SmemCStorage = cute::conditional_t, - EmptyType>; - using SmemDStorage = cute::conditional_t, - EmptyType>; + using SmemCStorage = EmptyType; + using SmemDStorage = EmptyType; struct TensorStorageImpl: cute::tuple { using Base = cute::tuple; From d470b5366da9ca96a91b00ad087a523af6434f45 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 30 May 2024 14:54:20 +0100 Subject: [PATCH 07/19] Reverted fill_matrix change --- examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index 597ad64a4e..3e9bb78d60 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -51,7 +51,7 @@ template static void fill_matrix(std::vector &vector) { std::generate(std::begin(vector), std::end(vector), [&] { - return static_cast(10*rand() / double(RAND_MAX) -1 ); + return static_cast(rand() / double(RAND_MAX) ); }); } From cabd5878b1c6e17fa0a4b0e8f04a4f69acfd5e9e Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 30 May 2024 15:37:54 +0100 Subject: [PATCH 08/19] WIP --- include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp | 7 ++----- .../include/cutlass/util/reference/device/tensor_compare.h | 1 + 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index 89ec3f43f9..c03d449fa7 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -304,7 +304,6 @@ class CollectiveEpilogue< bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); Tensor trC = make_tensor(Shape, Int, Int>{}); - Tensor trD = make_tensor(Shape, Int, Int>{}); Tensor tOuti = params.xe_store_d.get_pvc_tensor(make_coord(m_coord, n_coord, 0), make_shape(Int{}, Int{}, L), make_stride(Int{}, Int{})); @@ -335,7 +334,6 @@ class CollectiveEpilogue< } auto acc_frag = recast>(accumulators); - auto trD_frag = recast>(trD); CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { @@ -345,18 +343,17 @@ class CollectiveEpilogue< cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); auto acc_frag_mn = acc_frag(_, epi_m, epi_n); - auto trD_frag_mn = trD_frag(_, epi_m, epi_n); CUTLASS_PRAGMA_UNROLL for (int epi_v = 0; epi_v < FragmentSize; ++epi_v) { - trD_frag_mn(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + acc_frag_mn(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); } } } cst_callbacks.end(); - copy(params.xe_store_d, trD, tOuti(_,_,_,l_coord)); + copy(params.xe_store_d, accumulators, tOuti(_,_,_,l_coord)); } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index 3c312f5ff8..9fb68f24ff 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -101,6 +101,7 @@ __global__ void Element b = cutlass::ReferenceFactory::get(ptr_B, idx); if (!relatively_equal(a, b, epsilon, nonzero_floor)) { + printf("idx :%lu | a: %f | b: %f\n", idx, a, b); *equal = 0; return; } From fcb4ac6316cb54d0c9d2a6f67cf136fd684f0037 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Fri, 31 May 2024 15:21:29 +0100 Subject: [PATCH 09/19] Reduce C registers --- .../epilogue/collective/intel_pvc_epilogue.hpp | 13 +++++++------ .../sm90_visitor_load_tma_warpspecialized.hpp | 4 ---- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index c03d449fa7..b02e6c9ceb 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -303,11 +303,12 @@ class CollectiveEpilogue< bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - Tensor trC = make_tensor(Shape, Int, Int>{}); + Tensor trC = make_tensor(Shape>{}); Tensor tOuti = params.xe_store_d.get_pvc_tensor(make_coord(m_coord, n_coord, 0), make_shape(Int{}, Int{}, L), make_stride(Int{}, Int{})); + Tensor rw_coord = tOuti(_,_,_,l_coord); Tensor mD_crd = make_identity_tensor(make_shape(M,N)); Tensor cD = local_tile(mD_crd, take<0,2>(TileShapeMNK{}), make_coord(m_coord, n_coord)); // Get the fusion callbacks @@ -329,10 +330,6 @@ class CollectiveEpilogue< cst_callbacks.begin(); - if (is_C_load_needed) { - copy(params.xe_load_c, tOuti(_,_,_,l_coord), trC); - } - auto acc_frag = recast>(accumulators); CUTLASS_PRAGMA_UNROLL @@ -340,6 +337,10 @@ class CollectiveEpilogue< CUTLASS_PRAGMA_UNROLL for (int epi_m = 0; epi_m < FragsM; epi_m++) { + if (is_C_load_needed) { + copy(params.xe_load_c, rw_coord(_, epi_m * FragsM, epi_n * FragsN), trC); + } + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); auto acc_frag_mn = acc_frag(_, epi_m, epi_n); @@ -353,7 +354,7 @@ class CollectiveEpilogue< cst_callbacks.end(); - copy(params.xe_store_d, accumulators, tOuti(_,_,_,l_coord)); + copy(params.xe_store_d, accumulators, rw_coord); } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index 320472e3c5..1ea663f6f0 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -116,11 +116,7 @@ struct Sm90SrcFetch : Sm90VisitorImpl<> { template CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { -#ifdef SYCL_INTEL_TARGET - return recast>(tCrC)(epi_v, epi_m, epi_n); -#else return recast>(tCrC)(epi_v); -#endif } }; From 968a40edf6d8fc7f7e7cc12cafad320ab31301c9 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Fri, 31 May 2024 16:02:54 +0100 Subject: [PATCH 10/19] add D registers --- .../cutlass/epilogue/collective/intel_pvc_epilogue.hpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index b02e6c9ceb..22297c6eba 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -304,6 +304,7 @@ class CollectiveEpilogue< bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); Tensor trC = make_tensor(Shape>{}); + Tensor trD = make_tensor(Shape>{}); Tensor tOuti = params.xe_store_d.get_pvc_tensor(make_coord(m_coord, n_coord, 0), make_shape(Int{}, Int{}, L), make_stride(Int{}, Int{})); @@ -331,6 +332,7 @@ class CollectiveEpilogue< cst_callbacks.begin(); auto acc_frag = recast>(accumulators); + auto trD_frag = recast>(trD); CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { @@ -347,15 +349,15 @@ class CollectiveEpilogue< CUTLASS_PRAGMA_UNROLL for (int epi_v = 0; epi_v < FragmentSize; ++epi_v) { - acc_frag_mn(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + trD_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); } + + copy(params.xe_store_d, trD, rw_coord(_, epi_m * FragsM, epi_n * FragsN)); } } cst_callbacks.end(); - copy(params.xe_store_d, accumulators, rw_coord); - } private: From 415d7847991bc8a7fc666305be8ac2ca2d239013 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Fri, 31 May 2024 16:09:02 +0100 Subject: [PATCH 11/19] Fix read_write coordinates --- include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index 22297c6eba..ba95f2e2ea 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -340,7 +340,7 @@ class CollectiveEpilogue< for (int epi_m = 0; epi_m < FragsM; epi_m++) { if (is_C_load_needed) { - copy(params.xe_load_c, rw_coord(_, epi_m * FragsM, epi_n * FragsN), trC); + copy(params.xe_load_c, rw_coord(_, epi_m, epi_n), trC); } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); @@ -352,7 +352,7 @@ class CollectiveEpilogue< trD_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); } - copy(params.xe_store_d, trD, rw_coord(_, epi_m * FragsM, epi_n * FragsN)); + copy(params.xe_store_d, trD, rw_coord(_, epi_m, epi_n)); } } From 1ee50b565b7fbdae3e362f27d9c5eef18fca2431 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Fri, 31 May 2024 16:12:12 +0100 Subject: [PATCH 12/19] Remove printf --- .../util/include/cutlass/util/reference/device/tensor_compare.h | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index 9fb68f24ff..3c312f5ff8 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -101,7 +101,6 @@ __global__ void Element b = cutlass::ReferenceFactory::get(ptr_B, idx); if (!relatively_equal(a, b, epsilon, nonzero_floor)) { - printf("idx :%lu | a: %f | b: %f\n", idx, a, b); *equal = 0; return; } From ef94bfb36a5b89dbba23d4e7757d4fb72c651c08 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Fri, 31 May 2024 16:30:14 +0100 Subject: [PATCH 13/19] Remove unnecessary code --- .../collective/intel_pvc_epilogue.hpp | 37 ++++++------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index ba95f2e2ea..fa2708cbae 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -55,8 +55,8 @@ namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// template < - class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) - class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) + class CtaTileMNK_, + class EpilogueTile_, class ElementC_, class StrideC_, class ElementD_, @@ -100,10 +100,10 @@ class CollectiveEpilogue< using StrideD = StrideD_; using CopyOpG2R = CopyOpG2R_; using SmemLayoutAtomC = SmemLayoutAtomC_; - using CopyOpS2R = void; + using CopyOpS2R = CopyOpS2R_; using CopyOpR2G = CopyOpR2G_; - using SmemLayoutAtomD = void; // SmemLayoutAtomD_; - using CopyOpR2S = void; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; using GmemTiledCopyC = CopyOpG2R; @@ -121,6 +121,11 @@ class CollectiveEpilogue< static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + static_assert(std::is_same_v, "Intel PVC does not support shared memory"); + static_assert(std::is_same_v, "Intel PVC does not support shared memory"); + static_assert(std::is_same_v, "Intel PVC does not support shared memory"); + static_assert(std::is_same_v, "Intel PVC does not support shared memory"); + private: constexpr static bool is_source_supported = not cute::is_void_v; constexpr static bool is_destination_supported = not cute::is_void_v; @@ -232,26 +237,6 @@ class CollectiveEpilogue< can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { - // constexpr int tma_alignment_bits = 128; - // auto problem_shape_MNKL = append<4>(problem_shape, 1); - // auto [M,N,K,L] = problem_shape_MNKL; - - // bool implementable = true; - // if constexpr (is_destination_supported) { - // constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits::value; - // implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideD{}); - // } - - // if constexpr (not cute::is_void_v) { - // constexpr int min_tma_aligned_elements_C = tma_alignment_bits / cutlass::sizeof_bits::value; - // implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideC{}); - // } - - // if (!implementable) { - // CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - // } - - // return implementable; return true; } @@ -324,7 +309,7 @@ class CollectiveEpilogue< params.xe_load_c, thread_idx, cD, - cD/*tRS_cD*/, + cD, trC }; auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); From 1a8513263becd6ac9b09c473b54c976b7596a6d5 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Mon, 3 Jun 2024 10:02:40 +0100 Subject: [PATCH 14/19] Address PR feedback v1 --- examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 2 +- include/cutlass/gemm/kernel/intel_pvc_gemm.hpp | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index 3e9bb78d60..ee34f826f8 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -51,7 +51,7 @@ template static void fill_matrix(std::vector &vector) { std::generate(std::begin(vector), std::end(vector), [&] { - return static_cast(rand() / double(RAND_MAX) ); + return static_cast( 2 * (rand() / double(RAND_MAX)) -1 ); }); } diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 24cae2dfef..6e7aee895b 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -113,12 +113,8 @@ class GemmUniversal< // Kernel level shared memory storage struct SharedStorage { - // Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union - union TensorStorage { - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - EpilogueTensorStorage epilogue; - } tensors; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; }; // Device side arguments @@ -265,7 +261,7 @@ class GemmUniversal< params.mainloop ); - CollectiveEpilogue epilogue{params.epilogue, shared_storage.tensors.epilogue}; + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; epilogue( problem_shape_MNKL, subgroup_shape, From bf6e3edcd5baaf04ca7f70a742f2d9a205cacc05 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Tue, 4 Jun 2024 10:59:04 +0100 Subject: [PATCH 15/19] Address PR feedback v2 --- examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 4 ++-- .../cutlass/epilogue/collective/intel_pvc_epilogue.hpp | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index ee34f826f8..cf1d196479 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -208,8 +208,8 @@ struct ExampleRunner { // Check if output from CUTLASS kernel and reference kernel are relatively equal or not // need to set a larger error margin for comparison to succeed - auto epsilon = static_cast(0.5f); - auto nonzero_floor = static_cast(0.5f); + auto epsilon = static_cast(0.1f); + auto nonzero_floor = static_cast(0.1f); bool passed = cutlass::reference::device::BlockCompareRelativelyEqual( block_ref_D.get(), block_D.get(), block_D.size(), diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index fa2708cbae..cacd4a5b7a 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -121,10 +121,10 @@ class CollectiveEpilogue< static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); - static_assert(std::is_same_v, "Intel PVC does not support shared memory"); - static_assert(std::is_same_v, "Intel PVC does not support shared memory"); - static_assert(std::is_same_v, "Intel PVC does not support shared memory"); - static_assert(std::is_same_v, "Intel PVC does not support shared memory"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); private: constexpr static bool is_source_supported = not cute::is_void_v; From 74720da315ab28286788d383f0d03faad48ab8a8 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Wed, 5 Jun 2024 16:21:39 +0100 Subject: [PATCH 16/19] Removed unnecessary code --- .../epilogue/collective/intel_pvc_epilogue.hpp | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index cacd4a5b7a..71569fcff2 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -140,18 +140,6 @@ class CollectiveEpilogue< using SmemDStorage = EmptyType; struct TensorStorageImpl: cute::tuple { - using Base = cute::tuple; - - constexpr decltype(auto) - smem_C() { - return cute::get<0>(static_cast(*this)); - } - - constexpr decltype(auto) - smem_D() { - return cute::get<1>(static_cast(*this)); - } - using FusionStorage = typename FusionCallbacks::SharedStorage; FusionStorage thread; }; @@ -298,7 +286,7 @@ class CollectiveEpilogue< Tensor mD_crd = make_identity_tensor(make_shape(M,N)); Tensor cD = local_tile(mD_crd, take<0,2>(TileShapeMNK{}), make_coord(m_coord, n_coord)); // Get the fusion callbacks - constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout + constexpr bool RefSrc = true; auto residue_mn = make_coord(M, N); auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ problem_shape_mnkl, From d04b3369eaf8b3c634b61a4d16f3a171b9c215b7 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Fri, 14 Jun 2024 16:11:33 +0100 Subject: [PATCH 17/19] Update dpas* variables --- .../cutlass/epilogue/collective/intel_pvc_epilogue.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index 71569fcff2..42cc0fd52f 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -116,8 +116,8 @@ class CollectiveEpilogue< static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); - static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); - static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + //static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + //static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); @@ -262,8 +262,10 @@ class CollectiveEpilogue< (void) smem; using namespace cute; - static constexpr int DpasM = get<0>(shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per sub_group for Matrix A - static constexpr int DpasN = get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per sub_group for Matrix B + using DpasShape = typename TiledMma::Shape_MNK; + + static constexpr int DpasM = get<0>(DpasShape()); // rows per dpas operation per sub_group for Matrix A + static constexpr int DpasN = get<1>(DpasShape()); // cols per dpas operation per sub_group for Matrix B static constexpr int FragsM = get<0>(EpilogueTile{}) / DpasM; // A frags per sub_group static constexpr int FragsN = get<1>(EpilogueTile{}) / DpasN; // B frags per sub_group From 8d1f87bb7cf7adaa24fcc62882128979271e82b4 Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Wed, 19 Jun 2024 14:24:43 +0100 Subject: [PATCH 18/19] rebase --- examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index cf1d196479..9fe7f4f4ec 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -51,7 +51,7 @@ template static void fill_matrix(std::vector &vector) { std::generate(std::begin(vector), std::end(vector), [&] { - return static_cast( 2 * (rand() / double(RAND_MAX)) -1 ); + return static_cast( (rand() / double(RAND_MAX)) ); }); } @@ -357,6 +357,7 @@ int main(int argc, const char** argv) // Workgroup-level tile using TileShape = Shape<_32, _256, _32>; + using EpilogueShape = Shape<_32, _64>; using TiledMma = TiledMMA, Layout>, From c3a3074d1957bea8e2784764fecc7b9803c63de0 Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Thu, 20 Jun 2024 16:15:51 +0100 Subject: [PATCH 19/19] Use subgroup tile information from MMATile --- .../sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 8 ++--- .../collective/intel_pvc_epilogue.hpp | 30 +++++++------------ 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index 9fe7f4f4ec..204214da11 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -357,7 +357,6 @@ int main(int argc, const char** argv) // Workgroup-level tile using TileShape = Shape<_32, _256, _32>; - using EpilogueShape = Shape<_32, _64>; using TiledMma = TiledMMA, Layout>, @@ -366,13 +365,14 @@ int main(int argc, const char** argv) using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, - EpilogueShape, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index 42cc0fd52f..4d33308653 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -56,7 +56,6 @@ namespace collective { template < class CtaTileMNK_, - class EpilogueTile_, class ElementC_, class StrideC_, class ElementD_, @@ -72,7 +71,6 @@ template < class CollectiveEpilogue< IntelPVCEpilogue, CtaTileMNK_, - EpilogueTile_, ElementC_, StrideC_, ElementD_, @@ -91,7 +89,6 @@ class CollectiveEpilogue< // using DispatchPolicy = IntelPVCEpilogue; using CtaTileMNK = CtaTileMNK_; - using EpilogueTile = EpilogueTile_; using FusionCallbacks = FusionCallbacks_; using ElementC = ElementC_; using ElementAccumulator = ElementC_; @@ -113,11 +110,7 @@ class CollectiveEpilogue< static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); - static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); - //static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); - //static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); @@ -262,15 +255,13 @@ class CollectiveEpilogue< (void) smem; using namespace cute; - using DpasShape = typename TiledMma::Shape_MNK; + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + using SubgroupTileShape = decltype(tile_shape(TiledMma())); - static constexpr int DpasM = get<0>(DpasShape()); // rows per dpas operation per sub_group for Matrix A - static constexpr int DpasN = get<1>(DpasShape()); // cols per dpas operation per sub_group for Matrix B + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group - static constexpr int FragsM = get<0>(EpilogueTile{}) / DpasM; // A frags per sub_group - static constexpr int FragsN = get<1>(EpilogueTile{}) / DpasN; // B frags per sub_group - - static constexpr int FragmentSize = (DpasN * DpasM) / SubgroupSize; + static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; @@ -280,13 +271,14 @@ class CollectiveEpilogue< Tensor trC = make_tensor(Shape>{}); Tensor trD = make_tensor(Shape>{}); - Tensor tOuti = params.xe_store_d.get_pvc_tensor(make_coord(m_coord, n_coord, 0), - make_shape(Int{}, Int{}, L), - make_stride(Int{}, Int{})); + Tensor tOuti = params.xe_store_d.get_pvc_tensor( + make_coord(m_coord, n_coord, 0), + make_shape(Int{}, Int{}, L), + make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{})); Tensor rw_coord = tOuti(_,_,_,l_coord); Tensor mD_crd = make_identity_tensor(make_shape(M,N)); - Tensor cD = local_tile(mD_crd, take<0,2>(TileShapeMNK{}), make_coord(m_coord, n_coord)); + Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(m_coord, n_coord)); // Get the fusion callbacks constexpr bool RefSrc = true; auto residue_mn = make_coord(M, N); @@ -295,7 +287,7 @@ class CollectiveEpilogue< TileShapeMNK{}, tile_coord_mnkl, residue_mn, - EpilogueTile{}, + SubgroupTileShape{}, params.xe_load_c, thread_idx, cD,