Skip to content

Commit

Permalink
Add workgroup level TileShape (#84)
Browse files Browse the repository at this point in the history
* Add workgroup-level tile

* Rename tile shapes

* Rename mma shape

* Remove unused code

* Update benchmark
  • Loading branch information
aacostadiaz authored Jun 18, 2024
1 parent d161fa7 commit b42305f
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 51 deletions.
20 changes: 11 additions & 9 deletions benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,8 @@ int main(int argc, const char** argv)
// to use a GPU other than that with device ID 0.
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

bool passed;

// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A
Expand All @@ -82,16 +80,20 @@ int main(int argc, const char** argv)
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;
// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;

using TileShape = Shape<_32, _64, _32>;
using TiledMma = TiledMMA<
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>; // Subgroup level-tile

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_BF16BF16F32F32_NN>,
Layout<Shape<_8,_16,_1>>>;
using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;

using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;

// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
Expand Down
5 changes: 3 additions & 2 deletions examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,12 @@ int main(int argc, const char** argv)
using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;

using TileShape = Shape<_1, _1, _1>;
// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>;
Tile<_32,_64,_32>>; // Subgroup level-tile

using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;

Expand Down
28 changes: 15 additions & 13 deletions include/cutlass/gemm/collective/intel_pvc_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct CollectiveMma<
// Type Aliases
//
using DispatchPolicy = MainloopIntelPVCUnpredicated;
using TileShape = TileShape_;
using WorkgroupTileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
Expand All @@ -100,21 +100,22 @@ struct CollectiveMma<

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

using DpasShape = typename TiledMma::Shape_MNK;
using TileDpasShape = decltype(tile_shape(TiledMma()));
using MmaAtomShape = typename TiledMma::AtomShape_MNK;
using SubgroupTileShape = decltype(tile_shape(TiledMma()));

static constexpr uint32_t MaxThreadsPerBlock = get<0>(DpasShape()) * get<1>(DpasShape());
static constexpr uint32_t MaxThreadsPerBlock =
cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize;

static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group
static constexpr int FragsK = get<2>(TileDpasShape{}) / get<2>(DpasShape());
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group
static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(MmaAtomShape());

// Calculate the vector width based on the amount of registers
// required per work item by dividing the total fragment size by
// the sub_group size.
static constexpr int VecC = (get<1>(DpasShape()) * get<0>(DpasShape())) / SubgroupSize;
static constexpr int VecA = (get<0>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize;
static constexpr int VecB = (get<1>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize;
static constexpr int VecC = (get<1>(MmaAtomShape()) * get<0>(MmaAtomShape())) / SubgroupSize;
static constexpr int VecA = (get<0>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize;
static constexpr int VecB = (get<1>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize;

// Host side kernel arguments
struct Arguments {
Expand Down Expand Up @@ -186,8 +187,9 @@ struct CollectiveMma<
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");

// Tensor to hold input data
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(TileDpasShape{}) * FragsK>, Int<1>>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(Shape<Int<FragsK * get<1>(TileDpasShape{}) / FragsN>, Int<FragsN>>{});
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(SubgroupTileShape{}) * FragsK>, Int<1>>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(
Shape<Int<FragsK * get<1>(SubgroupTileShape{}) / FragsN>, Int<FragsN>>{});

Tensor tAr_view = make_tensor(static_cast<decltype(tAr) &&>(tAr).data(),
Shape<Int<VecA>, Int<FragsM>, Int<FragsK>>{});
Expand All @@ -200,7 +202,7 @@ struct CollectiveMma<
//
// Mainloop
//
for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(DpasShape()) * FragsK)
for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(MmaAtomShape()) * FragsK)
{
// Copy gmem to rmem for the first k_tile
copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr);
Expand Down
48 changes: 21 additions & 27 deletions include/cutlass/gemm/kernel/intel_pvc_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class GemmUniversal<
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TileShape = typename CollectiveMainloop::WorkgroupTileShape;
using WorkgroupTileShape = TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
Expand All @@ -81,7 +82,7 @@ class GemmUniversal<
"Intel PVC does not support specializing the tile scheduler.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler = typename detail::TileSchedulerSelector<
TileScheduler_, ArchTag, TileShape,
TileScheduler_, ArchTag, WorkgroupTileShape,
cute::Shape<cute::Int<1>, cute::Int<1>, cute::Int<1>>>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
Expand All @@ -101,13 +102,9 @@ class GemmUniversal<
static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor;
static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group
using DpasShape = typename CollectiveMainloop::DpasShape;
using TileDpasShape = typename CollectiveMainloop::TileDpasShape;
using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape;
using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape;
static constexpr int FragsM = CollectiveMainloop::FragsM;
static constexpr int FragsN = CollectiveMainloop::FragsN;
Expand Down Expand Up @@ -175,16 +172,10 @@ class GemmUniversal<
batch_count = cute::size<3>(params.problem_shape);
}
auto M = get<0>(params.problem_shape);
auto N = get<1>(params.problem_shape);
const int sg_m = (M - 1) / get<0>(TileDpasShape{}) + 1; // sub_groups required to process A fragments
const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments
return dim3(
sg_m,
cute::ceil_div(sg_n, num_sg),
batch_count
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))),
batch_count
);
}
Expand All @@ -200,7 +191,7 @@ class GemmUniversal<
(void)smem_buf;
// Preconditions
CUTE_STATIC_ASSERT(is_static<TileShape>::value);
CUTE_STATIC_ASSERT(is_static<WorkgroupTileShape>::value);
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
Expand All @@ -218,18 +209,21 @@ class GemmUniversal<
// Get the appropriate blocks for this sub_group -- potential for sub_group locality
int thread_idx = int(ThreadIdxX());
auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K)
const int m_coord = BlockIdxX() * get<0>(subgroup_shape);
const int n_coord = (BlockIdxY() * num_sg + thread_idx / SubgroupSize) * get<1>(subgroup_shape);
const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape);
const int l_coord = BlockIdxZ();
Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, 0),
make_shape(_1{}, K, L),
make_stride(Int<FragsM>{} * get<0>(DpasShape()), _1{}));
Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(
make_coord(m_coord, 0, 0),
make_shape(_1{}, K, L),
make_stride(Int<FragsM>{} * get<0>(MmaAtomShape()),_1{}));
Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(make_coord(0, n_coord, 0),
make_shape(K, Int<FragsN>{}, L),
make_stride(_1{}, get<1>(DpasShape())));
Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(
make_coord(0, n_coord, 0),
make_shape(K, Int<FragsN>{}, L),
make_stride(_1{}, get<1>(MmaAtomShape())));
// Compute tile residues for predication
auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord
Expand Down Expand Up @@ -263,7 +257,7 @@ class GemmUniversal<
Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, 0),
make_shape(Int<FragsM>{}, Int<FragsN>{}, L),
make_stride(get<0>(DpasShape()), get<1>(DpasShape())));
make_stride(get<0>(MmaAtomShape()), get<1>(MmaAtomShape())));
copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord));
}
Expand Down

0 comments on commit b42305f

Please sign in to comment.