diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 4fe0f86cc5f..77b650b88dc 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -146,6 +146,7 @@ class Val; f(HostUnit); \ f(PostOnStream); \ f(SetCurrentStream); \ + f(GetCurrentStream); \ f(Wait); \ f(Synchronize); \ f(StartCoalescing); \ diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 69b5b9c704d..1b2554cdabb 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -274,6 +274,13 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { setCurrentCUDAStream(getCUDAStream(set_current_stream->stream())); } +void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) { + streams_.insert( + {get_current_stream->stream(), + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_))}); +} + void HostIrEvaluator::handle(Synchronize* synchronize) { cudaStream_t current_stream = c10::cuda::getCurrentCUDAStream( diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 6f9070b810a..a51dc32aed4 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -112,6 +112,7 @@ class HostIrEvaluator final : public OptOutDispatch { private: using OptOutDispatch::handle; void handle(SetCurrentStream* set_current_stream) override; + void handle(GetCurrentStream* get_current_stream) override; void handle(Synchronize* synchronize) override; void handle(PostOnStream* post_ir) override; void handle(Communication* communication) override; diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 492b2b22aab..49b33f59823 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -179,6 +179,22 @@ bool SetCurrentStream::sameAs(const Statement* other) const { return false; } +GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR(passkey.ir_container_->isA()); + auto stream = IrBuilder::createInContainer(passkey.ir_container_); + addAttribute(stream); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream) + +std::string GetCurrentStream::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString() + << std::endl; + return ss.str(); +} + Wait::Wait(IrBuilderPasskey passkey, Expr* expr) : Expr(passkey, {}, {}, {expr}) { NVF_ERROR(passkey.ir_container_ != nullptr); diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index 587ffc43638..82d67d6f4cc 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -161,6 +161,28 @@ class SetCurrentStream : public Expr { } }; +class GetCurrentStream : public Expr { + public: + using Expr::Expr; + GetCurrentStream(IrBuilderPasskey passkey); + + GetCurrentStream(const GetCurrentStream& other) = delete; + GetCurrentStream& operator=(const GetCurrentStream& other) = delete; + GetCurrentStream(GetCurrentStream&& other) = delete; + GetCurrentStream& operator=(GetCurrentStream&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + const char* getOpString() const override { + return "hir::GetCurrentStream"; + } + + Stream* stream() const { + return attributes_.at(0)->as(); + } +}; + class Wait : public Expr { public: using Expr::Expr; diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index 7e8ee6dc4d7..c2b42dcf20b 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -216,6 +216,8 @@ class MatmulParams : public HeuristicParams { : "column-major") << "\n" << "Grid swizzle factor: " << grid_swizzle_factor << "\n" + << "Cluster dimensions: " << std::get<0>(cluster_dims) << " " + << std::get<1>(cluster_dims) << " " << std::get<2>(cluster_dims) << "\n" << "Use shared memory epilogue: " << use_smem_epilogue << "\n" << "Promote re-use of prologue shared memory: " << promote_prologue_smem_reuse << "\n" diff --git a/csrc/scheduler/matmul_heuristic_plugin.cpp b/csrc/scheduler/matmul_heuristic_plugin.cpp index 01333727841..ef0954f2185 100644 --- a/csrc/scheduler/matmul_heuristic_plugin.cpp +++ b/csrc/scheduler/matmul_heuristic_plugin.cpp @@ -141,6 +141,9 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) { setConfigTile(config->cta_tile, mparams->tile_sizes.cta_tile); setConfigTile(config->warp_tile, mparams->tile_sizes.warp_tile); setConfigTile(config->instruction_tile, getMmaOpShape(mparams->mma_macro)); + config->cluster_dims[0] = std::get<0>(mparams->cluster_dims); + config->cluster_dims[1] = std::get<1>(mparams->cluster_dims); + config->cluster_dims[2] = std::get<2>(mparams->cluster_dims); config->splitk_factor = mparams->splitk_factor; config->grid_swizzle_factor = mparams->grid_swizzle_factor; config->cta_order = @@ -163,6 +166,9 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) { }; setGemmTile(mparams->tile_sizes.cta_tile, config->cta_tile); setGemmTile(mparams->tile_sizes.warp_tile, config->warp_tile); + std::get<0>(mparams->cluster_dims) = config->cluster_dims[0]; + std::get<1>(mparams->cluster_dims) = config->cluster_dims[1]; + std::get<2>(mparams->cluster_dims) = config->cluster_dims[2]; mparams->circular_buffer_options.smem_circular_buffer_stage = config->load_stages; mparams->circular_buffer_options.smem_circular_buffer_prefetch_gap = diff --git a/csrc/scheduler/matmul_heuristic_plugin_api.h b/csrc/scheduler/matmul_heuristic_plugin_api.h index 207da96e9a8..1cd028b6a0a 100644 --- a/csrc/scheduler/matmul_heuristic_plugin_api.h +++ b/csrc/scheduler/matmul_heuristic_plugin_api.h @@ -72,6 +72,7 @@ struct KernelConfig { Tile cta_tile = {128, 128, 32}; Tile warp_tile = {64, 64, 32}; Tile instruction_tile = {16, 16, 16}; + Tile cluster_dims = {1, 1, 1}; uint16_t splitk_factor = 1; uint8_t load_stages = 2; // The circular buffering prefetch distance will be set to diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index 64aa2a0564b..e97550309e1 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -513,6 +513,26 @@ TEST_F(StreamTest, HostIrDefaultStream) { c10::cuda::getDefaultCUDAStream(0), c10::cuda::getCurrentCUDAStream(0)); } +TEST_F(StreamTest, HostIrGetCurrentStream) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + auto get_stream = IrBuilder::create(); + auto current_stream = get_stream->stream(); + auto other_stream = IrBuilder::create(); + hic->pushBackTopLevelExprs(get_stream); + hic->pushBackTopLevelExprs(IrBuilder::create(other_stream)); + hic->pushBackTopLevelExprs( + IrBuilder::create(current_stream)); + + auto cuda_stream = c10::cuda::getStreamFromPool(); + setCurrentCUDAStream(cuda_stream); + + HostIrEvaluator hie(std::move(hic)); + hie.runWithInput({}); + + EXPECT_EQ(cuda_stream, c10::cuda::getCurrentCUDAStream(0)); +} + TEST_F(StreamTest, ByIndex) { constexpr int64_t kStreamIndex1 = 2; constexpr int64_t kStreamIndex2 = 3;