diff --git a/docs/SupportedONNXOps-cpu.md b/docs/SupportedONNXOps-cpu.md index 554c097f95..1200e38810 100644 --- a/docs/SupportedONNXOps-cpu.md +++ b/docs/SupportedONNXOps-cpu.md @@ -26,7 +26,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 19. Limitatio | **Asinh** |9 - * | | | | **Atan** |7 - * | | | | **Atanh** |9 - * | | | -| **AveragePool** |6 - 18 | | | +| **AveragePool** |6 - * | | | | **BatchNormalization** |6 - * |Training not supported. | | | **Bernoulli** |none | | | | | **Binarizer** |none | | | | diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td index 2762893b3d..5d6d2ab924 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td @@ -405,7 +405,7 @@ def GetI64ArrayAttrStridesAveragePool: NativeCodeCall< "($0.getDefiningOp()))">; def replaceONNXAveragePoolPattern : Pattern< - (ONNXAveragePoolOp:$res $x, $_, $_, $_, $_, $_, $_), + (ONNXAveragePoolOp:$res $x, $_, $_, $_, $_, $_, $_, $_), [ // Get attributes using shape helper (GetStrAttrPaddingtypeAveragePool:$padtype $res), diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp index e2db3a9e29..b0e9c7b34e 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp @@ -133,6 +133,12 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern { ZLowStickOp stickOp, PatternRewriter &rewriter) const override { Value stickInput = stickOp.getX(); + // Do not handle NCHW layout stickification that transposes data + // internally. + std::string stickLayout = stickOp.getLayout().value().str(); + if (stickLayout == LAYOUT_NCHW) + return failure(); + // Input is a block argument, ignore it. if (stickInput.dyn_cast()) return failure(); @@ -157,6 +163,11 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern { ZLowUnstickOp userOp = llvm::dyn_cast(user); if (!userOp) continue; + // Do not handle NCHW layout stickification that transposes data + // internally. + std::string unstickLayout = userOp.getLayout().value().str(); + if (unstickLayout == LAYOUT_NCHW) + continue; // UnstickOp must be before the view operation. if (userOp.getOut() == viewSource && user->isBeforeInBlock(viewOp.getOperation())) { diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 004df5abc9..ab51de133e 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -1235,7 +1235,10 @@ class FrontendGenImpl { } for (auto &name : functionProto.output()) { - outputs.push_back(LookupOnnxName(name)); + // Skip missing optional outputs: they are not mapped. + if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(name)) { + outputs.push_back(*valuePtr); + } } frontend_symbols_.popScope(scopeName); diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index 5867fc7e1a..b07cf9d9fd 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -18,7 +18,7 @@ op_dialect_version_map_["Asin"] = {7}; op_dialect_version_map_["Asinh"] = {9}; op_dialect_version_map_["Atan"] = {7}; op_dialect_version_map_["Atanh"] = {9}; -op_dialect_version_map_["AveragePool"] = {11}; +op_dialect_version_map_["AveragePool"] = {19}; op_dialect_version_map_["BatchNormalization"] = {15}; op_dialect_version_map_["Bernoulli"] = {15}; op_dialect_version_map_["Binarizer"] = {1}; @@ -157,7 +157,7 @@ op_dialect_version_map_["ReduceSum"] = {13, 11}; op_dialect_version_map_["ReduceSumSquare"] = {18, 13}; op_dialect_version_map_["Relu"] = {14}; op_dialect_version_map_["Reshape"] = {19}; -op_dialect_version_map_["Resize"] = {18, 13, 11, 10}; +op_dialect_version_map_["Resize"] = {19, 13, 11, 10}; op_dialect_version_map_["ReverseSequence"] = {10}; op_dialect_version_map_["RoiAlign"] = {16}; op_dialect_version_map_["Round"] = {11}; diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index fa656219ce..9344e4a806 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -1119,7 +1119,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, CheckIfCustomScalarOpIsSupported(elementType); Value dividend = scalarOperands[0]; Value divisor = scalarOperands[1]; - MultiDialectBuilder create(rewriter, loc); + MultiDialectBuilder create(rewriter, loc); // TODO: here we assume fmod=1, what should if that is not the case? if (create.math.isFloatWithVector(elementType)) { @@ -1136,9 +1136,59 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, #endif } if (create.math.isIntegerWithVector(elementType)) { - // TODO: implement - llvm_unreachable("not support integers at this moment since MLIR integers " - "are signless."); + // "math.rem" returns "minus" for minus dividend and "plus or zero" for plus + // dividend. We call the math.rem's return value "mathRemaider". However + // onnx.ModOp should return "minus" for minus divisor and "plus or zero" for + // plus divisor. we call the value that onnx.Mod op should return "onnxMod". + // The following table shows mathRemainder, onnxMod and their diference + // (=onnxMod-mathRemainder) for some inputs. + // + // dividend | 7 | 7 | -7 | -7 | 6 | 6 | -6 | -6 | + // divisor | 3 | -3 | 3 | -3 | 3 | -3 | 3 | -3 | + // ------------------------+-----+----+----+----+----+----+----+----+ + // mathRemainder | 1 | 1 | -1 | -1 | 0 | 0 | 0 | 0 | + // onnxMod | 1 | -2 | 2 | -1 | 0 | 0 | 0 | 0 | + // onnxMod - mathRemainder | 0 | -3 | 3 | 0 | 0 | 0 | 0 | 0 | + // + // The following code shows logic to get onnxMod from mathRemainder + // + // int dividend, divisor; + // int mathRemainder = diviend % divisor; + // int adjustedRemainder = mathRemainder + divisor; + // + // if ((mathRemainder != 0) && ((dividend < 0) ^ (divisor < 0))) # c.f. "^" + // shows "exclusive or". + // return adjustedRemainder; + // return mathRemainder; + + Value mathRemainder = create.math.rem(dividend, divisor); + Value adjustedRemainder = create.math.add(mathRemainder, divisor); + Value zero = create.math.constant(elementType, 0); + Value falseVal = create.math.constant(rewriter.getI1Type(), 0); + Value isMathRemainderNonZero = + create.math.eq(create.math.eq(mathRemainder, zero), falseVal); + Value isDividendMinus = create.math.slt(dividend, zero); + Value isDivisorMinus = create.math.slt(divisor, zero); + Value exclusiveOrOfIsDividendMinusAndIsDivisorMinus = create.math.eq( + create.math.eq(isDividendMinus, isDivisorMinus), falseVal); + Value needAdjust = create.math.andi( + isMathRemainderNonZero, exclusiveOrOfIsDividendMinusAndIsDivisorMinus); + Value answer = + create.math.select(needAdjust, adjustedRemainder, mathRemainder); + +#ifdef DEBUG_ONNX_MOD + create.krnl.printf("XXXX emitScalarOpFor: diviend=", dividend, + dividend.getType()); + create.krnl.printf(", divisor=", divisor, divisor.getType()); + create.krnl.printf( + ", mathReminder=", mathRemainder, mathRemainder.getType()); + create.krnl.printf( + ", adjustedReminder=", adjustedRemainder, adjustedRemainder.getType()); + create.krnl.printf(", Answer=", answer, answer.getType()); + create.krnl.printf("\n"); +#endif + + return answer; } llvm_unreachable("unsupported element type"); } diff --git a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp index 97405615b6..f2e8449298 100644 --- a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp @@ -57,15 +57,8 @@ Value emitScalarOpFor( // template std::vector getDilations(PoolOp poolOp) { - return {}; -} - -// MaxPool has dilations attribute. -template <> -std::vector getDilations( - ONNXMaxPoolSingleOutOp poolOp) { std::vector dilations; - auto dilationsAttribute = poolOp.getDilationsAttr(); + ArrayAttr dilationsAttribute = poolOp.getDilationsAttr(); bool isDefaultDilations = true; for (auto dilation : dilationsAttribute.getValue()) { int64_t dilationValue = dilation.cast().getInt(); @@ -84,13 +77,6 @@ std::vector getDilations( // template std::optional getDilationAttr(PoolOp poolOp) { - return std::nullopt; -} - -// MaxPool has dilations attribute. -template <> -std::optional getDilationAttr( - ONNXMaxPoolSingleOutOp poolOp) { return poolOp.getDilations(); } diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 5c327182b7..73b8fda8c8 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -629,7 +629,10 @@ ElementsAttr ElementsAttrBuilder::gather( ArrayRef inputShape = inputType.getShape(); assert(axis < inputShape.size() && "gather axis out of range"); auto postAxisShape = inputShape.drop_front(axis + 1); - ArrayRef indicesShape = indices.getShapedType().getShape(); + ShapedType indicesType = indices.getShapedType(); + assert(indicesType.getElementType().isSignlessInteger() && + "gather indices must be i32 or i64"); + ArrayRef indicesShape = indicesType.getShape(); SmallVector outShape(inputShape.take_front(axis)); outShape.append(indicesShape.begin(), indicesShape.end()); outShape.append(postAxisShape.begin(), postAxisShape.end()); @@ -637,13 +640,18 @@ ElementsAttr ElementsAttrBuilder::gather( return fromWideNums(outType, [&](MutableArrayRef dst) { size_t postAxisNumElements = ShapedType::getNumElements(postAxisShape); ArrayBuffer src = getElementsWideNums(input); - ArrayBuffer indicesArray = getElementsArray(indices); + // Convert indices of any signed int element type to int64 by + // first promoting to WideNum and then casting to int64. + // In practice we support both int32 and int64 in this way. + ArrayBuffer indicesWideNums = getElementsWideNums(indices); + ArrayRef indicesArray = + castArrayRef(indicesWideNums.get()); size_t axisInputSize = inputShape[axis]; size_t inputBlockLen = axisInputSize * postAxisNumElements; - size_t outBlockLen = indicesArray.get().size() * postAxisNumElements; + size_t outBlockLen = indicesArray.size() * postAxisNumElements; size_t start = 0; WideNum *out = dst.begin(); - for (int64_t idx : indicesArray.get()) { + for (int64_t idx : indicesArray) { int64_t adjustedIdx = idx < 0 ? idx + axisInputSize : idx; const WideNum *in = src.get().begin() + adjustedIdx * postAxisNumElements; for (size_t offset = start; offset < dst.size(); offset += outBlockLen) { diff --git a/src/Dialect/ONNX/ONNXDimAnalysis.cpp b/src/Dialect/ONNX/ONNXDimAnalysis.cpp index 2eb7ca4768..2bb3595904 100644 --- a/src/Dialect/ONNX/ONNXDimAnalysis.cpp +++ b/src/Dialect/ONNX/ONNXDimAnalysis.cpp @@ -68,21 +68,30 @@ static bool areOverlapping( static std::optional insertDimWhenUseful(const Value tensor, const uint64_t dimIndex, DimAnalysis::DimSetT &sameDims) { auto tensorType = cast(tensor.getType()); + uint64_t axis = dimIndex; bool okToInsert = false; if (tensor.isa()) { okToInsert = true; } else { Operation *op = tensor.getDefiningOp(); - if (isa(op) || - tensorType.isDynamicDim(dimIndex)) + // A constant of -1 to define a dynamic value, e.g. -1 in Reshape means a + // dynamic dimension computed from the other dimensions. It's a contant, no + // need to insert it. + if (isa(op)) + okToInsert = false; + else if (auto dimOp = dyn_cast(op)) { + // The correct axis is from ONNXDimOp. + axis = dimOp.getAxis(); + okToInsert = true; + } else if (isa(op) || tensorType.isDynamicDim(axis)) okToInsert = true; } if (!okToInsert) return std::nullopt; - DimAnalysis::DimT dim(tensor, dimIndex); + DimAnalysis::DimT dim(tensor, axis); sameDims.insert(dim); return dim; } @@ -122,9 +131,224 @@ static void findAndAddSameDim(const QuestionmarkIndexExpr &qmOuputIE, } } -/// Given a dynamic dimension, find the same dynamic dimensions in the inputs. -/// This function uses ShapeHelper to explore the same dynamic dimensions. -/// Use this function for operations that use adaptor to compute shape. +/// Given a dynamic dimension of a tensor, find the same dynamic dimensions in +/// the input tensors of the consuming operator. +/// +/// For example, in MatMul(A, B) : MxN * NxP, dimA[1] = dimB[0] = N. +static void exploreSameDimsFromConsumingOperators( + const DimAnalysis::DimT &dim, DimAnalysis::DimSetT &sameDims) { + LLVM_DEBUG(llvm::dbgs() << "Explore using consuming operators\n"); + for (Operation *op : dim.first.getUsers()) { + LLVM_DEBUG({ + llvm::dbgs() << " - exploring "; + op->dump(); + }); + if (auto concatOp = dyn_cast(op)) { + // Dimensions on the same axis (except the concatenating axis) are the + // same across all inputs. + int64_t axis = concatOp.getAxis(); + ValueRange operands = concatOp.getOperands(); + if (llvm::any_of(operands, [](Value v) { return !hasShapeAndRank(v); })) + continue; + uint64_t rank = getRank(operands[0].getType()); + if (axis < 0) + axis += rank; + // Find the axis of the working dimension. + int64_t dimAxis = -1; + for (uint64_t i = 0; i < operands.size(); ++i) { + for (uint64_t r = 0; r < rank; ++r) { + DimAnalysis::DimT NinA(operands[i], r); + if (NinA == dim) { + dimAxis = r; + break; + } + } + } + if (dimAxis == -1 || dimAxis == axis) + continue; + for (uint64_t i = 0; i < operands.size(); ++i) { + if (operands[i] == dim.first) + continue; + if (auto d = insertDimWhenUseful(operands[i], dimAxis, sameDims)) + LLVM_DEBUG(llvm::dbgs() + << " - [Concat] Added a new dim(" << d.value().first + << ", " << d.value().second << ")\n"); + } + continue; + } + if (auto gemmOp = dyn_cast(op)) { + Value A = gemmOp.getA(); + Value B = gemmOp.getB(); + if (!hasShapeAndRank(A) || !hasShapeAndRank(B)) + continue; + bool transA = (gemmOp.getTransA() != 0); + bool transB = (gemmOp.getTransB() != 0); + DimAnalysis::DimT NinA(A, transA ? 0 : 1); + DimAnalysis::DimT NinB(B, transB ? 1 : 0); + if (NinA == dim) { + if (auto d = insertDimWhenUseful(NinB.first, NinB.second, sameDims)) + LLVM_DEBUG(llvm::dbgs() + << " - [Gemm] Added a new dim(" << d.value().first << ", " + << d.value().second << ")\n"); + } else if (NinB == dim) { + if (auto d = insertDimWhenUseful(NinA.first, NinA.second, sameDims)) + LLVM_DEBUG(llvm::dbgs() + << " - [Gemm] Added a new dim(" << d.value().first << ", " + << d.value().second << ")\n"); + } + continue; + } + if (auto gruOp = dyn_cast(op)) { + int64_t layout = gruOp.getLayout(); + // In LSTM, sequence_lens and batch_size are potentially dynamic. + // Only batch_size is used in multiple inputs, so we'll check batch_size. + // X: [seq_length, batch_size, input_size] + Value X = gruOp.getX(); + // seqLen: [batch_size] + Value seqLen = gruOp.getSequenceLens(); + // initialH: [num_directions, batch_size, hidden_size] + Value initialH = gruOp.getInitialH(); + + // Collect batch_size dimensions from each input. + SmallVector batchDims; + if (hasShapeAndRank(X)) { + DimAnalysis::DimT d(X, (layout == 0) ? 1 : 0); + batchDims.emplace_back(d); + } + if (!isNoneValue(seqLen) && hasShapeAndRank(seqLen)) { + DimAnalysis::DimT d(seqLen, 0); + batchDims.emplace_back(d); + } + if (!isNoneValue(initialH) && hasShapeAndRank(initialH)) { + DimAnalysis::DimT d(initialH, (layout == 0) ? 1 : 0); + batchDims.emplace_back(d); + } + + // Found same dims if the working dim is in the batch dim set. + if (llvm::any_of( + batchDims, [&dim](DimAnalysis::DimT d) { return d == dim; })) { + for (auto bd : batchDims) { + if (auto d = insertDimWhenUseful(bd.first, bd.second, sameDims)) + LLVM_DEBUG(llvm::dbgs() + << " - [GRU] Added a new dim(" << d.value().first + << ", " << d.value().second << ")\n"); + } + } + continue; + } + if (auto lstmOp = dyn_cast(op)) { + int64_t layout = lstmOp.getLayout(); + // In LSTM, sequence_lens and batch_size are potentially dynamic. + // Only batch_size is used in multiple inputs, so we'll check batch_size. + // X: [seq_length, batch_size, input_size] + Value X = lstmOp.getX(); + // seqLen: [batch_size] + Value seqLen = lstmOp.getSequenceLens(); + // initialH: [num_directions, batch_size, hidden_size] + Value initialH = lstmOp.getInitialH(); + // initialC: [num_directions, batch_size, hidden_size] + Value initialC = lstmOp.getInitialC(); + + // Collect batch_size dimensions from each input. + SmallVector batchDims; + if (hasShapeAndRank(X)) { + DimAnalysis::DimT d(X, (layout == 0) ? 1 : 0); + batchDims.emplace_back(d); + } + if (!isNoneValue(seqLen) && hasShapeAndRank(seqLen)) { + DimAnalysis::DimT d(seqLen, 0); + batchDims.emplace_back(d); + } + if (!isNoneValue(initialH) && hasShapeAndRank(initialH)) { + DimAnalysis::DimT d(initialH, (layout == 0) ? 1 : 0); + batchDims.emplace_back(d); + } + if (!isNoneValue(initialC) && hasShapeAndRank(initialC)) { + DimAnalysis::DimT d(initialC, (layout == 0) ? 1 : 0); + batchDims.emplace_back(d); + } + + // Found same dims if the working dim is in the batch dim set. + if (llvm::any_of( + batchDims, [&dim](DimAnalysis::DimT d) { return d == dim; })) { + for (auto bd : batchDims) { + if (auto d = insertDimWhenUseful(bd.first, bd.second, sameDims)) + LLVM_DEBUG(llvm::dbgs() + << " - [LSTM] Added a new dim(" << d.value().first + << ", " << d.value().second << ")\n"); + } + } + continue; + } + if (isa(op)) { + Value A = op->getOperands()[0]; + Value B = op->getOperands()[1]; + if (!hasShapeAndRank(A) || !hasShapeAndRank(B)) + continue; + uint64_t aRank = getRank(A.getType()); + uint64_t bRank = getRank(B.getType()); + if (aRank >= 2 && bRank >= 2) { + DimAnalysis::DimT NinA(A, aRank - 1); + DimAnalysis::DimT NinB(B, bRank - 2); + if (NinA == dim) { + if (auto d = insertDimWhenUseful(NinB.first, NinB.second, sameDims)) + LLVM_DEBUG(llvm::dbgs() + << " - [MatMul] Added a new dim(" << d.value().first + << ", " << d.value().second << ")\n"); + } else if (NinB == dim) { + if (auto d = insertDimWhenUseful(NinA.first, NinA.second, sameDims)) + LLVM_DEBUG(llvm::dbgs() + << " - [MatMul] Added a new dim(" << d.value().first + << ", " << d.value().second << ")\n"); + } + } + continue; + } + if (auto rnnOp = dyn_cast(op)) { + int64_t layout = rnnOp.getLayout(); + // In LSTM, sequence_lens and batch_size are potentially dynamic. + // Only batch_size is used in multiple inputs, so we'll check batch_size. + // X: [seq_length, batch_size, input_size] + Value X = rnnOp.getX(); + // seqLen: [batch_size] + Value seqLen = rnnOp.getSequenceLens(); + // initialH: [num_directions, batch_size, hidden_size] + Value initialH = rnnOp.getInitialH(); + + // Collect batch_size dimensions from each input. + SmallVector batchDims; + if (hasShapeAndRank(X)) { + DimAnalysis::DimT d(X, (layout == 0) ? 1 : 0); + batchDims.emplace_back(d); + } + if (!isNoneValue(seqLen) && hasShapeAndRank(seqLen)) { + DimAnalysis::DimT d(seqLen, 0); + batchDims.emplace_back(d); + } + if (!isNoneValue(initialH) && hasShapeAndRank(initialH)) { + DimAnalysis::DimT d(initialH, (layout == 0) ? 1 : 0); + batchDims.emplace_back(d); + } + + // Found same dims if the working dim is in the batch dim set. + if (llvm::any_of( + batchDims, [&dim](DimAnalysis::DimT d) { return d == dim; })) { + for (auto bd : batchDims) { + if (auto d = insertDimWhenUseful(bd.first, bd.second, sameDims)) + LLVM_DEBUG(llvm::dbgs() + << " - [RNN] Added a new dim(" << d.value().first + << ", " << d.value().second << ")\n"); + } + } + continue; + } + } +} + +/// Given an output dynamic dimension, find the same dynamic dimensions in the +/// inputs. This function uses ShapeHelper to explore the same dynamic +/// dimensions. Use this function for operations that use adaptor to compute +/// shape. static bool exploreSameDimsUsingShapeHelper(const DimAnalysis::DimT &dim, mlir::Operation *op, DimAnalysis::DimSetT &sameDims) { LLVM_DEBUG(llvm::dbgs() << "Explore using shape helper\n"); @@ -278,6 +502,10 @@ DimAnalysis::DimAnalysis(ArrayRef vals) { DimAnalysis::DimAnalysis(ModuleOp moduleOp) { moduleOp.walk([&](Operation *op) { + if (auto funcOp = dyn_cast(op)) { + for (Value arg : funcOp.getArguments()) + build(arg); + } for (Value output : op->getResults()) build(output); }); @@ -492,16 +720,25 @@ void DimAnalysis::visitDim( Value tensor = dim.first; uint64_t dimIndex = dim.second; - // Tensor is a block argument. Nothing to do further. + LLVM_DEBUG( + llvm::dbgs() << "\nVisiting dim(" << tensor << ", " << dimIndex << ")\n"); + + // When the current tensor is consumed by an operator, find the relation + // between dimensions of the input operands. + // + // For example, in MatMul(A, B) : MxN * NxP, dimA[1] = dimB[0]. + exploreSameDimsFromConsumingOperators(dim, sameDims); + + // The remaining code will find where a dimension comes from, depending on + // operation semantics, by *exploring the defining operator*. We utilize the + // operation's shape helper for this purpose as much as possible. + + // Tensor is a block argument. There is no defining operator to explore. if (tensor.isa()) return; - // Find where a dimension comes from, depending on operation semantics. - // We utilize the operation's shape helper for this purpose as much as - // possible. + // Get the defining operator. Operation *op = tensor.getDefiningOp(); - LLVM_DEBUG( - llvm::dbgs() << "\nVisiting dim(" << tensor << ", " << dimIndex << ")\n"); // Tensor is from a constant. Nothing to do further. if (isa(op)) @@ -647,19 +884,20 @@ void DimAnalysis::visitDim( int64_t dataRank = dataType.getRank(); int64_t outputRank = outputType.getRank(); if ((dataRank == 2) && (outputRank == 2)) { - // Find if the output dim is from an input dim. + // Find if the other output dim is from an input dim. int64_t iDim = -1; for (int64_t i = 0; i < dataRank; ++i) { - if (sameDynDim(data, i, output, 1 - dimIndex)) { + if (sameDim(data, i, output, 1 - dimIndex)) { iDim = i; - // The other output dim must be the same as the other input dim. - if (auto d = insertDimWhenUseful(data, 1 - iDim, sameDims)) - LLVM_DEBUG(llvm::dbgs() - << " - Case 2: Added a new dim(" << d.value().first - << ", " << d.value().second << ")\n"); + break; } - break; } + if (iDim != -1) + // The current output dim must be the same as the other input dim. + if (auto d = insertDimWhenUseful(data, 1 - iDim, sameDims)) + LLVM_DEBUG(llvm::dbgs() + << " - Case 2: Added a new dim(" << d.value().first + << ", " << d.value().second << ")\n"); } } } diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index d75ee6b66d..287d90944b 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -424,11 +424,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool", ``` output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1) ``` - if ceil_mode is enabled - - ``` - * pad_shape[i] is sum of pads along axis i - ``` + if ceil_mode is enabled `pad_shape[i]` is the sum of pads along axis `i`. `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following: ``` @@ -446,6 +442,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool", DefaultValuedStrAttr:$auto_pad, DefaultValuedAttr:$ceil_mode, DefaultValuedAttr:$count_include_pad, + OptionalAttr:$dilations, I64ArrayAttr:$kernel_shape, OptionalAttr:$pads, OptionalAttr:$strides); @@ -6967,8 +6964,10 @@ def ONNXResizeOp:ONNX_Op<"Resize", let summary = "ONNX Resize operation"; let description = [{ Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. - Each dimension value of the output tensor is:
- `output_dimension = floor(input_dimension * (roi_end - roi_start) * scale)`
+ Each dimension value of the output tensor is: + ``` + output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) + ``` if input \\"sizes\\" is not specified. }]; let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X, diff --git a/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp b/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp index 5426e541f7..95fca161ee 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp @@ -86,8 +86,8 @@ LogicalResult ONNXAveragePoolOpShapeHelper::computeShape() { ONNXAveragePoolOp poolOp = llvm::cast(op); return customComputeShape(operandAdaptor.getX(), /*W*/ nullptr, poolOp.getKernelShape(), poolOp.getAutoPad(), poolOp.getPads(), - poolOp.getStrides(), - /*dilation*/ std::nullopt, /*hasFilter*/ false, poolOp.getCeilMode()); + poolOp.getStrides(), poolOp.getDilations(), /*hasFilter*/ false, + poolOp.getCeilMode()); } } // namespace onnx_mlir @@ -117,6 +117,8 @@ LogicalResult ONNXAveragePoolOp::verify() { return failure(); if (failed(verifyStrides(this, spatialRank))) return failure(); + if (failed(verifyDilations(this, spatialRank))) + return failure(); if (failed(verifyPadding(this, spatialRank))) return failure(); return success(); diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index 9e58bd08c4..a727e316b2 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -148,7 +148,6 @@ def get_test_models(): # ==OP== AveragePool # ==MIN== 1 - # ==UNSUPPORTED== 19 # TODO: original comment stated "same_upper/lower with dynamic padding-shapes not supported." # However, I see the dyn shape test being done on all tests, including same_upper. So I am # assuming that this comment is outdated. @@ -659,8 +658,8 @@ def get_test_models(): # "test_mod_broadcast_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_mod_int64_fmod_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_mod_mixed_sign_int16_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, - # "test_mod_mixed_sign_int32_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, - # "test_mod_mixed_sign_int64_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_mod_mixed_sign_int32_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_mod_mixed_sign_int64_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_mod_mixed_sign_int8_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_mod_uint16_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_mod_uint32_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, diff --git a/test/mlir/accelerators/nnpa/analysis/dyn-dim-analysis.mlir b/test/mlir/accelerators/nnpa/analysis/dyn-dim-analysis.mlir index 96926e91d9..9d3f9b39d7 100644 --- a/test/mlir/accelerators/nnpa/analysis/dyn-dim-analysis.mlir +++ b/test/mlir/accelerators/nnpa/analysis/dyn-dim-analysis.mlir @@ -72,22 +72,22 @@ func.func @test_nhwc_layout(%arg0 : tensor) -> tensor // CHECK-LABEL: func.func @test_nhwc_layout // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 3 : si64, group_id = 11 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: [[VAR_0_:%.+]] = "onnx.Sigmoid"([[PARAM_0_]]) : (tensor) -> tensor // CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 3 : si64, group_id = 11 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "NHWC"} : (tensor) -> tensor> // CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 2 : si64, group_id = 11 : si64} : (tensor>) -> () -// CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 1 : si64, group_id = 2 : si64} : (tensor>) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 1 : si64, group_id = 14 : si64} : (tensor>) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 3 : si64, group_id = 7 : si64} : (tensor>) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor>) -> () // CHECK: [[VAR_2_:%.+]] = "zhigh.Unstick"([[VAR_1_]]) : (tensor>) -> tensor // CHECK-DAG: "onnx.DimGroup"([[VAR_2_]]) {axis = 3 : si64, group_id = 11 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[VAR_2_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_2_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_2_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_2_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: onnx.Return [[VAR_2_]] : tensor diff --git a/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir b/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir index 82b678412c..a5a3488c7e 100644 --- a/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir +++ b/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir @@ -137,6 +137,22 @@ func.func @test_remove_unstick_view_stick(%arg0: memref<7x4x1x8x32x64xf16>) -> ( // ----- +func.func @test_should_not_remove_unstick_view_stick_nchw(%arg0: memref<1x1x1x1x32x64xf16>) -> (memref<1x1x1x1x32x64xf16>){ + %0 = memref.alloc() {alignment = 16 : i64} : memref<1x32x1x22xf32> + "zlow.unstick"(%arg0, %0) {layout = "NCHW"} : (memref<1x1x1x1x32x64xf16>, memref<1x32x1x22xf32>) -> () + %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [1, 32, 22], strides: [704, 22, 1] : memref<1x32x1x22xf32> to memref<1x32x22xf32> + %2 = memref.alloc() {alignment = 4096 : i64} : memref<1x1x1x1x32x64xf16> + "zlow.stick"(%1, %2) {layout = "3DS"} : (memref<1x32x22xf32>, memref<1x1x1x1x32x64xf16>) -> () + "func.return"(%2) : (memref<1x1x1x1x32x64xf16>) -> () + + // CHECK-LABEL: test_should_not_remove_unstick_view_stick_nchw + // CHECK: "zlow.unstick" + // CHECK: memref.reinterpret_cast + // CHECK: "zlow.stick" +} + +// ----- + // Remove zlow.stick and zlow.unstick in pattern: unstick -> transpose -> stick. // Test a simple transpose. diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise.mlir index 5f7a8ea289..2660dd4246 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise.mlir @@ -115,6 +115,52 @@ func.func private @test_div(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32> // ----- +func.func private @test_signed_int_mod(%arg0 : tensor<10x10xi64>, %arg1 : tensor<10x10xi64>) -> tensor<*xi64> { + %0 = "onnx.Mod"(%arg0, %arg1) : (tensor<10x10xi64>, tensor<10x10xi64>) -> tensor<*xi64> + "func.return"(%0) : (tensor<*xi64>) -> () +// mlir2FileCheck.py +// CHECK-LABEL: func.func private @test_signed_int_mod +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<10x10xi64>, [[PARAM_1_:%.+]]: memref<10x10xi64>) -> memref<10x10xi64> { +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : index +// CHECK-DAG: [[CST_10_1_:%.+]] = arith.constant 10 : index +// CHECK-DAG: [[CST_10_2_:%.+]] = arith.constant 10 : index +// CHECK-DAG: [[CST_10_3_:%.+]] = arith.constant 10 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<10x10xi64> +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_10_4_:%.+]] = arith.constant 10 : index +// CHECK-DAG: [[CST_10_5_:%.+]] = arith.constant 10 : index +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 10, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 10){ +// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[CST_10_6_:%.+]] = arith.constant 10 : index +// CHECK-DAG: [[CST_10_7_:%.+]] = arith.constant 10 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<10x10xi64> +// CHECK-DAG: [[CST_10_8_:%.+]] = arith.constant 10 : index +// CHECK-DAG: [[CST_10_9_:%.+]] = arith.constant 10 : index +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<10x10xi64> +// CHECK: [[VAR_4_:%.+]] = arith.remsi [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : i64 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_4_]], [[LOAD_PARAM_1_MEM_]] : i64 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i64 +// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false +// CHECK: [[VAR_6_:%.+]] = arith.cmpi eq, [[VAR_4_]], [[CST_0_1_]] : i64 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpi eq, [[VAR_6_]], [[VAR_false_]] : i1 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpi slt, [[LOAD_PARAM_0_MEM_]], [[CST_0_1_]] : i64 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.cmpi slt, [[LOAD_PARAM_1_MEM_]], [[CST_0_1_]] : i64 +// CHECK: [[VAR_10_:%.+]] = arith.cmpi eq, [[VAR_8_]], [[VAR_9_]] : i1 +// CHECK: [[VAR_11_:%.+]] = arith.cmpi eq, [[VAR_10_]], [[VAR_false_]] : i1 +// CHECK: [[VAR_12_:%.+]] = arith.andi [[VAR_7_]], [[VAR_11_]] : i1 +// CHECK: [[VAR_13_:%.+]] = arith.select [[VAR_12_]], [[VAR_5_]], [[VAR_4_]] : i64 +// CHECK: krnl.store [[VAR_13_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<10x10xi64> +// CHECK: } +// CHECK: return [[RES_]] : memref<10x10xi64> +// CHECK: } +} + +// ----- + func.func private @test_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Sub"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index 1342fd9cb7..e96f8eb238 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -1663,6 +1663,21 @@ func.func @test_gather_negative_index() -> tensor<*xf32>{ // ----- +func.func @test_gather_rank0_int32_indices() -> tensor<*xf32>{ + %0 = onnx.Constant dense<[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]> : tensor<3x2xf32> + %1 = onnx.Constant dense<1> : tensor + %2 = "onnx.Gather"(%0, %1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor) -> tensor<*xf32> + "onnx.Return"(%2) : (tensor<*xf32>) -> () + + // CHECK-LABEL: func @test_gather_rank0_int32_indices + // CHECK-SAME: () -> tensor<2xf32> { + // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2.300000e+00, 3.400000e+00]> : tensor<2xf32> + // CHECK: onnx.Return [[VAR_0_]] : tensor<2xf32> + // CHECK: } +} + +// ----- + func.func @test_reshape() -> tensor<*xf32> { %0 = onnx.Constant dense<[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]]> : tensor<3x3xf32> %1 = onnx.Constant dense<[1, -1]> : tensor<2xi64> diff --git a/test/mlir/onnx/onnx_dim_analysis.mlir b/test/mlir/onnx/onnx_dim_analysis.mlir index e4d6aa4eda..858bd104f5 100644 --- a/test/mlir/onnx/onnx_dim_analysis.mlir +++ b/test/mlir/onnx/onnx_dim_analysis.mlir @@ -218,12 +218,11 @@ func.func @test_reshape_rank_2(%arg0: tensor) -> tensor { // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64> -// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<1xi64>) -> () // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> -// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 2 : si64} : (tensor) -> () // CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_0_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> // CHECK: [[VAR_3_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor, tensor<2xi64>) -> tensor -// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 1 : si64, group_id = 2 : si64} : (tensor) -> () // CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: onnx.Return [[VAR_3_]] : tensor // CHECK: } @@ -339,17 +338,166 @@ func.func @test_max_unpool(%arg0: tensor<1x1x2x2xf32>, %arg1: tensor<1x1x2x2xi64 // CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 3 : si64, group_id = 3 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 3 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_2_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_2_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_2_]]) {axis = 2 : si64} : (tensor) -> tensor<1xi64> // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_2_]]) {axis = 3 : si64} : (tensor) -> tensor<1xi64> // CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> // CHECK: [[VAR_5_:%.+]] = "onnx.MaxUnpool"([[PARAM_0_]], [[PARAM_1_]], [[VAR_4_]]) {kernel_shape = [2, 2], strides = [2, 2]} : (tensor<1x1x2x2xf32>, tensor<1x1x2x2xi64>, tensor<4xi64>) -> tensor -// CHECK: "onnx.DimGroup"([[VAR_5_]]) {axis = 3 : si64, group_id = 3 : si64} : (tensor) -> () -// CHECK: "onnx.DimGroup"([[VAR_5_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () -// CHECK: "onnx.DimGroup"([[VAR_5_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () -// CHECK: "onnx.DimGroup"([[VAR_5_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 3 : si64, group_id = 7 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: return [[VAR_5_]] : tensor // CHECK: } } + +// ----- + +func.func @test_correct_dimgroup_axis_for_onnx_dim(%arg0: tensor<1x?xi64>) -> tensor<1x1x1x?xf32> { + %0 = onnx.Constant dense<1> : tensor<1xi64> + %1 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<1x?xi64>) -> tensor<1xi64> + %2 = "onnx.Concat"(%0, %0, %0, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> + %3 = onnx.ConstantOfShape(%2) {value = dense<0.000000e+00> : tensor<1xf32>} : (tensor<4xi64>) -> tensor<1x1x1x?xf32> + return %3: tensor<1x1x1x?xf32> + +// CHECK-LABEL: func.func @test_correct_dimgroup_axis_for_onnx_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x?xi64>) -> tensor<1x1x1x?xf32> { +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<1x?xi64>) -> () +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<1x?xi64>) -> tensor<1xi64> +// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_0_]], [[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> +// CHECK: [[VAR_3_:%.+]] = onnx.ConstantOfShape([[VAR_2_]]) {value = dense<0.000000e+00> : tensor<1xf32>} : (tensor<4xi64>) -> tensor<1x1x1x?xf32> +// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 3 : si64, group_id = 0 : si64} : (tensor<1x1x1x?xf32>) -> () +// CHECK: return [[VAR_3_]] : tensor<1x1x1x?xf32> +// CHECK: } +} + +// ----- + +func.func @test_matmul_reduction_dimension(%arg0: tensor<5x?xf32>, %arg1: tensor) -> tensor<5x10xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<5x?xf32>, tensor) -> tensor<5x10xf32> + return %0 : tensor<5x10xf32> +// CHECK-LABEL: func.func @test_matmul_reduction_dimension +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x?xf32>, [[PARAM_1_:%.+]]: tensor) -> tensor<5x10xf32> { +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<5x?xf32>) -> () +// CHECK: [[VAR_0_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<5x?xf32>, tensor) -> tensor<5x10xf32> +// CHECK: return [[VAR_0_]] : tensor<5x10xf32> +// CHECK: } +} + +// ----- + +func.func @test_gemm_reduction_dimension(%arg0: tensor<5x?xf32>, %arg1: tensor, %arg2: tensor<10xf32>) -> tensor<5x10xf32> { + %0 = "onnx.Gemm"(%arg0, %arg1, %arg2) : (tensor<5x?xf32>, tensor, tensor<10xf32>) -> tensor<5x10xf32> + return %0 : tensor<5x10xf32> + +// CHECK-LABEL: func.func @test_gemm_reduction_dimension +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x?xf32>, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor<10xf32>) -> tensor<5x10xf32> { +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<5x?xf32>) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: [[VAR_0_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<5x?xf32>, tensor, tensor<10xf32>) -> tensor<5x10xf32> +// CHECK: return [[VAR_0_]] : tensor<5x10xf32> +// CHECK: } +} + +// ----- + +func.func @test_gemm_reduction_dimension_trans(%arg0: tensor, %arg1: tensor<10x?xf32>, %arg2: tensor<10xf32>) -> tensor<5x10xf32> { + %0 = "onnx.Gemm"(%arg0, %arg1, %arg2) {transA = 1 : si64, transB = 1 : si64} : (tensor, tensor<10x?xf32>, tensor<10xf32>) -> tensor<5x10xf32> + return %0 : tensor<5x10xf32> + +// CHECK-LABEL: func.func @test_gemm_reduction_dimension_trans +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<10x?xf32>, [[PARAM_2_:%.+]]: tensor<10xf32>) -> tensor<5x10xf32> { +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<10x?xf32>) -> () +// CHECK: [[VAR_0_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 1 : si64, transB = 1 : si64} : (tensor, tensor<10x?xf32>, tensor<10xf32>) -> tensor<5x10xf32> +// CHECK: return [[VAR_0_]] : tensor<5x10xf32> +// CHECK: } +} + +// ----- + +func.func @test_concat_input_dims(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor + return %0 : tensor + +// CHECK-LABEL: func.func @test_concat_input_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor) -> tensor { +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 4 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 1 : si64, group_id = 10 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK: [[VAR_0_:%.+]] = "onnx.Concat"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +func.func @test_lstm_input_dims(%X: tensor, %W: tensor<1x16x10xf32>, %R: tensor<1x16x4xf32>, %B: tensor<1x32xf32>, %seq_len: tensor, %initial_h: tensor, %initial_c: tensor) -> tensor<*xf32> { + %cst = "onnx.NoValue"() {value} : () -> none + %Y, %Y_h, %Y_c = "onnx.LSTM"(%X, %W, %R, %B, %seq_len, %initial_h, %initial_c, %cst) {hidden_size = 4 : si64} : (tensor, tensor<1x16x10xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, tensor, tensor, tensor, none) -> (none, tensor<*xf32>, none) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_lstm_input_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<1x16x10xf32>, [[PARAM_2_:%.+]]: tensor<1x16x4xf32>, [[PARAM_3_:%.+]]: tensor<1x32xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor, [[PARAM_6_:%.+]]: tensor) -> tensor<*xf32> { +// CHECK-DAG: "onnx.DimGroup"([[PARAM_6_]]) {axis = 0 : si64, group_id = 5 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_5_]]) {axis = 0 : si64, group_id = 3 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_5_]]) {axis = 1 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_6_]]) {axis = 1 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_4_]]) {axis = 0 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_:%.+]], [[Y_h_:%.+]], [[VAR_Y_c_:%.+]] = "onnx.LSTM"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]], [[PARAM_6_]], [[VAR_0_]]) {direction = "forward", hidden_size = 4 : si64, input_forget = 0 : si64, layout = 0 : si64} : (tensor, tensor<1x16x10xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, tensor, tensor, tensor, none) -> (none, tensor<*xf32>, none) +// CHECK: return [[Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_gru_input_dims(%X: tensor, %W: tensor<1x12x10xf32>, %R: tensor<1x12x4xf32>, %B: tensor<1x24xf32>, %seq_len: tensor, %initial_h: tensor) -> tensor<*xf32> { + %Y, %Y_h = "onnx.GRU"(%X, %W, %R, %B, %seq_len, %initial_h) {hidden_size = 4 : si64} : (tensor, tensor<1x12x10xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, tensor, tensor) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_gru_input_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<1x12x10xf32>, [[PARAM_2_:%.+]]: tensor<1x12x4xf32>, [[PARAM_3_:%.+]]: tensor<1x24xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor) -> tensor<*xf32> { +// CHECK-DAG: "onnx.DimGroup"([[PARAM_5_]]) {axis = 0 : si64, group_id = 3 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_4_]]) {axis = 0 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_5_]]) {axis = 1 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: [[Y_:%.+]], [[VAR_Y_h_:%.+]] = "onnx.GRU"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]]) {direction = "forward", hidden_size = 4 : si64, layout = 0 : si64, linear_before_reset = 0 : si64} : (tensor, tensor<1x12x10xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, tensor, tensor) -> (none, tensor<*xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_rnn_input_dims(%X: tensor, %W: tensor<1x4x10xf32>, %R: tensor<1x4x4xf32>, %B: tensor<1x8xf32>, %seq_len: tensor, %initial_h: tensor) -> tensor<*xf32> { + %Y, %Y_h = "onnx.RNN"(%X, %W, %R, %B, %seq_len, %initial_h) {hidden_size = 4 : si64} : (tensor, tensor<1x4x10xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor, tensor) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_rnn_input_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<1x4x10xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor) -> tensor<*xf32> { +// CHECK-DAG: "onnx.DimGroup"([[PARAM_5_]]) {axis = 0 : si64, group_id = 3 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_4_]]) {axis = 0 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_5_]]) {axis = 1 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: [[Y_:%.+]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor, tensor<1x4x10xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor, tensor) -> (none, tensor<*xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: } +} + diff --git a/test/mlir/onnx/parse/layer_normalization_function_decomposition.onnxtext b/test/mlir/onnx/parse/layer_normalization_function_decomposition.onnxtext new file mode 100644 index 0000000000..29a9831c08 --- /dev/null +++ b/test/mlir/onnx/parse/layer_normalization_function_decomposition.onnxtext @@ -0,0 +1,52 @@ +// RUN: onnx-mlir --functions-to-decompose=LayerNormalization --EmitONNXBasic --printIR %s | FileCheck %s + +// from onnx-mlir issue #2492 +< + ir_version: 8, + opset_import: ["" : 17] +> +agraph (float[12,3,5] X, float[5] S) => (float[12,3,5] LN) { + LN = LayerNormalization (X, S) +} +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<12x3x5xf32>, [[PARAM_1_:%.+]]: tensor<5xf32>) -> tensor<12x3x5xf32> attributes {input_names = ["X", "S"], output_names = ["LN"]} { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<9.99999974E-6> : tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[VAR_1_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<12x3x5xf32>) -> tensor<3xi64> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Size"([[VAR_3_]]) : (tensor<3xi64>) -> tensor +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> +// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Slice"([[VAR_3_]], [[VAR_5_]], [[VAR_6_]], [[VAR_7_]], [[VAR_8_]]) : (tensor<3xi64>, tensor<1xi64>, tensor<1xi64>, none, none) -> tensor<2xi64> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Neg"([[VAR_6_]]) : (tensor<1xi64>) -> tensor<1xi64> +// CHECK: [[VAR_11_:%.+]] = onnx.ConstantOfShape([[VAR_10_]]) {value = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor +// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.Concat"([[VAR_9_]], [[VAR_11_]]) {axis = 0 : si64} : (tensor<2xi64>, tensor) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Flatten"([[PARAM_0_]]) {axis = -1 : si64} : (tensor<12x3x5xf32>) -> tensor<36x5xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.Cast"([[VAR_13_]]) {saturate = 1 : si64, to = f32} : (tensor<36x5xf32>) -> tensor<36x5xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.ReduceMean"([[VAR_14_]], [[VAR_15_]]) {axes = [1], keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<36x5xf32>, none) -> tensor<36x1xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Mul"([[VAR_14_]], [[VAR_14_]]) : (tensor<36x5xf32>, tensor<36x5xf32>) -> tensor<36x5xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_19_:%.+]] = "onnx.ReduceMean"([[VAR_17_]], [[VAR_18_]]) {axes = [1], keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<36x5xf32>, none) -> tensor<36x1xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = "onnx.Mul"([[VAR_16_]], [[VAR_16_]]) : (tensor<36x1xf32>, tensor<36x1xf32>) -> tensor<36x1xf32> +// CHECK: [[VAR_21_:%.+]] = "onnx.Sub"([[VAR_19_]], [[VAR_20_]]) : (tensor<36x1xf32>, tensor<36x1xf32>) -> tensor<36x1xf32> +// CHECK: [[VAR_22_:%.+]] = "onnx.Add"([[VAR_21_]], [[VAR_2_]]) : (tensor<36x1xf32>, tensor) -> tensor<36x1xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = "onnx.Sqrt"([[VAR_22_]]) : (tensor<36x1xf32>) -> tensor<36x1xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = "onnx.Sub"([[VAR_14_]], [[VAR_16_]]) : (tensor<36x5xf32>, tensor<36x1xf32>) -> tensor<36x5xf32> +// CHECK: [[VAR_25_:%.+]] = "onnx.Div"([[VAR_24_]], [[VAR_23_]]) : (tensor<36x5xf32>, tensor<36x1xf32>) -> tensor<36x5xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = "onnx.Cast"([[VAR_25_]]) {saturate = 1 : si64, to = f32} : (tensor<36x5xf32>) -> tensor<36x5xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = "onnx.Flatten"([[PARAM_1_]]) {axis = 0 : si64} : (tensor<5xf32>) -> tensor<1x5xf32> +// CHECK: [[VAR_28_:%.+]] = "onnx.Mul"([[VAR_26_]], [[VAR_27_]]) : (tensor<36x5xf32>, tensor<1x5xf32>) -> tensor<36x5xf32> +// CHECK: [[VAR_29_:%.+]] = "onnx.Identity"([[VAR_28_]]) : (tensor<36x5xf32>) -> tensor<36x5xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = "onnx.Reshape"([[VAR_29_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<36x5xf32>, tensor<3xi64>) -> tensor<12x3x5xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = "onnx.Reciprocal"([[VAR_23_]]) : (tensor<36x1xf32>) -> tensor<36x1xf32> +// CHECK: onnx.Return [[VAR_30_]] : tensor<12x3x5xf32> +// CHECK: } diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 95f4bb2253..8f3220e131 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -87,7 +87,7 @@ 'Asinh': [9], 'Atan': [7], 'Atanh': [9], - 'AveragePool': [11], + 'AveragePool': [19], 'BatchNormalization': [15], 'Bernoulli': [15], 'Binarizer': [1], @@ -226,7 +226,7 @@ 'ReduceSumSquare': [18, 13], 'Relu': [14], 'Reshape': [19], - 'Resize': [18, 13, 11, 10], + 'Resize': [19, 13, 11, 10], 'ReverseSequence': [10], 'RoiAlign': [16], 'Round': [11],