Skip to content

Commit

Permalink
add const prop to hybrid pass
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen committed Sep 12, 2023
1 parent 971f39a commit 3006263
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 103 deletions.
1 change: 1 addition & 0 deletions src/Transform/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ add_onnx_mlir_library(OMHybridTransform
LINK_LIBS PUBLIC
OMONNXOps
OMShapeInferenceOpInterface
OMONNXRewrite
MLIRPass
MLIRTransforms
OMShapeInference
Expand Down
18 changes: 14 additions & 4 deletions src/Transform/ONNX/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
//
//===----------------------------------------------------------------------===//

#include "src/Transform/ONNX/ConstProp.hpp"
#include "src/Pass/Passes.hpp"

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -29,7 +32,6 @@
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Support/TypeUtilities.hpp"

#include <math.h>
Expand Down Expand Up @@ -1011,6 +1013,12 @@ class SplitOfConst : public OpRewritePattern<ONNXSplitOp> {
}
};

void getPatterns(RewritePatternSet &patterns) {
populateWithGenerated(patterns);
if (isNotDisabled("SplitOfConst"))
patterns.insert<SplitOfConst>(patterns.getContext());
}

//===----------------------------------------------------------------------===//
// Code to manage the pass.
//===----------------------------------------------------------------------===//
Expand All @@ -1034,15 +1042,17 @@ void ConstPropONNXToONNXPass::runOnOperation() {
MLIRContext *context = &getContext();

RewritePatternSet patterns(context);
populateWithGenerated(patterns);
if (isNotDisabled("SplitOfConst"))
patterns.insert<SplitOfConst>(context);
getPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns))))
signalPassFailure();
}

} // end anonymous namespace.

void onnx_mlir::getConstPropPatterns(RewritePatternSet &patterns) {
getPatterns(patterns);
}

void onnx_mlir::configureConstPropONNXToONNXPass(
int expansionBound, ArrayRef<std::string> disabledPatterns) {
ConstPropONNXToONNXPassConfiguration::expansionBound = expansionBound;
Expand Down
13 changes: 13 additions & 0 deletions src/Transform/ONNX/ConstProp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include "mlir/IR/PatternMatch.h"

namespace onnx_mlir {

void getConstPropPatterns(mlir::RewritePatternSet &patterns);

}
9 changes: 6 additions & 3 deletions src/Transform/ONNX/ONNXHybridTransformPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
//===------------------ ONNXHybridTransformPass.cpp -----------------------===//
//
// Hybrid ONNX transformation pass that combines conversion patterns for
// shape inference and canonicalization.
// shape inference and canonicalization and constant propagation.
//
// TODO: add constant propagation and decomposition
// TODO: add decomposition
//
//===----------------------------------------------------------------------===//

Expand All @@ -17,6 +17,7 @@
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Interface/ShapeInferenceOpInterface.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Transform/ONNX/ConstProp.hpp"
#include "src/Transform/ONNX/ShapeInference.hpp"

using namespace mlir;
Expand Down Expand Up @@ -54,7 +55,9 @@ struct ONNXHybridTransformPass
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(cumulativePatterns, context);

// TODO: constant propagation, decomposition
getConstPropPatterns(cumulativePatterns);

// TODO: decomposition

patterns = FrozenRewritePatternSet(std::move(cumulativePatterns));
return success();
Expand Down
Loading

0 comments on commit 3006263

Please sign in to comment.