diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index cc21f2155e46..5354ca2339db 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1201,6 +1201,21 @@ class DecomposeAtenIsposinfOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenNonzeroOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNonzeroOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value zeroScalar = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + zeroScalar); + return success(); + } +}; +} // namespace namespace { class DecomposeAtenReshapeOp : public OpRewritePattern { public: @@ -7740,10 +7755,13 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + // is-xxx ops addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 530160f990ae..39b150339392 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -78,3 +78,10 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch %0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16> return %0 : !torch.tensor<[?], f16> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.nonzero +func.func @torch.aten.nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> { + %0 = torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64> + return %0 : !torch.vtensor<[3,4,5],si64> +}