diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 168040d9b289..a7f707cae9bb 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1087,9 +1087,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return rewriter.notifyMatchFailure(binder.op, "auto_pad bind failure"); - if (autoPad != "NOTSET") - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: auto_pad != NOTSET"); Torch::ValueTensorType resultTypeOut; Value operand; @@ -1136,6 +1133,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure(binder.op, "dilations bind failure"); + // set default padding if (padding.empty()) padding.resize(spatial, 0); if (strides.empty()) @@ -1143,6 +1141,34 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (dilations.empty()) dilations.resize(spatial, 1); + auto inputTensorType = cast(operand.getType()); + + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (autoPad != "NOTSET" && autoPad != "VALID") { + const bool isSameLower = autoPad == "SAME_LOWER"; + ArrayRef inputShape = inputTensorType.getSizes(); + padding.resize_for_overwrite(2 * spatial); + for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) { + const int64_t dilatedKernelSize = + dilations[dimIdx] * (kernel[dimIdx] - 1) + 1; + int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / + strides[dimIdx] - + 1) * + strides[dimIdx] + + dilatedKernelSize - inputShape[dimIdx + 2]; + totalPad = totalPad >= 0 ? totalPad : 0; + padding[dimIdx] = + isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); + padding[spatial + dimIdx] = totalPad - padding[dimIdx]; + } + } + // If the padding is symmetric we can push the padding operation to the // torch operator. if (padding.size() == static_cast(2 * spatial)) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 21be2a65f4a6..d567db79fdf8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -730,6 +730,86 @@ func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch return %0 : !torch.vtensor<[1,64,56,56],f32> } +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_same_lower +func.func @test_maxpool_2d_same_lower(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_0:.*]] = torch.constant.int 1 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int1]], %[[int0]], %[[int1_0]], %[[int0_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308 + // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,3,33,33],f32> + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int0_3:.*]] = torch.constant.int 0 + // CHECK: %[[int0_4:.*]] = torch.constant.int 0 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_5:.*]] = torch.constant.int 1 + // CHECK: %[[int1_6:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_7:.*]] = torch.constant.int 1 + // CHECK: %[[int1_8:.*]] = torch.constant.int 1 + // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,32,32],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> + return %0 : !torch.vtensor<[1,3,32,32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_same_upper +func.func @test_maxpool_2d_same_upper(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_1:.*]] = torch.constant.int 1 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int0]], %[[int1]], %[[int0_0]], %[[int1_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308 + // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,3,33,33],f32> + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int0_3:.*]] = torch.constant.int 0 + // CHECK: %[[int0_4:.*]] = torch.constant.int 0 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_5:.*]] = torch.constant.int 1 + // CHECK: %[[int1_6:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_7:.*]] = torch.constant.int 1 + // CHECK: %[[int1_8:.*]] = torch.constant.int 1 + // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,32,32],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> + return %0 : !torch.vtensor<[1,3,32,32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_precomputed_same_upper +func.func @test_maxpool_2d_precomputed_same_upper(%arg0: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64}{ + // CHECK: %[[int3:.*]] = torch.constant.int 3 + // CHECK: %[[int3_0:.*]] = torch.constant.int 3 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int3]], %[[int3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int1_1:.*]] = torch.constant.int 1 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int1]], %[[int1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_3:.*]] = torch.constant.int 1 + // CHECK: %[[int1_4:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_3]], %[[int1_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC4:.*]] = torch.aten.max_pool2d %arg0, %[[list0]], %[[list2]], %[[list1]], %[[list3]], %[[FALSE]] : !torch.vtensor<[1,1,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,3,3],f32> +%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,3,3],f32> +return %0 : !torch.vtensor<[1,1,3,3],f32> +} + // -----