Skip to content

Commit

Permalink
OnnxToTorch bicubic interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
aldesilv committed Oct 23, 2024
1 parent 0a86deb commit 71487e6
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 20 deletions.
12 changes: 8 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2914,7 +2914,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<Value> operands;
std::string mode, nearest_mode, coordTfMode;
int64_t antialias, exclude_outside;
float extrapolation_value;
float extrapolation_value, cubic_coeff_a;
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
Expand All @@ -2939,7 +2939,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.f32FloatAttr(extrapolation_value, "extrapolation_value",
0.0) ||
binder.customOpNameStringAttr(nearest_mode, "nearest_mode",
"round_prefer_floor"))
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();
if (antialias != 0) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -2983,8 +2984,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value alignCorners =
coordTfMode == "align_corners" ? cstTrue : cstFalse;
if (mode == "cubic") {
return rewriter.notifyMatchFailure(binder.op,
"unimplemented: bicubic mode");
std::string modeStr = "cubic";
if (coordTfMode != "half_pixel")
modeStr = modeStr + "_" + coordTfMode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
}
// supported modes:
// bilinear (half_pixel), bilinear with align_corners,
Expand Down
242 changes: 226 additions & 16 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2740,12 +2740,13 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
return retVal;
}

static Value BilinearInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
static SmallVector<Value>
CoordinateTransform(OpBuilder &b, Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes, Value input,
SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues, std::string coordStr,
bool alignCornersBool, SmallVector<Value> indices) {

unsigned dimOffset = 2;
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
Expand All @@ -2754,15 +2755,7 @@ static Value BilinearInterpolate(OpBuilder &b,
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));

bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));

SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}

SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
SmallVector<Value> proj;
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
// length_original
Value inputFP =
Expand Down Expand Up @@ -2832,6 +2825,40 @@ static Value BilinearInterpolate(OpBuilder &b,
// clip to [0,length_original - 1].
// proj is properly within the input image.
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
}
return proj;
}

static Value BilinearInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
unsigned dimOffset = 2;
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();

Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));

bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));

SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}

SmallVector<Value> proj, high, low, highFP, lowFP;
proj = CoordinateTransform(b, op, loc, outputSizes, input, inputSizes,
scaleValues, coordStr, alignCornersBool, indices);
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
// length_original
Value inputFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);

// for bilinear interpolation, we look for the nearest indices below and
// above proj
Expand Down Expand Up @@ -2895,6 +2922,184 @@ static Value BilinearInterpolate(OpBuilder &b,
return b.create<arith::AddFOp>(loc, left, right);
}

static Value WeightFunction(OpBuilder &b, Location loc, Value xDistance) {
Value a = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(-0.75));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
Value cstTwoFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(2.0));
Value cstThreeFloat =
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(3.0));
Value cstFourFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(4.0));
Value cstFiveFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(5.0));
Value cstEightFloat =
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(8.0));

Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
Value xDistanceCubed =
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
Value lessThanTwo = b.create<arith::MulFOp>(loc, xDistanceCubed, a);
Value fiveA = b.create<arith::MulFOp>(loc, xDistanceSquared, a);
fiveA = b.create<arith::MulFOp>(loc, fiveA, cstFiveFloat);
lessThanTwo = b.create<arith::AddFOp>(loc, fiveA, lessThanTwo);
Value eightA = b.create<arith::MulFOp>(loc, a, xDistance);
eightA = b.create<arith::MulFOp>(loc, eightA, cstEightFloat);
lessThanTwo = b.create<arith::AddFOp>(loc, eightA, lessThanTwo);
Value fourA = b.create<arith::MulFOp>(loc, a, cstFourFloat);
lessThanTwo = b.create<arith::AddFOp>(loc, fourA, lessThanTwo);

Value greaterthanOrEqualToTwo = zero;

Value lessEqualOne = b.create<arith::AddFOp>(loc, a, cstTwoFloat);
lessEqualOne = b.create<arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
Value aPlusThree = b.create<arith::AddFOp>(loc, a, cstThreeFloat);
aPlusThree = b.create<arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
lessEqualOne = b.create<arith::AddFOp>(loc, lessEqualOne, aPlusThree);
lessEqualOne = b.create<arith::AddFOp>(loc, lessEqualOne, cstOneFloat);

Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE, xDistance,
cstTwoFloat);
Value greaterThanOne =
b.create<arith::SelectOp>(loc, cmp, greaterthanOrEqualToTwo, lessThanTwo);
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, xDistance,
cstOneFloat);
Value middle =
b.create<arith::SelectOp>(loc, cmp, lessEqualOne, greaterThanOne);

return middle;
}

static Value BicubicInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
unsigned dimOffset = 2;
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();

Value inputFPH =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[0]);
Value inputFPW =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[1]);

Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
Value cstNegativeOneFloat =
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(-1.0));
bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));

SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}

SmallVector<Value> proj;
proj = CoordinateTransform(b, op, loc, outputSizes, input, inputSizes,
scaleValues, coordStr, alignCornersBool, indices);

Value x1 = b.create<math::CeilOp>(loc, proj[1]);
Value x_1 = b.create<arith::SubFOp>(loc, x1, cstOneFloat);
Value x_2 = b.create<arith::SubFOp>(loc, x_1, cstOneFloat);
Value x2 = b.create<arith::AddFOp>(loc, x1, cstOneFloat);

Value y1 = b.create<math::CeilOp>(loc, proj[0]);
Value y_1 = b.create<arith::SubFOp>(loc, y1, cstOneFloat);
Value y_2 = b.create<arith::SubFOp>(loc, y_1, cstOneFloat);
Value y2 = b.create<arith::AddFOp>(loc, y1, cstOneFloat);

// the offset is zero if x_2 is inside to image (leftwise greater than 0)
Value max = b.create<arith::MaximumFOp>(loc, x_2, zero);
Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ, max, x_2);
Value xOffset = b.create<arith::MulFOp>(loc, x_2, cstNegativeOneFloat);
xOffset = b.create<arith::SelectOp>(loc, cmp, zero, xOffset);

Value inputWSubOne = b.create<arith::SubFOp>(loc, inputFPW, cstOneFloat);
Value min = b.create<arith::MinimumFOp>(loc, x2, inputWSubOne);
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ, min, x2);
Value highOffset = b.create<arith::SubFOp>(loc, x2, inputWSubOne);
highOffset = b.create<arith::MulFOp>(loc, highOffset, cstNegativeOneFloat);
xOffset = b.create<arith::SelectOp>(loc, cmp, xOffset, highOffset);

// get y offset
max = b.create<arith::MaximumFOp>(loc, y_2, zero);
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ, max, y_2);
Value yOffset = b.create<arith::MulFOp>(loc, y_2, cstNegativeOneFloat);
yOffset = b.create<arith::SelectOp>(loc, cmp, zero, yOffset);

Value inputHSubOne = b.create<arith::SubFOp>(loc, inputFPH, cstOneFloat);
min = b.create<arith::MinimumFOp>(loc, y2, inputHSubOne);
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ, min, y2);
highOffset = b.create<arith::SubFOp>(loc, y2, inputHSubOne);
highOffset = b.create<arith::MulFOp>(loc, highOffset, cstNegativeOneFloat);
yOffset = b.create<arith::SelectOp>(loc, cmp, yOffset, highOffset);

x1 = b.create<arith::AddFOp>(loc, x1, xOffset);
x_1 = b.create<arith::AddFOp>(loc, x_1, xOffset);
x_2 = b.create<arith::AddFOp>(loc, x_2, xOffset);
x2 = b.create<arith::AddFOp>(loc, x2, xOffset);

y1 = b.create<arith::AddFOp>(loc, y1, yOffset);
y_1 = b.create<arith::AddFOp>(loc, y_1, yOffset);
y_2 = b.create<arith::AddFOp>(loc, y_2, yOffset);
y2 = b.create<arith::AddFOp>(loc, y2, yOffset);

Value x1Distance = b.create<arith::SubFOp>(loc, proj[1], x1);
x1Distance = b.create<math::AbsFOp>(loc, x1Distance);
Value x_1Distance = b.create<arith::SubFOp>(loc, proj[1], x_1);
x_1Distance = b.create<math::AbsFOp>(loc, x_1Distance);
Value x_2Distance = b.create<arith::SubFOp>(loc, proj[1], x_2);
x_2Distance = b.create<math::AbsFOp>(loc, x_2Distance);
Value x2Distance = b.create<arith::SubFOp>(loc, proj[1], x2);
x2Distance = b.create<math::AbsFOp>(loc, x2Distance);

Value y1Distance = b.create<arith::SubFOp>(loc, proj[0], y1);
y1Distance = b.create<math::AbsFOp>(loc, y1Distance);
Value y_1Distance = b.create<arith::SubFOp>(loc, proj[0], y_1);
y_1Distance = b.create<math::AbsFOp>(loc, y_1Distance);
Value y_2Distance = b.create<arith::SubFOp>(loc, proj[0], y_2);
y_2Distance = b.create<math::AbsFOp>(loc, y_2Distance);
Value y2Distance = b.create<arith::SubFOp>(loc, proj[0], y2);
y2Distance = b.create<math::AbsFOp>(loc, y2Distance);

SmallVector<Value> y{y_2, y_1, y1, y2};
SmallVector<Value> x{x_2, x_1, x1, x2};
SmallVector<Value> yDistance{y_2Distance, y_1Distance, y1Distance,
y2Distance};
SmallVector<Value> xDistance{x_2Distance, x_1Distance, x1Distance,
x2Distance};
SmallVector<Value> xInterp{zero, zero, zero, zero};

// f(x_orig, y_orig) = Sum_y Sum_x W(x_original - x)*input[x,y] * W(y_original
// -y)
Value fxy = zero;
for (int j = 0; j < 4; j++) {
Value wy = WeightFunction(b, loc, yDistance[j]);
Value xInterpy = xInterp[j];
for (int i = 0; i < 4; i++) {
Value wx = WeightFunction(b, loc, xDistance[i]);

Value yInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), y[j]);
Value yIndex = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yInt);
indices[dimOffset] = yIndex;

Value xInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), x[i]);
Value xIndex = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xInt);
indices[dimOffset + 1] = xIndex;

Value p = b.create<tensor::ExtractOp>(loc, input, indices);
Value wxp = b.create<arith::MulFOp>(loc, wx, p);
xInterpy = b.create<arith::AddFOp>(loc, xInterpy, wxp);
}
Value wyXInterpy = b.create<arith::MulFOp>(loc, wy, xInterpy);
fxy = b.create<arith::AddFOp>(loc, fxy, wyXInterpy);
}

return fxy;
}

namespace {
class ConvertInterpolateOp
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
Expand All @@ -2910,7 +3115,8 @@ class ConvertInterpolateOp
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
// op with the non-standard mode="bilinear_asymmetric".
matchPattern(op.getMode(), m_TorchConstantStr(mode));
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" &&
mode.substr(0, 5) != "cubic") {
return failure();
}

Expand Down Expand Up @@ -2999,6 +3205,10 @@ class ConvertInterpolateOp
retVal = BilinearInterpolate(
b, op, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, mode.substr(8));
} else if (mode.substr(0, 5) == "cubic") {
retVal = BicubicInterpolate(
b, op, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, mode.substr(5));
}
b.create<linalg::YieldOp>(loc, retVal);
})
Expand Down

0 comments on commit 71487e6

Please sign in to comment.