Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workgroup level TileShape #84

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,8 @@ int main(int argc, const char** argv)
// to use a GPU other than that with device ID 0.
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

bool passed;

// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A
Expand All @@ -82,16 +80,20 @@ int main(int argc, const char** argv)
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;
// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;

using TileShape = Shape<_32, _64, _32>;
using TiledMma = TiledMMA<
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>; // Subgroup level-tile

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_BF16BF16F32F32_NN>,
Layout<Shape<_8,_16,_1>>>;
using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;

using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;

// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
Expand Down
5 changes: 3 additions & 2 deletions examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,12 @@ int main(int argc, const char** argv)
using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;

using TileShape = Shape<_1, _1, _1>;
// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>;
Tile<_32,_64,_32>>; // Subgroup level-tile

using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;

Expand Down
28 changes: 15 additions & 13 deletions include/cutlass/gemm/collective/intel_pvc_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct CollectiveMma<
// Type Aliases
//
using DispatchPolicy = MainloopIntelPVCUnpredicated;
using TileShape = TileShape_;
using WorkgroupTileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
Expand All @@ -100,21 +100,22 @@ struct CollectiveMma<

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

using DpasShape = typename TiledMma::Shape_MNK;
using TileDpasShape = decltype(tile_shape(TiledMma()));
using MmaAtomShape = typename TiledMma::AtomShape_MNK;
using SubgroupTileShape = decltype(tile_shape(TiledMma()));

static constexpr uint32_t MaxThreadsPerBlock = get<0>(DpasShape()) * get<1>(DpasShape());
static constexpr uint32_t MaxThreadsPerBlock =
cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize;

static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group
static constexpr int FragsK = get<2>(TileDpasShape{}) / get<2>(DpasShape());
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group
static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(MmaAtomShape());

// Calculate the vector width based on the amount of registers
// required per work item by dividing the total fragment size by
// the sub_group size.
static constexpr int VecC = (get<1>(DpasShape()) * get<0>(DpasShape())) / SubgroupSize;
static constexpr int VecA = (get<0>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize;
static constexpr int VecB = (get<1>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize;
static constexpr int VecC = (get<1>(MmaAtomShape()) * get<0>(MmaAtomShape())) / SubgroupSize;
static constexpr int VecA = (get<0>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize;
static constexpr int VecB = (get<1>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize;

// Host side kernel arguments
struct Arguments {
Expand Down Expand Up @@ -186,8 +187,9 @@ struct CollectiveMma<
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");

// Tensor to hold input data
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(TileDpasShape{}) * FragsK>, Int<1>>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(Shape<Int<FragsK * get<1>(TileDpasShape{}) / FragsN>, Int<FragsN>>{});
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(SubgroupTileShape{}) * FragsK>, Int<1>>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(
Shape<Int<FragsK * get<1>(SubgroupTileShape{}) / FragsN>, Int<FragsN>>{});

Tensor tAr_view = make_tensor(static_cast<decltype(tAr) &&>(tAr).data(),
Shape<Int<VecA>, Int<FragsM>, Int<FragsK>>{});
Expand All @@ -200,7 +202,7 @@ struct CollectiveMma<
//
// Mainloop
//
for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(DpasShape()) * FragsK)
for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(MmaAtomShape()) * FragsK)
{
// Copy gmem to rmem for the first k_tile
copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr);
Expand Down
48 changes: 21 additions & 27 deletions include/cutlass/gemm/kernel/intel_pvc_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class GemmUniversal<

// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TileShape = typename CollectiveMainloop::WorkgroupTileShape;
using WorkgroupTileShape = TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
Expand All @@ -81,7 +82,7 @@ class GemmUniversal<
"Intel PVC does not support specializing the tile scheduler.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler = typename detail::TileSchedulerSelector<
TileScheduler_, ArchTag, TileShape,
TileScheduler_, ArchTag, WorkgroupTileShape,
cute::Shape<cute::Int<1>, cute::Int<1>, cute::Int<1>>>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;

Expand All @@ -101,13 +102,9 @@ class GemmUniversal<

static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor;

static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group

using DpasShape = typename CollectiveMainloop::DpasShape;
using TileDpasShape = typename CollectiveMainloop::TileDpasShape;

using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape;
using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape;

static constexpr int FragsM = CollectiveMainloop::FragsM;
static constexpr int FragsN = CollectiveMainloop::FragsN;
Expand Down Expand Up @@ -175,16 +172,10 @@ class GemmUniversal<
batch_count = cute::size<3>(params.problem_shape);
}

auto M = get<0>(params.problem_shape);
auto N = get<1>(params.problem_shape);

const int sg_m = (M - 1) / get<0>(TileDpasShape{}) + 1; // sub_groups required to process A fragments
const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments

return dim3(
sg_m,
cute::ceil_div(sg_n, num_sg),
batch_count
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))),
batch_count
);
}

Expand All @@ -200,7 +191,7 @@ class GemmUniversal<
(void)smem_buf;

// Preconditions
CUTE_STATIC_ASSERT(is_static<TileShape>::value);
CUTE_STATIC_ASSERT(is_static<WorkgroupTileShape>::value);

// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
Expand All @@ -218,18 +209,21 @@ class GemmUniversal<

// Get the appropriate blocks for this sub_group -- potential for sub_group locality
int thread_idx = int(ThreadIdxX());
auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K)
const int m_coord = BlockIdxX() * get<0>(subgroup_shape);
const int n_coord = (BlockIdxY() * num_sg + thread_idx / SubgroupSize) * get<1>(subgroup_shape);
const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape);
const int l_coord = BlockIdxZ();

Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, 0),
make_shape(_1{}, K, L),
make_stride(Int<FragsM>{} * get<0>(DpasShape()), _1{}));
Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(
make_coord(m_coord, 0, 0),
make_shape(_1{}, K, L),
make_stride(Int<FragsM>{} * get<0>(MmaAtomShape()),_1{}));

Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(make_coord(0, n_coord, 0),
make_shape(K, Int<FragsN>{}, L),
make_stride(_1{}, get<1>(DpasShape())));
Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(
make_coord(0, n_coord, 0),
make_shape(K, Int<FragsN>{}, L),
make_stride(_1{}, get<1>(MmaAtomShape())));

// Compute tile residues for predication
auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord
Expand Down Expand Up @@ -263,7 +257,7 @@ class GemmUniversal<

Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, 0),
make_shape(Int<FragsM>{}, Int<FragsN>{}, L),
make_stride(get<0>(DpasShape()), get<1>(DpasShape())));
make_stride(get<0>(MmaAtomShape()), get<1>(MmaAtomShape())));

copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord));
}
Expand Down
Loading