Skip to content

Commit

Permalink
Merge branch 'main' into pbasu_smem_epi_no_stmatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu authored Dec 23, 2024
2 parents 50e1183 + 02ffc83 commit a3f138a
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class Val;
f(HostUnit); \
f(PostOnStream); \
f(SetCurrentStream); \
f(GetCurrentStream); \
f(Wait); \
f(Synchronize); \
f(StartCoalescing); \
Expand Down
7 changes: 7 additions & 0 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,13 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) {
setCurrentCUDAStream(getCUDAStream(set_current_stream->stream()));
}

void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) {
streams_.insert(
{get_current_stream->stream(),
c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))});
}

void HostIrEvaluator::handle(Synchronize* synchronize) {
cudaStream_t current_stream =
c10::cuda::getCurrentCUDAStream(
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class HostIrEvaluator final : public OptOutDispatch {
private:
using OptOutDispatch::handle;
void handle(SetCurrentStream* set_current_stream) override;
void handle(GetCurrentStream* get_current_stream) override;
void handle(Synchronize* synchronize) override;
void handle(PostOnStream* post_ir) override;
void handle(Communication* communication) override;
Expand Down
16 changes: 16 additions & 0 deletions csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ bool SetCurrentStream::sameAs(const Statement* other) const {
return false;
}

GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
auto stream = IrBuilder::createInContainer<Stream>(passkey.ir_container_);
addAttribute(stream);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream)

std::string GetCurrentStream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString()
<< std::endl;
return ss.str();
}

Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
: Expr(passkey, {}, {}, {expr}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
Expand Down
22 changes: 22 additions & 0 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,28 @@ class SetCurrentStream : public Expr {
}
};

class GetCurrentStream : public Expr {
public:
using Expr::Expr;
GetCurrentStream(IrBuilderPasskey passkey);

GetCurrentStream(const GetCurrentStream& other) = delete;
GetCurrentStream& operator=(const GetCurrentStream& other) = delete;
GetCurrentStream(GetCurrentStream&& other) = delete;
GetCurrentStream& operator=(GetCurrentStream&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::GetCurrentStream";
}

Stream* stream() const {
return attributes_.at(0)->as<Stream>();
}
};

class Wait : public Expr {
public:
using Expr::Expr;
Expand Down
2 changes: 2 additions & 0 deletions csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ class MatmulParams : public HeuristicParams {
: "column-major")
<< "\n"
<< "Grid swizzle factor: " << grid_swizzle_factor << "\n"
<< "Cluster dimensions: " << std::get<0>(cluster_dims) << " "
<< std::get<1>(cluster_dims) << " " << std::get<2>(cluster_dims) << "\n"
<< "Use shared memory epilogue: " << use_smem_epilogue << "\n"
<< "Promote re-use of prologue shared memory: "
<< promote_prologue_smem_reuse << "\n"
Expand Down
6 changes: 6 additions & 0 deletions csrc/scheduler/matmul_heuristic_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) {
setConfigTile(config->cta_tile, mparams->tile_sizes.cta_tile);
setConfigTile(config->warp_tile, mparams->tile_sizes.warp_tile);
setConfigTile(config->instruction_tile, getMmaOpShape(mparams->mma_macro));
config->cluster_dims[0] = std::get<0>(mparams->cluster_dims);
config->cluster_dims[1] = std::get<1>(mparams->cluster_dims);
config->cluster_dims[2] = std::get<2>(mparams->cluster_dims);
config->splitk_factor = mparams->splitk_factor;
config->grid_swizzle_factor = mparams->grid_swizzle_factor;
config->cta_order =
Expand All @@ -163,6 +166,9 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) {
};
setGemmTile(mparams->tile_sizes.cta_tile, config->cta_tile);
setGemmTile(mparams->tile_sizes.warp_tile, config->warp_tile);
std::get<0>(mparams->cluster_dims) = config->cluster_dims[0];
std::get<1>(mparams->cluster_dims) = config->cluster_dims[1];
std::get<2>(mparams->cluster_dims) = config->cluster_dims[2];
mparams->circular_buffer_options.smem_circular_buffer_stage =
config->load_stages;
mparams->circular_buffer_options.smem_circular_buffer_prefetch_gap =
Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/matmul_heuristic_plugin_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ struct KernelConfig {
Tile cta_tile = {128, 128, 32};
Tile warp_tile = {64, 64, 32};
Tile instruction_tile = {16, 16, 16};
Tile cluster_dims = {1, 1, 1};
uint16_t splitk_factor = 1;
uint8_t load_stages = 2;
// The circular buffering prefetch distance will be set to
Expand Down
20 changes: 20 additions & 0 deletions tests/cpp/test_host_irs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,26 @@ TEST_F(StreamTest, HostIrDefaultStream) {
c10::cuda::getDefaultCUDAStream(0), c10::cuda::getCurrentCUDAStream(0));
}

TEST_F(StreamTest, HostIrGetCurrentStream) {
auto hic = std::make_unique<HostIrContainer>();
FusionGuard fg(hic.get());
auto get_stream = IrBuilder::create<GetCurrentStream>();
auto current_stream = get_stream->stream();
auto other_stream = IrBuilder::create<Stream>();
hic->pushBackTopLevelExprs(get_stream);
hic->pushBackTopLevelExprs(IrBuilder::create<SetCurrentStream>(other_stream));
hic->pushBackTopLevelExprs(
IrBuilder::create<SetCurrentStream>(current_stream));

auto cuda_stream = c10::cuda::getStreamFromPool();
setCurrentCUDAStream(cuda_stream);

HostIrEvaluator hie(std::move(hic));
hie.runWithInput({});

EXPECT_EQ(cuda_stream, c10::cuda::getCurrentCUDAStream(0));
}

TEST_F(StreamTest, ByIndex) {
constexpr int64_t kStreamIndex1 = 2;
constexpr int64_t kStreamIndex2 = 3;
Expand Down

0 comments on commit a3f138a

Please sign in to comment.