diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index db2777e24..0e1f344f2 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -37,7 +37,7 @@ #include #include -#if defined SYCL_INTEL_TARGET +#if defined(SYCL_INTEL_TARGET) SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics); SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics); @@ -160,10 +160,10 @@ class NamedBarrier { private: CUTLASS_DEVICE static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) { -#if defined SYCL_INTEL_TARGET +#if defined(SYCL_INTEL_TARGET) __spirv_ControlBarrierArriveINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED); __spirv_ControlBarrierWaitINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED); -#elif defined CUDA_BARRIER_ENABLED +#elif CUDA_BARRIER_ENABLED asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); @@ -172,7 +172,7 @@ class NamedBarrier { CUTLASS_DEVICE static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) { -#if defined SYCL_INTEL_TARGET +#if defined(SYCL_INTEL_TARGET) __spirv_ControlBarrierArriveINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED); #elif CUDA_BARRIER_ENABLED asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); diff --git a/include/cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp b/include/cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp index 443e2cf10..b98613b54 100644 --- a/include/cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp +++ b/include/cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp @@ -244,8 +244,6 @@ struct PersistentTileSchedulerXeStreamKParams { // Calculate the number of work units covering the data-parallel and stream-K tiles. // A "work unit" is a single index in the linearized ID space used by the scheduler. - // We distinguish it from a "block," which is typically tied to a hardware unit - // (e.g., the callers into this scheduler will be persistent thread blocks). // A work unit can encompass multiple output tiles worth of work (as will be the // case for stream-K blocks). // Since splitting is not required for data-parallel tiles, only one data-parallel unit @@ -438,12 +436,12 @@ struct PersistentTileSchedulerXeStreamKParams { static uint64_t get_num_sk_units(uint64_t wgs_per_sk_wave, uint32_t sk_tiles, uint32_t k_tiles_per_output_tile) { // If there are stream-K tiles to compute and a sufficiently large number of k iterations - // across them, they will be covered by a single wave of persistent threadblocks. Thus, there - // will be as many work units as there are threadblocks in a single wave. + // across them, they will be covered by a single wave of persistent work_groups. Thus, there + // will be as many work units as there are work_groups in a single wave. // // When the total k iterations across stream-K tiles is too small to justify distributing - // across an entire wave of blocks, we instead distribute the iterations over a smaller - // set of blocks. + // across an entire wave of work_groups, we instead distribute the iterations over a smaller + // set of work_groups. // Calculate the number of stream-K units that would be needed if each stream-K unit // computed the minimum allowable k iterations. diff --git a/include/cutlass/gemm/kernel/xe_tile_scheduler_streamk.hpp b/include/cutlass/gemm/kernel/xe_tile_scheduler_streamk.hpp index 1caf2cba0..4b405b39c 100644 --- a/include/cutlass/gemm/kernel/xe_tile_scheduler_streamk.hpp +++ b/include/cutlass/gemm/kernel/xe_tile_scheduler_streamk.hpp @@ -245,8 +245,7 @@ class PersistentTileSchedulerXeStreamK { current_work_linear_idx_ += uint64_t(GridDimX()) * uint64_t(GridDimY()) * uint64_t(GridDimZ()) * uint64_t(advance_count); } - // Given the inputs, computes the total number of output blocks this problem will compute over - // Note that this is only the logical size of our grid, not the physical grid we will actually launch. + // Given the inputs, computes the total number of output work-groups this problem will compute over. template CUTLASS_HOST_DEVICE static dim3 @@ -254,7 +253,7 @@ class PersistentTileSchedulerXeStreamK { return Params::get_tiled_wg_shape_mnl(to_gemm_coord(problem_shape_mnkl), to_gemm_coord(cta_shape)); } - // Given the cluster shape, computes the physical grid we should launch. + // Computes the physical grid we should launch. template CUTLASS_HOST_DEVICE static dim3