Skip to content

Commit

Permalink
updating a mma test with stmatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu committed Dec 23, 2024
1 parent ca17048 commit 494aea7
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions tests/cpp/test_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,12 +515,6 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) {
EXPECT_TRUE(tv3->getMemoryType() == MemoryType::Shared);
EXPECT_TRUE(tv4->getMemoryType() == MemoryType::Global);

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3c->getLoopDomain());
tv3c->setLoopDomain(s.as<IterDomain*>());
tv3c->setAllocationDomain(s.as<IterDomain*>(), true);
}
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv2->getLoopDomain());
Expand All @@ -531,8 +525,26 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) {
tv2->axis(-3)->parallelize(ParallelType::Mma);
}

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3c->getLoopDomain());
tv3c->setLoopDomain(s.as<IterDomain*>());
tv3c->setAllocationDomain(s.as<IterDomain*>(), true);
}

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(tv3);
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3->getLoopDomain());

if (swizzle != MmaInputSmemSwizzle::None) {
mma_utils::scheduleTMAStoreForMmaOutput(tv3, swizzle);
}

tv3->setLoopDomain(s.as<IterDomain*>());
}
mma_utils::scheduleStMatrixForMmaOutput(tv3, swizzle, tile_m, tile_n);
tv3->axis(-1)->parallelize(ParallelType::Vectorize);

mma_utils::scheduleTMAStoreForMmaOutput(tv4, swizzle);

Expand Down

0 comments on commit 494aea7

Please sign in to comment.