Skip to content

Commit

Permalink
Merge branch 'main' into constantof-shape-related-canonicalize
Browse files Browse the repository at this point in the history
  • Loading branch information
tungld authored Sep 12, 2023
2 parents e55713b + c69721b commit 181bcae
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 4 deletions.
11 changes: 11 additions & 0 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern<ZLowStickOp> {
ZLowStickOp stickOp, PatternRewriter &rewriter) const override {
Value stickInput = stickOp.getX();

// Do not handle NCHW layout stickification that transposes data
// internally.
std::string stickLayout = stickOp.getLayout().value().str();
if (stickLayout == LAYOUT_NCHW)
return failure();

// Input is a block argument, ignore it.
if (stickInput.dyn_cast<BlockArgument>())
return failure();
Expand All @@ -157,6 +163,11 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern<ZLowStickOp> {
ZLowUnstickOp userOp = llvm::dyn_cast<ZLowUnstickOp>(user);
if (!userOp)
continue;
// Do not handle NCHW layout stickification that transposes data
// internally.
std::string unstickLayout = userOp.getLayout().value().str();
if (unstickLayout == LAYOUT_NCHW)
continue;
// UnstickOp must be before the view operation.
if (userOp.getOut() == viewSource &&
user->isBeforeInBlock(viewOp.getOperation())) {
Expand Down
2 changes: 1 addition & 1 deletion src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ op_dialect_version_map_["ReduceSum"] = {13, 11};
op_dialect_version_map_["ReduceSumSquare"] = {18, 13};
op_dialect_version_map_["Relu"] = {14};
op_dialect_version_map_["Reshape"] = {19};
op_dialect_version_map_["Resize"] = {18, 13, 11, 10};
op_dialect_version_map_["Resize"] = {19, 13, 11, 10};
op_dialect_version_map_["ReverseSequence"] = {10};
op_dialect_version_map_["RoiAlign"] = {16};
op_dialect_version_map_["Round"] = {11};
Expand Down
6 changes: 4 additions & 2 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -6967,8 +6967,10 @@ def ONNXResizeOp:ONNX_Op<"Resize",
let summary = "ONNX Resize operation";
let description = [{
Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor.
Each dimension value of the output tensor is: <br/>
`output_dimension = floor(input_dimension * (roi_end - roi_start) * scale)` <br/>
Each dimension value of the output tensor is:
```
output_dimension = floor(input_dimension * (roi_end - roi_start) * scale)
```
if input \\"sizes\\" is not specified.
}];
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>:$X,
Expand Down
16 changes: 16 additions & 0 deletions test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,22 @@ func.func @test_remove_unstick_view_stick(%arg0: memref<7x4x1x8x32x64xf16>) -> (

// -----

func.func @test_should_not_remove_unstick_view_stick_nchw(%arg0: memref<1x1x1x1x32x64xf16>) -> (memref<1x1x1x1x32x64xf16>){
%0 = memref.alloc() {alignment = 16 : i64} : memref<1x32x1x22xf32>
"zlow.unstick"(%arg0, %0) {layout = "NCHW"} : (memref<1x1x1x1x32x64xf16>, memref<1x32x1x22xf32>) -> ()
%1 = memref.reinterpret_cast %0 to offset: [0], sizes: [1, 32, 22], strides: [704, 22, 1] : memref<1x32x1x22xf32> to memref<1x32x22xf32>
%2 = memref.alloc() {alignment = 4096 : i64} : memref<1x1x1x1x32x64xf16>
"zlow.stick"(%1, %2) {layout = "3DS"} : (memref<1x32x22xf32>, memref<1x1x1x1x32x64xf16>) -> ()
"func.return"(%2) : (memref<1x1x1x1x32x64xf16>) -> ()

// CHECK-LABEL: test_should_not_remove_unstick_view_stick_nchw
// CHECK: "zlow.unstick"
// CHECK: memref.reinterpret_cast
// CHECK: "zlow.stick"
}

// -----

// Remove zlow.stick and zlow.unstick in pattern: unstick -> transpose -> stick.
// Test a simple transpose.

Expand Down
2 changes: 1 addition & 1 deletion utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@
'ReduceSumSquare': [18, 13],
'Relu': [14],
'Reshape': [19],
'Resize': [18, 13, 11, 10],
'Resize': [19, 13, 11, 10],
'ReverseSequence': [10],
'RoiAlign': [16],
'Round': [11],
Expand Down

0 comments on commit 181bcae

Please sign in to comment.