From 71487e6e98dffae8c0e729c3532f0f0fd628b89c Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 15 Oct 2024 16:18:57 +0000 Subject: [PATCH] OnnxToTorch bicubic interpolation --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 12 +- .../TorchToLinalg/Uncategorized.cpp | 242 ++++++++++++++++-- 2 files changed, 234 insertions(+), 20 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 68868e95c385..6ba2a5887eea 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2914,7 +2914,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector operands; std::string mode, nearest_mode, coordTfMode; int64_t antialias, exclude_outside; - float extrapolation_value; + float extrapolation_value, cubic_coeff_a; Value noneVal = rewriter.create(binder.getLoc()); if (auto attr = binder.op->getAttr("torch.onnx.axes")) { @@ -2939,7 +2939,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.f32FloatAttr(extrapolation_value, "extrapolation_value", 0.0) || binder.customOpNameStringAttr(nearest_mode, "nearest_mode", - "round_prefer_floor")) + "round_prefer_floor") || + binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) return failure(); if (antialias != 0) { return rewriter.notifyMatchFailure( @@ -2983,8 +2984,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value alignCorners = coordTfMode == "align_corners" ? cstTrue : cstFalse; if (mode == "cubic") { - return rewriter.notifyMatchFailure(binder.op, - "unimplemented: bicubic mode"); + std::string modeStr = "cubic"; + if (coordTfMode != "half_pixel") + modeStr = modeStr + "_" + coordTfMode; + modeStrValue = + rewriter.create(binder.getLoc(), modeStr); } // supported modes: // bilinear (half_pixel), bilinear with align_corners, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 7823138c9672..1174a90d7b78 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2740,12 +2740,13 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, return retVal; } -static Value BilinearInterpolate(OpBuilder &b, - Aten__InterpolateSizeListScaleListOp op, - Location loc, SmallVector outputSizes, - Value input, SmallVector inputSizes, - SmallVector scaleValues, - std::string coordStr) { +static SmallVector +CoordinateTransform(OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, Value input, + SmallVector inputSizes, + SmallVector scaleValues, std::string coordStr, + bool alignCornersBool, SmallVector indices) { + unsigned dimOffset = 2; auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); @@ -2754,15 +2755,7 @@ static Value BilinearInterpolate(OpBuilder &b, Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); Value zero = b.create(loc, b.getF32FloatAttr(0.0)); - bool alignCornersBool; - matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); - - SmallVector indices; - for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); - } - - SmallVector proj, projEps, high, low, highFP, lowFP; + SmallVector proj; for (unsigned i = 0; i < inputRank - dimOffset; i++) { // length_original Value inputFP = @@ -2832,6 +2825,40 @@ static Value BilinearInterpolate(OpBuilder &b, // clip to [0,length_original - 1]. // proj is properly within the input image. proj.push_back(b.create(loc, max, inputSubOne)); + } + return proj; +} + +static Value BilinearInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + + bool alignCornersBool; + matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + SmallVector proj, high, low, highFP, lowFP; + proj = CoordinateTransform(b, op, loc, outputSizes, input, inputSizes, + scaleValues, coordStr, alignCornersBool, indices); + for (unsigned i = 0; i < inputRank - dimOffset; i++) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); // for bilinear interpolation, we look for the nearest indices below and // above proj @@ -2895,6 +2922,184 @@ static Value BilinearInterpolate(OpBuilder &b, return b.create(loc, left, right); } +static Value WeightFunction(OpBuilder &b, Location loc, Value xDistance) { + Value a = b.create(loc, b.getF32FloatAttr(-0.75)); + Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + Value cstTwoFloat = b.create(loc, b.getF32FloatAttr(2.0)); + Value cstThreeFloat = + b.create(loc, b.getF32FloatAttr(3.0)); + Value cstFourFloat = b.create(loc, b.getF32FloatAttr(4.0)); + Value cstFiveFloat = b.create(loc, b.getF32FloatAttr(5.0)); + Value cstEightFloat = + b.create(loc, b.getF32FloatAttr(8.0)); + + Value xDistanceSquared = b.create(loc, xDistance, xDistance); + Value xDistanceCubed = + b.create(loc, xDistanceSquared, xDistance); + Value lessThanTwo = b.create(loc, xDistanceCubed, a); + Value fiveA = b.create(loc, xDistanceSquared, a); + fiveA = b.create(loc, fiveA, cstFiveFloat); + lessThanTwo = b.create(loc, fiveA, lessThanTwo); + Value eightA = b.create(loc, a, xDistance); + eightA = b.create(loc, eightA, cstEightFloat); + lessThanTwo = b.create(loc, eightA, lessThanTwo); + Value fourA = b.create(loc, a, cstFourFloat); + lessThanTwo = b.create(loc, fourA, lessThanTwo); + + Value greaterthanOrEqualToTwo = zero; + + Value lessEqualOne = b.create(loc, a, cstTwoFloat); + lessEqualOne = b.create(loc, xDistanceCubed, lessEqualOne); + Value aPlusThree = b.create(loc, a, cstThreeFloat); + aPlusThree = b.create(loc, xDistanceSquared, aPlusThree); + lessEqualOne = b.create(loc, lessEqualOne, aPlusThree); + lessEqualOne = b.create(loc, lessEqualOne, cstOneFloat); + + Value cmp = b.create(loc, arith::CmpFPredicate::UGE, xDistance, + cstTwoFloat); + Value greaterThanOne = + b.create(loc, cmp, greaterthanOrEqualToTwo, lessThanTwo); + cmp = b.create(loc, arith::CmpFPredicate::ULE, xDistance, + cstOneFloat); + Value middle = + b.create(loc, cmp, lessEqualOne, greaterThanOne); + + return middle; +} + +static Value BicubicInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value inputFPH = + b.create(loc, b.getF32Type(), inputSizes[0]); + Value inputFPW = + b.create(loc, b.getF32Type(), inputSizes[1]); + + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + Value cstNegativeOneFloat = + b.create(loc, b.getF32FloatAttr(-1.0)); + bool alignCornersBool; + matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + SmallVector proj; + proj = CoordinateTransform(b, op, loc, outputSizes, input, inputSizes, + scaleValues, coordStr, alignCornersBool, indices); + + Value x1 = b.create(loc, proj[1]); + Value x_1 = b.create(loc, x1, cstOneFloat); + Value x_2 = b.create(loc, x_1, cstOneFloat); + Value x2 = b.create(loc, x1, cstOneFloat); + + Value y1 = b.create(loc, proj[0]); + Value y_1 = b.create(loc, y1, cstOneFloat); + Value y_2 = b.create(loc, y_1, cstOneFloat); + Value y2 = b.create(loc, y1, cstOneFloat); + + // the offset is zero if x_2 is inside to image (leftwise greater than 0) + Value max = b.create(loc, x_2, zero); + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, max, x_2); + Value xOffset = b.create(loc, x_2, cstNegativeOneFloat); + xOffset = b.create(loc, cmp, zero, xOffset); + + Value inputWSubOne = b.create(loc, inputFPW, cstOneFloat); + Value min = b.create(loc, x2, inputWSubOne); + cmp = b.create(loc, arith::CmpFPredicate::UEQ, min, x2); + Value highOffset = b.create(loc, x2, inputWSubOne); + highOffset = b.create(loc, highOffset, cstNegativeOneFloat); + xOffset = b.create(loc, cmp, xOffset, highOffset); + + // get y offset + max = b.create(loc, y_2, zero); + cmp = b.create(loc, arith::CmpFPredicate::UEQ, max, y_2); + Value yOffset = b.create(loc, y_2, cstNegativeOneFloat); + yOffset = b.create(loc, cmp, zero, yOffset); + + Value inputHSubOne = b.create(loc, inputFPH, cstOneFloat); + min = b.create(loc, y2, inputHSubOne); + cmp = b.create(loc, arith::CmpFPredicate::UEQ, min, y2); + highOffset = b.create(loc, y2, inputHSubOne); + highOffset = b.create(loc, highOffset, cstNegativeOneFloat); + yOffset = b.create(loc, cmp, yOffset, highOffset); + + x1 = b.create(loc, x1, xOffset); + x_1 = b.create(loc, x_1, xOffset); + x_2 = b.create(loc, x_2, xOffset); + x2 = b.create(loc, x2, xOffset); + + y1 = b.create(loc, y1, yOffset); + y_1 = b.create(loc, y_1, yOffset); + y_2 = b.create(loc, y_2, yOffset); + y2 = b.create(loc, y2, yOffset); + + Value x1Distance = b.create(loc, proj[1], x1); + x1Distance = b.create(loc, x1Distance); + Value x_1Distance = b.create(loc, proj[1], x_1); + x_1Distance = b.create(loc, x_1Distance); + Value x_2Distance = b.create(loc, proj[1], x_2); + x_2Distance = b.create(loc, x_2Distance); + Value x2Distance = b.create(loc, proj[1], x2); + x2Distance = b.create(loc, x2Distance); + + Value y1Distance = b.create(loc, proj[0], y1); + y1Distance = b.create(loc, y1Distance); + Value y_1Distance = b.create(loc, proj[0], y_1); + y_1Distance = b.create(loc, y_1Distance); + Value y_2Distance = b.create(loc, proj[0], y_2); + y_2Distance = b.create(loc, y_2Distance); + Value y2Distance = b.create(loc, proj[0], y2); + y2Distance = b.create(loc, y2Distance); + + SmallVector y{y_2, y_1, y1, y2}; + SmallVector x{x_2, x_1, x1, x2}; + SmallVector yDistance{y_2Distance, y_1Distance, y1Distance, + y2Distance}; + SmallVector xDistance{x_2Distance, x_1Distance, x1Distance, + x2Distance}; + SmallVector xInterp{zero, zero, zero, zero}; + + // f(x_orig, y_orig) = Sum_y Sum_x W(x_original - x)*input[x,y] * W(y_original + // -y) + Value fxy = zero; + for (int j = 0; j < 4; j++) { + Value wy = WeightFunction(b, loc, yDistance[j]); + Value xInterpy = xInterp[j]; + for (int i = 0; i < 4; i++) { + Value wx = WeightFunction(b, loc, xDistance[i]); + + Value yInt = b.create(loc, b.getI64Type(), y[j]); + Value yIndex = b.create(loc, b.getIndexType(), yInt); + indices[dimOffset] = yIndex; + + Value xInt = b.create(loc, b.getI64Type(), x[i]); + Value xIndex = b.create(loc, b.getIndexType(), xInt); + indices[dimOffset + 1] = xIndex; + + Value p = b.create(loc, input, indices); + Value wxp = b.create(loc, wx, p); + xInterpy = b.create(loc, xInterpy, wxp); + } + Value wyXInterpy = b.create(loc, wy, xInterpy); + fxy = b.create(loc, fxy, wyXInterpy); + } + + return fxy; +} + namespace { class ConvertInterpolateOp : public OpConversionPattern { @@ -2910,7 +3115,8 @@ class ConvertInterpolateOp // coordinate_transformation_mode="asymmetric" will lower to an interpolate // op with the non-standard mode="bilinear_asymmetric". matchPattern(op.getMode(), m_TorchConstantStr(mode)); - if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") { + if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" && + mode.substr(0, 5) != "cubic") { return failure(); } @@ -2999,6 +3205,10 @@ class ConvertInterpolateOp retVal = BilinearInterpolate( b, op, loc, outputSizeIntValues, input, inputSizes, ScaleFactorFloatValues, mode.substr(8)); + } else if (mode.substr(0, 5) == "cubic") { + retVal = BicubicInterpolate( + b, op, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, mode.substr(5)); } b.create(loc, retVal); })