From fae9b2e19f799ba7ffbadec70d3ce9bd1e2701ed Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Tue, 16 Jul 2024 13:44:45 +0100 Subject: [PATCH] Epilogue with RELU (#86) --- examples/sycl/pvc/CMakeLists.txt | 9 +- ...bfloat_dpas_gemm_cute.cpp => pvc_gemm.cpp} | 1 - .../sycl/pvc/pvc_gemm_with_epilogue_relu.cpp | 420 ++++++++++++++++++ .../collective/intel_pvc_epilogue.hpp | 4 +- .../epilogue/fusion/intel_pvc_callbacks.hpp | 56 +++ .../util/reference/device/tensor_relu.h | 9 + 6 files changed, 493 insertions(+), 6 deletions(-) rename examples/sycl/pvc/{pvc_bfloat_dpas_gemm_cute.cpp => pvc_gemm.cpp} (99%) create mode 100644 examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp diff --git a/examples/sycl/pvc/CMakeLists.txt b/examples/sycl/pvc/CMakeLists.txt index 3ac67d2319..2911959327 100644 --- a/examples/sycl/pvc/CMakeLists.txt +++ b/examples/sycl/pvc/CMakeLists.txt @@ -28,6 +28,11 @@ cutlass_example_add_executable( - pvc_bfloat_dpas_gemm_cute - pvc_bfloat_dpas_gemm_cute.cpp + pvc_gemm + pvc_gemm.cpp +) + +cutlass_example_add_executable( + pvc_gemm_with_epilogue_relu + pvc_gemm_with_epilogue_relu.cpp ) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_gemm.cpp similarity index 99% rename from examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp rename to examples/sycl/pvc/pvc_gemm.cpp index f19c7b2165..5141a084cd 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -31,7 +31,6 @@ #define CUTLASS_SYCLCOMPAT_PROFILING_ENABLED -#include "cutlass/gemm/device/gemm.h" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp" #include "cutlass/epilogue/fusion/intel_pvc_callbacks.hpp" diff --git a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp new file mode 100644 index 0000000000..2075379580 --- /dev/null +++ b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp @@ -0,0 +1,420 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp" +#include "cutlass/epilogue/fusion/intel_pvc_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +template +static void fill_matrix(std::vector &vector) +{ + std::generate(std::begin(vector), std::end(vector), [&] { + return static_cast( (rand() / double(RAND_MAX)) ); + }); +} + +template +static void vnni_matrix( + T* dst, const T* src, + int batch, int numRows, int numCols, int factor) +{ + for (int b = 0; b < batch; b++) { + for (int r = 0; r < numRows / factor; r++) { + for (int c = 0; c < numCols; c++) { + for (int k = 0; k < factor; k++) { + dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = + src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; + } + } + } + } +} + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(4096), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 4096); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_B_vnni; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + syclcompat::wait(); + + using TensorView = cutlass::TensorView; + cutlass::reference::device::TensorReLu(TensorView(block_ref_D.get(), LayoutD::packed({M, N}), + cutlass::make_Coord(M, N))); + + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are relatively equal or not + // need to set a larger error margin for comparison to succeed + auto epsilon = static_cast(0.1f); + auto nonzero_floor = static_cast(0.1f); + + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual( + block_ref_D.get(), block_D.get(), block_D.size(), + epsilon, nonzero_floor); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_B_vnni.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + // TODO: Enable initialization on device directly once RNG is + // available through SYCL. + std::vector a(K * M * L); + std::vector b(K * N * L); + std::vector b_vnni(b.size()); + std::vector c(M * N * L); + std::vector d(M * N * L, ElementC{0}); + + fill_matrix(a); + fill_matrix(b); + fill_matrix(c); + vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2); + + syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA)); + syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB)); + syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB)); + syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); + syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B_vnni.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; + + // Workgroup-level tile + using TileShape = Shape<_32, _256, _32>; + + using TiledMma = TiledMMA, + Layout>, + Tile<_32,_64,_32>>; // Subgroup level-tile + + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16x1x1_LD_N, + void, void, + XE_2D_U32x8x16x1x1_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + runner.run(options, hw_info); + + return 0; +} diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp index 4d33308653..f6e6f5c790 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -43,7 +43,6 @@ #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" #include "cutlass/detail/layout.hpp" - #include "cute/tensor.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -55,7 +54,7 @@ namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// template < - class CtaTileMNK_, + class CtaTileMNK_, class ElementC_, class StrideC_, class ElementD_, @@ -318,7 +317,6 @@ class CollectiveEpilogue< for (int epi_v = 0; epi_v < FragmentSize; ++epi_v) { trD_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); } - copy(params.xe_store_d, trD, rw_coord(_, epi_m, epi_n)); } } diff --git a/include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp b/include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp index c0e662b778..76fa811725 100644 --- a/include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp @@ -101,6 +101,62 @@ struct FusionCallbacks< using Impl::Impl; }; + +template < + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelPVCEpilogue, + fusion::LinCombEltAct, + CtaTileShapeMNK_, + EpilogueTile_ +> : Sm90LinCombEltAct { + + using Impl = Sm90LinCombEltAct::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Operation = fusion::LinCombEltAct; + + struct Arguments { + ElementScalar_ alpha = ElementScalar_(1); + ElementScalar_ beta = ElementScalar_(0); + ElementScalar_ const* alpha_ptr = nullptr; + ElementScalar_ const* beta_ptr = nullptr; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + } // namespace cutlass::epilogue::fusion ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/device/tensor_relu.h b/tools/util/include/cutlass/util/reference/device/tensor_relu.h index 4e5a50403c..3cf499d951 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_relu.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_relu.h @@ -139,3 +139,12 @@ void TensorReLu( } // namespace device } // namespace reference } // namespace cutlass + +#if (CUTLASS_ENABLE_SYCL) +namespace sycl { + template <> + struct is_device_copyable < + cutlass::reference::device::detail::TensorReLuFunc::Params> : std::true_type {}; +} +#endif