Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for smem_epilogue when mma output is not cast to half #3620

Merged
merged 10 commits into from
Dec 25, 2024
66 changes: 43 additions & 23 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,12 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set);

std::vector<TensorView*> tvs_to_schedule{d, d_smem};
if (std::find(mma_results_.begin(), mma_results_.end(), dc) ==
mma_results_.end()) {

bool dc_in_mma_results =
std::find(mma_results_.begin(), mma_results_.end(), dc) !=
mma_results_.end();

if (!dc_in_mma_results) {
// Skip scheduling dc if it is an mma_result. This can happen if we are
// not casting back to half-precision in the output
tvs_to_schedule.push_back(dc);
Expand All @@ -519,14 +523,13 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
dc->setMemoryType(MemoryType::Local);
d_smem->setMemoryType(MemoryType::Shared);

// Set LoadStoreOp
// TODO: extend support when mma is not cast to half
NVF_CHECK(
dataTypeSize(dc->dtype()) == 2,
"We support use_smem_epilogue on Hopper only when the output is 16-bit");
auto store_with_stmatrix = dataTypeSize(dc->dtype()) == 2;

d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
if (store_with_stmatrix) {
// Set LoadStoreOp
d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
}
d->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

Expand All @@ -539,23 +542,40 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
}

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
// Should not propagate if the dc is a mma output as the mma output has
// already been scheduled.
if (!dc_in_mma_results) {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
}

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);

// Schedule shared memory cache; Output from StMatrix
mma_utils::scheduleStMatrixForMmaOutput(
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n);
// [M, N] -> [128(TIDx), N/8 , m(2) , n(2)]
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d_smem->getLoopDomain());
if (swizzle != MmaInputSmemSwizzle::None) {
// Create tma store allocation domain with swizzle
mma_utils::scheduleTMAStoreForMmaOutput(d_smem, swizzle);
}
d_smem->setLoopDomain(s.as<IterDomain*>());

if (store_with_stmatrix) {
// Schedule shared memory cache; Output from StMatrix
mma_utils::scheduleStMatrixForMmaOutput(
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n);
}
protonu marked this conversation as resolved.
Show resolved Hide resolved

d_smem->axis(-1)->parallelize(ParallelType::Vectorize);

// Schedule global memory output; Output from TMA Store
mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle);
Expand Down
12 changes: 0 additions & 12 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,17 +1315,6 @@ void scheduleStMatrixForMmaOutput(
dataTypeSize(tv->dtype()) == 2,
"we only support 16-bit types in stmatrix");

// [M, N] -> [128(TIDx), N/8 , 2 , 2]
auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain());

if (swizzle != MmaInputSmemSwizzle::None) {
// Create tma store allocation domain with swizzle
mma_utils::scheduleTMAStoreForMmaOutput(tv, swizzle);
}

tv->setLoopDomain(s.as<IterDomain*>());

if (tile_m == 16 && tile_n == 16) {
// Let [M, N] be [64, 32]
// After scheduleMmaOutputAllocation: [128(TIDx), 4, 2, 2]
Expand All @@ -1344,7 +1333,6 @@ void scheduleStMatrixForMmaOutput(
// [2, 128(TIDx), 2, 2] -> [2, 128(TIDx), 4(vectorize)]
tv->merge(-2);
}
tv->axis(-1)->parallelize(ParallelType::Vectorize);
}

MatmulOperandInnerDimsOpt getOperandInnerDims(Fusion* fusion) {
Expand Down
26 changes: 18 additions & 8 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2660,7 +2660,7 @@ TEST_F(MatmulSchedulerTest, SegmentMatmulOpUnsupportedDtype) {
testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__);
}

TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
TEST_F(MatmulSchedulerTest, PreBroadcastMmaBiasNeg) {
// TODO: fix up params or switch to FusionExecutorCache when ready, then
// enable Ampere
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
Expand All @@ -2671,12 +2671,20 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
// A - tv0, B - tv1
auto tv0 = makeContigConcreteTensor({-1, 1, -1}, DataType::Half);
auto tv1 = makeContigConcreteTensor({1, -1, -1}, DataType::Half);
TensorView* tv2 = makeContigConcreteTensor({-1}, DataType::Half);
protonu marked this conversation as resolved.
Show resolved Hide resolved
fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addInput(tv2);

auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
auto tv3 = fusedMultiplySum(tv0, tv1, {-1});
// We add these computations to test
// scheduling (with epilogue) when the ouptut of mma is not
// cast to half.
auto tv4 = maybeCastOp(DataType::Float, tv2);
auto tv5 = biasEpilogue(tv3, tv4);
auto tv6 = neg(tv5);

fusion->addOutput(tv2);
fusion->addOutput(tv6);

NVF_CHECK(
1 == ir_utils::getOpsOfType<MmaOp>(fusion.get()).size(),
Expand All @@ -2689,10 +2697,14 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a = at::randn({M, K}, options);
auto b = at::randn({N, K}, options);
auto c = at::randn({M}, options);
auto t0 = a.unsqueeze(1);
auto t1 = b.unsqueeze(0);
auto tref = at::matmul(a.to(at::kFloat), b.to(at::kFloat).t());
std::vector<c10::IValue> inputs{t0, t1};
auto tref =
atBiasEpilogue(
at::matmul(a.to(at::kFloat), b.to(at::kFloat).t()), c.to(at::kFloat))
.neg_();
std::vector<c10::IValue> inputs{t0, t1, c};

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 4};
Expand All @@ -2705,9 +2717,7 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
// TODO: Currently we use stmatrix whenever this is true. We cannot do that
// when the dtype is not 16 bits.
mparams.use_smem_epilogue = false;
mparams.use_smem_epilogue = true;
mparams.promote_prologue_smem_reuse = false;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
Expand Down
12 changes: 8 additions & 4 deletions tests/cpp/test_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2860,14 +2860,18 @@ TEST_P(StMatrixTest, Regular) {
tv0->split(0, 32);
tv0->axis(1)->parallelize(ParallelType::TIDx);

auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv1->getLoopDomain());
tv1->setLoopDomain(s.as<IterDomain*>());
tv1->setAllocationDomain(s.as<IterDomain*>(), true);
for (auto tv : {tv1, tv2}) {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
}
tv1->setAllocationDomain(tv1->getLoopDomain(), true);

mma_utils::scheduleStMatrixForMmaOutput(
tv2, /*swizzle=*/MmaInputSmemSwizzle::None, tile_m, tile_n);

tv2->axis(-1)->parallelize(ParallelType::Vectorize);

tv3->merge(0);
tv3->split(0, 32);
tv3->axis(1)->parallelize(ParallelType::TIDx);
Expand Down
24 changes: 18 additions & 6 deletions tests/cpp/test_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,12 +515,6 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) {
EXPECT_TRUE(tv3->getMemoryType() == MemoryType::Shared);
EXPECT_TRUE(tv4->getMemoryType() == MemoryType::Global);

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3c->getLoopDomain());
tv3c->setLoopDomain(s.as<IterDomain*>());
tv3c->setAllocationDomain(s.as<IterDomain*>(), true);
}
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv2->getLoopDomain());
Expand All @@ -531,8 +525,26 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) {
tv2->axis(-3)->parallelize(ParallelType::Mma);
}

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3c->getLoopDomain());
tv3c->setLoopDomain(s.as<IterDomain*>());
tv3c->setAllocationDomain(s.as<IterDomain*>(), true);
}

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(tv3);
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3->getLoopDomain());

if (swizzle != MmaInputSmemSwizzle::None) {
mma_utils::scheduleTMAStoreForMmaOutput(tv3, swizzle);
}

tv3->setLoopDomain(s.as<IterDomain*>());
}
mma_utils::scheduleStMatrixForMmaOutput(tv3, swizzle, tile_m, tile_n);
tv3->axis(-1)->parallelize(ParallelType::Vectorize);

mma_utils::scheduleTMAStoreForMmaOutput(tv4, swizzle);

Expand Down
Loading