From 5b1f514dcb26f811d4b30e3ff2bfab3f7bfd1538 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Tue, 23 Jul 2024 00:52:16 -0700 Subject: [PATCH] rm epilogue and revert gemm example --- build.sh | 40 -- examples/sycl/pvc/pvc_gemm.cpp | 437 +++++++----------- .../epilogue/collective/default_epilogue.hpp | 33 -- .../intel_pvc_epilogue_tensor_softmax.hpp | 157 ------- .../epilogue/thread/linear_combination_relu.h | 11 - 5 files changed, 164 insertions(+), 514 deletions(-) delete mode 100644 build.sh delete mode 100644 include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp diff --git a/build.sh b/build.sh deleted file mode 100644 index 3ceb147e9e..0000000000 --- a/build.sh +++ /dev/null @@ -1,40 +0,0 @@ -script_dir=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -cp ${script_dir}/tools/clang-format/clang-format.hook ${script_dir}/.git/hooks/pre-commit -chmod +x ${script_dir}/.git/hooks/pre-commit - -# https://github.com/intel/llvm/releases/tag/nightly-2024-07-03 -sycl_compiler_path=/opt/cutlass/compiler/0703/ - -# https://ubit-gfx.intel.com/build/19168301/artifacts -gpu_driver_path=/opt/cutlass/gpu_driver/gfx-driver-ci-comp_igc-25012/extract/ - -# AOT compile -output=intel_gpu_pvc -# jit compile -#output=spir64 - -unset epilogue - -# epilogue relu -# epilogue+=" -DEPILOGUE_RELU " - -# epilogue softmax -# epilogue+=" -DEPILOGUE_SOFTMAX " - -export ZE_AFFINITY_MASK=0 -export CPATH=$sycl_compiler_path:$sycl_compiler_path/include/:$sycl_compiler_path/include/sycl/ -export LIBRARY_PATH=$gpu_driver_path/usr/lib/x86_64-linux-gnu/:$sycl_compiler_path/lib/ -export LD_LIBRARY_PATH=$LIBRARY_PATH -export IGC_EnableVISANoSchedule=1 -export IGC_ShaderDumpEnable=1 -export IGC_DumpToCustomDir=./mm_dumps -export IGC_VATemp=1 -export ONEAPI_DEVICE_SELECTOR=level_zero:gpu - -target=./examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute -rm -rf * - -cmake .. -G Ninja -DCMAKE_CUDA_HOST_COMPILER=${sycl_compiler_path}/bin/clang++ \ --DCUTLASS_ENABLE_SYCL=ON -DDPCPP_SYCL_TARGET=$output -DCMAKE_CXX_COMPILER=${sycl_compiler_path}/bin/clang++ \ --DCMAKE_CXX_FLAGS=" -DPREFETCH_DEFAULT -DSYCL_INTEL_TARGET ${epilogue} " \ -&& ninja -v $target && $target diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 93d6dc78de..51c44d6a79 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -32,77 +32,73 @@ #define CUTLASS_SYCLCOMPAT_PROFILING_ENABLED #include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/gemm/collective/collective_mma.hpp" #include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp" #include "cutlass/epilogue/fusion/intel_pvc_callbacks.hpp" -#include "cutlass/gemm/device/gemm.h" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" #include "cutlass/util/GPU_Clock.hpp" #include #include -#include "cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp" #include "cutlass/util/command_line.h" #include "cutlass/util/device_memory.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" -template static void fill_matrix(std::vector& M) { - std::random_device dev; - std::mt19937 rng(dev()); - std::uniform_real_distribution dist((T)0.0, (T)1.0); - std::generate(std::begin(M), std::end(M), [&] { return static_cast(dist(rng)); }); +template +static void fill_matrix(std::vector &vector) +{ + std::generate(std::begin(vector), std::end(vector), [&] { + return static_cast( (rand() / double(RAND_MAX)) ); + }); } template -static void vnni_matrix(T* dst, T const* src, int batch, int numRows, int numCols, int factor) { - for (int b = 0; b < batch; b++) { - for (int r = 0; r < numRows / factor; r++) { - for (int c = 0; c < numCols; c++) { - for (int k = 0; k < factor; k++) { - dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = - src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; - } +static void vnni_matrix( + T* dst, const T* src, + int batch, int numRows, int numCols, int factor) +{ + for (int b = 0; b < batch; b++) { + for (int r = 0; r < numRows / factor; r++) { + for (int c = 0; c < numCols; c++) { + for (int k = 0; k < factor; k++) { + dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = + src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; + } + } } } - } } using namespace cute; -using ElementAccumulator = float; // <- data type of accumulator -using ElementComputeEpilogue = float; // <- data type of epilogue operations -using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A -using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B -using ElementOutput = float; // <- data type of elements in output matrix D - /////////////////////////////////////////////////////////////////////////////////////////////////// // Command line options parsing struct Options { - bool help; bool error; - int m, n, k, l, iterations; float alpha, beta; - Options() - : help(false), error(false), m(4096), n(4096), k(4096), l(1), iterations(100), alpha(1.f), - beta(0.f) {} + Options(): + help(false), + error(false), + m(4096), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } // Parses the command line - void parse(int argc, char const** args) { + void parse(int argc, char const **args) { cutlass::CommandLine cmd(argc, args); if (cmd.check_cmd_line_flag("help")) { help = true; return; } - cmd.get_cmd_line_argument("m", m, 4096); cmd.get_cmd_line_argument("n", n, 4096); cmd.get_cmd_line_argument("k", k, 4096); @@ -113,20 +109,18 @@ struct Options { } /// Prints the usage statement. - std::ostream& print_usage(std::ostream& out) const { + std::ostream & print_usage(std::ostream &out) const { out << "PVC GEMM Example\n\n" - << "Options:\n\n" - << " --help If specified, displays this " - "usage statement\n\n" - << " --m= Sets the M extent of the GEMM\n" - << " --n= Sets the N extent of the GEMM\n" - << " --k= Sets the K extent of the GEMM\n" - << " --l= Sets the L extent (batch count) " - "of the GEMM\n" - << " --alpha= Epilogue scalar alpha\n" - << " --beta= Epilogue scalar beta\n\n" - << " --iterations= Iterations\n\n"; + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; return out; } @@ -134,34 +128,31 @@ struct Options { /////////////////////////////////////////////////////////////////////////////////////////////////// -template struct ExampleRunner { +template < + class Gemm +> +struct ExampleRunner { using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; using StrideC = typename Gemm::GemmKernel::StrideC; using StrideD = typename Gemm::GemmKernel::StrideD; - using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; using LayoutD = typename Gemm::LayoutD; - using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; using ElementAcc = typename Gemm::ElementAccumulator; - using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - // // Data members // - /// Initialization StrideA stride_A; StrideB stride_B; @@ -170,256 +161,181 @@ template struct ExampleRunner { cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; - // cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_B_vnni; + cutlass::DeviceAllocation block_C; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; - static auto constexpr l3_cache_size = 256 * 1024 * 1024; - - size_t PINGPONG_ITER = 1; - size_t pingpong_size_a; - size_t pingpong_size_b; - size_t pingpong_size_d; - - std::vector a; - std::vector b; - std::vector d; // // Methods // - bool verify(ProblemShapeType const& problem_size, ElementCompute alpha, ElementCompute beta) { + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { auto [M, N, K, L] = problem_size; cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); - cutlass::TensorRef ref_C((ElementC*)nullptr /*block_C.get()*/, LayoutC::packed({M, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); cutlass::reference::device::GemmComplex( - {M, N, K}, alpha, ref_A, cutlass::ComplexTransform::kNone, ref_B, - cutlass::ComplexTransform::kNone, beta, ref_C, ref_D, ElementAccumulator(0), - L, // batch_count - M * K, // batch_stride_A - K * N, // batch_stride_B - M * N, // batch_stride_C - M * N // batch_stride_D - ); + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); syclcompat::wait(); - // Check if output from CUTLASS kernel and reference kernel are relatively - // equal or not need to set a larger error margin for comparison to succeed + // Check if output from CUTLASS kernel and reference kernel are relatively equal or not + // need to set a larger error margin for comparison to succeed + auto epsilon = static_cast(0.1f); + auto nonzero_floor = static_cast(0.1f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual( - block_ref_D.get(), block_D.get(), M * N * L, 0.5f, 0.5f); + block_ref_D.get(), block_D.get(), block_D.size(), + epsilon, nonzero_floor); return passed; } - void init_cache_clear(ProblemShapeType const& problem_size) { + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; - pingpong_size_a = max((size_t)M * K * L, l3_cache_size / sizeof(ElementA)); - pingpong_size_b = max((size_t)K * N * L, l3_cache_size / sizeof(ElementB)); - pingpong_size_d = max((size_t)M * N * L, l3_cache_size / sizeof(ElementOutput)); - auto gmem_size = syclcompat::get_current_device().get_global_mem_size(); - PINGPONG_ITER = std::min((size_t)3, - std::max((size_t)1, (size_t)gmem_size / ((pingpong_size_a * sizeof(ElementA) + - pingpong_size_b * sizeof(ElementB) + - pingpong_size_d * sizeof(ElementOutput))) - - 1)); - block_A.reset(pingpong_size_a * PINGPONG_ITER); - block_B.reset(pingpong_size_b * PINGPONG_ITER); - // block_C.reset(M * N * L * ITER); - block_D.reset(pingpong_size_d * PINGPONG_ITER); - - for (int i = 0; i < PINGPONG_ITER; i++) { - syclcompat::memcpy( - block_A.get() + i * pingpong_size_a, a.data(), a.size() * sizeof(ElementA)); - syclcompat::memcpy( - block_B.get() + i * pingpong_size_b, b.data(), b.size() * sizeof(ElementB)); - syclcompat::memcpy( - block_D.get() + i * pingpong_size_d, d.data(), d.size() * sizeof(ElementC)); - } - // syclcompat::wait(); - } - - /// Initialize operands to be used in the GEMM and reference GEMM - void initialize(ProblemShapeType const& problem_size) { - auto [M, N, K, L] = problem_size; - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(K, N, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - block_A.reset((size_t)M * K * L); - block_B.reset((size_t)K * N * L); - // block_C.reset(M * N * L); - block_D.reset((size_t)M * N * L); - block_ref_D.reset((size_t)max(l3_cache_size / sizeof(ElementOutput), (size_t)M * N * L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_B_vnni.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); // TODO: Enable initialization on device directly once RNG is // available through SYCL. - a = std::vector((size_t)M * K * L); - b = std::vector((size_t)K * N * L); - d = std::vector((size_t)M * N * L, ElementC{0}); - std::cout << "random generating..." << std::endl; + std::vector a(K * M * L); + std::vector b(K * N * L); + std::vector b_vnni(b.size()); + std::vector c(M * N * L); + std::vector d(M * N * L, ElementC{0}); + fill_matrix(a); fill_matrix(b); + fill_matrix(c); + vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2); + syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA)); syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB)); - // syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); + syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB)); + syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); } - template - void run(int M, int K, int N, int L, cutlass::KernelHardwareInfo const& hw_info) { - static auto constexpr warmup = 10; - static auto constexpr testIterations = 10; - static auto constexpr total_iterations = warmup + testIterations; - ProblemShapeType problem_size = ProblemShapeType{M, N, K, L}; + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + initialize(problem_size); sycl::property_list prop = { sycl::property::queue::in_order(), sycl::property::queue::enable_profiling() }; - auto q = sycl::queue(syclcompat::get_default_context(), syclcompat::get_current_device(), prop); syclcompat::set_default_queue(q); typename Gemm::GemmKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {block_A.get(), stride_A, block_B.get(), stride_B}, - {{1, 0.f}, - nullptr /*block_C.get()*/, - stride_C, - block_D.get(), - stride_D}, - hw_info}; - Gemm gemm_op_verify; + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B_vnni.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); - gemm_op_verify.can_implement(arguments); + gemm_op.can_implement(arguments); - gemm_op_verify.initialize(arguments, workspace.get()); + gemm_op.initialize(arguments, workspace.get()); // Run the GEMM - gemm_op_verify.run(); + gemm_op.run(); + syclcompat::wait(); // Verify that the result is correct - bool passed = verify(problem_size, 1, 0.f); - if (!passed) { - printf("PVC GEMM Example %s, MKNL(%d, %d,%d,%d), Config(%d, " - "%d,%d,%d,%d) !!!!!!!!!!!!!\n\n", - (passed ? "Passed" : "Failed"), M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, - sg_tile_k); - // return; - } - - // ================ init cache clear ================ - if constexpr (cache_clear) { - init_cache_clear(problem_size); - } - - // ================ run and collect performance data ================ - if (total_iterations > 0) { - auto total_time = 0.f; - auto best = 999.f; - auto worst = 0.f; - - for (int i = 0; i < testIterations + warmup; ++i) { - typename Gemm::GemmKernel::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {block_A.get() + (i % PINGPONG_ITER) * pingpong_size_a, stride_A, - block_B.get() + (i % PINGPONG_ITER) * pingpong_size_b, stride_B}, - {{1, 0.f}, nullptr /*block_C.get() + i * M * N * L*/, stride_C, - block_D.get() + (i % PINGPONG_ITER) * pingpong_size_d, stride_D}, - hw_info}; - - Gemm gemm_op; - gemm_op.can_implement(arguments); - gemm_op.initialize(arguments, workspace.get()); - - GPU_Clock timer; - timer.start(); + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + // timer.start(); + for (int i = 0; i < options.iterations; ++i) { + if (i == 10) timer.start(); gemm_op.run(); - syclcompat::wait(); - - auto current_time = timer.seconds(); - if (i >= warmup) { - total_time += current_time; - - best = min(best, current_time); - - worst = max(worst, current_time); - } } + syclcompat::wait(); - float average = total_time / testIterations; - double tflops = (2.0 * M * N * K * L) * 1e-12; - - double hbm = L * - (M * K * sizeof(ElementInputA) + K * N * sizeof(ElementInputB) + - M * N * sizeof(ElementOutput)) * - 1e-9; - - printf("Collective pvc gemm, MKNL(%d, %d, %d, %d), Config(%d, %d, " - "%d, %d, %d):\n max: (%6.4f)ms, (%4.2f)TFlop/s, " - "(%4.2f)GB/s\n min: (%6.4f)ms, (%4.2f)TFlop/s, " - "(%4.2f)GB/s\n average: (%6.4f)ms, (%4.2f)TFlop/s, " - "(%4.2f)GB/s\n\n\n", - M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, best * 1000, - tflops / best, hbm / best, worst * 1000, tflops / worst, hbm / worst, average * 1000, - tflops / average, hbm / average); + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); } + + return; } + }; -template -void collective_gemm(int M, int K, int N, int L = 1) { +int main(int argc, const char** argv) +{ // // Parse options // Options options; - // options.parse(argc, argv); + options.parse(argc, argv); if (options.help) { options.print_usage(std::cout) << std::endl; - return; + return 0; } if (options.error) { std::cerr << "Aborting execution." << std::endl; - return; + return -1; } // // Run examples // - // The KernelHardwareInfo struct holds the number of EUs on the GPU with a - // given device ID. This information is used by the underlying kernel. + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. cutlass::KernelHardwareInfo hw_info; - // Change device_id to another value if you are running on a machine with - // multiple GPUs and wish to use a GPU other than that with device ID 0. - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); bool passed; @@ -431,9 +347,6 @@ void collective_gemm(int M, int K, int N, int L = 1) { using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B using ElementOutput = float; // <- data type of elements in output matrix D - // The code section below describes datatype for input, output matrices and - // computation between elements in input matrices. - using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; @@ -443,13 +356,11 @@ void collective_gemm(int M, int K, int N, int L = 1) { using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; using TileShape = Shape<_256, _256, _32>; - // using TileShape = - // Shape, Int, Int, Int, Int>; using TiledMma = TiledMMA, Layout>, Tile<_32,_64,_32>>; - + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; @@ -458,63 +369,43 @@ void collective_gemm(int M, int K, int N, int L = 1) { using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - - // Mainloop - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma, ElementInputB, - cutlass::gemm::TagToStrideB_t, TiledMma, GmemTiledCopyA, void, void, - cute::identity, // A - GmemTiledCopyB, void, void, cute::identity // B - >; - using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, - TileShape, - ElementAccumulator, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16x1x1_LD_N, - void, void, - XE_2D_U32x8x16x1x1_ST_N, - void, void>; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal, - CollectiveMainloop, CollectiveEpilogue>; + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16x1x1_LD_N, + void, void, + XE_2D_U32x8x16x1x1_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - ExampleRunner runner; + ExampleRunner runner; - runner.template run(M, K, N, L, hw_info); -} + runner.run(options, hw_info); -int main() { - collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 4096); - collective_gemm<256, 256, 32, 64, 32>(8192, 8192, 8192); - collective_gemm<256, 256, 32, 64, 32>(1, 5120, 13824); - collective_gemm<256, 256, 32, 64, 32>(1024, 28672, 8192); - collective_gemm<256, 256, 32, 64, 32>(3072, 4096, 3072); - collective_gemm<256, 256, 32, 64, 32>(4, 4096, 12288); - - // collective shape from habana - collective_gemm<256, 256, 32, 64, 32>(512, 8192, 8192); - collective_gemm<256, 256, 32, 64, 32>(512, 8192, 32768); - collective_gemm<256, 256, 32, 64, 32>(512, 32768, 8192); - collective_gemm<256, 256, 32, 64, 32>(16384, 8192, 1024); - collective_gemm<256, 256, 32, 64, 32>(16384, 1024, 8192); - collective_gemm<256, 256, 32, 64, 32>(16384, 8192, 4096); - collective_gemm<256, 256, 32, 64, 32>(16384, 4096, 8192); - collective_gemm<256, 256, 32, 64, 32>(4096, 16384, 8192); - collective_gemm<256, 256, 32, 64, 32>(8192, 16384, 4096); - collective_gemm<256, 256, 32, 64, 32>(1024, 16384, 8192); - collective_gemm<256, 256, 32, 64, 32>(8192, 16384, 1024); - - collective_gemm<256, 256, 32, 64, 32>(8, 128, 16384, 4096); - collective_gemm<16, 512, 16, 16, 32>(8, 16384, 128, 4096); - - collective_gemm<256, 256, 32, 64, 32>(32768, 128, 4096, 4); - collective_gemm<256, 256, 32, 64, 32>(32768, 4096, 128, 4); - collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 128, 32); -} + return 0; +} \ No newline at end of file diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index de24020265..bbeeacacd3 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -147,39 +147,6 @@ class DefaultEpilogue { return epilogue_op.is_source_needed(); } - template< - class ProblemShapeMNKL, - class BlockShapeMNK, - class BlockCoordMNKL, - class FrgEngine, class FrgLayout> - CUTLASS_HOST_DEVICE void - operator()( - ProblemShapeMNKL problem_shape_mnkl, - BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, - cute::Tensor & accumulators){ - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - if (epilogue_op.is_source_needed()) { - auto source = make_fragment_like(accumulators); - auto gmem_tiled_copy_c = - make_xe_2d_copy(make_tensor( - params.ptr_C, make_shape(M, N, L), params.dC)); - - Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), - make_shape(size<1>(accumulators), size<2>(accumulators), L), - make_stride(size<0>(blk_shape_MNK), size<1>(blk_shape_MNK))); - copy(gmem_tiled_copy_c, tCi(_, _, _, l_coord), source); - epilogue_op(accumulators, source); - } else { - epilogue_op(accumulators); - } - } - template< class ProblemShapeMNKL, class BlockShapeMNK, diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp deleted file mode 100644 index 01bd25b7ec..0000000000 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp +++ /dev/null @@ -1,157 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" - -#include "cute/tensor.hpp" -#include "cutlass/cuda_host_adapter.hpp" -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class PvcEpilogueTensorSoftmax { -public: - using EpilogueSchedule = EpilogueSchedule_; - using DispatchPolicy = EpilogueSchedule_; - - // derived types of output thread level operator - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using ElementD = typename ThreadEpilogueOp::ElementD; - using StrideD = StrideD_; - - using GmemTiledCopyC = void; - using GmemTiledCopyD = void; - - // Host side epilogue arguments - struct Arguments { - typename ThreadEpilogueOp::Params thread{}; - ElementC const* ptr_C = nullptr; - StrideC dC{}; - ElementD* ptr_D = nullptr; - StrideD dD{}; - }; - - // Device side epilogue params - using Params = Arguments; - - template - static Params constexpr to_underlying_arguments([[maybe_unused]] ProblemShape const& _, - Arguments const& args, - [[maybe_unused]] void* workspace) { - return args; - } - - template CUTLASS_DEVICE void operator()(T& t) { - static_assert(cute::is_same_v && m <= 32); - - auto const& group = sycl::ext::oneapi::experimental::this_nd_item<3>().get_group(); - - static auto constexpr vec_size = 4; - - static_assert((m % vec_size) == 0 && vec_size <= 16); - static auto constexpr loop_cnt = m / vec_size; - - sycl::vec local_max; - sycl::vec local_plus; - - for (int loop = 0; loop < loop_cnt; loop++) { - - auto base_row = loop * vec_size; - // init local max - for (int i = 0; i < vec_size; i++) { - local_max[i] = t[(base_row + i) * n]; - } - - for (int i = 0; i < vec_size; i++) { - for (int j = 0; j < n; j++) { - local_max[i] = max(local_max[i], t((base_row + i) * n + j)); - } - } - - // get group max - auto group_max = reduce_over_group(group, local_max, sycl::maximum<>()); - - // -max, exp, and get local plus - for (int i = 0; i < vec_size; i++) { - for (int j = 0; j < n; j++) { - auto offset = (base_row + i) * n + j; - t[offset] -= group_max[i]; - t[offset] = sycl::exp(t[offset]); - - local_plus[i] += t[offset]; - } - } - - // get group plus - auto group_plus = reduce_over_group(group, local_plus, sycl::plus<>()); - - // last div - for (int i = 0; i < vec_size; i++) { - for (int j = 0; j < n; j++) { - auto offset = (base_row + i) * n + j; - t[offset] = t[offset] / group_plus[i]; - // local_sum += t[i * n + j]; - } - } - } - - // printf("verify softmax, local_sum: %f, group_sum: %f\n", local_sum, - // reduce_over_group(group, local_sum, sycl::plus<>())); - // } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index 343e2a9ec2..2d66a4e2a8 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -184,17 +184,6 @@ class LinearCombinationRelu { } } - using ElementC = ElementOutput_; - using ElementD = ElementOutput_; - - template - CUTLASS_HOST_DEVICE - void operator()(TensorType &accumulators) const { - for (int i = 0; i < size(accumulators); i++) { - accumulators(i) = accumulators(i) < 0 ? 0 : accumulators(i); - } - } - template CUTLASS_HOST_DEVICE void operator()(TensorDst &accumulators,