Skip to content

Commit

Permalink
Adding more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtopalovicTT committed Sep 5, 2024
1 parent fee5cd1 commit cd71235
Showing 1 changed file with 105 additions and 26 deletions.
131 changes: 105 additions & 26 deletions test/unittests/TestScheduler/TestScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ class SchedulerBase : public ::testing::Test {
mlir::MLIRContext context;
mlir::OwningOpRef<mlir::ModuleOp> module;
mlir::OpBuilder builder = mlir::OpBuilder(&context);
mlir::func::FuncOp func;

void SetUp() override {
// Initialize context and module
context.loadDialect<TTDialect>();
context.loadDialect<ttir::TTIRDialect>();
module = mlir::ModuleOp::create(builder.getUnknownLoc());
builder.setInsertionPointToStart(&module->getBodyRegion().front());
createFuncOp();
}

llvm::SmallVector<int64_t, 2> getTensorShape() {
Expand All @@ -85,34 +87,36 @@ class SchedulerBase : public ::testing::Test {
return operand_constraints;
}

void TearDown() override {}
};
mlir::func::FuncOp createFuncOp() {
mlir::SmallVector<mlir::Type> input;
input.push_back(getTensorType());

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdefaulted-function-deleted"
// This tests chains all operations one after
// another, so output of scheduler should be same
TEST_F(SchedulerBase, FixedSchedule) {
mlir::SmallVector<mlir::Type> input;
input.push_back(getTensorType());
mlir::SmallVector<mlir::Type> output;
output.push_back(getTensorType());

auto funcType = builder.getType<mlir::FunctionType>(
mlir::TypeRange(input), mlir::TypeRange(output));
func = builder.create<mlir::func::FuncOp>(builder.getUnknownLoc(), "test",
funcType);

mlir::SmallVector<mlir::Type> output;
output.push_back(getTensorType());
mlir::Block *block = func.addEntryBlock();
block->addArgument(getTensorType(), builder.getUnknownLoc());
block->addArgument(getTensorType(), builder.getUnknownLoc());

auto funcType = builder.getType<mlir::FunctionType>(mlir::TypeRange(input),
mlir::TypeRange(output));
auto func = builder.create<mlir::func::FuncOp>(builder.getUnknownLoc(),
"test", funcType);
builder.setInsertionPointToStart(block);

mlir::Block *block = func.addEntryBlock();
block->addArgument(getTensorType(), builder.getUnknownLoc());
block->addArgument(getTensorType(), builder.getUnknownLoc());
return func;
}

builder.setInsertionPointToStart(block);
void TearDown() override {}
};

// This tests chains all operations one after
// another, so output of scheduler should be same
TEST_F(SchedulerBase, FixedSchedule) {
mlir::Value dest = createEmptyTensor();
mlir::Value lhs = block->getArgument(0);
mlir::Value rhs = block->getArgument(1);
mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0);
mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1);

mlir::ArrayAttr attrs = builder.getArrayAttr(createOperandConstraints());

Expand Down Expand Up @@ -140,14 +144,17 @@ TEST_F(SchedulerBase, FixedSchedule) {

// Run scheduler to get the schedule
mlir::tt::scheduler::Scheduler scheduler(&func);
while (scheduler.hasUnscheduledOps()) {
for (std::size_t i = 0; i < NumberOfOps; i++) {
llvm::SmallVector<mlir::Operation *> scheduleableOps =
scheduler.getScheduleableOps();
for (mlir::Operation *op : scheduleableOps) {
scheduler.scheduleOp(op);
}
ASSERT_EQ(scheduleableOps.size(), 1);
ASSERT_TRUE(scheduler.hasUnscheduledOps());
;
scheduler.scheduleOp(scheduleableOps[0]);
}

ASSERT_FALSE(scheduler.hasUnscheduledOps());

// Compare the schedule we got with the operations we created
llvm::SmallVector<mlir::Operation *> schedule = scheduler.getSchedule();
for (std::size_t i = 0; i < ops.size(); i++) {
Expand All @@ -157,4 +164,76 @@ TEST_F(SchedulerBase, FixedSchedule) {
// Just a sanity check that comparison is working
EXPECT_NE(ops[0].getOperation(), schedule[1]);
}
#pragma clang diagnostic pop

TEST_F(SchedulerBase, SingleOp) {
mlir::Value dest = createEmptyTensor();
mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0);
mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1);

mlir::ArrayAttr attrs = builder.getArrayAttr(createOperandConstraints());

// First operation has arg1 and arg2 and %0 as dps operand
ttir::TTIROp op = builder.create<ttir::AddOp>(builder.getUnknownLoc(), lhs,
rhs, dest, attrs);

mlir::tt::scheduler::Scheduler scheduler(&func);
ASSERT_TRUE(scheduler.hasUnscheduledOps());
llvm::SmallVector<mlir::Operation *> scheduleableOps =
scheduler.getScheduleableOps();
ASSERT_EQ(scheduleableOps.size(), 1);
scheduler.scheduleOp(scheduleableOps[0]);
ASSERT_FALSE(scheduler.hasUnscheduledOps());
ASSERT_EQ(scheduleableOps[0], op.getOperation());
}

TEST_F(SchedulerBase, VerifyFork) {
mlir::Value dest = createEmptyTensor();
mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0);
mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1);
mlir::ArrayAttr attrs = builder.getArrayAttr(createOperandConstraints());
ttir::TTIROp op = builder.create<ttir::AddOp>(builder.getUnknownLoc(), lhs,
rhs, dest, attrs);

std::vector<ttir::TTIROp> ops;
ops.push_back(op);

lhs = func.getBody().getBlocks().front().getArgument(0);
rhs = op.getOperation()->getResult(0);

dest = createEmptyTensor();
op = builder.create<ttir::AddOp>(builder.getUnknownLoc(), lhs, rhs, dest,
attrs);
ops.push_back(op);

dest = createEmptyTensor();
op = builder.create<ttir::AddOp>(builder.getUnknownLoc(), lhs, rhs, dest,
attrs);
ops.push_back(op);

lhs = ops[ops.size() - 2].getOperation()->getResult(0);
rhs = ops[ops.size() - 1].getOperation()->getResult(0);
dest = createEmptyTensor();
op = builder.create<ttir::AddOp>(builder.getUnknownLoc(), lhs, rhs, dest,
attrs);
ops.push_back(op);

mlir::tt::scheduler::Scheduler scheduler(&func);
llvm::SmallVector<mlir::Operation *> scheduleableOps =
scheduler.getScheduleableOps();
ASSERT_EQ(scheduleableOps.size(), 1);

scheduler.scheduleOp(scheduleableOps[0]);
scheduleableOps = scheduler.getScheduleableOps();
ASSERT_EQ(scheduleableOps.size(), 2);

scheduler.scheduleOp(scheduleableOps[0]);
scheduleableOps = scheduler.getScheduleableOps();
ASSERT_EQ(scheduleableOps.size(), 1);

scheduler.scheduleOp(scheduleableOps[0]);
scheduleableOps = scheduler.getScheduleableOps();
ASSERT_EQ(scheduleableOps.size(), 1);

scheduler.scheduleOp(scheduleableOps[0]);
ASSERT_FALSE(scheduler.hasUnscheduledOps());
}

0 comments on commit cd71235

Please sign in to comment.