Skip to content

Commit

Permalink
Merge branch 'main' into nnpa-fix-nchw-rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
tungld authored Sep 11, 2023
2 parents e65d766 + 562b9b9 commit 638ecff
Show file tree
Hide file tree
Showing 35 changed files with 169 additions and 3,508 deletions.
2 changes: 1 addition & 1 deletion src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ op_dialect_version_map_["OptionalGetElement"] = {18};
op_dialect_version_map_["OptionalHasElement"] = {18};
op_dialect_version_map_["Or"] = {7};
op_dialect_version_map_["PRelu"] = {16};
op_dialect_version_map_["Pad"] = {18, 13, 11, 2};
op_dialect_version_map_["Pad"] = {19, 13, 11, 2};
op_dialect_version_map_["Pow"] = {15};
op_dialect_version_map_["QLinearConv"] = {10};
op_dialect_version_map_["QLinearMatMul"] = {10};
Expand Down
9 changes: 0 additions & 9 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ std::string instrumentOps; // onnx-mlir only
unsigned instrumentControlBits; // onnx-mlir only
bool instrumentONNXSignature; // onnx-mlir only
std::string ONNXOpStats; // onnx-mlir only
bool enableMemoryBundling; // onnx-mlir only
int onnxOpTransformThreshold; // onnx-mlir only
bool onnxOpTransformReport; // onnx-mlir only
bool enableParallel; // onnx-mlir only
Expand Down Expand Up @@ -376,14 +375,6 @@ static llvm::cl::opt<std::string, true> ONNXOpStatsOpt("onnx-op-stats",
llvm::cl::location(ONNXOpStats), llvm::cl::init(""),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> enableMemoryBundlingOpt(
"enable-memory-bundling",
llvm::cl::desc(
"Enable memory bundling related optimizations (default=false)\n"
"Set to 'false' if you experience significant compile time."),
llvm::cl::location(enableMemoryBundling), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<int, true> onnxOpTransformThresholdOpt(
"onnx-op-transform-threshold",
llvm::cl::desc(
Expand Down
1 change: 0 additions & 1 deletion src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ extern std::string instrumentOps; // onnx-mlir only
extern unsigned instrumentControlBits; // onnx-mlir only
extern bool instrumentONNXSignature; // onnx-mlir only
extern std::string ONNXOpStats; // onnx-mlir only
extern bool enableMemoryBundling; // onnx-mlir only
extern int onnxOpTransformThreshold; // onnx-mlir only
extern bool onnxOpTransformReport; // onnx-mlir only
extern bool enableParallel; // onnx-mlir only
Expand Down
6 changes: 0 additions & 6 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,6 @@ void addKrnlToLLVMPasses(
// https://mlir.llvm.org/docs/BufferDeallocationInternals.
pm.addNestedPass<func::FuncOp>(
mlir::bufferization::createBufferDeallocationPass());
if (enableMemoryBundling) {
pm.addNestedPass<func::FuncOp>(krnl::createKrnlEnableMemoryPoolPass());
pm.addNestedPass<func::FuncOp>(krnl::createKrnlBundleMemoryPoolsPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(krnl::createKrnlOptimizeMemoryPoolsPass());
}

// The pass below is needed for subview and collapseShape.. Unfortunately,
// MLIR supports only collapse for scalar loaded by scalar memory at this
Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
* SPDX-License-Identifier: Apache-2.0
*/

//===------ KrnlGetRefOp.cpp - Lower KrnlGetRefOp -------------------------===//
//===------ KrnlVectorTypeCastOp.cpp - Lower KrnlVectorTypeCastOp ---------===//
//
// Copyright 2019-2023 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the KrnlGetRefOp operator.
// This file lowers the KrnlVectorTypeCastOp operator.
//
//===----------------------------------------------------------------------===//

Expand Down
1 change: 1 addition & 0 deletions src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ struct ONNXCategoryMapperOpLowering
(shape[i] == ShapedType::kDynamic) ? 1 : shape[i]);
auto memRefType = MemRefType::get(
newShape, krnl::StringType::get(elementType.getContext()));
// Sole use of krnl.getRef.
Value stringMemRef = createKrnl.getRef(memRefType, memref, zero);
inputElem = createKrnl.load(stringMemRef, loopInd);
})
Expand Down
2 changes: 0 additions & 2 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,6 @@ void populateLoweringONNXShapeTransformOpPattern(
void populateLoweringONNXCustomOpPattern(
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);

bool checkOpResultIsUsedByGetRef(mlir::memref::AllocOp *allocOp);

/// This function returns the index in the list of alloc arguments of the
/// dynamic dimension corresponding to `index` in the MemRef shape.
/// As an example:
Expand Down
13 changes: 10 additions & 3 deletions src/Conversion/ONNXToKrnl/RNN/GRU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ template <>
void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
ConversionPatternRewriter &rewriter, Location loc, Value Xt, GruState state,
GruActivationPack activationPack, GruWeightPack weightPack,
GruBiasPack biasPack, Value sequenceIV, Value directionIV, bool isForward) {
GruBiasPack biasPack, Value sequenceIV, Value directionIV,
Value sequenceLens, Value initialH, bool isForward) {
// Equations (Default: f=Sigmoid, g=Tanh):"
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)"
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)"
Expand Down Expand Up @@ -498,7 +499,10 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
Value nextHt = createMath.add(ztht, ztHt);

// Store the intermediate Ht.
createKrnl.store(nextHt, Ht, indices);
// Handle sequence_lens
nextHt = handleSequenceLens(createKrnl, createMath, sequenceLens,
initialH, nextHt, sequenceIV, directionIV, bs, hs, Ht);

if (!isNoneValue(state.allH))
createKrnl.store(
nextHt, state.allH, {sequenceIV, directionIV, bs, hs});
Expand Down Expand Up @@ -602,7 +606,10 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
Value nextHt = createMath.add(ztht, ztHt);

// Store the intermediate Ht.
createKrnl.store(nextHt, Ht, indices);
// Handle sequence_lens
nextHt = handleSequenceLens(createKrnl, createMath, sequenceLens,
initialH, nextHt, sequenceIV, directionIV, bs, hs, Ht);

if (!isNoneValue(state.allH))
createKrnl.store(
nextHt, state.allH, {sequenceIV, directionIV, bs, hs});
Expand Down
9 changes: 8 additions & 1 deletion src/Conversion/ONNXToKrnl/RNN/LSTM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ void calculateState<LstmState, LstmActivationPack, LstmWeightPack,
LstmBiasPack>(ConversionPatternRewriter &rewriter, Location loc, Value Xt,
LstmState state, LstmActivationPack activationPack,
LstmWeightPack weightPack, LstmBiasPack biasPack, Value sequenceIV,
Value directionIV, bool isForward) {
Value directionIV, Value sequenceLens, Value initialH, bool isForward) {
// Equations for LSTM.
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
Expand All @@ -452,6 +452,13 @@ void calculateState<LstmState, LstmActivationPack, LstmWeightPack,
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
// Ht = ot (.) h(Ct)

// ToFix: add support of sequence lens for LSTM
// The assert will fail the test_lstm_with_peephole.
// In that test case, the length of the input is used as sequence_lens.
// Therefore, onnx-mlir can pass the test by ignoring the sequence_lens
// paramenter.
// assert(isNoneValue(sequenceLens) && "not implemented yet");

// TODO remove scope
MultiDialectBuilder<KrnlBuilder, MathBuilder, MemRefBuilder, OnnxBuilder>
create(rewriter, loc);
Expand Down
6 changes: 5 additions & 1 deletion src/Conversion/ONNXToKrnl/RNN/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ template <>
void calculateState<RnnState, RnnActivationPack, RnnWeightPack, RnnBiasPack>(
ConversionPatternRewriter &rewriter, Location loc, Value Xt, RnnState state,
RnnActivationPack activationPack, RnnWeightPack weightPack,
RnnBiasPack biasPack, Value sequenceIV, Value directionIV, bool isForward) {
RnnBiasPack biasPack, Value sequenceIV, Value directionIV,
Value sequenceLens, Value initialH, bool isForward) {
// Equations for RNN.
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// Shape information:
Expand All @@ -311,6 +312,9 @@ void calculateState<RnnState, RnnActivationPack, RnnWeightPack, RnnBiasPack>(
// Wbi: [hidden_size]
// Rbi: [hidden_size]

// ToFix: add support of sequenceLens for RNN
assert(isNoneValue(sequenceLens) && "not implemented yet");

MultiDialectBuilder<KrnlBuilder, MathBuilder, MemRefBuilder, OnnxBuilder>
create(rewriter, loc);

Expand Down
29 changes: 29 additions & 0 deletions src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,33 @@ Value emitXSliceAt(ConversionPatternRewriter &rewriter, Location loc, Value X,
return sliceX;
}

// Change the nextHt and Ht value if sequenceLens is defined.
// When a sample reachs the limit of its sequence len, nextHt will be padded
// with 0 (or initialH), and Ht will keep the last value at the sequence end
// so that the final value Ht is the last value at their sequence len.
Value handleSequenceLens(KrnlBuilder &createKrnl, MathBuilder &createMath,
Value sequenceLens, Value initialH, Value nextHt, Value sequenceIV,
Value directionIV, Value bs, Value hs, Value Ht) {
if (!isNoneValue(sequenceLens)) {
Value sequenceUB = createKrnl.load(sequenceLens, {bs});
Value initial;
if (isNoneValue(initialH)) {
initial = createMath.constant(nextHt.getType(), 0.);
} else {
initial = createKrnl.load(initialH, {directionIV, bs, hs});
}
Value cond = createMath.sge(
createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB);
nextHt = createMath.select(cond, /*padding*/ initial, nextHt);

// Last HT should be the last in sequenceLens or the current result
Value lastHt =
createMath.select(cond, createKrnl.load(Ht, {bs, hs}), nextHt);
createKrnl.store(lastHt, Ht, {bs, hs});
} else {
createKrnl.store(nextHt, Ht, {bs, hs});
}
return nextHt;
}

} // namespace onnx_mlir
18 changes: 15 additions & 3 deletions src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ mlir::Value applyActivation(mlir::OpBuilder &rewriter, mlir::Location loc,
mlir::Value emitXSliceAt(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc, mlir::Value X, mlir::Value timestep);

// Change the nextHt and Ht value if sequenceLens is defined.
// When a sample reachs the limit of its sequence len, nextHt will be padded
// with 0 (or initialH), and Ht will keep the last value at the sequence end
// so that the final value Ht is the last value at their sequence len.
mlir::Value handleSequenceLens(KrnlBuilder &createKrnl, MathBuilder &createMath,
mlir::Value sequenceLens, mlir::Value initialH, mlir::Value nextHt,
mlir::Value sequenceIV, mlir::Value directionIV, mlir::Value bs,
mlir::Value hs, mlir::Value Ht);

// Override the following methods when lowering an RNN operation:
// - hasAllNoneOutput
// - getActivationPack
Expand Down Expand Up @@ -116,7 +125,8 @@ S allocAndInitializeStates(mlir::ConversionPatternRewriter &rewriter,
template <typename S, typename A, typename W, typename B>
void calculateState(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc, mlir::Value Xt, S state, A activationSet, W weight,
B bias, mlir::Value sequenceIV, mlir::Value directionIV, bool isForward);
B bias, mlir::Value sequenceIV, mlir::Value directionIV,
mlir::Value sequenceLens, mlir::Value initialH, bool isForward);

// Write states to the RNN's outputs.
template <typename RNNOp, typename S>
Expand All @@ -136,6 +146,8 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
mlir::Operation *op = rnnOp.getOperation();
mlir::Location loc = ONNXLoc<RNNOp>(op);
mlir::Value X = adaptor.getX();
mlir::Value sequenceLens = adaptor.getSequenceLens();
mlir::Value initialH = adaptor.getInitialH();

if (hasAllNoneOutput<RNNOp>(&rnnOp)) {
rewriter.eraseOp(op);
Expand Down Expand Up @@ -188,7 +200,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
// Emit calculation for one RNN step.
calculateState<S, A, W, B>(rewriter, loc, Xt, state,
activationForward, weightForward, biasForward, sequenceIV,
directionIV,
directionIV, sequenceLens, initialH,
/*isForward=*/true);
});
}
Expand Down Expand Up @@ -226,7 +238,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
// Emit calculation for one RNN step.
calculateState<S, A, W, B>(rewriter, loc, Xt, state,
activationReverse, weightReverse, biasReverse,
reverseSequenceIV, directionIV,
reverseSequenceIV, directionIV, sequenceLens, initialH,
/*isForward=*/false);
});
}
Expand Down
16 changes: 12 additions & 4 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,21 +629,29 @@ ElementsAttr ElementsAttrBuilder::gather(
ArrayRef<int64_t> inputShape = inputType.getShape();
assert(axis < inputShape.size() && "gather axis out of range");
auto postAxisShape = inputShape.drop_front(axis + 1);
ArrayRef<int64_t> indicesShape = indices.getShapedType().getShape();
ShapedType indicesType = indices.getShapedType();
assert(indicesType.getElementType().isSignlessInteger() &&
"gather indices must be i32 or i64");
ArrayRef<int64_t> indicesShape = indicesType.getShape();
SmallVector<int64_t> outShape(inputShape.take_front(axis));
outShape.append(indicesShape.begin(), indicesShape.end());
outShape.append(postAxisShape.begin(), postAxisShape.end());
auto outType = inputType.clone(outShape);
return fromWideNums(outType, [&](MutableArrayRef<WideNum> dst) {
size_t postAxisNumElements = ShapedType::getNumElements(postAxisShape);
ArrayBuffer<WideNum> src = getElementsWideNums(input);
ArrayBuffer<int64_t> indicesArray = getElementsArray<int64_t>(indices);
// Convert indices of any signed int element type to int64 by
// first promoting to WideNum and then casting to int64.
// In practice we support both int32 and int64 in this way.
ArrayBuffer<WideNum> indicesWideNums = getElementsWideNums(indices);
ArrayRef<int64_t> indicesArray =
castArrayRef<int64_t, WideNum>(indicesWideNums.get());
size_t axisInputSize = inputShape[axis];
size_t inputBlockLen = axisInputSize * postAxisNumElements;
size_t outBlockLen = indicesArray.get().size() * postAxisNumElements;
size_t outBlockLen = indicesArray.size() * postAxisNumElements;
size_t start = 0;
WideNum *out = dst.begin();
for (int64_t idx : indicesArray.get()) {
for (int64_t idx : indicesArray) {
int64_t adjustedIdx = idx < 0 ? idx + axisInputSize : idx;
const WideNum *in = src.get().begin() + adjustedIdx * postAxisNumElements;
for (size_t offset = start; offset < dst.size(); offset += outBlockLen) {
Expand Down
25 changes: 25 additions & 0 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -5168,6 +5168,8 @@ def ONNXPadOp:ONNX_Op<"Pad",

3) `edge` - pads with the edge values of array

4) `wrap` - wrap-around padding as if the data tensor forms a torus


Example 1 (`constant` mode):

Expand Down Expand Up @@ -5232,6 +5234,29 @@ def ONNXPadOp:ONNX_Op<"Pad",
[4.5, 4.5, 4.5, 5.7],
]
```

Example 4 (`wrap` mode):

```
data = [
[1.0, 1.2],
[2.3, 3.4],
[4.5, 5.7],
]

pads = [2, 1, 1, 1]

mode = 'wrap'

output = [
[3.4, 2.3, 3.4, 2.3],
[5.7, 4.5, 5.7, 4.5],
[1.2, 1.0, 1.2, 1.0],
[3.4, 2.3, 3.4, 2.3],
[5.7, 4.5, 5.7, 4.5],
[1.2, 1.0, 1.2, 1.0],
]
```
}];
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>:$data,
TensorOf<[I64]>:$pads,
Expand Down
12 changes: 11 additions & 1 deletion src/Dialect/ONNX/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,15 @@ bool AreTheSameAxesConstant(int64_t rank, Value lhs, Value rhs) {
createArrayAttrFromConstantOp(rhsConstOp));
}

/// Test if two values have the same static shape or not.
bool haveSameStaticShape(Value lhs, Value rhs) {
if (!hasShapeAndRank(lhs) || !hasShapeAndRank(rhs))
return false;
Type lhsT = lhs.getType();
Type rhsT = rhs.getType();
return hasStaticShape(lhsT) && (getShape(lhsT) == getShape(rhsT));
}

} // namespace onnx_mlir

// =============================================================================
Expand Down Expand Up @@ -1020,7 +1029,8 @@ void ONNXOrOp::getCanonicalizationPatterns(
void ONNXReshapeOp::getCanonicalizationPatterns(
RewritePatternSet &result, MLIRContext *context) {
result.insert<FuseReshapePattern>(context);
result.insert<RemoveIdentityReshapePattern>(context);
result.insert<RemoveIdentityReshapePattern1>(context);
result.insert<RemoveIdentityReshapePattern2>(context);
result.insert<SwapReshapeMatMulPattern>(context);
}

Expand Down
16 changes: 14 additions & 2 deletions src/Dialect/ONNX/Rewrite.td
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ class HaveSameDim<int dim>: Constraint<
"$1.getType().cast<RankedTensorType>().getShape()[" # dim # "])">,
"Two tensors have the same specified dimension">;

def HaveSameStaticShape: Constraint<
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
"Two tensors have the same static shape">;

// Create a unit constant that will be used as none input.
def CreateNoneValue : NativeCodeCall<"$_builder.create<ONNXNoneOp>($_loc).getResult()">;

Expand Down Expand Up @@ -575,14 +579,22 @@ def FuseReshapePattern: Pat<
// Remove the first reshape op.
(ONNXReshapeOp $v, $s2, $az2)>;

def RemoveIdentityReshapePattern: Pat<
def RemoveIdentityReshapePattern1: Pat<
// Remove an identity pattern. Input tensor already has the specified shape.
(ONNXReshapeOp $val, $shape, $az),
// Remove the transpose.
// Remove the reshape.
(replaceWithValue $val),
// Check that val has the specified shape.
[(HasSpecifiedConstantShape $val, $shape)]>;

def RemoveIdentityReshapePattern2: Pat<
// Remove an identity pattern. Output and input shapes are static and the same.
(ONNXReshapeOp:$out $val, $_, $_),
// Remove the reshape.
(replaceWithValue $val),
// Check that val and out have the same static shape.
[(HaveSameStaticShape $out, $val)]>;

def GetReturnTypeForMatMulOpND2D: NativeCodeCall<
"onnx_mlir::getReturnTypeForMatMulOpND2D($0, $1)"
>;
Expand Down
Loading

0 comments on commit 638ecff

Please sign in to comment.