Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to Cutlass 3.5.1 #112

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
20 changes: 10 additions & 10 deletions examples/cute/tutorial/sgemm_sm80_sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)

CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K)
CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K)
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
CUTE_STATIC_ASSERT_V(( shape(tCrA) == take<0,3>(shape(tCsA)))); // (MMA,MMA_M,MMA_K)
CUTE_STATIC_ASSERT_V(( shape(tCrB) == take<0,3>(shape(tCsB)))); // (MMA,MMA_N,MMA_K)
CUTE_STATIC_ASSERT_V(( shape(tCrC) == take<0,3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N
CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K

// Clear the accumulators
clear(tCrC);
Expand Down Expand Up @@ -390,10 +390,10 @@ gemm_tn(int m, int n, int k,
auto bP = Int<3>{}; // Pipeline

// Define the smem layouts (static)
auto sA_atom = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
auto sB_atom = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sA_atom = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
[[maybe_unused]] auto sB_atom = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP));
auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP));
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
Expand Down
2 changes: 2 additions & 0 deletions include/cutlass/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ struct alignas(2) bfloat16_t {
/// Default constructor
bfloat16_t() = default;

#if !defined(CUTLASS_ENABLE_SYCL)
/// Reinterpret cast from CUDA's __nv_bfloat16 type
CUTLASS_HOST_DEVICE
explicit bfloat16_t(__nv_bfloat16 const & x) {
Expand All @@ -113,6 +114,7 @@ struct alignas(2) bfloat16_t {
std::memcpy(&storage, &raw.x, sizeof(storage));
#endif
}
#endif

/// Floating-point conversion - round toward nearest
CUTLASS_HOST_DEVICE
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ struct conjugate<complex<T>> {
}
};

#if ! defined(__CUDACC_RTC__)
#if ! defined(__CUDACC_RTC__) && !defined(CUTLASS_ENABLE_SYCL)
template <>
struct conjugate<cuFloatComplex> {
CUTLASS_HOST_DEVICE
Expand Down
10 changes: 6 additions & 4 deletions include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,15 @@ class CollectiveEpilogue<
problem_shape_mnkl,
TileShapeMNK{},
tile_coord_mnkl,
residue_mn,
SubgroupTileShape{},
tiled_mma,
SubgroupTileShape{}, // Epilogue tile
params.xe_load_c,
thread_idx,
cD,
residue_mn,
cD,
trC
residue_mn,
trC,
thread_idx,
};
auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks<RefSrc>(cst_args);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx);
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Expand Down
2 changes: 2 additions & 0 deletions include/cutlass/gemm/device/gemm_universal_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ class GemmUniversalAdapter<
Status launch_result{ Status::kSuccess };
// Use extended launch API only for mainloops that use it
if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) {
#if !defined(CUTLASS_ENABLE_SYCL)
constexpr bool is_static_1x1x1 = cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape> and
cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1;
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
Expand Down Expand Up @@ -400,6 +401,7 @@ class GemmUniversalAdapter<
}
}
}
#endif
}
else {
launch_result = Status::kSuccess;
Expand Down
4 changes: 2 additions & 2 deletions include/cutlass/pipeline/sm90_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ pipeline_init_wait(int cluster_size) {
cute::cluster_wait();
}
else {
__syncthreads();
syncthreads();
}
}

Expand All @@ -1160,7 +1160,7 @@ pipeline_init_arrive_relaxed(int cluster_size) {
cute::cluster_arrive_relaxed();
}
else {
__syncthreads();
syncthreads();
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/unit/transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,5 @@ add_custom_target(
test_unit_transform
DEPENDS
test_unit_transform_threadblock
test_unit_transform_kernel
test_unit_transform_filter_format
)