From c9a52cbcc6289c2f5c180393be1a77c6c99005bd Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 10 Jun 2024 14:52:56 +0000 Subject: [PATCH 1/2] Fix resize ceil numerics and add support for half_pixel_symmetric --- .../TorchToLinalg/Uncategorized.cpp | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index b6fc225c42fe..0f3b26901946 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2647,14 +2647,21 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, nearestFP = b.create(loc, cmp, floor, ceil); } else if (nearestMode == "round_prefer_ceil") { Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value cstOne = b.create(loc, b.getF32FloatAttr(1)); Value floor = b.create(loc, proj); Value ceil = b.create(loc, proj); Value decimal = b.create(loc, proj, floor); Value cmp = b.create(loc, arith::CmpFPredicate::UGE, decimal, cstHalf); nearestFP = b.create(loc, cmp, ceil, floor); + Value inputSizeMOne = b.create(loc, inputSizeFP, cstOne); + // don't extract out of bounds + nearestFP = b.create(loc, nearestFP, inputSizeMOne); } else if (nearestMode == "ceil") { + Value cstOne = b.create(loc, b.getF32FloatAttr(1)); + Value inputSizeMOne = b.create(loc, inputSizeFP, cstOne); nearestFP = b.create(loc, proj); + nearestFP = b.create(loc, nearestFP, inputSizeMOne); } Value nearestInt = b.create(loc, b.getI64Type(), nearestFP); @@ -2726,7 +2733,8 @@ static Value BilinearInterpolate(OpBuilder &b, if (coordStr == "_asymmetric") { preClip = b.create(loc, outFP, scale); } - if (coordStr == "_pytorch_half_pixel" || coordStr == "") { + if (coordStr == "_pytorch_half_pixel" || coordStr == "" || + coordStr == "_half_pixel_symmetric") { // half-pixel modes // y_resized + 0.5 Value outPlusHalf = b.create(loc, outFP, cstHalf); @@ -2735,6 +2743,18 @@ static Value BilinearInterpolate(OpBuilder &b, // _ - 0.5 preClip = b.create(loc, outDivScale, cstHalf); } + // for half_pixel_symmetric, need to compute offset from raw scales + if (coordStr == "_half_pixel_symmetric" && !scaleValues.empty()) { + Value outputSizeFromScale = b.create(loc, inputFP, scale); + Value adjustment = + b.create(loc, outputSizeFP, outputSizeFromScale); + Value cstTwo = b.create(loc, b.getF32FloatAttr(2.0)); + Value center = b.create(loc, inputFP, cstTwo); + Value oneMAdjustment = + b.create(loc, cstOneFloat, adjustment); + Value offset = b.create(loc, center, oneMAdjustment); + preClip = b.create(loc, offset, preClip); + } // for pytorch half pixel , special case for length_resized == 1: if (coordStr == "_pytorch_half_pixel") { Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, From 2ecb77ebe185e6079ae4503ef5db1924ea81907d Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Tue, 11 Jun 2024 17:50:00 +0000 Subject: [PATCH 2/2] add some lit tests --- test/Conversion/TorchToLinalg/resize.mlir | 84 ++++++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 6847d25736f1..64198d03f2a1 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -156,7 +156,89 @@ func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: return %7 : !torch.vtensor<[?,?,?,?,?],f32> } -// CHECK-LABEL: func.func @test_resize_nearest_half_pixel +// ----- + +// CHECK-LABEL: func.func @test_resize_nearest_ceil +func.func @test_resize_nearest_ceil(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[cst:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[add:.*]] = arith.addf %[[x24]], %[[cst]] : f32 + // CHECK: %[[x25:.*]] = arith.divf %[[add]], %[[x21]] : f32 + // CHECK: %[[sub:.*]] = arith.subf %[[x25]], %[[cst]] : f32 + // CHECK: %[[cst3:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[nM1:.*]] = arith.subf %[[inputsizefp:.*]], %[[cst3]] + // CHECK: %[[ceil:.*]] = math.ceil %[[sub]] : f32 + // CHECK: %[[minindex:.*]] = arith.minimumf %[[ceil]], %[[nM1]] + // CHECK: %[[x31:.*]] = arith.fptosi %[[minindex]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest_half_pixel,ceil" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_resize_scales_linear_half_pixel_symmetric +func.func @test_resize_scales_linear_half_pixel_symmetric(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] +,f64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[generic:.*]] = linalg.generic + // CHECK: %[[cst7:.*]] = arith.constant 2.0 + // CHECK: %[[halfsize:.*]] = arith.divf %[[sizefp:.*]], %[[cst7]] + // CHECK: %[[modifier:.*]] = arith.subf %[[cstOne:.*]], %[[adjustment:.*]] + // CHECK: %[[offset:.*]] = arith.mulf %[[halfsize]], %[[modifier]] + // CHECK: %[[preClip:.*]] = arith.addf %[[offset]], %[[halfpixelbase:.*]] + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]] + // CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]] + // CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]] + // CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]] + // CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]] + // CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]] + // CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]] + // CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]] + // CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]] + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "bilinear_half_pixel_symmetric" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],f64>, !torch.int, !torch.int -> !torch.vtensor<[1],f64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],f64> -> !torch.float + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],f64>, !torch.int, !torch.int -> !torch.vtensor<[1],f64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],f64> -> !torch.float + %4 = torch.prim.ListConstruct %1, %3 : (!torch.float, !torch.float) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %none_0, %4, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.list, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_nearest_half_pixel_round_prefer_floor func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[GENERIC:.*]] = linalg.generic // CHECK: %[[x11:.*]] = linalg.index 0 : index