From 99fb12b8716ea5c9b30775cf6b9118e3e0d74a84 Mon Sep 17 00:00:00 2001 From: samnordmann Date: Mon, 23 Dec 2024 14:48:31 +0100 Subject: [PATCH 1/2] Host IR: add `GetCurrentStream` (#3605) # What adds the primitive `GetCurrentStream` to Host Ir stack. # Why needed for - https://github.com/NVIDIA/Fuser/pull/3606 The idea is that if we want to use multiple stream internally, we need before hand to capture the user stream and to set it back to being the active stream when returning --- csrc/dispatch.h | 1 + csrc/host_ir/executor.cpp | 7 +++++++ csrc/host_ir/executor.h | 1 + csrc/host_ir/host_ir.cpp | 16 ++++++++++++++++ csrc/host_ir/host_ir.h | 22 ++++++++++++++++++++++ tests/cpp/test_host_irs.cpp | 20 ++++++++++++++++++++ 6 files changed, 67 insertions(+) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 4fe0f86cc5f..77b650b88dc 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -146,6 +146,7 @@ class Val; f(HostUnit); \ f(PostOnStream); \ f(SetCurrentStream); \ + f(GetCurrentStream); \ f(Wait); \ f(Synchronize); \ f(StartCoalescing); \ diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 69b5b9c704d..1b2554cdabb 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -274,6 +274,13 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { setCurrentCUDAStream(getCUDAStream(set_current_stream->stream())); } +void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) { + streams_.insert( + {get_current_stream->stream(), + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_))}); +} + void HostIrEvaluator::handle(Synchronize* synchronize) { cudaStream_t current_stream = c10::cuda::getCurrentCUDAStream( diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 6f9070b810a..a51dc32aed4 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -112,6 +112,7 @@ class HostIrEvaluator final : public OptOutDispatch { private: using OptOutDispatch::handle; void handle(SetCurrentStream* set_current_stream) override; + void handle(GetCurrentStream* get_current_stream) override; void handle(Synchronize* synchronize) override; void handle(PostOnStream* post_ir) override; void handle(Communication* communication) override; diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 492b2b22aab..49b33f59823 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -179,6 +179,22 @@ bool SetCurrentStream::sameAs(const Statement* other) const { return false; } +GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR(passkey.ir_container_->isA()); + auto stream = IrBuilder::createInContainer(passkey.ir_container_); + addAttribute(stream); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream) + +std::string GetCurrentStream::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString() + << std::endl; + return ss.str(); +} + Wait::Wait(IrBuilderPasskey passkey, Expr* expr) : Expr(passkey, {}, {}, {expr}) { NVF_ERROR(passkey.ir_container_ != nullptr); diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index 587ffc43638..82d67d6f4cc 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -161,6 +161,28 @@ class SetCurrentStream : public Expr { } }; +class GetCurrentStream : public Expr { + public: + using Expr::Expr; + GetCurrentStream(IrBuilderPasskey passkey); + + GetCurrentStream(const GetCurrentStream& other) = delete; + GetCurrentStream& operator=(const GetCurrentStream& other) = delete; + GetCurrentStream(GetCurrentStream&& other) = delete; + GetCurrentStream& operator=(GetCurrentStream&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + const char* getOpString() const override { + return "hir::GetCurrentStream"; + } + + Stream* stream() const { + return attributes_.at(0)->as(); + } +}; + class Wait : public Expr { public: using Expr::Expr; diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index 64aa2a0564b..e97550309e1 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -513,6 +513,26 @@ TEST_F(StreamTest, HostIrDefaultStream) { c10::cuda::getDefaultCUDAStream(0), c10::cuda::getCurrentCUDAStream(0)); } +TEST_F(StreamTest, HostIrGetCurrentStream) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + auto get_stream = IrBuilder::create(); + auto current_stream = get_stream->stream(); + auto other_stream = IrBuilder::create(); + hic->pushBackTopLevelExprs(get_stream); + hic->pushBackTopLevelExprs(IrBuilder::create(other_stream)); + hic->pushBackTopLevelExprs( + IrBuilder::create(current_stream)); + + auto cuda_stream = c10::cuda::getStreamFromPool(); + setCurrentCUDAStream(cuda_stream); + + HostIrEvaluator hie(std::move(hic)); + hie.runWithInput({}); + + EXPECT_EQ(cuda_stream, c10::cuda::getCurrentCUDAStream(0)); +} + TEST_F(StreamTest, ByIndex) { constexpr int64_t kStreamIndex1 = 2; constexpr int64_t kStreamIndex2 = 3; From 02ffc838df553d7fc271ea92a10483532d5692bf Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Mon, 23 Dec 2024 12:35:12 -0500 Subject: [PATCH 2/2] Allow matmul heuristic plugin to set cluster dimensions (#3634) This allows heuristic plugins to set cluster dimensions. By default, the cluster dims are set to {1, 1, 1}, which disables this feature, so if a plugin does not explicitly handle cluster dims, then it will just not make use of this feature. This is necessary because setting to an invalid value can cause a launch failure. This PR also prints the cluster dims in `MatmulParams::toString`. --- csrc/scheduler/matmul_heuristic.h | 2 ++ csrc/scheduler/matmul_heuristic_plugin.cpp | 6 ++++++ csrc/scheduler/matmul_heuristic_plugin_api.h | 1 + 3 files changed, 9 insertions(+) diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index 7e8ee6dc4d7..c2b42dcf20b 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -216,6 +216,8 @@ class MatmulParams : public HeuristicParams { : "column-major") << "\n" << "Grid swizzle factor: " << grid_swizzle_factor << "\n" + << "Cluster dimensions: " << std::get<0>(cluster_dims) << " " + << std::get<1>(cluster_dims) << " " << std::get<2>(cluster_dims) << "\n" << "Use shared memory epilogue: " << use_smem_epilogue << "\n" << "Promote re-use of prologue shared memory: " << promote_prologue_smem_reuse << "\n" diff --git a/csrc/scheduler/matmul_heuristic_plugin.cpp b/csrc/scheduler/matmul_heuristic_plugin.cpp index 01333727841..ef0954f2185 100644 --- a/csrc/scheduler/matmul_heuristic_plugin.cpp +++ b/csrc/scheduler/matmul_heuristic_plugin.cpp @@ -141,6 +141,9 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) { setConfigTile(config->cta_tile, mparams->tile_sizes.cta_tile); setConfigTile(config->warp_tile, mparams->tile_sizes.warp_tile); setConfigTile(config->instruction_tile, getMmaOpShape(mparams->mma_macro)); + config->cluster_dims[0] = std::get<0>(mparams->cluster_dims); + config->cluster_dims[1] = std::get<1>(mparams->cluster_dims); + config->cluster_dims[2] = std::get<2>(mparams->cluster_dims); config->splitk_factor = mparams->splitk_factor; config->grid_swizzle_factor = mparams->grid_swizzle_factor; config->cta_order = @@ -163,6 +166,9 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) { }; setGemmTile(mparams->tile_sizes.cta_tile, config->cta_tile); setGemmTile(mparams->tile_sizes.warp_tile, config->warp_tile); + std::get<0>(mparams->cluster_dims) = config->cluster_dims[0]; + std::get<1>(mparams->cluster_dims) = config->cluster_dims[1]; + std::get<2>(mparams->cluster_dims) = config->cluster_dims[2]; mparams->circular_buffer_options.smem_circular_buffer_stage = config->load_stages; mparams->circular_buffer_options.smem_circular_buffer_prefetch_gap = diff --git a/csrc/scheduler/matmul_heuristic_plugin_api.h b/csrc/scheduler/matmul_heuristic_plugin_api.h index 207da96e9a8..1cd028b6a0a 100644 --- a/csrc/scheduler/matmul_heuristic_plugin_api.h +++ b/csrc/scheduler/matmul_heuristic_plugin_api.h @@ -72,6 +72,7 @@ struct KernelConfig { Tile cta_tile = {128, 128, 32}; Tile warp_tile = {64, 64, 32}; Tile instruction_tile = {16, 16, 16}; + Tile cluster_dims = {1, 1, 1}; uint16_t splitk_factor = 1; uint8_t load_stages = 2; // The circular buffering prefetch distance will be set to