Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Oct 24, 2024
1 parent 2997eca commit 29fe8a5
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
8 changes: 4 additions & 4 deletions include/cutlass/arch/barrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include <cutlass/arch/memory_sm75.h>
#include <cute/arch/cluster_sm90.hpp>

#if defined SYCL_INTEL_TARGET
#if defined(SYCL_INTEL_TARGET)
SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics);
SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics);

Expand Down Expand Up @@ -160,10 +160,10 @@ class NamedBarrier {
private:
CUTLASS_DEVICE
static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) {
#if defined SYCL_INTEL_TARGET
#if defined(SYCL_INTEL_TARGET)
__spirv_ControlBarrierArriveINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED);
__spirv_ControlBarrierWaitINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED);
#elif defined CUDA_BARRIER_ENABLED
#elif CUDA_BARRIER_ENABLED
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
#elif defined(__CUDA_ARCH__)
asm volatile ("brkpt;\n" ::);
Expand All @@ -172,7 +172,7 @@ class NamedBarrier {

CUTLASS_DEVICE
static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) {
#if defined SYCL_INTEL_TARGET
#if defined(SYCL_INTEL_TARGET)
__spirv_ControlBarrierArriveINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED);
#elif CUDA_BARRIER_ENABLED
asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ struct PersistentTileSchedulerXeStreamKParams {

// Calculate the number of work units covering the data-parallel and stream-K tiles.
// A "work unit" is a single index in the linearized ID space used by the scheduler.
// We distinguish it from a "block," which is typically tied to a hardware unit
// (e.g., the callers into this scheduler will be persistent thread blocks).
// A work unit can encompass multiple output tiles worth of work (as will be the
// case for stream-K blocks).
// Since splitting is not required for data-parallel tiles, only one data-parallel unit
Expand Down Expand Up @@ -438,12 +436,12 @@ struct PersistentTileSchedulerXeStreamKParams {
static uint64_t
get_num_sk_units(uint64_t wgs_per_sk_wave, uint32_t sk_tiles, uint32_t k_tiles_per_output_tile) {
// If there are stream-K tiles to compute and a sufficiently large number of k iterations
// across them, they will be covered by a single wave of persistent threadblocks. Thus, there
// will be as many work units as there are threadblocks in a single wave.
// across them, they will be covered by a single wave of persistent work_groups. Thus, there
// will be as many work units as there are work_groups in a single wave.
//
// When the total k iterations across stream-K tiles is too small to justify distributing
// across an entire wave of blocks, we instead distribute the iterations over a smaller
// set of blocks.
// across an entire wave of work_groups, we instead distribute the iterations over a smaller
// set of work_groups.

// Calculate the number of stream-K units that would be needed if each stream-K unit
// computed the minimum allowable k iterations.
Expand Down
5 changes: 2 additions & 3 deletions include/cutlass/gemm/kernel/xe_tile_scheduler_streamk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,15 @@ class PersistentTileSchedulerXeStreamK {
current_work_linear_idx_ += uint64_t(GridDimX()) * uint64_t(GridDimY()) * uint64_t(GridDimZ()) * uint64_t(advance_count);
}

// Given the inputs, computes the total number of output blocks this problem will compute over
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
// Given the inputs, computes the total number of output work-groups this problem will compute over.
template <class ProblemShape>
CUTLASS_HOST_DEVICE static
dim3
get_tiled_wg_shape_mnl(ProblemShape problem_shape_mnkl, TileShape cta_shape) {
return Params::get_tiled_wg_shape_mnl(to_gemm_coord(problem_shape_mnkl), to_gemm_coord(cta_shape));
}

// Given the cluster shape, computes the physical grid we should launch.
// Computes the physical grid we should launch.
template <class ProblemShape>
CUTLASS_HOST_DEVICE static
dim3
Expand Down

0 comments on commit 29fe8a5

Please sign in to comment.