Skip to content

Commit

Permalink
Add Epilogue Pipeline for PVC using EVT (codeplaysoftware#80)
Browse files Browse the repository at this point in the history
This PR introduces the Epilogue implementation for PVC using the Epilogue Visitor Tree available for SM90 (and onwards) GPUs for NVIDIA. We only support fusion::LinearCombination operation for PVC i.e. D = alpha * A * B + beta * C through this PR, but it can be extended further to add other fusion operations by partial specialization of the FusionCallBacks struct available in the include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp file.

---------

Co-authored-by: Alejandro Acosta <[email protected]>
  • Loading branch information
2 people authored and Jiaxingla committed Jul 16, 2024
1 parent 773dc82 commit d270223
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 6 deletions.
2 changes: 1 addition & 1 deletion include/cute/arch/copy_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ struct XE_2D_U16x16x16x1x1_V
struct XE_2D_U32x8x16x1x1_ST_N
{
template <class T>
CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height,
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height,
int pitch, intel::coord_t coord, const T *src) {
static_assert(sizeof(T) == 4, "Expected T to have size 4");
__builtin_IB_subgroup_block_write_flat_u32_m8k16v1(
Expand Down
4 changes: 4 additions & 0 deletions include/cutlass/epilogue/collective/collective_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ class CollectiveEpilogue {
#include "default_epilogue.hpp"
#include "default_epilogue_array.hpp"
#include "epilogue_tensor_broadcast.hpp"
#if defined (SYCL_INTEL_TARGET)
#include "intel_pvc_epilogue.hpp"
#else
#include "sm70_epilogue_vectorized.hpp"
#include "sm90_epilogue_tma_warpspecialized.hpp"
#include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp"
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
6 changes: 6 additions & 0 deletions include/cutlass/epilogue/dispatch_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ struct Sm90TmaWarpSpecializedBiasElementwise {
constexpr static int FragmentSize = FragmentSize_;
};

#if defined (SYCL_INTEL_TARGET)
struct IntelPVCEpilogue {
static constexpr int SubgroupSize = 16;
};
#endif

//////////////////////////////////////////////////////////////////////////////

} // namespace cutlass::epilogue
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 =
Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z)
Sm90EVT<Sm90AuxStore<StagesD, EpilogueTile, ElementAux, RoundStyle, StrideAux, SmemLayoutAtom, CopyOpR2S, AlignmentAux>, // Aux = Z
// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>,
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>
>
>
>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ struct Sm90TreeVisitor<
if (lane_idx == i) {
copy_if(FunctionPredTensor(predicate_fn), tC_rAux, tC_gAux);
}
__syncwarp();
syncwarp();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ struct Sm90RowReduction {
//
if constexpr (not IsAtomic && FinalReduction) {
// Ensure gmem writes are visible to other threads before incrementing counter
__threadfence();
threadfence();
sync_fn();
// Collective thread 0 increments atomic tile counter and copies value to smem
int* prev_tile_count = reinterpret_cast<int*>(raw_pointer_cast(smem_buffer.data()));
Expand Down Expand Up @@ -1255,7 +1255,7 @@ struct Sm90ColReduction {
//
if constexpr (not IsAtomic && FinalReduction) {
// Ensure gmem writes are visible to other threads before incrementing counter
__threadfence();
threadfence();
sync_fn();
// Collective thread 0 increments atomic tile counter and copies value to smem
int* prev_tile_count = reinterpret_cast<int*>(raw_pointer_cast(smem_buffer.data()));
Expand Down
9 changes: 8 additions & 1 deletion include/cutlass/gemm/kernel/intel_pvc_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ class GemmUniversal<
static constexpr int VecC = CollectiveMainloop::VecC;
// Kernel level shared memory storage
struct SharedStorage {
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
EpilogueTensorStorage epilogue;
};
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
Expand Down Expand Up @@ -197,7 +203,7 @@ class GemmUniversal<
CUTLASS_DEVICE
void operator()(Params const& params, char* smem_buf) {
(void)smem_buf;
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
// Preconditions
CUTE_STATIC_ASSERT(is_static<TileShape>::value);
Expand Down Expand Up @@ -231,6 +237,7 @@ class GemmUniversal<
BlockIdxX() * CollectiveMainloop::wg_tile_n +
(get_sub_group_id() % sg_per_wg_n) * CollectiveMainloop::sg_tile_n;
const int l_coord = BlockIdxZ();
const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord);
Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(
make_coord(m_coord, 0, l_coord), make_shape(_1{}, K, _1{}),
Expand Down

0 comments on commit d270223

Please sign in to comment.