Skip to content

Commit

Permalink
add decomposition to hybrid pass (#2496)
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen authored Sep 22, 2023
1 parent 2f1003b commit 4fb146a
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 67 deletions.
120 changes: 64 additions & 56 deletions src/Transform/ONNX/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
//
//===----------------------------------------------------------------------===//

#include "src/Transform/ONNX/Decompose.hpp"
#include "src/Pass/Passes.hpp"

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -29,7 +32,6 @@
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Support/TypeUtilities.hpp"
#include "src/Transform/ONNX/DecomposeEinsum.hpp"

Expand Down Expand Up @@ -326,6 +328,17 @@ bool hasStaticSpatialDims(Value v) {
return !llvm::any_of(Ds, ShapedType::isDynamic);
}

bool shouldDecomposeConvTransposeOp(Value convTransposeResult) {
#ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE
ONNXConvTransposeOp op =
cast<ONNXConvTransposeOp>(convTransposeResult.getDefiningOp());
return hasStaticSpatialDims(op.getX()) && hasStaticSpatialDims(op.getW());
#else
// Disable the ONNXConvTransposeOp decomposition patterns.
return false;
#endif
}

// Split on the specified axis. The length of each output is one.
ValueRange emitSplitAxisOutputLength1(
PatternRewriter &rewriter, Location loc, Value input, int64_t axis) {
Expand Down Expand Up @@ -445,6 +458,9 @@ Value insertAdditionalPadsConvTranspose(PatternRewriter &rewriter, Location loc,
namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc"

#ifdef ONNX_MLIR_ENABLE_MHLO

RankedTensorType createResultType(
Type outputType, int64_t axisValue, bool keepDims) {
RankedTensorType outputShapeType = outputType.dyn_cast<RankedTensorType>();
Expand All @@ -465,26 +481,18 @@ RankedTensorType createResultType(
return resultType;
}

struct SoftmaxPattern : public ConversionPattern {
SoftmaxPattern(MLIRContext *context)
: ConversionPattern(ONNXSoftmaxOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op0, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Variables for capturing values and attributes used while creating ops.
IntegerAttr axis;
struct SoftmaxPattern : public OpRewritePattern<ONNXSoftmaxOp> {
using OpRewritePattern<ONNXSoftmaxOp>::OpRewritePattern;

LogicalResult matchAndRewrite(
ONNXSoftmaxOp softmaxOp, PatternRewriter &rewriter) const final {
// Match
ONNXSoftmaxOp softmaxOp = ::llvm::dyn_cast<ONNXSoftmaxOp>(op0);
Value input = softmaxOp.getInput();
Type inputType = input.getType();
axis = op0->getAttrOfType<IntegerAttr>("axis");
if (!axis)
axis = rewriter.getIntegerAttr(
rewriter.getIntegerType(64, /*isSigned=*/true), -1);
int64_t axisValue = axis.getSInt();
int64_t axisValue = softmaxOp.getAxis();

// Rewrite
Location odsLoc = op0->getLoc();
Location odsLoc = softmaxOp.getLoc();
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(
rewriter, odsLoc);

Expand All @@ -506,16 +514,16 @@ struct SoftmaxPattern : public ConversionPattern {
/*axis=*/axisOp, keepDimsAttr, noopWithEmptyAxes);
Value divValue =
rewriter.create<ONNXDivOp>(odsLoc, inputType, expValue, sumValue);
rewriter.replaceOp(op0, divValue);
rewriter.replaceOp(softmaxOp, divValue);
return success();
}
};

#ifdef ONNX_MLIR_ENABLE_MHLO
void populateDecomposingONNXBeforeMhloPatterns(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.add<SoftmaxPattern>(ctx);
}

#endif

// Special Op fusion for the following pattern:
Expand All @@ -528,11 +536,11 @@ void populateDecomposingONNXBeforeMhloPatterns(

// Helper function: is the ConcatOp matched to the fusion pattern?
static bool isConcatFuseMatched(
Operation *op, ONNXShapeOp &shapeOp, ONNXTransposeOp &transposeOp) {
shapeOp = NULL;
transposeOp = NULL;
ONNXConcatOp concatOp, ONNXShapeOp &shapeOp, ONNXTransposeOp &transposeOp) {
shapeOp = nullptr;
transposeOp = nullptr;
bool failed = false;
for (Operation *user : op->getUsers()) {
for (Operation *user : concatOp->getUsers()) {
if (isa<ONNXShapeOp>(user) && !shapeOp)
shapeOp = cast<ONNXShapeOp>(user);
else if (isa<ONNXTransposeOp>(user) && !transposeOp)
Expand All @@ -543,31 +551,29 @@ static bool isConcatFuseMatched(
return (shapeOp && transposeOp && !failed);
}

struct ConcatFusePattern : public ConversionPattern {
ConcatFusePattern(MLIRContext *context)
: ConversionPattern(ONNXConcatOp::getOperationName(), 4, context) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {

ONNXConcatOp concatOp = ::llvm::dyn_cast<ONNXConcatOp>(op);
struct ConcatFusePattern : public OpRewritePattern<ONNXConcatOp> {
using OpRewritePattern<ONNXConcatOp>::OpRewritePattern;

LogicalResult matchAndRewrite(
ONNXConcatOp concatOp, PatternRewriter &rewriter) const final {
// Match
ONNXShapeOp shapeOp = NULL;
ONNXTransposeOp transposeOp = NULL;
if (!isConcatFuseMatched(op, shapeOp, transposeOp))
ONNXShapeOp shapeOp;
ONNXTransposeOp transposeOp;
if (!isConcatFuseMatched(concatOp, shapeOp, transposeOp))
return failure();

// Rewrite
SmallVector<Type, 2> outputTypes;
outputTypes.emplace_back(shapeOp.getResult().getType());
outputTypes.emplace_back(transposeOp.getResult().getType());

auto fusedV = rewriter.create<ONNXConcatShapeTransposeOp>(op->getLoc(),
outputTypes, operands, concatOp.getAxisAttr(), shapeOp.getEndAttr(),
shapeOp.getStartAttr(), transposeOp.getPermAttr());
auto fusedV = rewriter.create<ONNXConcatShapeTransposeOp>(concatOp.getLoc(),
outputTypes, concatOp->getOperands(), concatOp.getAxisAttr(),
shapeOp.getEndAttr(), shapeOp.getStartAttr(),
transposeOp.getPermAttr());
rewriter.replaceOp(shapeOp.getOperation(), fusedV.getResults()[0]);
rewriter.replaceOp(transposeOp.getOperation(), fusedV.getResults()[1]);
rewriter.eraseOp(op);
rewriter.eraseOp(concatOp);
return success();
}
};
Expand All @@ -594,12 +600,12 @@ struct ConcatFusePattern : public ConversionPattern {
// transA = 0 : si64, transB = 1 : si64} :
// (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// ```
struct CustomOpFuseMatMulPattern : public OpConversionPattern<ONNXCustomOp> {
CustomOpFuseMatMulPattern(MLIRContext *context)
: OpConversionPattern(context) {}
LogicalResult matchAndRewrite(ONNXCustomOp customOp,
ONNXCustomOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {

struct CustomOpFuseMatMulPattern : public OpRewritePattern<ONNXCustomOp> {
using OpRewritePattern<ONNXCustomOp>::OpRewritePattern;

LogicalResult matchAndRewrite(
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
using namespace onnx_mlir;
Location loc = customOp.getLoc();

Expand Down Expand Up @@ -806,8 +812,8 @@ void DecomposeONNXToONNXPass::runOnOperation() {
target.addIllegalOp<ONNXUpsampleV7Op>();
target.addIllegalOp<ONNXUnsqueezeV11Op>();
target.addDynamicallyLegalOp<ONNXConcatOp>([](ONNXConcatOp op) {
ONNXShapeOp shapeOp = NULL;
ONNXTransposeOp transposeOp = NULL;
ONNXShapeOp shapeOp;
ONNXTransposeOp transposeOp;
return !isConcatFuseMatched(op, shapeOp, transposeOp);
});
// Decompose CustomOp FusedMatMul introduced by onnxruntime:
Expand All @@ -828,22 +834,15 @@ void DecomposeONNXToONNXPass::runOnOperation() {
#endif
target.addDynamicallyLegalOp<ONNXConvTransposeOp>(
[](ONNXConvTransposeOp op) {
return !(onnx_mlir::hasStaticSpatialDims(op.getX()) &&
onnx_mlir::hasStaticSpatialDims(op.getW()));
return !onnx_mlir::shouldDecomposeConvTransposeOp(op);
});
#ifdef ONNX_MLIR_ENABLE_MHLO
}
#endif
#endif

RewritePatternSet patterns(context);
populateWithGenerated(patterns);
patterns.insert<onnx_mlir::DecomposeEinsumPattern>(&getContext());
patterns.insert<ConcatFusePattern>(&getContext());
// Decompose CustomOp FusedMatMul introduced by onnxruntime:
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul
patterns.insert<CustomOpFuseMatMulPattern>(&getContext());

onnx_mlir::getDecomposeONNXToONNXPatterns(patterns);
#ifdef ONNX_MLIR_ENABLE_MHLO
if (this->target == "mhlo") {
populateDecomposingONNXBeforeMhloPatterns(patterns, context);
Expand All @@ -857,14 +856,23 @@ void DecomposeONNXToONNXPass::runOnOperation() {

} // namespace

namespace onnx_mlir {
void onnx_mlir::getDecomposeONNXToONNXPatterns(
mlir::RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
populateWithGenerated(patterns);
patterns.insert<onnx_mlir::DecomposeEinsumPattern>(context);
patterns.insert<ConcatFusePattern>(context);
// Decompose CustomOp FusedMatMul introduced by onnxruntime:
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul
patterns.insert<CustomOpFuseMatMulPattern>(context);

// TODO: consider whether to include SoftmaxPattern here
}

/*!
* Create a DecomposeONNX pass.
*/
std::unique_ptr<mlir::Pass> createDecomposeONNXToONNXPass(
std::unique_ptr<mlir::Pass> onnx_mlir::createDecomposeONNXToONNXPass(
const std::string &target) {
return std::make_unique<DecomposeONNXToONNXPass>(target);
}

} // namespace onnx_mlir
15 changes: 15 additions & 0 deletions src/Transform/ONNX/Decompose.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include "mlir/IR/PatternMatch.h"

namespace onnx_mlir {

// Exports the DecomposeONNXToONNXPass patterns. They are all plain rewrite
// patterns that can be used with any PatternRewriter, not conversion patterns.
void getDecomposeONNXToONNXPatterns(mlir::RewritePatternSet &patterns);

} // namespace onnx_mlir
9 changes: 7 additions & 2 deletions src/Transform/ONNX/Decompose.td
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,11 @@ def HasUnitStrides: Constraint<
"has unit strides"
>;

def ShouldDecomposeConvTransposeOp: Constraint<
CPred<"onnx_mlir::shouldDecomposeConvTransposeOp($_self)">,
"X and W have static spatial dims and ConvTransposeOp decomposition is enabled"
>;

def HasStaticSpatialDims: Constraint<
CPred<"onnx_mlir::hasStaticSpatialDims($_self)">,
"has static spatial dims"
Expand All @@ -496,7 +501,7 @@ def ConvTransposeOpPattern1: Pattern<
(GetNullStringAttr), $dilation, $group, $kernel_shape, $new_pads, $strides),
(insertAdditionalPadsConvTranspose $conv_res, $res)
],
[(HasUnitStrides:$strides), (HasStaticSpatialDims:$x), (HasStaticSpatialDims:$w)], [],
[(ShouldDecomposeConvTransposeOp:$res), (HasUnitStrides:$strides)], [],
(addBenefit 1)
>;

Expand All @@ -513,7 +518,7 @@ def ConvTransposeOpPattern2: Pattern<
(insertAdditionalPadsConvTranspose
$conv_res, $res)
],
[(HasStaticSpatialDims:$x), (HasStaticSpatialDims:$w)], [],
[(ShouldDecomposeConvTransposeOp:$res)], [],
(addBenefit 0)
>;

Expand Down
19 changes: 13 additions & 6 deletions src/Transform/ONNX/ONNXHybridTransformPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
//===------------------ ONNXHybridTransformPass.cpp -----------------------===//
//
// Hybrid ONNX transformation pass that combines conversion patterns for
// shape inference and canonicalization and constant propagation.
// shape inference, canonicalization, constant propagation, and decomposition.
//
// TODO: add decomposition
// Note that the decomposition patterns are applied "best effort" with a greedy
// rewrite, not a partial conversion with "legalization" to ensure that every
// decomposable op is decomposed.
//
//===----------------------------------------------------------------------===//

Expand All @@ -21,6 +23,7 @@
#include "src/Interface/ShapeInferenceOpInterface.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Transform/ONNX/ConstProp.hpp"
#include "src/Transform/ONNX/Decompose.hpp"
#include "src/Transform/ONNX/ShapeInference.hpp"

using namespace mlir;
Expand Down Expand Up @@ -60,15 +63,17 @@ struct ONNXHybridTransformPass
llvm::cl::desc("Enable constant propagation in hybrid transform"),
llvm::cl::init(true)};

Option<bool> decomposition{*this, "decomposition",
llvm::cl::desc("Enable decomposition in hybrid transform"),
llvm::cl::init(true)};

FrozenRewritePatternSet patterns;

ONNXHybridTransformPass() = default;

ONNXHybridTransformPass(const ONNXHybridTransformPass &pass)
: patterns(pass.patterns) {
shapeInference = pass.shapeInference;
canonicalization = pass.canonicalization;
constantPropagation = pass.constantPropagation;
copyOptionValuesFrom(&pass);
}

StringRef getArgument() const override { return "onnx-hybrid-transform"; }
Expand All @@ -92,7 +97,9 @@ struct ONNXHybridTransformPass
getConstPropONNXToONNXPatterns(cumulativePatterns);
}

// TODO: decomposition
if (decomposition) {
getDecomposeONNXToONNXPatterns(cumulativePatterns);
}

patterns = FrozenRewritePatternSet(std::move(cumulativePatterns));
return success();
Expand Down
Loading

0 comments on commit 4fb146a

Please sign in to comment.