diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp index 6dfe458bec..406c44cb91 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp @@ -133,6 +133,12 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern { ZLowStickOp stickOp, PatternRewriter &rewriter) const override { Value stickInput = stickOp.getX(); + // Do not handle NCHW layout stickification that transposes data + // internally. + std::string stickLayout = stickOp.getLayout().value().str(); + if (stickLayout == LAYOUT_NCHW) + return failure(); + // Input is a block argument, ignore it. if (stickInput.dyn_cast()) return failure(); @@ -157,6 +163,11 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern { ZLowUnstickOp userOp = llvm::dyn_cast(user); if (!userOp) continue; + // Do not handle NCHW layout stickification that transposes data + // internally. + std::string unstickLayout = userOp.getLayout().value().str(); + if (unstickLayout == LAYOUT_NCHW) + continue; // UnstickOp must be before the view operation. if (userOp.getOut() == viewSource && user->isBeforeInBlock(viewOp.getOperation())) { diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index 5867fc7e1a..e6b1ac4d38 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -157,7 +157,7 @@ op_dialect_version_map_["ReduceSum"] = {13, 11}; op_dialect_version_map_["ReduceSumSquare"] = {18, 13}; op_dialect_version_map_["Relu"] = {14}; op_dialect_version_map_["Reshape"] = {19}; -op_dialect_version_map_["Resize"] = {18, 13, 11, 10}; +op_dialect_version_map_["Resize"] = {19, 13, 11, 10}; op_dialect_version_map_["ReverseSequence"] = {10}; op_dialect_version_map_["RoiAlign"] = {16}; op_dialect_version_map_["Round"] = {11}; diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index d75ee6b66d..77ceec013a 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -6967,8 +6967,10 @@ def ONNXResizeOp:ONNX_Op<"Resize", let summary = "ONNX Resize operation"; let description = [{ Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. - Each dimension value of the output tensor is:
- `output_dimension = floor(input_dimension * (roi_end - roi_start) * scale)`
+ Each dimension value of the output tensor is: + ``` + output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) + ``` if input \\"sizes\\" is not specified. }]; let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X, diff --git a/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir b/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir index 82b678412c..a5a3488c7e 100644 --- a/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir +++ b/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir @@ -137,6 +137,22 @@ func.func @test_remove_unstick_view_stick(%arg0: memref<7x4x1x8x32x64xf16>) -> ( // ----- +func.func @test_should_not_remove_unstick_view_stick_nchw(%arg0: memref<1x1x1x1x32x64xf16>) -> (memref<1x1x1x1x32x64xf16>){ + %0 = memref.alloc() {alignment = 16 : i64} : memref<1x32x1x22xf32> + "zlow.unstick"(%arg0, %0) {layout = "NCHW"} : (memref<1x1x1x1x32x64xf16>, memref<1x32x1x22xf32>) -> () + %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [1, 32, 22], strides: [704, 22, 1] : memref<1x32x1x22xf32> to memref<1x32x22xf32> + %2 = memref.alloc() {alignment = 4096 : i64} : memref<1x1x1x1x32x64xf16> + "zlow.stick"(%1, %2) {layout = "3DS"} : (memref<1x32x22xf32>, memref<1x1x1x1x32x64xf16>) -> () + "func.return"(%2) : (memref<1x1x1x1x32x64xf16>) -> () + + // CHECK-LABEL: test_should_not_remove_unstick_view_stick_nchw + // CHECK: "zlow.unstick" + // CHECK: memref.reinterpret_cast + // CHECK: "zlow.stick" +} + +// ----- + // Remove zlow.stick and zlow.unstick in pattern: unstick -> transpose -> stick. // Test a simple transpose. diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 95f4bb2253..488038eb62 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -226,7 +226,7 @@ 'ReduceSumSquare': [18, 13], 'Relu': [14], 'Reshape': [19], - 'Resize': [18, 13, 11, 10], + 'Resize': [19, 13, 11, 10], 'ReverseSequence': [10], 'RoiAlign': [16], 'Round': [11],