Skip to content

Commit

Permalink
Recompose multiple ops into a single ONNXGelu (#2965)
Browse files Browse the repository at this point in the history
Recompose multiple ops into a single ONNXGelu (#2965)

Signed-off-by: Tung D. Le <[email protected]>

---------

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Oct 4, 2024
1 parent 265ee60 commit 7c58751
Show file tree
Hide file tree
Showing 8 changed files with 480 additions and 20 deletions.
5 changes: 5 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ Value OnnxBuilder::expand(Type outputType, Value input, Value shape) const {
outputType, toTensor(input), toTensor(shape));
}

Value OnnxBuilder::gelu(Value input, StringAttr approximateAttr) const {
return createOpAndInferShapes<ONNXGeluOp>(
toTensor(input.getType()), input, approximateAttr);
}

// ONNXLayerNormalizationOp, version with one output only (Y).
Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
Value bias, int64_t axis, FloatAttr epsilon) const {
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ struct OnnxBuilder : DialectBuilder {
mlir::Value expand(
mlir::Type outputType, mlir::Value input, mlir::Value shape) const;

// ONNXGeluOp
mlir::Value gelu(mlir::Value input, mlir::StringAttr approximateAttr) const;

// ONNXLayerNormalizationOp, version with one output only (Y).
mlir::Value layerNorm(mlir::Type outputType, mlir::Value input,
mlir::Value scale, mlir::Value bias, int64_t axis,
Expand Down
18 changes: 18 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,24 @@ RESULT_TYPE getScalarValue(ONNXConstantOp constantOp) {
template double getScalarValue<double>(ONNXConstantOp constantOp);
template int64_t getScalarValue<int64_t>(ONNXConstantOp constantOp);

/// Return the wide type of a value.
WideNum asWideNum(double n, Type elemType) {
return wideZeroDispatch(elemType, [n](auto wideZero) {
using cpptype = decltype(wideZero);
constexpr BType TAG = toBType<cpptype>;
return WideNum::widen<TAG>(static_cast<cpptype>(n));
});
}

/// Checks whether a constant tensor's elements are all equal to a given scalar.
bool isConstOf(Value constValue, double n) {
ElementsAttr constElements = getElementAttributeFromONNXValue(constValue);
Type elemType = constElements.getElementType();
assert(!elemType.isInteger(1) && "booleans are not supported");
WideNum w = asWideNum(n, elemType);
return ElementsAttrBuilder::allEqual(constElements, w);
}

// Convert type to MLIR type.
// A complete list of types can be found in:
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
Expand Down
43 changes: 43 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ RESULT_TYPE getScalarValue(mlir::ElementsAttr denseAttr, mlir::Type type);
template <typename RESULT_TYPE>
RESULT_TYPE getScalarValue(mlir::ONNXConstantOp constantOp);

/// Return the wide type of a value.
WideNum asWideNum(double n, mlir::Type elemType);

/// Checks whether a constant tensor's elements are all equal to a given scalar.
bool isConstOf(mlir::Value constValue, double n);

mlir::Type convertONNXTypeToMLIRType(
mlir::Builder &builder, onnx::TensorProto_DataType onnxType);

Expand Down Expand Up @@ -277,6 +283,43 @@ bool operandOfOpDefinedBy(mlir::Operation *&matchOp, mlir::Operation *op,
mlir::Value &matchOperand0, mlir::Value &matchOperand1,
int64_t matchThisOperandIndex);

// This is to recognize a binary op, e.g. A*B where one of A and B is a constant
// and the other one is defined by OP.
// Note: this function can handle the communitive property of the binary op.
//
// For example, to recognize this pattern:
// %x = "onnx.Tanh"()
// %y = 0.5 * %x // or %x * 0.5
//
// we call
// ```
// ONNXTanhOp tanhOp;
// bool found = matchConstAndOp<ONNXTanhOp>(A, B, 0.5, tanhOp);
// ```
// where `A` and `B` are operands of ONNXMul that produces %y.
template <typename OP>
bool matchConstAndOp(mlir::Value A, mlir::Value B, double cst, OP &op);

// This is to recognize a binary op, e.g. A*B where one of A and B is the given
// value and the other one is defined by OP.
// Note: this function can handle the communitive property of the binary op.
//
// For example, to recognize this pattern where %z is one of the inputs of *,
// and the other input of * is defined by onnx.Tanh:
// %x = "onnx.Tanh"()
// %y = %z * %x // or %x * %z
//
// we call
// ```
// Value z;
// ONNXTanhOp tanhOp;
// bool found = matchConstAndOp<ONNXTanhOp>(A, B, z, tanhOp);
// ```
// where `A` and `B` are operands of ONNXMul that produces %y.
template <typename OP>
bool matchValueAndOp(
mlir::Value A, mlir::Value B, mlir::Value matchValue, OP &matchOp);

/// Check if a value is to store dimensions, meaning it is a tensor of one
/// element or concatenation of one-element tensors.
bool areDims(mlir::Value val);
Expand Down
62 changes: 62 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,65 @@ bool operandOfOpDefinedBy(mlir::Operation *&matchOp, mlir::Operation *op,
}
return false;
}

// This is to recognize a binary op, e.g. A*B where one of A and B is a constant
// and the other one is defined by OP.
// Note: this function can handle the communitive property of the binary op.
//
// For example, to recognize this pattern:
// %x = "onnx.Tanh"()
// %y = 0.5 * %x // or %x * 0.5
//
// we call
// ```
// ONNXTanhOp tanhOp;
// bool found = matchConstAndOp<ONNXTanhOp>(A, B, 0.5, tanhOp);
// ```
// where `A` and `B` are operands of ONNXMul that produces %y.
template <typename OP>
bool matchConstAndOp(mlir::Value A, mlir::Value B, double cst, OP &matchOp) {
auto opA = A.getDefiningOp<OP>();
auto opB = B.getDefiningOp<OP>();
if (onnx_mlir::isDenseONNXConstant(A) && onnx_mlir::isConstOf(A, cst) && opB)
{
matchOp = opB;
return true;
}
if (opA && onnx_mlir::isDenseONNXConstant(B) && onnx_mlir::isConstOf(B, cst))
{
matchOp = opA;
return true;
}
return false;
}

// This is to recognize a binary op, e.g. A*B where one of A and B is the given
// value and the other one is defined by OP.
// Note: this function can handle the communitive property of the binary op.
//
// For example, to recognize this pattern where %z is one of the inputs of *,
// and the other input of * is defined by onnx.Tanh:
// %x = "onnx.Tanh"()
// %y = %z * %x // or %x * %z
//
// we call
// ```
// Value z;
// ONNXTanhOp tanhOp;
// bool found = matchConstAndOp<ONNXTanhOp>(A, B, z, tanhOp);
// ```
// where `A` and `B` are operands of ONNXMul that produces %y.
template <typename OP>
bool matchValueAndOp(mlir::Value A, mlir::Value B, mlir::Value matchValue, OP &matchOp) {
auto opA = A.getDefiningOp<OP>();
auto opB = B.getDefiningOp<OP>();
if ((A == matchValue) && opB) {
matchOp = opB;
return true;
}
if (opA && (B == matchValue)) {
matchOp = opA;
return true;
}
return false;
}
17 changes: 0 additions & 17 deletions src/Dialect/ONNX/Transforms/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,23 +186,6 @@ Value createMinimumValueForClip(
llvm::APFloat::getLargest, true, llvm::APInt::getMinValue);
}

WideNum asWideNum(double n, Type elemType) {
return wideZeroDispatch(elemType, [n](auto wideZero) {
using cpptype = decltype(wideZero);
constexpr BType TAG = toBType<cpptype>;
return WideNum::widen<TAG>(static_cast<cpptype>(n));
});
}

/// Checks whether a constant tensor's elements are all equal to a given scalar.
bool isConstOf(Value constValue, double n) {
ElementsAttr constElements = getConstValueElements(constValue);
Type elemType = constElements.getElementType();
assert(!elemType.isInteger(1) && "booleans are not supported");
WideNum w = asWideNum(n, elemType);
return ElementsAttrBuilder::allEqual(constElements, w);
}

// Extracts number from a scalar constant value.
WideNum getScalarNum(Value constValue) {
ElementsAttr elements = getConstValueElements(constValue);
Expand Down
Loading

0 comments on commit 7c58751

Please sign in to comment.