diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 26c69b310..b194713c5 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -149,12 +149,14 @@ if (NOT CUTLASS_ENABLE_SYCL) 57_hopper_grouped_gemm 58_ada_fp8_gemm 59_ampere_gather_scatter_conv + ampere ) add_subdirectory(${EXAMPLE}) endforeach() else() foreach(EXAMPLE 14_ampere_tf32_tensorop_gemm + ampere cute sycl ) diff --git a/examples/ampere/CMakeLists.txt b/examples/ampere/CMakeLists.txt new file mode 100644 index 000000000..d494b5114 --- /dev/null +++ b/examples/ampere/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +add_subdirectory(ampere_gemm_with_collective_builder) \ No newline at end of file diff --git a/examples/ampere/ampere_gemm_with_collective_builder/CMakeLists.txt b/examples/ampere/ampere_gemm_with_collective_builder/CMakeLists.txt new file mode 100644 index 000000000..f23ce77ca --- /dev/null +++ b/examples/ampere/ampere_gemm_with_collective_builder/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + ampere_collective_builder + ampere_collective_builder.cu + ) diff --git a/examples/ampere/ampere_gemm_with_collective_builder/ampere_collective_builder.cu b/examples/ampere/ampere_gemm_with_collective_builder/ampere_collective_builder.cu new file mode 100644 index 000000000..5784997ae --- /dev/null +++ b/examples/ampere/ampere_gemm_with_collective_builder/ampere_collective_builder.cu @@ -0,0 +1,402 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/GPU_Clock.hpp" + +#include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/reference/device/sycl_tensor_fill.h" +#else +#include "cutlass/util/reference/device/tensor_fill.h" +#endif + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + return true; +} + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + #if defined(CUTLASS_ENABLE_SYCL) + syclcompat::wait(); + #else + cudaDeviceSynchronize(); + #endif + + #if defined(CUTLASS_ENABLE_SYCL) + syclcompat::wait(); + #else + cudaDeviceSynchronize(); + #endif + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + #if defined(CUTLASS_ENABLE_SYCL) + syclcompat::wait(); + #else + cudaDeviceSynchronize(); + #endif + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + #if defined(CUTLASS_ENABLE_SYCL) + syclcompat::wait(); + #else + cudaDeviceSynchronize(); + #endif + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = cute::half_t; // <- data type of elements in input matrix A + using ElementInputB = cute::half_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + constexpr int AlignmentA = sizeof(ElementInputA); + constexpr int AlignmentB = sizeof(ElementInputB); + constexpr int AlignmentC = sizeof(ElementAccumulator); + constexpr int AlignmentD = sizeof(ElementOutput); + + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + // Workgroup-level tile + using TileShape = Shape<_128, _128, _64>; + + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm80, cutlass::arch::OpClassTensorOp, + ElementInputA, LayoutA, AlignmentA, + ElementInputB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, Shape<_1, _1, _1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< + ElementOutput, ElementComputeEpilogue, ElementAccumulator, + ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm80, cutlass::arch::OpClassTensorOp, + TileShape, Shape<_1, _1, _1>, + cutlass::epilogue::collective::EpilogueTileAuto, ElementComputeEpilogue, + ElementAccumulator, + ElementAccumulator, LayoutC, AlignmentC, + ElementOutput, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + EpilogueOp + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + runner.run(options, hw_info); + + return 0; +} diff --git a/include/cutlass/epilogue/collective/builders/sm80_builder.inl b/include/cutlass/epilogue/collective/builders/sm80_builder.inl new file mode 100644 index 000000000..56215c0c7 --- /dev/null +++ b/include/cutlass/epilogue/collective/builders/sm80_builder.inl @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include "cutlass/epilogue/collective/default_epilogue.hpp" + + +namespace cutlass::epilogue::collective { + + template< + class TileShape_MNK, + class EpilogueTileType, + class ElementAccumulator_, + class ElementCompute_, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC_, + class ElementD_, + class GmemLayoutTagD_, + int AlignmentD_, + class FusionOpOrCallbacks + > + struct CollectiveBuilder< + arch::Sm80, + arch::OpClassTensorOp, + TileShape_MNK, + Shape<_1, _1, _1>, + EpilogueTileType, + ElementAccumulator_, + ElementCompute_, + ElementC_, + GmemLayoutTagC_, + AlignmentC_, + ElementD_, + GmemLayoutTagD_, + AlignmentD_, + EpilogueScheduleAuto, + FusionOpOrCallbacks, + cute::enable_if_t< + (cute::is_same_v>) + > + > + { + + using ElementD = ElementD_; + using ElementOutput = ElementD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + + static constexpr int FragmentSize = 128 / cutlass::sizeof_bits::value; + using ThreadOp = thread::LinearCombination< + ElementD, FragmentSize, ElementAccumulator, ElementCompute>; + + using CollectiveOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueDefault>; + }; +} diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index becb1fb82..4184bb95a 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -114,6 +114,7 @@ struct CallbacksBuilder< ///////////////////////////////////////////////////////////////////////////////////////////////// #include "builders/sm90_builder.inl" +#include "builders/sm80_builder.inl" #if defined(SYCL_INTEL_TARGET) #include "builders/xe_builder.inl" diff --git a/include/cutlass/gemm/collective/builders/sm80_common.inl b/include/cutlass/gemm/collective/builders/sm80_common.inl new file mode 100644 index 000000000..0be009ac8 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm80_common.inl @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/arch/mma_sm80.hpp" +#include "cute/atom/mma_traits_sm80.hpp" +#include "cute/arch/copy_sm80.hpp" + + +namespace cutlass::gemm::collective::detail { + //================== MMA Types ==================// + + template + struct Sm80_TiledMMA { + using MMA_Atom = MMA_Atom>; + using TiledMMA = TiledMMA>>; + }; + + template<> + struct Sm80_TiledMMA { + using MMA_Atom = MMA_Atom; + using TiledMMA = TiledMMA>, + Tile<_32, _32, _16>>; + }; + + template<> + struct Sm80_TiledMMA { + using MMA_Atom = MMA_Atom; + using TiledMMA = TiledMMA>, + Tile<_32, _32, _16>>; + }; + + template<> + struct Sm80_TiledMMA { + using MMA_Atom = MMA_Atom; + using TiledMMA = TiledMMA>, + Tile<_32, _32, _16>>; + }; + + template<> + struct Sm80_TiledMMA { + using MMA_Atom = MMA_Atom; + using TiledMMA = TiledMMA, Stride<_2, _1, _1>>, + Tile<_32, _32, _8>>; + }; + + template<> + struct Sm80_TiledMMA { + using MMA_Atom = MMA_Atom; + using TiledMMA = TiledMMA>, + Tile, Stride<_2, _1>>, + Layout, Stride<_2, _1>>, + Underscore>>; + }; + + template<> + struct Sm80_TiledMMA { + using MMA_Atom = MMA_Atom; + using TiledMMA = TiledMMA>, + Tile<_32, _32, _32>>; + }; + + template<> + struct Sm80_TiledMMA { + using MMA_Atom = MMA_Atom; + using TiledMMA = TiledMMA>, + Tile<_32, _32, _32>>; + }; + + template<> + struct Sm80_TiledMMA { + using MMA_Atom = MMA_Atom; + using TiledMMA = TiledMMA>, + Tile<_32, _32, _32>>; + }; + + template<> + struct Sm80_TiledMMA { + using MMA_Atom = MMA_Atom; + using TiledMMA = TiledMMA>, + Tile<_32, _32, _32>>; + }; + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + template + struct Sm80_MemoryAtomsA; + + template + struct Sm80_MemoryAtomsB; + + template<> + struct Sm80_MemoryAtomsA { + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 3, 3>{}, + Layout, + Stride<_64, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, cute::half_t>{}, + Layout, + Stride<_8, _1>>{}, + Layout>{})); + }; + + template<> + struct Sm80_MemoryAtomsA { + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 3, 3>{}, + Layout, + Stride<_1, _64>>{})); + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, cute::half_t>{}, + Layout, + Stride<_1, _16>>{}, + Layout>{})); + }; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsA{}; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsA{}; + + // We can re-use half_t memory layouts for bf16 as well + template<> + struct Sm80_MemoryAtomsA : + Sm80_MemoryAtomsA{}; + + template<> + struct Sm80_MemoryAtomsA : + Sm80_MemoryAtomsA{}; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsB{}; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsB{}; + + template<> + struct Sm80_MemoryAtomsA { + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 3, 3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride<_8, _1>>{}, + Layout>{})); + }; + + template<> + struct Sm80_MemoryAtomsA { + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 2, 3>{}, + Layout, + Stride<_1, _32>>{})); + using SmemCopyAtom = Copy_Atom, tfloat32_t>; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride<_1, _16>>{}, + Layout>{})); + }; + + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsA{}; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsA{}; + + + template<> + struct Sm80_MemoryAtomsA { + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 0, 4>{}, + Layout, + Stride<_1, _4>>{})); + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride<_16, _1>>{}, + Layout>{})); + }; + + template<> + struct Sm80_MemoryAtomsA { + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 2, 2>{}, + Layout, + Stride<_1, _16>>{})); + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride<_1, _16>>{}, + Layout>{})); + }; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsA{}; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsA{}; + + template<> + struct Sm80_MemoryAtomsA { + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 4, 3>{}, + Layout, + Stride<_64, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride<_4, _1>>{}, + Layout>{})); + }; + + template<> + struct Sm80_MemoryAtomsA { + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 0, 8>{}, + Layout, + Stride<_1, _64>>{})); + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride<_1, _4>>{}, + Layout>>{})); + }; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsA{}; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsA{}; + + template<> + struct Sm80_MemoryAtomsA { + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 4, 3>{}, + Layout, + Stride<_64, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, uint8_t>{}, + Layout, + Stride<_4, _1>>{}, + Layout>{})); + }; + + template<> + struct Sm80_MemoryAtomsA { + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 0, 8>{}, + Layout, + Stride<_1, _64>>{})); + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, uint8_t>{}, + Layout, + Stride<_1, _4>>{}, + Layout>>{})); + }; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsA{}; + + template<> + struct Sm80_MemoryAtomsB : + Sm80_MemoryAtomsA{}; +} diff --git a/include/cutlass/gemm/collective/builders/sm80_mma_builder.inl b/include/cutlass/gemm/collective/builders/sm80_mma_builder.inl new file mode 100644 index 000000000..a96ff4d6c --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm80_mma_builder.inl @@ -0,0 +1,104 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/gemm/collective/sm80_mma_multistage.hpp" + +#include "cutlass/gemm/collective/builders/sm80_common.inl" + +using namespace cute; + +namespace cutlass::gemm::collective { + +template < + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder < + arch::Sm80, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + Shape<_1, _1, _1>, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v) + > +> { + using DispatchPolicy = MainloopSm80CpAsync<3>; + using GmemTiledCopyA = typename detail::Sm80_MemoryAtomsA::GmemTiledCopy; + using GmemTiledCopyB = typename detail::Sm80_MemoryAtomsB::GmemTiledCopy; + + using SmemLayoutAtomA = typename detail::Sm80_MemoryAtomsA::SmemLayoutAtom; + using SmemLayoutAtomB = typename detail::Sm80_MemoryAtomsB::SmemLayoutAtom; + + using SmemCopyAtomA = typename detail::Sm80_MemoryAtomsA::SmemCopyAtom; + using SmemCopyAtomB = typename detail::Sm80_MemoryAtomsB::SmemCopyAtom; + + using TiledMMA = typename detail::Sm80_TiledMMA::TiledMMA; + + using CollectiveOp = collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMMA, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; + }; +} diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index 3698cdfc6..c865c4e1c 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -38,6 +38,7 @@ #include "cutlass/gemm/collective/collective_builder_decl.hpp" #include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" +#include "cutlass/gemm/collective/builders/sm80_mma_builder.inl" #if defined(SYCL_INTEL_TARGET) #include "cutlass/gemm/collective/builders/xe_mma_builder.inl"