From 74b9709fbb03daca05e880ba6d5d10be1317ed79 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sun, 28 Apr 2024 00:58:50 +0800 Subject: [PATCH] [Torch] emit aten.ne.str and add folder (#3242) --- lib/Dialect/Torch/IR/TorchOps.cpp | 41 +++++++++++++++++++++++++++- test/Dialect/Torch/canonicalize.mlir | 41 ++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 29911961dc06..65ecf88c52a2 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4715,6 +4715,45 @@ LogicalResult AtenPermuteOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PrimsConvertElementTypeOp +//===----------------------------------------------------------------------===// + +OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) { + auto inputType = cast(getA().getType()); + auto outputType = cast(getResult().getType()); + if (inputType != outputType) + return nullptr; + if (!inputType.hasDtype() || !outputType.hasDtype()) + return nullptr; + if (inputType.getDtype() != outputType.getDtype()) + return nullptr; + return getA(); +} + +//===----------------------------------------------------------------------===// +// AtenMaxPool2dWithIndicesOp +//===----------------------------------------------------------------------===// + +void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) { + if (!op.getResult1().use_empty()) { + return rewriter.notifyMatchFailure( + op, "result1 of MaxPool2dWithIndices should be unused"); + } + + Value result = rewriter.create( + op->getLoc(), op.getResult0().getType(), op.getSelf(), + op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(), + op.getCeilMode()); + + op.getResult0().replaceAllUsesWith(result); + rewriter.eraseOp(op); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenLinalgCrossOp //===----------------------------------------------------------------------===// @@ -4967,4 +5006,4 @@ LogicalResult InitializeGlobalSlotsOp::verify() { if (getInitialValues().size() != getSlotSymNames().size()) return emitOpError("expected number of operands to match number of slots"); return success(); -} +} \ No newline at end of file diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a317e4011b3e..94393104ae91 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2974,3 +2974,44 @@ func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> { %result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32> return %result : !torch.vtensor<[4], f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> { +// CHECK: return %[[ARG]] : !torch.vtensor<[64],f32> +func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> { + %int6 = torch.constant.int 6 + %0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + return %0 : !torch.vtensor<[64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.prims.convert_element_type$no_fold( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> { +// CHECK: %[[RET:.*]] = torch.prims.convert_element_type %[[ARG]], %{{.*}} : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32> +// CHECK: return %[[RET]] : !torch.vtensor<[64],si32> +func.func @torch.prims.convert_element_type$no_fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> { + %int6 = torch.constant.int 6 + %0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32> + return %0 : !torch.vtensor<[64],si32> +} + +// ----- + +// CHECK-LABEL: @torch.aten.max_pool2d_with_indices$canonicalize( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> { +// CHECK: %[[RET:.*]] = torch.aten.max_pool2d %[[ARG]] +// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56],f32> +func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64> + return %result0 : !torch.vtensor<[10,64,56,56],f32> +} \ No newline at end of file