Skip to content

Commit

Permalink
Merge branch 'main' into hamptonm/feature/roberta_still
Browse files Browse the repository at this point in the history
  • Loading branch information
hamptonm1 authored Sep 19, 2023
2 parents 7554032 + e462c6f commit ae607aa
Show file tree
Hide file tree
Showing 11 changed files with 573 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ElementsAttr/WideNum.hpp"
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp"
#include "src/Support/TypeUtilities.hpp"

using namespace mlir;
Expand Down Expand Up @@ -388,6 +390,62 @@ class SplitLargeMatMulPattern : public OpRewritePattern<ONNXMatMulOp> {
}
};

/// This pattern is to replace `C = add/sub(A, B)` by `A` when B is a zero
/// defined by Expand of scalar constant and C's shape is the same as A's shape.
/// In other words, the output does not depend on the second operand.
/// This pattern is similar to Add/SubZerosOnRhs in ConstProp.td but allows
/// dynamic shape.
template <typename OP_TYPE>
class AddSubWithRHSZeroExpandPattern : public OpRewritePattern<OP_TYPE> {
public:
DimAnalysis *dimAnalysis;

AddSubWithRHSZeroExpandPattern(MLIRContext *context, DimAnalysis *dimAnalysis)
: OpRewritePattern<OP_TYPE>(context, 1001), dimAnalysis(dimAnalysis) {}

LogicalResult matchAndRewrite(
OP_TYPE binaryOp, PatternRewriter &rewriter) const override {
// Match
if (!canBeRewritten(binaryOp, dimAnalysis))
return failure();
// Rewrite
rewriter.replaceOp(binaryOp.getOperation(), {binaryOp.getA()});
return success();
}

static bool canBeRewritten(OP_TYPE binaryOp, DimAnalysis *dimAnalysis) {
Value A = binaryOp.getA();
Value B = binaryOp.getB();
Value C = binaryOp.getC();

// Match
// C's shape is the same as A'shape.
if (!dimAnalysis->sameShape(A, C))
return false;
// B is a zero defined by Expand.
if (isa<BlockArgument>(B))
return false;
bool BIsZero = false;
if (auto expandOp = dyn_cast<ONNXExpandOp>(B.getDefiningOp())) {
Value input = expandOp.getInput();
if (isDenseONNXConstant(input)) {
// Expand's input is 0?
ElementsAttr constElements = getElementAttributeFromONNXValue(input);
Type elemType = constElements.getElementType();
if (!elemType.isInteger(1)) { // Booleans are not supported.
WideNum zeroWN = wideZeroDispatch(elemType, [](auto wideZero) {
using cpptype = decltype(wideZero);
constexpr BType TAG = toBType<cpptype>;
return WideNum::widen<TAG>(static_cast<cpptype>(0.0));
});
BIsZero = ElementsAttrBuilder::allEqual(constElements, zeroWN);
}
}
}
return BIsZero;
}
};

//===----------------------------------------------------------------------===//
// Rewrite ONNX ops to ZHigh ops and ONNX ops for ZHigh.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -441,11 +499,13 @@ void RewriteONNXForZHighPass::runOnOperation() {
//
// This is preferred for NNPA because NNPA BinaryOp does not support
// broadcasting.
target.addDynamicallyLegalOp<ONNXAddOp>([](ONNXAddOp op) {
target.addDynamicallyLegalOp<ONNXAddOp>([&dimAnalysis](ONNXAddOp op) {
return !((isDefinedByONNXConstantOp(op.getA()) &&
isUniBroadcatableFirstToSecond(op.getA(), op.getB())) ||
(isDefinedByONNXConstantOp(op.getB()) &&
isUniBroadcatableFirstToSecond(op.getB(), op.getA())));
isUniBroadcatableFirstToSecond(op.getB(), op.getA())) ||
AddSubWithRHSZeroExpandPattern<ONNXAddOp>::canBeRewritten(
op, &dimAnalysis));
});
target.addDynamicallyLegalOp<ONNXDivOp>([](ONNXDivOp op) {
return !((isDefinedByONNXConstantOp(op.getA()) &&
Expand All @@ -459,11 +519,13 @@ void RewriteONNXForZHighPass::runOnOperation() {
(isDefinedByONNXConstantOp(op.getB()) &&
isUniBroadcatableFirstToSecond(op.getB(), op.getA())));
});
target.addDynamicallyLegalOp<ONNXSubOp>([](ONNXSubOp op) {
target.addDynamicallyLegalOp<ONNXSubOp>([&dimAnalysis](ONNXSubOp op) {
return !((isDefinedByONNXConstantOp(op.getA()) &&
isUniBroadcatableFirstToSecond(op.getA(), op.getB())) ||
(isDefinedByONNXConstantOp(op.getB()) &&
isUniBroadcatableFirstToSecond(op.getB(), op.getA())));
isUniBroadcatableFirstToSecond(op.getB(), op.getA())) ||
AddSubWithRHSZeroExpandPattern<ONNXSubOp>::canBeRewritten(
op, &dimAnalysis));
});

// Determine if MatMulOp is already legal (no need to rewrite) or need to
Expand Down Expand Up @@ -536,6 +598,10 @@ void RewriteONNXForZHighPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
patterns.insert<SplitLargeMatMulPattern>(&getContext());
patterns.insert<AddSubWithRHSZeroExpandPattern<ONNXAddOp>>(
&getContext(), &dimAnalysis);
patterns.insert<AddSubWithRHSZeroExpandPattern<ONNXSubOp>>(
&getContext(), &dimAnalysis);

// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
Expand Down
5 changes: 5 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ Value OnnxBuilder::div(Value A, Value B) const {
return createOpAndInferShapes<ONNXDivOp>(toTensor(A), toTensor(B));
}

Value OnnxBuilder::expand(Type outputType, Value input, Value shape) const {
return createOpAndInferShapes<ONNXExpandOp>(
outputType, toTensor(input), toTensor(shape));
}

Value OnnxBuilder::matmul(Type Y, Value A, Value B, bool useGemm) const {
// Gemm only supports rank 2.
bool canUseGemm = useGemm && A.getType().isa<ShapedType>() &&
Expand Down
4 changes: 4 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ struct OnnxBuilder : DialectBuilder {
// ONNXDimGroupOp
void dimGroup(mlir::Value input, int axis, int groupID) const;

// ONNXExpandOp
mlir::Value expand(
mlir::Type outputType, mlir::Value input, mlir::Value shape) const;

// ONNXMatMulOp or ONNXGemmOp
mlir::Value matmul(
mlir::Type Y, mlir::Value A, mlir::Value B, bool useGemm = false) const;
Expand Down
7 changes: 7 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ template bool definedBy<ONNXCastOp>(Value v);
template bool definedBy<ONNXConcatOp>(Value v);
template bool definedBy<ONNXConstantOp>(Value v);
template bool definedBy<ONNXDimOp>(Value v);
template bool definedBy<ONNXExpandOp>(Value v);

/// Check if a value is to store dimensions, meaning it is defined by
/// Dim/Constant/Cast/Concat.
Expand Down Expand Up @@ -681,6 +682,12 @@ int64_t mlirTypeToOnnxType(Type elemType) {
return onnxType;
}

bool isScalarTensor(Value v) {
return (hasShapeAndRank(v) &&
((getRank(v.getType()) == 0) ||
(getRank(v.getType()) == 1 && getShape(v.getType())[0] == 1)));
}

//===----------------------------------------------------------------------===//
// Support for location.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ mlir::Type convertONNXTypeToMLIRType(
/// Get the ONNX type corresponding to an MLIR type.
int64_t mlirTypeToOnnxType(mlir::Type elemType);

/// Check if a value is a scalar tensor.
bool isScalarTensor(mlir::Value v);

//===----------------------------------------------------------------------===//
// Support for dim operations.
//===----------------------------------------------------------------------===//
Expand Down
139 changes: 139 additions & 0 deletions src/Dialect/ONNX/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,68 @@ class BinaryOpBroadcastAxisPattern : public OpRewritePattern<OP_TYPE> {
}
};

// A pattern to turn
// `BinaryOp(Constant_X, ExpandOp(Constant_Y))`
// into
// `ExpandOp(BinaryOp(Constant_X, Constant_Y))`
// which put constants together so that BinaryOp can be folded. This pattern
// only handles the case where one of the operand is a scalar constant. For such
// a case, we can easily infer the shape operand for the resulting ExpandOp.

template <typename OP_TYPE>
class PropagateScalarConstantExpandPattern : public OpRewritePattern<OP_TYPE> {
public:
using OpRewritePattern<OP_TYPE>::OpRewritePattern;

LogicalResult matchAndRewrite(
OP_TYPE binaryOp, PatternRewriter &rewriter) const override {
Operation *op = binaryOp.getOperation();
Location loc = binaryOp.getLoc();

assert(op->getNumOperands() == 2 && "op must be binary");
Value lhs = op->getOperand(0);
Value rhs = op->getOperand(1);
Type outputType = op->getResult(0).getType();

// Match
// - lhs is a scalar constant, and
// - rhs is ExpandOp whose input is a scalar constant, or vice versa.
Value expandShape = nullptr;
auto matchValue = [&expandShape](Value v) -> Value {
Value res = v;
if (auto expandOp =
dyn_cast_if_present<ONNXExpandOp>(res.getDefiningOp())) {
if (!expandShape) {
res = expandOp.getInput();
expandShape = expandOp.getShape();
}
}
if (isDenseONNXConstant(res) && isScalarTensor(res))
return res;
return nullptr;
};
Value lhsConstant = matchValue(lhs);
Value rhsConstant = matchValue(rhs);
if (!expandShape || !lhsConstant || !rhsConstant)
return failure();
// Does not handle empty shape in ExpandOp, e.g. of type tensor<0xdtype>.
if (!hasShapeAndRank(expandShape))
return failure();
ArrayRef<int64_t> dims = getShape(expandShape.getType());
if ((dims.size() == 1) && (dims[0] == 0))
return failure();

// Rewrite
MultiDialectBuilder<OnnxBuilder> create(rewriter, loc);
Value res = create.onnx.expand(outputType,
create.onnx.createOpAndInferShapes<OP_TYPE>(lhsConstant, rhsConstant),
expandShape);

rewriter.replaceOp(op, {res});
return success();
}
};

// =============================================================================
// Rewrite pattern for Resize (not handled in Rewrite.td).
// =============================================================================
Expand Down Expand Up @@ -874,6 +936,78 @@ class PowToMulRewritePattern : public OpRewritePattern<ONNXPowOp> {
int64_t maxPower;
};

// Rewrite a pattern like the following:
//
// %shape = onnx.Concat(%dim1, %dim2)
// %data = onnx.Expand(%input, %shape)
// %u = "onnx.Unsqueeze"(%data, %axes)
//
// into
//
// %new_shape = onnx.Concat(%dim1, %dim2, 1)
// %u = onnx.Expand(%input, %new_shape)
class ReplaceUnsqueezeOfExpandRewritePattern
: public OpRewritePattern<ONNXUnsqueezeOp> {
public:
using OpRewritePattern<ONNXUnsqueezeOp>::OpRewritePattern;

ReplaceUnsqueezeOfExpandRewritePattern(MLIRContext *context)
: OpRewritePattern(context) {}

LogicalResult matchAndRewrite(
ONNXUnsqueezeOp unsqueezeOp, PatternRewriter &rewriter) const override {
Operation *op = unsqueezeOp.getOperation();
Location loc = unsqueezeOp.getLoc();
Value data = unsqueezeOp.getData();
Value axes = unsqueezeOp.getAxes();

// Match
// 1. data is from ExpandOp, axes is from ConstantOp.
if (!definedBy<ONNXExpandOp>(data) || !definedBy<ONNXConstantOp>(axes))
return failure();
auto expandOp = cast<ONNXExpandOp>(data.getDefiningOp());
// 2. ExpandOp's input is a scalar tensor so that it's safe to use a new
// shape that do not violate the broadcasting rule..
if (!isScalarTensor(expandOp.getInput()))
return failure();
// 3. ExpandOp's shape is defined by dimensions.
if (!areDims(expandOp.getShape()))
return failure();

// Rewrite
MultiDialectBuilder<OnnxBuilder> create(rewriter, loc);
// Get the old shape.
SmallVector<Value, 4> oldDims;
getDims(expandOp.getShape(), oldDims);
int64_t oldRank = oldDims.size();
// Get unsqueeze axes.
ElementsAttr axesAttrs = getElementAttributeFromONNXValue(axes);
SmallVector<int64_t> axesI64(axesAttrs.getValues<int64_t>());
for (unsigned int i = 0; i < axesI64.size(); ++i)
if (axesI64[i] < 0)
axesI64[i] += oldRank;

// Construct a new shape.
SmallVector<Value, 4> newDims;
int64_t newRank = oldRank + axesI64.size();
Value one = create.onnx.constantInt64(ArrayRef<int64_t>({1}));
for (int64_t i = 0, j = 0; i < newRank || j < oldRank; ++i)
if (std::find(axesI64.begin(), axesI64.end(), i) != axesI64.end())
// found i in unsqueeze axes.
newDims.emplace_back(one);
else
// original axes.
newDims.emplace_back(oldDims[j++]);
Value newShape = create.onnx.concat(
RankedTensorType::get({newRank}, rewriter.getI64Type()), newDims, 0);

Value res = create.onnx.expand(
op->getResult(0).getType(), expandOp.getInput(), newShape);
rewriter.replaceOp(op, {res});
return success();
};
};

// =============================================================================
/// Register optimization patterns as "canonicalization" patterns.
/// Add op to OpsWithCanonicalizer in gen_onnx_mlir.py to activate.
Expand All @@ -897,6 +1031,7 @@ void ONNXAddOp::getCanonicalizationPatterns(
results.insert<FuseAddConvPattern>(context);
results.insert<FuseAddConvNullBiasPattern>(context);
results.insert<BinaryOpBroadcastAxisPattern<ONNXAddOp>>(context);
results.insert<PropagateScalarConstantExpandPattern<ONNXAddOp>>(context);
}

/// on the ONNXAndOp.
Expand Down Expand Up @@ -934,6 +1069,7 @@ void ONNXDepthToSpaceOp::getCanonicalizationPatterns(
void ONNXDivOp::getCanonicalizationPatterns(
RewritePatternSet &result, MLIRContext *context) {
result.insert<BinaryOpBroadcastAxisPattern<ONNXDivOp>>(context);
result.insert<PropagateScalarConstantExpandPattern<ONNXDivOp>>(context);
}

/// on the ONNXDropoutOp.
Expand Down Expand Up @@ -1017,6 +1153,7 @@ void ONNXMulOp::getCanonicalizationPatterns(
results.insert<NormalizeMulPattern>(context);
results.insert<FuseMulConvNullBiasPattern>(context);
results.insert<BinaryOpBroadcastAxisPattern<ONNXMulOp>>(context);
results.insert<PropagateScalarConstantExpandPattern<ONNXMulOp>>(context);
}

/// on the ONNXOrOp.
Expand Down Expand Up @@ -1057,6 +1194,7 @@ void ONNXShapeOp::getCanonicalizationPatterns(
void ONNXSubOp::getCanonicalizationPatterns(
RewritePatternSet &result, MLIRContext *context) {
result.insert<BinaryOpBroadcastAxisPattern<ONNXSubOp>>(context);
result.insert<PropagateScalarConstantExpandPattern<ONNXSubOp>>(context);
}

/// on ONNXShapeTransformOp
Expand Down Expand Up @@ -1143,6 +1281,7 @@ void ONNXUnsqueezeOp::getCanonicalizationPatterns(
RewritePatternSet &result, MLIRContext *context) {
result.insert<RemoveUnsqueezeSqueezePattern>(context);
result.insert<RemoveUnsqueezeCastSqueezePattern>(context);
result.insert<ReplaceUnsqueezeOfExpandRewritePattern>(context);
}

void ONNXUnsqueezeV11Op::getCanonicalizationPatterns(
Expand Down
1 change: 1 addition & 0 deletions src/Transform/ONNX/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
target.addIllegalOp<ONNXClipV6Op>();
target.addIllegalOp<ONNXClipV11Op>();
target.addIllegalOp<ONNXClipV12Op>();
target.addIllegalOp<ONNXConstantOfShapeOp>();
target.addIllegalOp<ONNXEinsumOp>();
target.addIllegalOp<ONNXLogSoftmaxOp>();
target.addIllegalOp<ONNXPadV2Op>();
Expand Down
17 changes: 17 additions & 0 deletions src/Transform/ONNX/Decompose.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def CreateNoneValue : NativeCodeCall<"$_builder.create<ONNXNoneOp>($_loc)">;
def createScalarDenseAttrRank0
: NativeCodeCall<"::onnx_mlir::createScalarDenseAttr($_builder, $0)">;

// Create a scalar DenseElementsAttr of tensor<dtype> from an ElementsAttr.
// The input ElementsAttr must have only one element. Otherwise only the first
// element is used to create the scalar DenseElementsAttr
def ReshapeElementsAttrToRank0 : NativeCodeCall<
"onnx_mlir::OnnxElementsAttrBuilder($0.getContext()).reshape(cast<ElementsAttr>($0), {})">;

// Create a DenseElementsAttr from a single attribute.
def createDenseArrayAttrFromSingleAttr
: NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">;
Expand Down Expand Up @@ -511,4 +517,15 @@ def ConvTransposeOpPattern2: Pattern<
(addBenefit 0)
>;

//===----------------------------------------------------------------------===//
// Rewrite `ONNXConstantOfShapeOp {value} (%shape)` into
// `ONNXExpandOp(ONNXConstantOp {value}, %shape)
//===----------------------------------------------------------------------===//

def ConstantOfShapePattern: Pat<
(ONNXConstantOfShapeOp:$res $shape, $value),
(ONNXExpandOp (ONNXConstantOpFromDenseAttr (ReshapeElementsAttrToRank0 $value)),
$shape)
>;

#endif // ONNX_DECOMPOSE
Loading

0 comments on commit ae607aa

Please sign in to comment.