Skip to content

Commit

Permalink
Addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mtopalovicTT committed Sep 5, 2024
1 parent cd71235 commit d300d4c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
17 changes: 3 additions & 14 deletions lib/Scheduler/Scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,17 @@
#include "ttmlir/Dialect/TTIR/IR/TTIROpsDialect.h.inc"
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <mlir/Interfaces/DestinationStyleOpInterface.h>

namespace mlir::tt::scheduler {

bool isTTIROp(mlir::Operation *op) {
return isa<ttir::TTIRDialect>(op->getDialect());
}

bool isDpsOp(mlir::Operation *op) {
return isa<mlir::DestinationStyleOpInterface>(op);
}

// Init the dependencies map of all TTIR ops
// which implement DestinationStyleOpInterface
// We check for DPS because all operations which do
// some computation in TTIR implement this interface
// Init the dependencies map of all ops which are TTIR ops
Scheduler::Scheduler(func::FuncOp *func) {
for (auto &op : func->getOps()) {
if (isTTIROp(&op) && isDpsOp(&op)) {
if (isTTIROp(&op)) {
dependencies[&op] = {};
unscheduledOps.insert(&op);
}
Expand All @@ -34,13 +26,10 @@ Scheduler::Scheduler(func::FuncOp *func) {
for (auto &op : func->getOps()) {
// Skip non TTIR operations
// Skip operations which do not implement DestinationStyleOpInterface
if (!isTTIROp(&op) || !isDpsOp(&op)) {
if (!isTTIROp(&op)) {
continue;
}

auto dpsOp = cast<mlir::DestinationStyleOpInterface>(&op);
assert(dpsOp.getNumDpsInits() == 1 &&
"Operation must have a single DPS operand");
OpResult result = op.getResult(0);

for (mlir::Operation *use : result.getUsers()) {
Expand Down
16 changes: 14 additions & 2 deletions test/unittests/TestScheduler/TestScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ class SchedulerBase : public ::testing::Test {
};

// This tests chains all operations one after
// another, so output of scheduler should be same
// another, so output of scheduler order should
// be same as the order of operations created
TEST_F(SchedulerBase, FixedSchedule) {
mlir::Value dest = createEmptyTensor();
mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0);
Expand Down Expand Up @@ -165,6 +166,7 @@ TEST_F(SchedulerBase, FixedSchedule) {
EXPECT_NE(ops[0].getOperation(), schedule[1]);
}

// This tests the scheduler with a single operation
TEST_F(SchedulerBase, SingleOp) {
mlir::Value dest = createEmptyTensor();
mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0);
Expand All @@ -186,7 +188,14 @@ TEST_F(SchedulerBase, SingleOp) {
ASSERT_EQ(scheduleableOps[0], op.getOperation());
}

// Test the scheduler with a fork in the graph
// First we have operation which works on arg1 and arg2
// Then we have two operations which work on the result of the first operation
// and arg1. Then we have forth operation which works on the result of the
// second and third operation. So the scheduler should first yield the first op
// and then the second and third op and after that the forth op.
TEST_F(SchedulerBase, VerifyFork) {
// Create the first operation which works on arg1 and arg2
mlir::Value dest = createEmptyTensor();
mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0);
mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1);
Expand All @@ -200,16 +209,19 @@ TEST_F(SchedulerBase, VerifyFork) {
lhs = func.getBody().getBlocks().front().getArgument(0);
rhs = op.getOperation()->getResult(0);

// Create the second operation which works on the result of the first
// operation and arg1
dest = createEmptyTensor();
op = builder.create<ttir::AddOp>(builder.getUnknownLoc(), lhs, rhs, dest,
attrs);
ops.push_back(op);

dest = createEmptyTensor();
op = builder.create<ttir::AddOp>(builder.getUnknownLoc(), lhs, rhs, dest,
attrs);
ops.push_back(op);

// Create the third operation which works on the result of the second and
// third operation
lhs = ops[ops.size() - 2].getOperation()->getResult(0);
rhs = ops[ops.size() - 1].getOperation()->getResult(0);
dest = createEmptyTensor();
Expand Down

0 comments on commit d300d4c

Please sign in to comment.