Skip to content

Commit

Permalink
deduplicate getConstPropPatterns() implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen committed Sep 12, 2023
1 parent 3006263 commit 2524321
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions src/Transform/ONNX/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,12 +1013,6 @@ class SplitOfConst : public OpRewritePattern<ONNXSplitOp> {
}
};

void getPatterns(RewritePatternSet &patterns) {
populateWithGenerated(patterns);
if (isNotDisabled("SplitOfConst"))
patterns.insert<SplitOfConst>(patterns.getContext());
}

//===----------------------------------------------------------------------===//
// Code to manage the pass.
//===----------------------------------------------------------------------===//
Expand All @@ -1042,15 +1036,17 @@ void ConstPropONNXToONNXPass::runOnOperation() {
MLIRContext *context = &getContext();

RewritePatternSet patterns(context);
getPatterns(patterns);
getConstPropPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns))))
signalPassFailure();
}

} // end anonymous namespace.

void onnx_mlir::getConstPropPatterns(RewritePatternSet &patterns) {
getPatterns(patterns);
populateWithGenerated(patterns);
if (isNotDisabled("SplitOfConst"))
patterns.insert<SplitOfConst>(patterns.getContext());
}

void onnx_mlir::configureConstPropONNXToONNXPass(
Expand Down

0 comments on commit 2524321

Please sign in to comment.