Skip to content

Commit

Permalink
Add calls CSE pass
Browse files Browse the repository at this point in the history
  • Loading branch information
arrangabriel committed Oct 25, 2024
1 parent bd1ee53 commit da0eba0
Show file tree
Hide file tree
Showing 14 changed files with 1,903 additions and 631 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def EquationExpressionOpInterface
"void", "printExpression",
(ins "::llvm::raw_ostream&":$os,
"const ::llvm::DenseMap<::mlir::Value, int64_t>&":$inductions)>,
InterfaceMethod<
"Check if two expressions are equivalent",
"bool", "isEquivalent",
(ins "mlir::Operation*":$other,
"mlir::SymbolTableCollection&":$symbolTableCollection), "", [{
// Safely assume that the two expressions are different.
return false;
}]>,
InterfaceMethod<
"Get the number of elements.",
"uint64_t", "getNumOfExpressionElements",
Expand Down
13 changes: 13 additions & 0 deletions include/public/marco/Dialect/BaseModelica/Transforms/CallCSE.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef MARCO_DIALECT_BASEMODELICA_TRANSFORMS_CALLCSE_H
#define MARCO_DIALECT_BASEMODELICA_TRANSFORMS_CALLCSE_H

#include "mlir/Pass/Pass.h"

namespace mlir::bmodelica {
#define GEN_PASS_DECL_CALLCSEPASS
#include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc"

std::unique_ptr<mlir::Pass> createCallCSEPass();
} // namespace mlir::bmodelica

#endif // MARCO_DIALECT_BASEMODELICA_TRANSFORMS_CALLCSE_H
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "marco/Dialect/BaseModelica/Transforms/AccessReplacementTest.h"
#include "marco/Dialect/BaseModelica/Transforms/AutomaticDifferentiation.h"
#include "marco/Dialect/BaseModelica/Transforms/BindingEquationConversion.h"
#include "marco/Dialect/BaseModelica/Transforms/CallCSE.h"
#include "marco/Dialect/BaseModelica/Transforms/DerivativeChainRule.h"
#include "marco/Dialect/BaseModelica/Transforms/DerivativesMaterialization.h"
#include "marco/Dialect/BaseModelica/Transforms/EquationAccessSplit.h"
Expand Down
21 changes: 21 additions & 0 deletions include/public/marco/Dialect/BaseModelica/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,27 @@ def EquationFunctionLoopHoistingPass
let constructor = "mlir::bmodelica::createEquationFunctionLoopHoistingPass()";
}

def CallCSEPass
: Pass<"call-cse", "mlir::ModuleOp">
{
let summary = "Move equal function calls to dedicated equation.";

let description = [{
Move equal function calls to dedicated equation.
}];

let dependentDialects = [
"mlir::bmodelica::BaseModelicaDialect"
];

let statistics = [
Statistic<"newCSEVariables", "new-cse-variables", "How many CSE variables have been created">,
Statistic<"replacedCalls", "replaced-calls", "How many calls were replaced by a CSE variable usage">
];

let constructor = "mlir::bmodelica::createCallCSEPass()";
}

def ReadOnlyVariablesPropagationPass
: Pass<"propagate-read-only-variables", "mlir::ModuleOp">
{
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/BaseModelica/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRBaseModelicaTransforms
AllocationOpInterfaceImpl.cpp
BindingEquationConversion.cpp
BufferizableOpInterfaceImpl.cpp
CallCSE.cpp
ConstantMaterializableTypeInterfaceImpl.cpp
DerivableOpInterfaceImpl.cpp
DerivableTypeInterfaceImpl.cpp
Expand Down
238 changes: 238 additions & 0 deletions lib/Dialect/BaseModelica/Transforms/CallCSE.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#include "marco/Dialect/BaseModelica/Transforms/CallCSE.h"
#include "marco/Dialect/BaseModelica/IR/BaseModelica.h"

namespace mlir::bmodelica {
#define GEN_PASS_DEF_CALLCSEPASS
#include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc"
} // namespace mlir::bmodelica

using namespace ::mlir::bmodelica;

namespace {
class CallCSEPass final : public impl::CallCSEPassBase<CallCSEPass> {
public:
using CallCSEPassBase<CallCSEPass>::CallCSEPassBase;

void runOnOperation() override;

private:
mlir::LogicalResult processModelOp(ModelOp modelOp);

/// Replace all calls in the equivalence group with gets to a generated
/// variable. The variable will be driven by an equation derived from the
/// first call in the group.
///
/// One variable and driver equation will be emitted per result,
/// if the call is to a function with multiple result values.
void emitCse(llvm::SmallVectorImpl<CallOp> &equivalenceGroup, ModelOp modelOp,
DynamicOp dynamicOp, mlir::SymbolTable &symbolTable,
mlir::RewriterBase &rewriter);
};

/// Get all call operations in the model.
void collectCallOps(ModelOp modelOp, llvm::SmallVectorImpl<CallOp> &callOps) {
llvm::SmallVector<EquationInstanceOp> dynamicEquationOps;
modelOp.collectMainEquations(dynamicEquationOps);

llvm::DenseSet<EquationTemplateOp> visitedTemplateOps;
for (EquationInstanceOp equationOp : dynamicEquationOps) {
EquationTemplateOp templateOp = equationOp.getTemplate();
if (!templateOp.getInductionVariables().empty() ||
visitedTemplateOps.contains(templateOp)) {
continue;
}
visitedTemplateOps.insert(templateOp);
templateOp->walk([&](CallOp callOp) { callOps.push_back(callOp); });
}
}

/// Partition the list of call operations into groups given by
/// EquationExpressionOpInterface::isEquivalent
void buildCallEquivalenceGroups(
llvm::SmallVectorImpl<CallOp> &callOps,
llvm::SmallVectorImpl<llvm::SmallVector<CallOp>> &callEquivalenceGroups) {
mlir::SymbolTableCollection symbolTableCollection;
llvm::SmallVector<llvm::SmallVector<CallOp>> tmpCallEquivalenceGroups;

for (CallOp callOp : callOps) {
auto callExpression =
mlir::cast<EquationExpressionOpInterface>(callOp.getOperation());

llvm::SmallVector<CallOp> *equivalenceGroup = find_if(
tmpCallEquivalenceGroups, [&](llvm::SmallVector<CallOp> &group) {
assert(!group.empty() && "groups should never be empty");
return callExpression.isEquivalent(group.front(),
symbolTableCollection);
});

if (equivalenceGroup != tmpCallEquivalenceGroups.end()) {
// Add equivalent call to existing group
equivalenceGroup->push_back(callOp);
} else {
// Create new equivalence group
tmpCallEquivalenceGroups.push_back({callOp});
}
}

for (llvm::SmallVector<CallOp> &group : tmpCallEquivalenceGroups) {
if (group.size() > 1) {
callEquivalenceGroups.push_back(std::move(group));
}
}
}

/// Clone `op` and its def-use chain, returning the cloned version of `op`.
mlir::Operation *cloneDefUseChain(mlir::Operation *op,
mlir::RewriterBase &rewriter) {
llvm::SmallVector<mlir::Operation *> toClone;
llvm::SmallVector<mlir::Operation *> worklist({op});

// DFS through the def-use chain of `op`
while (!worklist.empty()) {
mlir::Operation *current = worklist.back();
worklist.pop_back();
toClone.push_back(current);
for (mlir::Value operand : current->getOperands()) {
if (mlir::Operation *defOp = operand.getDefiningOp()) {
worklist.push_back(defOp);
}
}
// Find the dependencies on operations not defined within the regions of
// `current`. No need to do this if it is isolated from above.
if (!current->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>()) {
// Find all uses of values defined outside `current`.
current->walk([&](mlir::Operation *childOp) {
// Walk includes current, so skip it.
if (childOp == current) {
return;
}
for (mlir::Value operand : childOp->getOperands()) {
// If an operand is defined in the same scope as `current`,
// i.e. the equation template scope, add it to the worklist.
mlir::Operation *definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == current->getBlock()) {
worklist.push_back(definingOp);
}
}
});
}
}

mlir::IRMapping mapping;
mlir::Operation *root = nullptr;
for (mlir::Operation *opToClone : llvm::reverse(toClone)) {
// Skip repeated dependencies on the same operation
if (mapping.contains(opToClone)) {
continue;
}
root = rewriter.clone(*opToClone, mapping);
}
return root;
}

void CallCSEPass::emitCse(llvm::SmallVectorImpl<CallOp> &equivalenceGroup,
ModelOp modelOp, DynamicOp dynamicOp,
mlir::SymbolTable &symbolTable,
mlir::RewriterBase &rewriter) {
assert(!equivalenceGroup.empty() && "equivalenceGroup cannot be empty");
CallOp representative = equivalenceGroup.front();
const mlir::Location loc = representative.getLoc();

// Emit one variable per function result
llvm::SmallVector<VariableOp> cseVariables;
for (auto result : llvm::enumerate(representative.getResults())) {
rewriter.setInsertionPointToStart(modelOp.getBody());
// Emit cse variable
auto cseVariable = rewriter.create<VariableOp>(
loc, "_cse", VariableType::wrap(result.value().getType()));
symbolTable.insert(cseVariable);
cseVariables.push_back(cseVariable);

// Emit driver equation
rewriter.setInsertionPoint(dynamicOp);
auto equationTemplateOp = rewriter.create<EquationTemplateOp>(loc);
rewriter.setInsertionPointToStart(equationTemplateOp.createBody(0));
auto lhsOp = rewriter.create<EquationSideOp>(
loc, rewriter.create<VariableGetOp>(loc, cseVariable)->getResults());
auto rhsOp = rewriter.create<EquationSideOp>(
loc,
cloneDefUseChain(representative, rewriter)->getResult(result.index()));
rewriter.create<EquationSidesOp>(loc, lhsOp, rhsOp);

// Add driver equation to dynamic operation
rewriter.setInsertionPointToEnd(dynamicOp.getBody());
rewriter.create<EquationInstanceOp>(rewriter.getUnknownLoc(),
equationTemplateOp);
}

// Replace calls with get(s) to CSE variable(s)
for (auto &callOp : equivalenceGroup) {
rewriter.setInsertionPoint(callOp);

llvm::SmallVector<mlir::Value> results;
for (VariableOp cseVariable : cseVariables) {
results.push_back(
rewriter.create<VariableGetOp>(loc, cseVariable).getResult());
}
rewriter.replaceOp(callOp, results);
}

this->replacedCalls += equivalenceGroup.size();
++this->newCSEVariables;
}

mlir::LogicalResult CallCSEPass::processModelOp(ModelOp modelOp) {
mlir::IRRewriter rewriter(modelOp);
mlir::SymbolTable symbolTable(modelOp);

llvm::SmallVector<CallOp> callOps;
collectCallOps(modelOp, callOps);

llvm::SmallVector<llvm::SmallVector<CallOp>> callEquivalenceGroups;
buildCallEquivalenceGroups(callOps, callEquivalenceGroups);

if (callEquivalenceGroups.empty()) {
return mlir::success();
}

rewriter.setInsertionPointToEnd(modelOp.getBody());
DynamicOp dynamicOp = rewriter.create<DynamicOp>(rewriter.getUnknownLoc());
rewriter.createBlock(&dynamicOp.getRegion());

for (llvm::SmallVector<CallOp> &equivalenceGroup : callEquivalenceGroups) {
// Only emit CSEs that will lead to an equivalent, or lower amount of calls
if (equivalenceGroup.size() >= equivalenceGroup.front().getNumResults()) {
emitCse(equivalenceGroup, modelOp, dynamicOp, symbolTable, rewriter);
}
}

if (dynamicOp.getBody()->empty()) {
rewriter.eraseOp(dynamicOp);
}

return mlir::success();
}
} // namespace

void CallCSEPass::runOnOperation() {
llvm::SmallVector<ModelOp, 1> modelOps;

walkClasses(getOperation(), [&](mlir::Operation *op) {
if (auto modelOp = mlir::dyn_cast<ModelOp>(op)) {
modelOps.push_back(modelOp);
}
});

if (mlir::failed(mlir::failableParallelForEach(
&getContext(), modelOps, [&](mlir::Operation *op) {
return processModelOp(mlir::cast<ModelOp>(op));
}))) {
return signalPassFailure();
}
}

namespace mlir::bmodelica {
std::unique_ptr<mlir::Pass> createCallCSEPass() {
return std::make_unique<CallCSEPass>();
}
} // namespace mlir::bmodelica
Loading

0 comments on commit da0eba0

Please sign in to comment.