diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index 010da8b3be..359cf7c862 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -45,12 +45,11 @@ using namespace onnx_mlir; namespace onnx_mlir { -void addONNXToZHighPasses( - mlir::PassManager &pm, ArrayRef execNodesOnCpu) { +void addONNXToZHighPasses(mlir::PassManager &pm) { for (unsigned i = 0; i < 3; i++) { // Repeat this process so that shape-related ops such as Shape, Expand, // Gather generated during RewriteONNXForZHigh will become constants. - pm.addPass(onnx_mlir::createRewriteONNXForZHighPass(execNodesOnCpu)); + pm.addPass(onnx_mlir::createRewriteONNXForZHighPass()); // Simplify shape-related ops, including ShapeOp-to-DimOp replacement, // constant propagation, shape inference and canonicalize. pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass()); @@ -75,7 +74,7 @@ void addONNXToZHighPasses( pm.addNestedPass( onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions)); - pm.addPass(onnx_mlir::createONNXToZHighPass(execNodesOnCpu)); + pm.addPass(onnx_mlir::createONNXToZHighPass()); pm.addNestedPass(onnx_mlir::createShapeInferencePass()); // There are more opportunities for const propagation once all zhigh ops were // generated. @@ -150,12 +149,14 @@ void addPassesNNPA(mlir::OwningOpRef &module, // InputIRLevelType inputIRLevel = determineInputIRLevel(module); // LLVM_DEBUG(llvm::dbgs() << "Adding NNPA passes" << std::endl;); - if (emissionTarget >= EmitONNXIR) + if (emissionTarget >= EmitONNXIR) { addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty()); + pm.addPass(onnx_mlir::createDevicePlacementPass()); + } if (emissionTarget >= EmitMLIR) { // Lower zAIU-compatible ONNX ops to ZHigh dialect where possible. - addONNXToZHighPasses(pm, execNodesOnCpu); + addONNXToZHighPasses(pm); if (nnpaEmissionTarget >= EmitZHighIR) emissionTarget = EmitMLIR; diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt index 1fc6c4aac8..4630ef92fc 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt @@ -53,3 +53,18 @@ add_onnx_mlir_library(OMZHighToONNX ACCEL_INCLUDE_DIRS PRIVATE ${NNPA_INCLUDE_PATH} ) + +add_onnx_mlir_library(OMDevicePlacement + DevicePlacement.cpp + + DEPENDS + libzdnn + + LINK_LIBS PUBLIC + OMONNXOps + OMONNXToZHigh + OMRewriteONNXForZHigh + + ACCEL_INCLUDE_DIRS PRIVATE + ${NNPA_INCLUDE_PATH} + ) diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp new file mode 100644 index 0000000000..0ff0bbe1a8 --- /dev/null +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp @@ -0,0 +1,129 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-------- DevicePlacement.cpp - Device Placement for NNPA -------------===// +// +// Copyright 2023 The IBM Research Authors. +// +// ============================================================================= +// +// This pass is to set device (CPU, or NNPA) for each operation in ONNX level. +// Device placement can be decided by: +// - user configuration file if given +// - a cost model +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/Debug.h" + +#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp" +#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp" +#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Pass/Passes.hpp" + +#define DEBUG_TYPE "device-placement" + +using namespace mlir; +using namespace onnx_mlir; + +namespace { + +struct DevicePlacementPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DevicePlacementPass) + + StringRef getArgument() const override { return "device-placement"; } + + StringRef getDescription() const override { + return "Device placement for NNPA"; + } + + void runOnOperation() final; +}; + +void DevicePlacementPass::runOnOperation() { + using OpSetType = DenseSet; + ModuleOp module = getOperation(); + MLIRContext *context = &getContext(); + + // Run the unknown dimension analysis to help check equality of unknown + // dimensions at compile time. + DimAnalysis dimAnalysis(module); + dimAnalysis.analyze(); + + // Cost model and user configuration file go here if it's given. + // (Reserved for cost model and user configuration file) + + // Run patterns that converts ONNX to ZHigh with analysis mode to collect + // operations that are not converted. Those non-converted ops are running on + // the host instead of accelerator. + // Keep the order of calling pass synced with RewriteONNXForZHigh.cpp and + // ONNXToZHigh.cpp. + + OpSetType legalizedOps1, legalizedOps2, legalizedOps3; + + ConversionTarget target(*context); + target.addLegalDialect(); + + // Call RewriteONNXForZHigh pass. + RewritePatternSet Patterns1(context); + getRewriteONNXForZHighPatterns(Patterns1, &dimAnalysis); + getRewriteONNXForZHighDynamicallyLegal(&target, &dimAnalysis); + (void)applyAnalysisConversion( + module, target, std::move(Patterns1), legalizedOps1); + + // Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh. + // E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv. + RewritePatternSet Patterns2(context); + getONNXToZHighOneOpPatterns(Patterns2); + (void)applyAnalysisConversion( + module, target, std::move(Patterns2), legalizedOps2); + + // Call ONNXToZHigh pass for lowering a single ONNX op to ZHigh. + RewritePatternSet Patterns3(context); + getONNXToZHighOneOpPatterns(Patterns3); + getONNXToZHighOneOpDynamicallyLegal(&target, &dimAnalysis); + (void)applyAnalysisConversion( + module, target, std::move(Patterns3), legalizedOps3); + + // Get the legalized ops that will run on the host. + OpSetType cpuOps = llvm::set_intersection( + legalizedOps1, llvm::set_intersection( + legalizedOps2, legalizedOps3)); + + // Now annotate accelerator operations in the IR with `device` attribute. + module.walk([&](Operation *op) -> WalkResult { + if (op->getDialect()->getNamespace() != ONNXDialect::getDialectNamespace()) + return WalkResult::advance(); + // No annotation for these ops. + if (isa(op)) + return WalkResult::advance(); + // If `device` is already set, respect it. + StringAttr device = op->getAttrOfType(DEVICE_ATTRIBUTE); + if (device && !device.getValue().empty()) + return WalkResult::advance(); + // Otherwise, set device. + if (!cpuOps.contains(op)) + op->setAttr(DEVICE_ATTRIBUTE, StringAttr::get(context, NNPA_DEVICE)); + return WalkResult::advance(); + }); +} + +} // namespace + +namespace onnx_mlir { + +/*! + * Create a DevicePlacement pass. + */ +std::unique_ptr createDevicePlacementPass() { + return std::make_unique(); +} + +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp index 1b8f8ed6fd..17cff8ffb1 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp @@ -13,6 +13,7 @@ // //===----------------------------------------------------------------------===// +#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp" #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" #include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" @@ -257,21 +258,49 @@ struct ONNXToZHighLoweringPass ONNXToZHighLoweringPass() = default; ONNXToZHighLoweringPass(const ONNXToZHighLoweringPass &pass) : PassWrapper>() {} - ONNXToZHighLoweringPass(mlir::ArrayRef execNodesOnCpu) { - this->execNodesOnCpu = execNodesOnCpu; - } void runOnOperation() final; - -public: - ListOption execNodesOnCpu{*this, "execNodesOnCpu", - llvm::cl::desc("Comma-separated list of node names in an onnx graph. The " - "specified nodes are forced to run on the CPU instead of " - "using the zDNN. The node name is an optional attribute " - "in onnx graph, which is `onnx_node_name` in ONNX IR"), - llvm::cl::ZeroOrMore}; }; } // end anonymous namespace. +void getONNXToZHighOneOpPatterns(RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + populateWithGenerated(patterns); + patterns.insert(context); +} + +void getONNXToZHighOneOpDynamicallyLegal( + ConversionTarget *target, const DimAnalysis *dimAnalysis) { + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); +} + +void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); +} + void ONNXToZHighLoweringPass::runOnOperation() { ModuleOp module = getOperation(); @@ -289,6 +318,10 @@ void ONNXToZHighLoweringPass::runOnOperation() { target.addLegalDialect(); + // NOTE: if we change the order of calling combinedPatterns and single op + // patterns, make sure to change the order in DevicePlacement.cpp also to make + // them synced. + // Combined ONNX ops to ZHigh lowering. // There are some combinations of ONNX ops that can be lowering into a single // ZHigh op, e.g. ONNXMatMul and ONNXAdd can be lowered to ZHighMatmul. @@ -296,18 +329,14 @@ void ONNXToZHighLoweringPass::runOnOperation() { // a single ONNX Op, because the single op lowering might have conditions that // prohibit the combined ops lowering happened. RewritePatternSet combinedPatterns(&getContext()); - combinedPatterns.insert(&getContext()); - combinedPatterns.insert(&getContext()); - combinedPatterns.insert(&getContext()); - combinedPatterns.insert(&getContext()); + onnx_mlir::getONNXToZHighMultipleOpPatterns(combinedPatterns); // It's ok to fail. (void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns)); // Single ONNX to ZHigh operation lowering. RewritePatternSet patterns(&getContext()); - populateWithGenerated(patterns); - patterns.insert(&getContext()); + onnx_mlir::getONNXToZHighOneOpPatterns(patterns); // This is to make sure we don't want to alloc any MemRef at this high-level // representation. @@ -317,32 +346,7 @@ void ONNXToZHighLoweringPass::runOnOperation() { // ONNX ops to ZHigh dialect under specific conditions. // When adding a new op, need to implement a method, i.e. isSuitableForZDNN, // for the op in ONNXLegalityCheck.cpp. - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); + getONNXToZHighOneOpDynamicallyLegal(&target, &dimAnalysis); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` @@ -355,9 +359,4 @@ std::unique_ptr createONNXToZHighPass() { return std::make_unique(); } -std::unique_ptr createONNXToZHighPass( - mlir::ArrayRef execNodesOnCpu) { - return std::make_unique(execNodesOnCpu); -} - } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp new file mode 100644 index 0000000000..334600a8aa --- /dev/null +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ONNXToZHigh.hpp - ONNX dialect to ZHigh lowering -------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements the lowering of ONNX operations to a combination of +// ONNX and ZHigh operations. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" + +namespace onnx_mlir { + +// Exports ONNXtoZHigh patterns. +void getONNXToZHighOneOpPatterns(mlir::RewritePatternSet &patterns); +void getONNXToZHighMultipleOpPatterns(mlir::RewritePatternSet &patterns); + +// Exports ONNXtoZHigh dynamically legal checks. +void getONNXToZHighOneOpDynamicallyLegal( + mlir::ConversionTarget *target, const DimAnalysis *dimAnalysis); + +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp index 23430e92e6..5f7e768c69 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp @@ -17,7 +17,7 @@ #include "src/Dialect/ONNX/DialectBuilder.hpp" using namespace mlir; -using namespace onnx_mlir; +namespace onnx_mlir { /// Get transposed tensor by using a permutation array. Value emitONNXTranspose( @@ -67,3 +67,5 @@ ValueRange splitAlongAxis( ValueRange splits = create.onnx.split(splitTy, X, splitSizes, axis); return splits; } + +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp index b66c0cc927..287efec6e8 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp @@ -14,49 +14,65 @@ #pragma once +#include "llvm/ADT/STLExtras.h" + #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/NNPALimit.h" #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp" #include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" #include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" +namespace onnx_mlir { + +const std::string DEVICE_ATTRIBUTE = "device"; +const std::string CPU_DEVICE = "cpu"; +const std::string NNPA_DEVICE = "nnpa"; + template void addDynamicallyLegalOpFor(mlir::ConversionTarget *target, const onnx_mlir::DimAnalysis *dimAnalysis, - mlir::ArrayRef execNodesOnCpu) { - target->addDynamicallyLegalOp([dimAnalysis, execNodesOnCpu]( + llvm::function_ref checkLegalityFn = + nullptr) { + target->addDynamicallyLegalOp([dimAnalysis, checkLegalityFn]( OP_TYPE op) { - // Check operations to be forced to run on CPU. mlir::Operation *genericOp = op.getOperation(); - mlir::StringAttr nodeName = - genericOp->getAttrOfType("onnx_node_name"); - if (nodeName) { - bool exists = - llvm::any_of(execNodesOnCpu, [nodeName](llvm::StringRef val) { - return nodeName.getValue().equals_insensitive(val); - }); - if (exists) - return true; - } + mlir::StringAttr device = + genericOp->getAttrOfType(DEVICE_ATTRIBUTE); + assert((!device || + (device && + (device.getValue().equals_insensitive("") || + device.getValue().equals_insensitive(CPU_DEVICE) || + device.getValue().equals_insensitive(NNPA_DEVICE)))) && + "Invalid device name"); - // Check zDNN limitations - // TODO: Check tensor size NNPA_MAXIMUM_TENSOR_SIZE of another limitation - bool exceedLimit = - llvm::any_of(genericOp->getOperands(), [](mlir::Value operand) { - if (auto valueType = operand.getType().dyn_cast()) { - // Check if static dimension size exceeds zDNN limitations - llvm::ArrayRef valueShape = valueType.getShape(); - if (llvm::any_of(valueShape, [](int64_t dim) { - return (!mlir::ShapedType::isDynamic(dim)) && - (dim > NNPA_MAXIMUM_DIMENSION_INDEX_SIZE); - })) - return true; - } - return false; - }); - if (exceedLimit) + // If device is CPU, force to run the op on CPU. + if (device && device.getValue().equals_insensitive(CPU_DEVICE)) return true; - return !isSuitableForZDNN(op, dimAnalysis); + // If not CPU, check if the op is legal for NNPA. + bool isLegalForNNPA = false; + if (checkLegalityFn) + isLegalForNNPA = !checkLegalityFn(op, dimAnalysis); + else { + // Check zDNN limitations for each input tensors. + // TODO: Check tensor size NNPA_MAXIMUM_TENSOR_SIZE of another limitation + bool exceedLimit = + llvm::any_of(genericOp->getOperands(), [](mlir::Value operand) { + if (auto valueType = + operand.getType().dyn_cast()) { + // Check if static dimension size exceeds zDNN limitations + llvm::ArrayRef valueShape = valueType.getShape(); + if (llvm::any_of(valueShape, [](int64_t dim) { + return (!mlir::ShapedType::isDynamic(dim)) && + (dim > NNPA_MAXIMUM_DIMENSION_INDEX_SIZE); + })) + return true; + } + return false; + }); + isLegalForNNPA = + !exceedLimit && isSuitableForZDNN(op, dimAnalysis); + } + return !isLegalForNNPA; }); } @@ -75,3 +91,5 @@ mlir::Value emitONNXTransposeWithType(mlir::Location loc, mlir::ValueRange splitAlongAxis( onnx_mlir::MultiDialectBuilder &create, mlir::Value X, int64_t axis); + +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp index 0e6b3fdb07..3e875e1f56 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp @@ -23,6 +23,7 @@ // //===----------------------------------------------------------------------===// +#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.hpp" #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/NNPALimit.h" #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" @@ -319,6 +320,13 @@ class SplitLargeMatMulPattern : public OpRewritePattern { LogicalResult matchAndRewrite( ONNXMatMulOp matmulOp, PatternRewriter &rewriter) const override { + // Expect N or M exceeds NNPA limitation. + bool nExceeded = false; + bool mExceeded = false; + if (!canBeRewritten(matmulOp, nExceeded, mExceeded)) + return failure(); + + // Rewrite Location loc = matmulOp.getLoc(); Operation *op = matmulOp.getOperation(); Value A = matmulOp.getA(); // NxK @@ -330,25 +338,10 @@ class SplitLargeMatMulPattern : public OpRewritePattern { int64_t aRank = getRank(aType); int64_t bRank = getRank(bType); int64_t outputRank = getRank(outputType); - ArrayRef aShape = getShape(aType); - ArrayRef bShape = getShape(bType); ArrayRef outputShape = getShape(outputType); Type elementType = getElementType(bType); auto unrankedType = UnrankedTensorType::get(elementType); - // Expect 2D or 3D input. - if (!((aRank == 2 || aRank == 3) && (bRank == 2 || bRank == 3))) - return failure(); - - // Expect N or M exceeds NNPA limitation. - int64_t N = aShape[aRank - 2]; - int64_t M = bShape[bRank - 1]; - bool nExceeded = N > NNPA_MAXIMUM_DIMENSION_INDEX_SIZE; - bool mExceeded = M > NNPA_MAXIMUM_DIMENSION_INDEX_SIZE; - if (!(nExceeded || mExceeded)) - return failure(); - - // Rewrite MultiDialectBuilder create(rewriter, loc); ValueRange subAs(A), subBs(B); if (nExceeded) { @@ -388,6 +381,33 @@ class SplitLargeMatMulPattern : public OpRewritePattern { rewriter.replaceOp(op, res); return success(); } + + static bool canBeRewritten( + ONNXMatMulOp matmulOp, bool &nExceeded, bool &mExceeded) { + Value A = matmulOp.getA(); // NxK + Value B = matmulOp.getB(); // KxM + + Type aType = A.getType(); + Type bType = B.getType(); + int64_t aRank = getRank(aType); + int64_t bRank = getRank(bType); + ArrayRef aShape = getShape(aType); + ArrayRef bShape = getShape(bType); + + // Expect 2D or 3D input. + if (!((aRank == 2 || aRank == 3) && (bRank == 2 || bRank == 3))) + return false; + + // Expect N or M exceeds NNPA limitation. + int64_t N = aShape[aRank - 2]; + int64_t M = bShape[bRank - 1]; + nExceeded = N > NNPA_MAXIMUM_DIMENSION_INDEX_SIZE; + mExceeded = M > NNPA_MAXIMUM_DIMENSION_INDEX_SIZE; + if (!(nExceeded || mExceeded)) + return false; + + return true; + } }; /// This pattern is to replace `C = add/sub(A, B)` by `A` when B is a zero @@ -413,7 +433,7 @@ class AddSubWithRHSZeroExpandPattern : public OpRewritePattern { return success(); } - static bool canBeRewritten(OP_TYPE binaryOp, DimAnalysis *dimAnalysis) { + static bool canBeRewritten(OP_TYPE binaryOp, const DimAnalysis *dimAnalysis) { Value A = binaryOp.getA(); Value B = binaryOp.getB(); Value C = binaryOp.getC(); @@ -453,45 +473,23 @@ class AddSubWithRHSZeroExpandPattern : public OpRewritePattern { /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXRewriteONNXForZHigh.inc" -struct RewriteONNXForZHighPass - : public PassWrapper> { - - StringRef getArgument() const override { return "rewrite-onnx-for-zhigh"; } - - StringRef getDescription() const override { - return "Rewrite ONNX ops for ZHigh."; - } - - RewriteONNXForZHighPass() = default; - RewriteONNXForZHighPass(mlir::ArrayRef execNodesOnCpu) - : execNodesOnCpu(execNodesOnCpu) {} - void runOnOperation() final; - -public: - mlir::ArrayRef execNodesOnCpu = mlir::ArrayRef(); -}; - -void RewriteONNXForZHighPass::runOnOperation() { - ModuleOp module = getOperation(); - - // Run the unknown dimension analysis to help check equality of unknown - // dimensions at compile time. - onnx_mlir::DimAnalysis dimAnalysis(module); - dimAnalysis.analyze(); - - // The first thing to define is the conversion target. This will define the - // final target for this lowering. - ConversionTarget target(getContext()); - - // We define the specific operations, or dialects, that are legal targets for - // this lowering. - target.addLegalDialect(); +void getRewriteONNXForZHighPatterns( + RewritePatternSet &patterns, DimAnalysis *dimAnalysis) { + populateWithGenerated(patterns); + patterns.insert(patterns.getContext()); + patterns.insert>( + patterns.getContext(), dimAnalysis); + patterns.insert>( + patterns.getContext(), dimAnalysis); +} +void getRewriteONNXForZHighDynamicallyLegal( + mlir::ConversionTarget *target, const DimAnalysis *dimAnalysis) { // `ONNXBatchNormalizationInferenceModeOp` to `ZHigh.BatchNorm`, // generating `ONNX.Add`, `ONNX.Sub`, `ONNX.Mul`, `ONNX.Div`, // and `ONNX.Sqrt` to calculate inputs(`a` and `b`) addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); + target, dimAnalysis); // Illegalize BinaryOp if one of the two inputs is a constant and // unidirectional broadcastable to the other input. Rewrite patterns will be @@ -499,34 +497,38 @@ void RewriteONNXForZHighPass::runOnOperation() { // // This is preferred for NNPA because NNPA BinaryOp does not support // broadcasting. - target.addDynamicallyLegalOp([&dimAnalysis](ONNXAddOp op) { - return !((isDefinedByONNXConstantOp(op.getA()) && - isUniBroadcatableFirstToSecond(op.getA(), op.getB())) || - (isDefinedByONNXConstantOp(op.getB()) && - isUniBroadcatableFirstToSecond(op.getB(), op.getA())) || - AddSubWithRHSZeroExpandPattern::canBeRewritten( - op, &dimAnalysis)); - }); - target.addDynamicallyLegalOp([](ONNXDivOp op) { - return !((isDefinedByONNXConstantOp(op.getA()) && - isUniBroadcatableFirstToSecond(op.getA(), op.getB())) || - (isDefinedByONNXConstantOp(op.getB()) && - isUniBroadcatableFirstToSecond(op.getB(), op.getA()))); - }); - target.addDynamicallyLegalOp([](ONNXMulOp op) { - return !((isDefinedByONNXConstantOp(op.getA()) && - isUniBroadcatableFirstToSecond(op.getA(), op.getB())) || - (isDefinedByONNXConstantOp(op.getB()) && - isUniBroadcatableFirstToSecond(op.getB(), op.getA()))); - }); - target.addDynamicallyLegalOp([&dimAnalysis](ONNXSubOp op) { - return !((isDefinedByONNXConstantOp(op.getA()) && - isUniBroadcatableFirstToSecond(op.getA(), op.getB())) || - (isDefinedByONNXConstantOp(op.getB()) && - isUniBroadcatableFirstToSecond(op.getB(), op.getA())) || - AddSubWithRHSZeroExpandPattern::canBeRewritten( - op, &dimAnalysis)); - }); + addDynamicallyLegalOpFor( + target, dimAnalysis, [](ONNXAddOp op, const DimAnalysis *dimAnalysis) { + return !((isDefinedByONNXConstantOp(op.getA()) && + isUniBroadcatableFirstToSecond(op.getA(), op.getB())) || + (isDefinedByONNXConstantOp(op.getB()) && + isUniBroadcatableFirstToSecond(op.getB(), op.getA())) || + AddSubWithRHSZeroExpandPattern::canBeRewritten( + op, dimAnalysis)); + }); + addDynamicallyLegalOpFor( + target, dimAnalysis, [](ONNXDivOp op, const DimAnalysis *dimAnalysis) { + return !((isDefinedByONNXConstantOp(op.getA()) && + isUniBroadcatableFirstToSecond(op.getA(), op.getB())) || + (isDefinedByONNXConstantOp(op.getB()) && + isUniBroadcatableFirstToSecond(op.getB(), op.getA()))); + }); + addDynamicallyLegalOpFor( + target, dimAnalysis, [](ONNXMulOp op, const DimAnalysis *dimAnalysis) { + return !((isDefinedByONNXConstantOp(op.getA()) && + isUniBroadcatableFirstToSecond(op.getA(), op.getB())) || + (isDefinedByONNXConstantOp(op.getB()) && + isUniBroadcatableFirstToSecond(op.getB(), op.getA()))); + }); + addDynamicallyLegalOpFor( + target, dimAnalysis, [](ONNXSubOp op, const DimAnalysis *dimAnalysis) { + return !((isDefinedByONNXConstantOp(op.getA()) && + isUniBroadcatableFirstToSecond(op.getA(), op.getB())) || + (isDefinedByONNXConstantOp(op.getB()) && + isUniBroadcatableFirstToSecond(op.getB(), op.getA())) || + AddSubWithRHSZeroExpandPattern::canBeRewritten( + op, dimAnalysis)); + }); // Determine if MatMulOp is already legal (no need to rewrite) or need to // rewrite. The following cases must be rewritten: @@ -536,72 +538,102 @@ void RewriteONNXForZHighPass::runOnOperation() { // // For such cases, rewrite patterns will be added to turn MatMulOp into the // one where N-D will become 3-D or to split MatMul into smaller MatMuls. - target.addDynamicallyLegalOp([&dimAnalysis](ONNXMatMulOp op) { - Type aType = op.getA().getType(); - Type bType = op.getB().getType(); - if (!isRankedShapedType(aType) || !isRankedShapedType(bType)) - return true; - - int64_t aRank = getRank(aType); - int64_t bRank = getRank(bType); - ArrayRef aShape = getShape(aType); - ArrayRef bShape = getShape(bType); - - // - one input is N-D (N > 3) and the other is 2-D. - if (aRank == 2 && bRank > 3) - return false; - if (bRank == 2 && aRank > 3) - return false; - // No input is N-D (N > 3) but dimension N or M (NxK * KxM) is dynamic or - // exceeds NNPA limitation. - if ((aRank == 2 || aRank == 3) && (bRank == 2 || bRank == 3) && - ((aShape[aRank - 2] > NNPA_MAXIMUM_DIMENSION_INDEX_SIZE) || - (bShape[bRank - 1] > NNPA_MAXIMUM_DIMENSION_INDEX_SIZE))) - return false; - - // - both inputs are *the same* N-D, N > 3 and there is no broadcasting - if (aRank > 3 && (aRank == bRank)) { - bool sameBatchDims = true; - for (int64_t i = 0; i < aRank - 2; ++i) { - sameBatchDims &= (aShape[i] == bShape[i]); - if (sameBatchDims && ShapedType::isDynamic(aShape[i])) - sameBatchDims = dimAnalysis.sameDynDim(op.getA(), i, op.getB(), i); - } - return !sameBatchDims; - } + addDynamicallyLegalOpFor( + target, dimAnalysis, [](ONNXMatMulOp op, const DimAnalysis *dimAnalysis) { + Type aType = op.getA().getType(); + Type bType = op.getB().getType(); + if (!isRankedShapedType(aType) || !isRankedShapedType(bType)) + return true; + + int64_t aRank = getRank(aType); + int64_t bRank = getRank(bType); + ArrayRef aShape = getShape(aType); + ArrayRef bShape = getShape(bType); + + // No input is N-D (N > 3) but dimension N or M (NxK * KxM) is dynamic + // or exceeds NNPA limitation. + bool nExceeded, mExceeded; + if (SplitLargeMatMulPattern::canBeRewritten(op, nExceeded, mExceeded)) + return false; + + // - one input is N-D (N > 3) and the other is 2-D. + if (aRank == 2 && bRank > 3) + return false; + if (bRank == 2 && aRank > 3) + return false; + + // - both inputs are *the same* N-D, N > 3 and there is no broadcasting + if (aRank > 3 && (aRank == bRank)) { + bool sameBatchDims = true; + for (int64_t i = 0; i < aRank - 2; ++i) { + sameBatchDims &= (aShape[i] == bShape[i]); + if (sameBatchDims && ShapedType::isDynamic(aShape[i])) + sameBatchDims = + dimAnalysis->sameDynDim(op.getA(), i, op.getB(), i); + } + return !sameBatchDims; + } - // Make other cases legal. - return true; - }); + // Make other cases legal. + return true; + }); // Illegalize SoftmaxOp if // - axis is the last dimension. // This SoftmaxOp will be rewritten in which its input is reshaped to 3D. - target.addDynamicallyLegalOp([](ONNXSoftmaxOp op) { - Value input = op.getInput(); - if (auto shapedType = input.getType().dyn_cast()) { - if ((shapedType.getRank() > 3) && - ((op.getAxis() == shapedType.getRank() - 1) || - (op.getAxis() == -1))) { - return false; - } - } - return true; - }); + addDynamicallyLegalOpFor(target, dimAnalysis, + [](ONNXSoftmaxOp op, const DimAnalysis *dimAnalysis) { + Value input = op.getInput(); + if (auto shapedType = input.getType().dyn_cast()) { + if ((shapedType.getRank() > 3) && + ((op.getAxis() == shapedType.getRank() - 1) || + (op.getAxis() == -1))) { + return false; + } + } + return true; + }); + + addDynamicallyLegalOpFor( + target, dimAnalysis, [](ONNXConvOp op, const DimAnalysis *dimAnalysis) { + return isSuitableForZDNN(op) || + !canInferencePadsForNNPAConv(op); + }); +} - target.addDynamicallyLegalOp([](ONNXConvOp op) { - return isSuitableForZDNN(op) || - !canInferencePadsForNNPAConv(op); - }); +struct RewriteONNXForZHighPass + : public PassWrapper> { + + StringRef getArgument() const override { return "rewrite-onnx-for-zhigh"; } + + StringRef getDescription() const override { + return "Rewrite ONNX ops for ZHigh."; + } + + RewriteONNXForZHighPass() = default; + void runOnOperation() final; +}; + +void RewriteONNXForZHighPass::runOnOperation() { + ModuleOp module = getOperation(); + + // Run the unknown dimension analysis to help check equality of unknown + // dimensions at compile time. + DimAnalysis dimAnalysis(module); + dimAnalysis.analyze(); + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. + target.addLegalDialect(); + onnx_mlir::getRewriteONNXForZHighDynamicallyLegal(&target, &dimAnalysis); // Single ONNX to ZHigh operation lowering. RewritePatternSet patterns(&getContext()); - populateWithGenerated(patterns); - patterns.insert(&getContext()); - patterns.insert>( - &getContext(), &dimAnalysis); - patterns.insert>( - &getContext(), &dimAnalysis); + onnx_mlir::getRewriteONNXForZHighPatterns(patterns, &dimAnalysis); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` @@ -614,9 +646,4 @@ std::unique_ptr createRewriteONNXForZHighPass() { return std::make_unique(); } -std::unique_ptr createRewriteONNXForZHighPass( - mlir::ArrayRef execNodesOnCpu) { - return std::make_unique(execNodesOnCpu); -} - } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.hpp new file mode 100644 index 0000000000..4ab40101ac --- /dev/null +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.hpp @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===--- RewriteONNXForZHigh.hpp - Rewrite ONNX ops for ZHigh lowering ----===// +// +// Copyright 2019-2023 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements pass for rewriting of ONNX operations to generate +// combination of ONNX and ZHigh operations. + +#pragma once + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" + +namespace onnx_mlir { + +// Exports RewriteONNXForZHigh patterns. +void getRewriteONNXForZHighPatterns( + mlir::RewritePatternSet &patterns, DimAnalysis *dimAnalysis); + +// Exports RewriteONNXForZHigh dynamically legal checks. +void getRewriteONNXForZHighDynamicallyLegal( + mlir::ConversionTarget *target, const DimAnalysis *dimAnalysis); + +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/NNPAAccelerator.cpp b/src/Accelerators/NNPA/NNPAAccelerator.cpp index 132206e531..cc1ddae500 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.cpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.cpp @@ -72,6 +72,10 @@ void NNPAAccelerator::registerDialects(mlir::DialectRegistry ®istry) const { void NNPAAccelerator::registerPasses(int optLevel) const { LLVM_DEBUG(llvm::dbgs() << "Registering passes for NNPA accelerator\n"); + mlir::registerPass([]() -> std::unique_ptr { + return onnx_mlir::createDevicePlacementPass(); + }); + mlir::registerPass([]() -> std::unique_ptr { return onnx_mlir::createONNXToZHighPass(); }); diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index f994b7c7bc..e166556145 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -19,15 +19,16 @@ namespace onnx_mlir { +// Add pass for device placement. +std::unique_ptr createDevicePlacementPass(); + /// Add pass for lowering ONNX ops to ZHigh ops. std::unique_ptr createONNXToZHighPass(); -std::unique_ptr createONNXToZHighPass( - mlir::ArrayRef execNodesOnCpu); +std::unique_ptr createONNXToZHighPass(); /// Add pass for rewriting ONNX ops for ZHigh. std::unique_ptr createRewriteONNXForZHighPass(); -std::unique_ptr createRewriteONNXForZHighPass( - mlir::ArrayRef execNodesOnCpu); +std::unique_ptr createRewriteONNXForZHighPass(); /// Add pass for re-construct ONNX ops from ZHigh ops. std::unique_ptr createZHighToONNXPass(); diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass.mlir new file mode 100644 index 0000000000..76b14d6322 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass.mlir @@ -0,0 +1,51 @@ +// RUN: onnx-mlir-opt --device-placement --maccel=NNPA --split-input-file %s | FileCheck %s + +module attributes {llvm.data_layout = "E-m:e-i1:8:16-i8:8:16-i64:64-f128:64-v128:64-a:8:16-n32:64", llvm.target_triple = "s390x-ibm-linux", "onnx-mlir.symbol-postfix" = "model"} { + func.func @mnist(%arg0: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> attributes {input_names = ["Input3"], output_names = ["Plus214_Output_0"]} { + %0 = onnx.Constant dense<[-0.0822488219, -0.108868778, -0.141039595, -0.204869166, -0.17913565, -0.215438381, -0.133805066, -0.195724562, -0.268250644, -0.258212209, -0.0761560649, 0.0132841459, -0.00444464432, -0.414740831, -0.17879115, -0.0386558883]> : tensor<16xf32> + %1 = onnx.Constant dense<[-0.161539719, -0.433835655, 0.091641359, -0.0168522168, -0.0650264397, -0.131737873, 0.0204175506, -0.121110231]> : tensor<8xf32> + %2 = onnx.Constant dense_resource<__elided__> : tensor<16x4x4x10xf32> + %3 = onnx.Constant dense_resource<__elided__> : tensor<16x8x5x5xf32> + %4 = onnx.Constant dense_resource<__elided__> : tensor<8x1x5x5xf32> + %5 = onnx.Constant dense<[1, 256]> : tensor<2xi64> + %6 = onnx.Constant dense<[256, 10]> : tensor<2xi64> + %7 = onnx.Constant dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32> + %8 = "onnx.Reshape"(%2, %6) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape1"} : (tensor<16x4x4x10xf32>, tensor<2xi64>) -> tensor<256x10xf32> + %9 = "onnx.Conv"(%arg0, %4, %1) {auto_pad = "SAME_UPPER", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x5x5xf32>, tensor<8xf32>) -> tensor<1x8x28x28xf32> + %10 = "onnx.Relu"(%9) {onnx_node_name = "ReLU32"} : (tensor<1x8x28x28xf32>) -> tensor<1x8x28x28xf32> + %11 = "onnx.MaxPoolSingleOut"(%10) {auto_pad = "NOTSET", ceil_mode = 0 : si64, kernel_shape = [2, 2], onnx_node_name = "Pooling66", pads = [0, 0, 0, 0], storage_order = 0 : si64, strides = [2, 2]} : (tensor<1x8x28x28xf32>) -> tensor<1x8x14x14xf32> + %12 = "onnx.Conv"(%11, %3, %0) {auto_pad = "SAME_UPPER", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -> tensor<1x16x14x14xf32> + %13 = "onnx.Relu"(%12) {onnx_node_name = "ReLU114"} : (tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> + %14 = "onnx.MaxPoolSingleOut"(%13) {auto_pad = "NOTSET", ceil_mode = 0 : si64, kernel_shape = [3, 3], onnx_node_name = "Pooling160", pads = [0, 0, 0, 0], storage_order = 0 : si64, strides = [3, 3]} : (tensor<1x16x14x14xf32>) -> tensor<1x16x4x4xf32> + %15 = "onnx.Reshape"(%14, %5) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape0"} : (tensor<1x16x4x4xf32>, tensor<2xi64>) -> tensor<1x256xf32> + %16 = "onnx.Gemm"(%15, %8, %7) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + return %16 : tensor<1x10xf32> + } + "onnx.EntryPoint"() {func = @mnist} : () -> () + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @mnist +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> attributes {input_names = ["Input3"], output_names = ["Plus214_Output_0"]} { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[-0.0822488219, -0.108868778, -0.141039595, -0.204869166, -0.17913565, -0.215438381, -0.133805066, -0.195724562, -0.268250644, -0.258212209, -0.0761560649, 0.0132841459, -0.00444464432, -0.414740831, -0.17879115, -0.0386558883]> : tensor<16xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[-0.161539719, -0.433835655, 0.091641359, -0.0168522168, -0.0650264397, -0.131737873, 0.0204175506, -0.121110231]> : tensor<8xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense_resource<__elided__> : tensor<16x4x4x10xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense_resource<__elided__> : tensor<16x8x5x5xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense_resource<__elided__> : tensor<8x1x5x5xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<[1, 256]> : tensor<2xi64> +// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<[256, 10]> : tensor<2xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = onnx.Constant dense<{{.}}[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]{{.}}> : tensor<1x10xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Reshape"([[VAR_2_]], [[VAR_6_]]) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape1"} : (tensor<16x4x4x10xf32>, tensor<2xi64>) -> tensor<256x10xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_4_]], [[VAR_1_]]) {auto_pad = "SAME_UPPER", device = "nnpa", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x5x5xf32>, tensor<8xf32>) -> tensor<1x8x28x28xf32> +// CHECK: [[VAR_10_:%.+]] = "onnx.Relu"([[VAR_9_]]) {device = "nnpa", onnx_node_name = "ReLU32"} : (tensor<1x8x28x28xf32>) -> tensor<1x8x28x28xf32> +// CHECK: [[VAR_11_:%.+]] = "onnx.MaxPoolSingleOut"([[VAR_10_]]) {auto_pad = "NOTSET", ceil_mode = 0 : si64, device = "nnpa", kernel_shape = [2, 2], onnx_node_name = "Pooling66", pads = [0, 0, 0, 0], storage_order = 0 : si64, strides = [2, 2]} : (tensor<1x8x28x28xf32>) -> tensor<1x8x14x14xf32> +// CHECK: [[VAR_12_:%.+]] = "onnx.Conv"([[VAR_11_]], [[VAR_3_]], [[VAR_0_]]) {auto_pad = "SAME_UPPER", device = "nnpa", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -> tensor<1x16x14x14xf32> +// CHECK: [[VAR_13_:%.+]] = "onnx.Relu"([[VAR_12_]]) {device = "nnpa", onnx_node_name = "ReLU114"} : (tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> +// CHECK: [[VAR_14_:%.+]] = "onnx.MaxPoolSingleOut"([[VAR_13_]]) {auto_pad = "NOTSET", ceil_mode = 0 : si64, device = "nnpa", kernel_shape = [3, 3], onnx_node_name = "Pooling160", pads = [0, 0, 0, 0], storage_order = 0 : si64, strides = [3, 3]} : (tensor<1x16x14x14xf32>) -> tensor<1x16x4x4xf32> +// CHECK: [[VAR_15_:%.+]] = "onnx.Reshape"([[VAR_14_]], [[VAR_5_]]) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape0"} : (tensor<1x16x4x4xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: [[VAR_16_:%.+]] = "onnx.Gemm"([[VAR_15_]], [[VAR_8_]], [[VAR_7_]]) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, device = "nnpa", transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> +// CHECK: return [[VAR_16_]] : tensor<1x10xf32> +// CHECK: } +// CHECK: "onnx.EntryPoint"() {func = @mnist} : () -> () +} + diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/emit-onnxir.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-onnxir.mlir new file mode 100644 index 0000000000..a2bc2c8f5a --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-onnxir.mlir @@ -0,0 +1,49 @@ +// RUN: onnx-mlir --EmitONNXIR --maccel=NNPA --printIR %s | FileCheck %s + +module attributes {llvm.data_layout = "E-m:e-i1:8:16-i8:8:16-i64:64-f128:64-v128:64-a:8:16-n32:64", llvm.target_triple = "s390x-ibm-linux", "onnx-mlir.symbol-postfix" = "model"} { + func.func @mnist(%arg0: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> attributes {input_names = ["Input3"], output_names = ["Plus214_Output_0"]} { + %0 = onnx.Constant dense<[-0.0822488219, -0.108868778, -0.141039595, -0.204869166, -0.17913565, -0.215438381, -0.133805066, -0.195724562, -0.268250644, -0.258212209, -0.0761560649, 0.0132841459, -0.00444464432, -0.414740831, -0.17879115, -0.0386558883]> : tensor<16xf32> + %1 = onnx.Constant dense<[-0.161539719, -0.433835655, 0.091641359, -0.0168522168, -0.0650264397, -0.131737873, 0.0204175506, -0.121110231]> : tensor<8xf32> + %2 = onnx.Constant dense_resource<__elided__> : tensor<16x4x4x10xf32> + %3 = onnx.Constant dense_resource<__elided__> : tensor<16x8x5x5xf32> + %4 = onnx.Constant dense_resource<__elided__> : tensor<8x1x5x5xf32> + %5 = onnx.Constant dense<[1, 256]> : tensor<2xi64> + %6 = onnx.Constant dense<[256, 10]> : tensor<2xi64> + %7 = onnx.Constant dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32> + %8 = "onnx.Reshape"(%2, %6) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape1"} : (tensor<16x4x4x10xf32>, tensor<2xi64>) -> tensor<256x10xf32> + %9 = "onnx.Conv"(%arg0, %4, %1) {auto_pad = "SAME_UPPER", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x5x5xf32>, tensor<8xf32>) -> tensor<1x8x28x28xf32> + %10 = "onnx.Relu"(%9) {onnx_node_name = "ReLU32"} : (tensor<1x8x28x28xf32>) -> tensor<1x8x28x28xf32> + %11 = "onnx.MaxPoolSingleOut"(%10) {auto_pad = "NOTSET", ceil_mode = 0 : si64, kernel_shape = [2, 2], onnx_node_name = "Pooling66", pads = [0, 0, 0, 0], storage_order = 0 : si64, strides = [2, 2]} : (tensor<1x8x28x28xf32>) -> tensor<1x8x14x14xf32> + %12 = "onnx.Conv"(%11, %3, %0) {auto_pad = "SAME_UPPER", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -> tensor<1x16x14x14xf32> + %13 = "onnx.Relu"(%12) {onnx_node_name = "ReLU114"} : (tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> + %14 = "onnx.MaxPoolSingleOut"(%13) {auto_pad = "NOTSET", ceil_mode = 0 : si64, kernel_shape = [3, 3], onnx_node_name = "Pooling160", pads = [0, 0, 0, 0], storage_order = 0 : si64, strides = [3, 3]} : (tensor<1x16x14x14xf32>) -> tensor<1x16x4x4xf32> + %15 = "onnx.Reshape"(%14, %5) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape0"} : (tensor<1x16x4x4xf32>, tensor<2xi64>) -> tensor<1x256xf32> + %16 = "onnx.Gemm"(%15, %8, %7) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + return %16 : tensor<1x10xf32> + } + "onnx.EntryPoint"() {func = @mnist} : () -> () + +// CHECK-LABEL: func.func @mnist +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> attributes {input_names = ["Input3"], output_names = ["Plus214_Output_0"]} { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[-0.0822488219, -0.108868778, -0.141039595, -0.204869166, -0.17913565, -0.215438381, -0.133805066, -0.195724562, -0.268250644, -0.258212209, -0.0761560649, 0.0132841459, -0.00444464432, -0.414740831, -0.17879115, -0.0386558883]> : tensor<16xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[-0.161539719, -0.433835655, 0.091641359, -0.0168522168, -0.0650264397, -0.131737873, 0.0204175506, -0.121110231]> : tensor<8xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense_resource<__elided__> : tensor<16x4x4x10xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense_resource<__elided__> : tensor<16x8x5x5xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense_resource<__elided__> : tensor<8x1x5x5xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<[1, 256]> : tensor<2xi64> +// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<[256, 10]> : tensor<2xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = onnx.Constant dense<{{.}}[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]{{.}}> : tensor<1x10xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Reshape"([[VAR_2_]], [[VAR_6_]]) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape1"} : (tensor<16x4x4x10xf32>, tensor<2xi64>) -> tensor<256x10xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_4_]], [[VAR_1_]]) {auto_pad = "SAME_UPPER", device = "nnpa", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x5x5xf32>, tensor<8xf32>) -> tensor<1x8x28x28xf32> +// CHECK: [[VAR_10_:%.+]] = "onnx.Relu"([[VAR_9_]]) {device = "nnpa", onnx_node_name = "ReLU32"} : (tensor<1x8x28x28xf32>) -> tensor<1x8x28x28xf32> +// CHECK: [[VAR_11_:%.+]] = "onnx.MaxPoolSingleOut"([[VAR_10_]]) {auto_pad = "NOTSET", ceil_mode = 0 : si64, device = "nnpa", kernel_shape = [2, 2], onnx_node_name = "Pooling66", pads = [0, 0, 0, 0], storage_order = 0 : si64, strides = [2, 2]} : (tensor<1x8x28x28xf32>) -> tensor<1x8x14x14xf32> +// CHECK: [[VAR_12_:%.+]] = "onnx.Conv"([[VAR_11_]], [[VAR_3_]], [[VAR_0_]]) {auto_pad = "SAME_UPPER", device = "nnpa", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -> tensor<1x16x14x14xf32> +// CHECK: [[VAR_13_:%.+]] = "onnx.Relu"([[VAR_12_]]) {device = "nnpa", onnx_node_name = "ReLU114"} : (tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> +// CHECK: [[VAR_14_:%.+]] = "onnx.MaxPoolSingleOut"([[VAR_13_]]) {auto_pad = "NOTSET", ceil_mode = 0 : si64, device = "nnpa", kernel_shape = [3, 3], onnx_node_name = "Pooling160", pads = [0, 0, 0, 0], storage_order = 0 : si64, strides = [3, 3]} : (tensor<1x16x14x14xf32>) -> tensor<1x16x4x4xf32> +// CHECK: [[VAR_15_:%.+]] = "onnx.Reshape"([[VAR_14_]], [[VAR_5_]]) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape0"} : (tensor<1x16x4x4xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: [[VAR_16_:%.+]] = "onnx.Gemm"([[VAR_15_]], [[VAR_8_]], [[VAR_7_]]) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, device = "nnpa", transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> +// CHECK: return [[VAR_16_]] : tensor<1x10xf32> +// CHECK: } +// CHECK: "onnx.EntryPoint"() {func = @mnist} : () -> () +} diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir new file mode 100644 index 0000000000..ad8a0a08ac --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir @@ -0,0 +1,56 @@ +// RUN: onnx-mlir --EmitZHighIR --maccel=NNPA --printIR %s | FileCheck %s + +// Note that, we intentionally add `device=cpu` into onnx.Gemm to force it run on CPU. +module { + func.func @mnist(%arg0: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> attributes {input_names = ["Input3"], output_names = ["Plus214_Output_0"]} { + %0 = onnx.Constant dense<[-0.0822488219, -0.108868778, -0.141039595, -0.204869166, -0.17913565, -0.215438381, -0.133805066, -0.195724562, -0.268250644, -0.258212209, -0.0761560649, 0.0132841459, -0.00444464432, -0.414740831, -0.17879115, -0.0386558883]> : tensor<16xf32> + %1 = onnx.Constant dense<[-0.161539719, -0.433835655, 0.091641359, -0.0168522168, -0.0650264397, -0.131737873, 0.0204175506, -0.121110231]> : tensor<8xf32> + %2 = onnx.Constant dense_resource<__elided__> : tensor<16x4x4x10xf32> + %3 = onnx.Constant dense_resource<__elided__> : tensor<16x8x5x5xf32> + %4 = onnx.Constant dense_resource<__elided__> : tensor<8x1x5x5xf32> + %5 = onnx.Constant dense<[1, 256]> : tensor<2xi64> + %6 = onnx.Constant dense<[256, 10]> : tensor<2xi64> + %7 = onnx.Constant dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32> + %8 = "onnx.Reshape"(%2, %6) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape1"} : (tensor<16x4x4x10xf32>, tensor<2xi64>) -> tensor<256x10xf32> + %9 = "onnx.Conv"(%arg0, %4, %1) {auto_pad = "SAME_UPPER", device = "nnpa", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x5x5xf32>, tensor<8xf32>) -> tensor<1x8x28x28xf32> + %10 = "onnx.Relu"(%9) {device = "nnpa", onnx_node_name = "ReLU32"} : (tensor<1x8x28x28xf32>) -> tensor<1x8x28x28xf32> + %11 = "onnx.MaxPoolSingleOut"(%10) {auto_pad = "NOTSET", ceil_mode = 0 : si64, device = "nnpa", kernel_shape = [2, 2], onnx_node_name = "Pooling66", pads = [0, 0, 0, 0], storage_order = 0 : si64, strides = [2, 2]} : (tensor<1x8x28x28xf32>) -> tensor<1x8x14x14xf32> + %12 = "onnx.Conv"(%11, %3, %0) {auto_pad = "SAME_UPPER", device = "nnpa", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -> tensor<1x16x14x14xf32> + %13 = "onnx.Relu"(%12) {device = "nnpa", onnx_node_name = "ReLU114"} : (tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> + %14 = "onnx.MaxPoolSingleOut"(%13) {auto_pad = "NOTSET", ceil_mode = 0 : si64, device = "nnpa", kernel_shape = [3, 3], onnx_node_name = "Pooling160", pads = [0, 0, 0, 0], storage_order = 0 : si64, strides = [3, 3]} : (tensor<1x16x14x14xf32>) -> tensor<1x16x4x4xf32> + %15 = "onnx.Reshape"(%14, %5) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape0"} : (tensor<1x16x4x4xf32>, tensor<2xi64>) -> tensor<1x256xf32> + %16 = "onnx.Gemm"(%15, %8, %7) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, device = "cpu", transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + return %16 : tensor<1x10xf32> + } + "onnx.EntryPoint"() {func = @mnist} : () -> () + +// CHECK-LABEL: func.func @mnist +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> attributes {input_names = ["Input3"], output_names = ["Plus214_Output_0"]} { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense_resource<__elided__> : tensor<16x4x4x10xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense_resource<__elided__> : tensor<16x8x5x5xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense_resource<__elided__> : tensor<8x1x5x5xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<[1, 256]> : tensor<2xi64> +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<[256, 10]> : tensor<2xi64> +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<{{.}}[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]{{.}}> : tensor<1x10xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_4_]]) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape1"} : (tensor<16x4x4x10xf32>, tensor<2xi64>) -> tensor<256x10xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "NHWC"} : (tensor<1x1x28x28xf32>) -> tensor<1x28x28x1xf32, #zhigh.layout<{dataLayout = "NHWC"}>> +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Transpose"([[VAR_2_]]) {perm = [2, 3, 1, 0]} : (tensor<8x1x5x5xf32>) -> tensor<5x5x1x8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = "zhigh.Stick"([[VAR_8_]]) {layout = "HWCK"} : (tensor<5x5x1x8xf32>) -> tensor<5x5x1x8xf32, #zhigh.layout<{dataLayout = "HWCK"}>> +// CHECK-DAG: [[VAR_10_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<8xf32, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK: [[VAR_11_:%.+]] = "zhigh.Conv2D"([[VAR_7_]], [[VAR_9_]], [[VAR_10_]]) {act_func = "ACT_RELU", kernel_shape = [5, 5], padding_type = "SAME_PADDING", strides = [1, 1]} : (tensor<1x28x28x1xf32, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<5x5x1x8xf32, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<8xf32, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<1x28x28x8xf32, #zhigh.layout<{dataLayout = "NHWC"}>> +// CHECK-DAG: [[VAR_12_:%.+]] = "zhigh.MaxPool2D"([[VAR_11_]]) {kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [2, 2]} : (tensor<1x28x28x8xf32, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x14x14x8xf32, #zhigh.layout<{dataLayout = "NHWC"}>> +// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Transpose"([[VAR_1_]]) {perm = [2, 3, 1, 0]} : (tensor<16x8x5x5xf32>) -> tensor<5x5x8x16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_14_:%.+]] = "zhigh.Stick"([[VAR_13_]]) {layout = "HWCK"} : (tensor<5x5x8x16xf32>) -> tensor<5x5x8x16xf32, #zhigh.layout<{dataLayout = "HWCK"}>> +// CHECK-DAG: [[VAR_15_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<16xf32, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK: [[VAR_16_:%.+]] = "zhigh.Conv2D"([[VAR_12_]], [[VAR_14_]], [[VAR_15_]]) {act_func = "ACT_RELU", kernel_shape = [5, 5], padding_type = "SAME_PADDING", strides = [1, 1]} : (tensor<1x14x14x8xf32, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<5x5x8x16xf32, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<16xf32, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<1x14x14x16xf32, #zhigh.layout<{dataLayout = "NHWC"}>> +// CHECK: [[VAR_17_:%.+]] = "zhigh.MaxPool2D"([[VAR_16_]]) {kernel_shape = [3, 3], padding_type = "VALID_PADDING", strides = [3, 3]} : (tensor<1x14x14x16xf32, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x4x4x16xf32, #zhigh.layout<{dataLayout = "NHWC"}>> +// CHECK: [[VAR_18_:%.+]] = "zhigh.Unstick"([[VAR_17_]]) : (tensor<1x4x4x16xf32, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x16x4x4xf32> +// CHECK: [[VAR_19_:%.+]] = "onnx.Reshape"([[VAR_18_]], [[VAR_3_]]) {allowzero = 0 : si64, onnx_node_name = "Times212_reshape0"} : (tensor<1x16x4x4xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: [[VAR_20_:%.+]] = "onnx.Gemm"([[VAR_19_]], [[VAR_6_]], [[VAR_5_]]) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, device = "cpu", transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> +// CHECK: return [[VAR_20_]] : tensor<1x10xf32> +// CHECK: } +// CHECK: "onnx.EntryPoint"() {func = @mnist} : () -> () +} diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/lit.local.cfg b/test/mlir/accelerators/nnpa/conversion/device-placement/lit.local.cfg new file mode 100644 index 0000000000..ac7f7ec3e6 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/lit.local.cfg @@ -0,0 +1,6 @@ +if sys.byteorder == "little": + config.unsupported = True +else: + config.unsupported = False + +root = config.root diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu-opt.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu-opt.mlir index 1ef74665e1..451591f64b 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu-opt.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu-opt.mlir @@ -1,18 +1,18 @@ -// RUN: onnx-mlir-opt --maccel=NNPA --shape-inference --convert-onnx-to-zhigh=execNodesOnCpu=test/add0,test/add2 %s | FileCheck %s +// RUN: onnx-mlir-opt --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s | FileCheck %s func.func @test_add_force_cpu_opt(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { - %0 = "onnx.Add"(%arg0, %arg1) {onnx_node_name = "test/add0"} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %0 = "onnx.Add"(%arg0, %arg1) {device = "cpu", onnx_node_name = "test/add0"} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> %1 = "onnx.Add"(%0, %arg0) {onnx_node_name = "test/add1"} : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> - %2 = "onnx.Add"(%1, %arg1) {onnx_node_name = "test/add2"} : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %2 = "onnx.Add"(%1, %arg1) {device = "cpu", onnx_node_name = "test/add2"} : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> "func.return"(%2) : (tensor<*xf32>) -> () // CHECK-LABEL: func @test_add_force_cpu_opt - // CHECK: "onnx.Add"({{.*}}, {{.*}}) {onnx_node_name = "test/add0"} + // CHECK: "onnx.Add"({{.*}}, {{.*}}) {device = "cpu", onnx_node_name = "test/add0"} // CHECK: "zhigh.Stick" // CHECK: "zhigh.Stick" // CHECK: "zhigh.Add" // CHECK: "zhigh.Unstick" - // CHECK: "onnx.Add"({{.*}}, {{.*}}) {onnx_node_name = "test/add2"} + // CHECK: "onnx.Add"({{.*}}, {{.*}}) {device = "cpu", onnx_node_name = "test/add2"} // CHECK: return } diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu.mlir index 1133707547..4e1a7c0646 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu.mlir @@ -1,14 +1,14 @@ -// RUN: onnx-mlir --maccel=NNPA --printIR --EmitZHighIR --execNodesOnCpu=test/add0,test/add2 -tag="test" %s | FileCheck %s +// RUN: onnx-mlir --maccel=NNPA --printIR --EmitZHighIR -tag="test" %s | FileCheck %s func.func @test_add_force_cpu(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { - %0 = "onnx.Add"(%arg0, %arg1) {onnx_node_name = "test/add0"} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %0 = "onnx.Add"(%arg0, %arg1) {device = "cpu", onnx_node_name = "test/add0"} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> %1 = "onnx.Add"(%0, %arg0) {onnx_node_name = "test/add1"} : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> - %2 = "onnx.Add"(%1, %arg1) {onnx_node_name = "test/add2"} : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %2 = "onnx.Add"(%1, %arg1) {device = "cpu", onnx_node_name = "test/add2"} : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> "onnx.Return"(%2) : (tensor<*xf32>) -> () // CHECK-LABEL: func @test_add_force_cpu - // CHECK: "onnx.Add"({{.*}}, {{.*}}) {onnx_node_name = "test/add0"} - // CHECK: "onnx.Add"({{.*}}, {{.*}}) {onnx_node_name = "test/add2"} + // CHECK: "onnx.Add"({{.*}}, {{.*}}) {device = "cpu", onnx_node_name = "test/add0"} + // CHECK: "onnx.Add"({{.*}}, {{.*}}) {device = "cpu", onnx_node_name = "test/add2"} // CHECK: return }