From 584bad6d4e91bc57ce8b77f548907bee5d63fa66 Mon Sep 17 00:00:00 2001 From: aldesilv <35313749+aldesilv@users.noreply.github.com> Date: Wed, 8 May 2024 14:35:03 -0700 Subject: [PATCH 01/11] OnnxToTorch lowering resize op (#3013) https://github.com/nod-ai/SHARK-Turbine/issues/358 adds a lowering from onnx to linalg for bilinear and nearest resize with support for using scales or sizes to get resize shape. uses coordinate transform half pixel for bilinear mode and asymmetrical for nearest mode. See https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize. Added two passes -- one for bilinear and the other for nearest. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 29 ++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 152 ++++++++ .../TorchToLinalg/Uncategorized.cpp | 337 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 68 ++++ .../build_tools/abstract_interp_lib_gen.py | 24 ++ .../build_tools/torch_ods_gen.py | 6 +- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 20 ++ test/Conversion/TorchToLinalg/resize.mlir | 142 ++++++++ 8 files changed, 775 insertions(+), 3 deletions(-) create mode 100644 test/Conversion/TorchToLinalg/resize.mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5b985a80b301..c7ce2f39eb6d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6984,6 +6984,35 @@ def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [ }]; } +def Torch_Aten__InterpolateSizeListScaleListOp : Torch_Op<"aten.__interpolate.size_list_scale_list", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalListOfTorchIntType:$size, + AnyTorchOptionalListOfTorchFloatType:$scale_factor, + Torch_StringType:$mode, + AnyTorchOptionalBoolType:$align_corners, + AnyTorchOptionalBoolType:$recompute_scale_factor, + Torch_BoolType:$antialias + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__InterpolateSizeListScaleListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void Aten__InterpolateSizeListScaleListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index b5e9162bc2bf..2a55378bc4a9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2099,4 +2099,156 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + std::string mode, nearest_mode, coordTfMode; + Value noneVal = rewriter.create(binder.getLoc()); + + if (auto attr = binder.op->getAttr("torch.onnx.antialias")) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for antialias attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.axes")) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for axes attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "exclude_outside attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "extrapolation_value attribute"); + } + if (auto attr = + binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "keep_aspect_ratio_policy attribute"); + } + + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + binder.customOpNameStringAttr(mode, "mode", "nearest") || + binder.customOpNameStringAttr( + coordTfMode, "coordinate_transformation_mode", "half_pixel") || + binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "")) + return failure(); + + if (mode == "nearest" && nearest_mode != "floor") { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for nearest_mode " + "except floor"); + } + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value cstTrue = + rewriter.create(binder.getLoc(), true); + Value modeStrValue; + + auto extract = [&rewriter, &binder](Value x, Value v) { + auto xTy = x.getType().cast(); + Type extractTy = rewriter.getType(); + if (isa(xTy.getDtype())) + extractTy = rewriter.getType(); + + return rewriter.create(binder.getLoc(), extractTy, + v); + }; + + auto getValueList = [&](Value operand) { + SmallVector itemList; + auto sizes = + dyn_cast(operand.getType()).getSizes(); + Torch::BaseTensorType operandType = + operand.getType().cast(); + + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = operandType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), operandType.getOptionalDtype()); + + MLIRContext *context = binder.op->getContext(); + for (int i = sizes[0] - 2; i < sizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value ext = rewriter.create( + binder.getLoc(), selectResultType, operand, zero, selectIndex); + Value item = extract(operand, ext); + itemList.push_back(item); + } + auto xTy = operand.getType().cast(); + Value ValueList; + if (isa(xTy.getDtype())) { + ValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(context)), itemList); + } else { + ValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::FloatType::get(context)), itemList); + } + return ValueList; + }; + + Value scalesValueList = noneVal; + Value sizesValueList = noneVal; + Value alignCorners = + coordTfMode == "align_corners" ? cstTrue : cstFalse; + + if (mode == "cubic") { + return rewriter.notifyMatchFailure(binder.op, + "unimplemented: bicubic mode"); + } + if (mode == "linear") { + modeStrValue = rewriter.create(binder.getLoc(), + "bilinear"); + if (operands.size() < 4) { + Value scaleOperand = operands[2]; + scalesValueList = getValueList(scaleOperand); + sizesValueList = noneVal; + } else { + Value sizeOperand = operands[3]; + scalesValueList = noneVal; + sizesValueList = getValueList(sizeOperand); + } + } + if (mode == "nearest") { + modeStrValue = + rewriter.create(binder.getLoc(), "nearest"); + if (operands.size() < 4) { + Value scaleOperand = operands[2]; + scalesValueList = getValueList(scaleOperand); + sizesValueList = noneVal; + } else { + Value sizesOperand = operands[3]; + scalesValueList = noneVal; + sizesValueList = getValueList(sizesOperand); + } + } + if (scalesValueList.getType().isa() && + sizesValueList.getType().isa()) { + return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); + } + rewriter + .replaceOpWithNewOp( + binder.op, resultType, operands[0], sizesValueList, + scalesValueList, modeStrValue, + /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, + /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, + /*Torch_BoolType:$antialias*/ cstFalse); + return success(); + }); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 86bc4578178f..dafeafc7bc80 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2589,6 +2589,341 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { }; } // namespace +static Value NearestInterpolate(OpBuilder &b, Location loc, Value outputSizeH, + Value outputSizeW, Value input, + Value inputSizeH, Value inputSizeW) { + + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + + Value yOut = b.create(loc, 2); + Value xOut = b.create(loc, 3); + + Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); + Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); + + Value outputSizeHFP = + b.create(loc, b.getF32Type(), outputSizeH); + Value outputSizeWFP = + b.create(loc, b.getF32Type(), outputSizeW); + + // scale = length_resized / length_original + // x_original = x_resized / scale + Value hScale = b.create(loc, outputSizeHFP, inputHFP); + Value wScale = b.create(loc, outputSizeWFP, inputWFP); + + Value yOutInt = b.create(loc, b.getI64Type(), yOut); + Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); + Value yProj = b.create(loc, yOutFP, hScale); + + Value xOutInt = b.create(loc, b.getI64Type(), xOut); + Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); + Value xProj = b.create(loc, xOutFP, wScale); + + // get nearest pixel using floor + Value yNearestFP = b.create(loc, yProj); + Value xNearestFP = b.create(loc, xProj); + + Value yNearestInt = + b.create(loc, b.getI64Type(), yNearestFP); + Value yNearest = + b.create(loc, b.getIndexType(), yNearestInt); + + Value xNearestInt = + b.create(loc, b.getI64Type(), xNearestFP); + Value xNearest = + b.create(loc, b.getIndexType(), xNearestInt); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + int hDimOffset = 2; + indices[hDimOffset] = yNearest; + indices[hDimOffset + 1] = xNearest; + Value retVal = b.create(loc, input, indices); + return retVal; +} + +static Value BilinearInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, Value outputSizeH, + Value outputSizeW, Value input, + Value inputSizeH, Value inputSizeW) { + int hDimOffset = 2; + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + + Value cstOneEps = b.create(loc, b.getF32FloatAttr(1.001)); + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + + Value yOut = b.create(loc, 2); + Value xOut = b.create(loc, 3); + + bool alignCornersBool; + matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); + + Value yProj, xProj; + if (alignCornersBool) { + // x_original = x_resized * (length_original - 1) / (length_resized - 1) + Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); + Value outputSizeHFP = + b.create(loc, b.getF32Type(), outputSizeH); + Value yOutInt = b.create(loc, b.getI64Type(), yOut); + Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); + Value inputHSubOne = b.create(loc, inputHFP, cstOneFloat); + Value outputSizeHSubOne = + b.create(loc, outputSizeHFP, cstOneFloat); + Value hScale = + b.create(loc, inputHSubOne, outputSizeHSubOne); + Value yProjBeforeClamp = b.create(loc, yOutFP, hScale); + Value yMax = b.create(loc, yProjBeforeClamp, zero); + Value outputSizeHSubOneEps = + b.create(loc, outputSizeHFP, cstOneEps); + yProj = b.create(loc, outputSizeHSubOneEps, yMax); + + Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); + Value outputSizeWFP = + b.create(loc, b.getF32Type(), outputSizeW); + Value xOutInt = b.create(loc, b.getI64Type(), xOut); + Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); + Value inputWSubOne = b.create(loc, inputWFP, cstOneFloat); + Value outputSizeWSubOne = + b.create(loc, outputSizeWFP, cstOneFloat); + Value wScale = + b.create(loc, inputWSubOne, outputSizeWSubOne); + Value xProjBeforeClamp = b.create(loc, xOutFP, wScale); + Value xMax = b.create(loc, xProjBeforeClamp, zero); + Value outputSizeWSubOneEps = + b.create(loc, outputSizeWFP, cstOneEps); + xProj = b.create(loc, outputSizeWSubOneEps, xMax); + } else { + // y_original = (y_resized + 0.5) / scale - 0.5 + Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); + Value outputSizeHFP = + b.create(loc, b.getF32Type(), outputSizeH); + Value hScale = b.create(loc, outputSizeHFP, inputHFP); + Value yOutInt = b.create(loc, b.getI64Type(), yOut); + Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); + Value yPlusHalf = b.create(loc, yOutFP, cstHalf); + Value yDivScale = b.create(loc, yPlusHalf, hScale); + Value ySubHalf = b.create(loc, yDivScale, cstHalf); + Value yMax = b.create(loc, ySubHalf, zero); + Value inputHSubOne = b.create(loc, inputHFP, cstOneEps); + yProj = b.create(loc, yMax, inputHSubOne); + + Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); + Value outputSizeWFP = + b.create(loc, b.getF32Type(), outputSizeW); + Value wScale = b.create(loc, outputSizeWFP, inputWFP); + Value xOutInt = b.create(loc, b.getI64Type(), xOut); + Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); + Value xPlusHalf = b.create(loc, xOutFP, cstHalf); + Value xDivScale = b.create(loc, xPlusHalf, wScale); + Value xSubHalf = b.create(loc, xDivScale, cstHalf); + // clamp + Value xMax = b.create(loc, xSubHalf, zero); + Value inputWSubOne = b.create(loc, inputWFP, cstOneEps); + xProj = b.create(loc, xMax, inputWSubOne); + } + Value yLow = b.create(loc, yProj); + Value yProjPlusOne = b.create(loc, cstOneFloat, yProj); + Value yHigh = b.create(loc, yProjPlusOne); + + Value xLow = b.create(loc, xProj); + Value xProjPlusOne = b.create(loc, cstOneFloat, xProj); + Value xHigh = b.create(loc, xProjPlusOne); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + Value yLowInt = b.create(loc, b.getI64Type(), yLow); + Value yLowIdx = b.create(loc, b.getIndexType(), yLowInt); + + Value xLowInt = b.create(loc, b.getI64Type(), xLow); + Value xLowIdx = b.create(loc, b.getIndexType(), xLowInt); + + Value yHighInt = b.create(loc, b.getI64Type(), yHigh); + Value yHighIdx = + b.create(loc, b.getIndexType(), yHighInt); + + Value xHighInt = b.create(loc, b.getI64Type(), xHigh); + Value xHighIdx = + b.create(loc, b.getIndexType(), xHighInt); + + indices[hDimOffset] = yLowIdx; + indices[hDimOffset + 1] = xLowIdx; + Value p00 = b.create(loc, input, indices); + + indices[hDimOffset] = yLowIdx; + indices[hDimOffset + 1] = xHighIdx; + Value p01 = b.create(loc, input, indices); + + indices[hDimOffset] = yHighIdx; + indices[hDimOffset + 1] = xLowIdx; + Value p10 = b.create(loc, input, indices); + + indices[hDimOffset] = yHighIdx; + indices[hDimOffset + 1] = xHighIdx; + Value p11 = b.create(loc, input, indices); + + // p00 p01 + // p10 p11 + // (xhigh - xproj) / (xhigh - xlow) * p00 + (xproj - xlow) / + // (xhigh - xlow) * p01 + Value xHighMinusxProj = b.create(loc, xHigh, xProj); + Value xHighMinusxLow = b.create(loc, xHigh, xLow); + Value w0 = b.create(loc, xHighMinusxProj, xHighMinusxLow); + Value lhs = b.create(loc, w0, p00); + + Value xProjMinusxLow = b.create(loc, xProj, xLow); + Value w1 = b.create(loc, xProjMinusxLow, xHighMinusxLow); + Value rhs = b.create(loc, w1, p01); + + Value xInter = b.create(loc, lhs, rhs); + + // (xhigh - xproj) / (xhigh - xlow) * p10 + (xproj - xlow) / + // (xhigh - xlow) * p11 + lhs = b.create(loc, w0, p10); + rhs = b.create(loc, w1, p11); + + Value xInter1 = b.create(loc, lhs, rhs); + + // (yhigh - yproj) / (yhigh - ylow) * xInter + (yproj - ylow) + // / (yhigh - ylow) * xInter1 + Value yHighMinusyProj = b.create(loc, yHigh, yProj); + Value yHighMinusyLow = b.create(loc, yHigh, yLow); + w0 = b.create(loc, yHighMinusyProj, yHighMinusyLow); + lhs = b.create(loc, w0, xInter); + + Value yProjMinusyLow = b.create(loc, yProj, yLow); + w1 = b.create(loc, yProjMinusyLow, yHighMinusyLow); + rhs = b.create(loc, w1, xInter1); + + Value retVal = b.create(loc, lhs, rhs); + + return retVal; +} + +namespace { +class ConvertInterpolateOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(Aten__InterpolateSizeListScaleListOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + std::string mode; + matchPattern(op.getMode(), m_TorchConstantStr(mode)); + if (mode != "bilinear" && mode != "nearest") { + return failure(); + } + + Location loc = op->getLoc(); + Value input = adaptor.getInput(); + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + + if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) { + return rewriter.notifyMatchFailure(op, "error: Dynamic dim on resize op"); + } + + SmallVector outputSizeIntValues; + + if (!op.getScaleFactor().getType().isa()) { + SmallVector ScaleFactorTorchFloat; + if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the output_size is not constructed from " + "ListConstruct"); + SmallVector ScaleFactorFloatValues; + ScaleFactorFloatValues = getTypeConvertedValues( + rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat); + Value inputSizeH = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputType.getShape()[2])); + Value inputHFP = rewriter.create( + loc, rewriter.getF32Type(), inputSizeH); + Value scale = rewriter.create(loc, inputHFP.getType(), + ScaleFactorFloatValues[0]); + Value outputSizeH = rewriter.create(loc, inputHFP, scale); + Value outputH = rewriter.create(loc, outputSizeH); + outputH = + rewriter.create(loc, rewriter.getI64Type(), outputH); + + Value inputSizeW = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputType.getShape()[3])); + Value inputWFP = rewriter.create( + loc, rewriter.getF32Type(), inputSizeW); + scale = rewriter.create(loc, inputWFP.getType(), + ScaleFactorFloatValues[1]); + Value outputSizeW = rewriter.create(loc, inputWFP, scale); + Value outputW = rewriter.create(loc, outputSizeW); + outputW = + rewriter.create(loc, rewriter.getI64Type(), outputW); + + outputSizeIntValues.push_back(outputH); + outputSizeIntValues.push_back(outputW); + } else { + SmallVector outputSizeTorchInt; + if (!getListConstructElements(op.getSize(), outputSizeTorchInt)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the output_size is not constructed from " + "ListConstruct"); + outputSizeIntValues = getTypeConvertedValues( + rewriter, loc, getTypeConverter(), outputSizeTorchInt); + } + int hDimOffset = 2; + SmallVector dims = getTensorSizes(rewriter, loc, input); + dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]); + dims[hDimOffset + 1] = + castIntToIndex(rewriter, loc, outputSizeIntValues[1]); + + Value outTensor = rewriter.create( + loc, getAsOpFoldResult(dims), inputType.getElementType()); + + AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank); + + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + + Value finalRes = + rewriter + .create( + loc, outTensor.getType(), ValueRange{}, outTensor, + /*indexingMaps=*/idMap, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value outputSizeH = outputSizeIntValues[0]; + Value outputSizeW = outputSizeIntValues[1]; + Value inputSizeH = b.create( + loc, b.getI64IntegerAttr(inputType.getShape()[2])); + Value inputSizeW = b.create( + loc, b.getI64IntegerAttr(inputType.getShape()[3])); + Value retVal; + if (mode == "nearest") { + retVal = + NearestInterpolate(b, loc, outputSizeH, outputSizeW, + input, inputSizeH, inputSizeW); + } else if (mode == "bilinear") { + retVal = BilinearInterpolate(b, op, loc, outputSizeH, + outputSizeW, input, inputSizeH, + inputSizeW); + } + b.create(loc, retVal); + }) + .getResult(0); + Type newResultType = + getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOpWithNewOp(op, newResultType, finalRes); + return success(); + } +}; +} // namespace void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -2644,4 +2979,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 06d36f58d1c8..65aeb6ddad4f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6608,6 +6608,70 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.__interpolate.size_list_scale_list\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.str, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.bool) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: Either size or scale_factor must be presented\"\n" +" %str_0 = torch.constant.str \"AssertionError: Must specify exactly one of size and scale_factor\"\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.list\n" +" %1 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__isnot__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.list) {\n" +" %7 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" +" %8 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.__getitem__.t %7, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.append.t %3, %9 : !torch.list, !torch.int -> !torch.list\n" +" %11 = torch.aten.__getitem__.t %7, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %3, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %true, %3 : !torch.bool, !torch.list\n" +" } else {\n" +" %7 = torch.aten.__isnot__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.list) {\n" +" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %10 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.__getitem__.t %9, %int0 : !torch.list, !torch.int -> !torch.float\n" +" %12 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.operator \"aten.mul.float_int\"(%11, %12) : (!torch.float, !torch.int) -> !torch.float \n" +" %14 = torch.aten.Int.float %13 : !torch.float -> !torch.int\n" +" %15 = torch.aten.append.t %3, %14 : !torch.list, !torch.int -> !torch.list\n" +" %16 = torch.aten.__getitem__.t %9, %int1 : !torch.list, !torch.int -> !torch.float\n" +" %17 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %18 = torch.operator \"aten.mul.float_int\"(%16, %17) : (!torch.float, !torch.int) -> !torch.float \n" +" %19 = torch.aten.Int.float %18 : !torch.float -> !torch.int\n" +" %20 = torch.aten.append.t %3, %19 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %true, %3 : !torch.bool, !torch.list\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.list\n" +" }\n" +" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.list\n" +" }\n" +" %6 = torch.prim.If %5#0 -> (!torch.list) {\n" +" torch.prim.If.yield %5#1 : !torch.list\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %3 : !torch.list\n" +" }\n" +" return %6 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n" @@ -9938,6 +10002,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.__interpolate.size_list_scale_list\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.str, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 31ce183bb7a0..f21d2d57fcb5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -291,6 +291,26 @@ def aten〇grid_sampler〡shape(input: List[int], grid: List[int], interpolation output = [input[0],input[1],grid[1],grid[2]] return output +def aten〇__interpolate〇size_list_scale_list〡shape(input: List[int], size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> List[int]: + output = [input[0], input[1]] + if size is not None: + assert ( + scale_factor is None + ), "Must specify exactly one of size and scale_factor" + output.append(size[0]) + output.append(size[1]) + return output + elif scale_factor is not None: + assert ( + size is None + ), "Must specify exactly one of size and scale_factor" + output.append(int(scale_factor[0] * input[2])) + output.append(int(scale_factor[1] * input[3])) + return output + assert 0, "Either size or scale_factor must be presented" + return output + + def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]: # Obtained through trial and error on a few examples in PyTorch: assert start < len(a), "start out of bounds" @@ -2217,6 +2237,10 @@ def aten〇grid_sampler〡dtype(input_rank_dtype: Tuple[int, int], grid_rank_dty grid_rank, grid_dtype = input_rank_dtype return input_dtype +def aten〇__interpolate〇size_list_scale_list〡dtype(input_rank_dtype: Tuple[int, int], size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1), ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]), ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index a16279c9df78..6096afcfc195 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -501,9 +501,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::_log_softmax : (Tensor, int, bool) -> (Tensor)" ) - emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") - emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") - emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)") + emit( + "aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)" + ) emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 508ed55d3337..13b25e2b16ca 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1664,3 +1664,23 @@ func.func @test_size(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si return %0 : !torch.vtensor<[],si32> } +// ----- + +// CHECK-LABEL: func.func @test_resize_sizes_nearest + func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_sizes_linear + func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], +f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> + } diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir new file mode 100644 index 000000000000..480454b3f1fc --- /dev/null +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -0,0 +1,142 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @test_resize_sizes_linear +func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] +,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[generic:.*]] = linalg.generic + // CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64 + // CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64 + // CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32 + // CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x14:.*]] = linalg.index 3 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32 + // CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32 + // CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x18]] : i64 to f32 + // CHECK: %[[x20:.*]] = arith.addf %[[x19]], %[[cst_5]] : f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x20]], %[[x17]] : f32 + // CHECK: %[[x22:.*]] = arith.subf %[[x21]], %[[cst_5]] : f32 + // CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32 + // CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32 + // CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32 + // CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32 + // CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 + // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32 + // CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64 + // CHECK: %[[x30:.*]] = arith.sitofp %[[x29]] : i64 to f32 + // CHECK: %[[x31:.*]] = arith.addf %[[x30]], %[[cst_5]] : f32 + // CHECK: %[[x32:.*]] = arith.divf %[[x31]], %[[x28]] : f32 + // CHECK: %[[x33:.*]] = arith.subf %[[x32]], %[[cst_5]] : f32 + // CHECK: %[[x34:.*]] = arith.maximumf %[[x33]], %[[cst_6]] : f32 + // CHECK: %[[x35:.*]] = arith.subf %[[x26]], %[[cst]] : f32 + // CHECK: %[[x36:.*]] = arith.minimumf %[[x34]], %[[x35]] : f32 + // CHECK: %[[x37:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x38:.*]] = arith.addf %[[cst_4]], %[[x25]] : f32 + // CHECK: %[[x39:.*]] = math.floor %[[x38]] : f32 + // CHECK: %[[x40:.*]] = math.floor %[[x36]] : f32 + // CHECK: %[[x41:.*]] = arith.addf %[[cst_4]], %[[x36]] : f32 + // CHECK: %[[x42:.*]] = math.floor %[[x41]] : f32 + // CHECK: %[[x43:.*]] = linalg.index 0 : index + // CHECK: %[[x44:.*]] = linalg.index 1 : index + // CHECK: %[[x45:.*]] = linalg.index 2 : index + // CHECK: %[[x46:.*]] = linalg.index 3 : index + // CHECK: %[[x47:.*]] = arith.fptosi %[[x37]] : f32 to i64 + // CHECK: %[[x48:.*]] = arith.index_cast %[[x47]] : i64 to index + // CHECK: %[[x49:.*]] = arith.fptosi %[[x40]] : f32 to i64 + // CHECK: %[[x50:.*]] = arith.index_cast %[[x49]] : i64 to index + // CHECK: %[[x51:.*]] = arith.fptosi %[[x39]] : f32 to i64 + // CHECK: %[[x52:.*]] = arith.index_cast %[[x51]] : i64 to index + // CHECK: %[[x53:.*]] = arith.fptosi %[[x42]] : f32 to i64 + // CHECK: %[[x54:.*]] = arith.index_cast %[[x53]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x43]], %[[x44]], %[[x48]], %[[x50]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x48]], %[[x54]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x50]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x54]]] : tensor<1x1x2x4xf32> + // CHECK: %[[x55:.*]] = arith.subf %[[x42]], %[[x36]] : f32 + // CHECK: %[[x56:.*]] = arith.subf %[[x42]], %[[x40]] : f32 + // CHECK: %[[x57:.*]] = arith.divf %[[x55]], %[[x56]] : f32 + // CHECK: %[[x58:.*]] = arith.mulf %[[x57]], %extracted : f32 + // CHECK: %[[x59:.*]] = arith.subf %[[x36]], %[[x40]] : f32 + // CHECK: %[[x60:.*]] = arith.divf %[[x59]], %[[x56]] : f32 + // CHECK: %[[x61:.*]] = arith.mulf %[[x60]], %[[extracted_7]] : f32 + // CHECK: %[[x62:.*]] = arith.addf %[[x58]], %[[x61]] : f32 + // CHECK: %[[x63:.*]] = arith.mulf %[[x57]], %[[extracted_8]] : f32 + // CHECK: %[[x64:.*]] = arith.mulf %[[x60]], %[[extracted_9]] : f32 + // CHECK: %[[x65:.*]] = arith.addf %[[x63]], %[[x64]] : f32 + // CHECK: %[[x66:.*]] = arith.subf %[[x39]], %[[x25]] : f32 + // CHECK: %[[x67:.*]] = arith.subf %[[x39]], %[[x37]] : f32 + // CHECK: %[[x68:.*]] = arith.divf %[[x66]], %[[x67]] : f32 + // CHECK: %[[x69:.*]] = arith.mulf %[[x68]], %[[x62]] : f32 + // CHECK: %[[x70:.*]] = arith.subf %[[x25]], %[[x37]] : f32 + // CHECK: %[[x71:.*]] = arith.divf %[[x70]], %[[x67]] : f32 + // CHECK: %[[x72:.*]] = arith.mulf %[[x71]], %[[x65]] : f32 + // CHECK: %[[x73:.*]] = arith.addf %[[x69]], %[[x72]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "bilinear" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64 + // CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64 + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x14:.*]] = linalg.index 3 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32 + // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x26:.*]] = arith.index_cast %[[x14]] : index to i64 + // CHECK: %[[x27:.*]] = arith.sitofp %[[x26]] : i64 to f32 + // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x22]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x30:.*]] = math.floor %[[x28]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 + // CHECK: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index + // CHECK: %[[x35:.*]] = linalg.index 0 : index + // CHECK: %[[x36:.*]] = linalg.index 1 : index + // CHECK: %[[x37:.*]] = linalg.index 2 : index + // CHECK: %[[x38:.*]] = linalg.index 3 : index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x35]], %[[x36]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32> + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> + } From 7ec27a90c2c5635d4b940b53d25ae3fb74ead3fe Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 17 May 2024 14:18:57 -0500 Subject: [PATCH 02/11] [ONNX][TorchToLinalg] Add support for dynamic dims in Interpolate lowering (#3351) Addresses [Shark-Turbine Related tracker [Shark-Turbine Related onnx.Resize issues [Shark-Turbine --- .../TorchToLinalg/Uncategorized.cpp | 26 +++++++------------ projects/pt1/e2e_testing/xfail_sets.py | 4 --- test/Conversion/TorchToLinalg/resize.mlir | 12 +++------ 3 files changed, 13 insertions(+), 29 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index dafeafc7bc80..e73fb1e88dc4 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2829,11 +2829,13 @@ class ConvertInterpolateOp auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); - if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) { - return rewriter.notifyMatchFailure(op, "error: Dynamic dim on resize op"); - } - SmallVector outputSizeIntValues; + Value inputSizeH = getDimOp(rewriter, loc, input, 2); + inputSizeH = rewriter.create( + loc, rewriter.getIntegerType(64), inputSizeH); + Value inputSizeW = getDimOp(rewriter, loc, input, 3); + inputSizeW = rewriter.create( + loc, rewriter.getIntegerType(64), inputSizeW); if (!op.getScaleFactor().getType().isa()) { SmallVector ScaleFactorTorchFloat; @@ -2844,8 +2846,6 @@ class ConvertInterpolateOp SmallVector ScaleFactorFloatValues; ScaleFactorFloatValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat); - Value inputSizeH = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputType.getShape()[2])); Value inputHFP = rewriter.create( loc, rewriter.getF32Type(), inputSizeH); Value scale = rewriter.create(loc, inputHFP.getType(), @@ -2855,8 +2855,6 @@ class ConvertInterpolateOp outputH = rewriter.create(loc, rewriter.getI64Type(), outputH); - Value inputSizeW = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputType.getShape()[3])); Value inputWFP = rewriter.create( loc, rewriter.getF32Type(), inputSizeW); scale = rewriter.create(loc, inputWFP.getType(), @@ -2877,11 +2875,9 @@ class ConvertInterpolateOp outputSizeIntValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), outputSizeTorchInt); } - int hDimOffset = 2; - SmallVector dims = getTensorSizes(rewriter, loc, input); - dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]); - dims[hDimOffset + 1] = - castIntToIndex(rewriter, loc, outputSizeIntValues[1]); + SmallVector dims = getTensorSizesUntilDim(rewriter, loc, input, 1); + dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[0])); + dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[1])); Value outTensor = rewriter.create( loc, getAsOpFoldResult(dims), inputType.getElementType()); @@ -2900,10 +2896,6 @@ class ConvertInterpolateOp [&](OpBuilder &b, Location loc, ValueRange args) { Value outputSizeH = outputSizeIntValues[0]; Value outputSizeW = outputSizeIntValues[1]; - Value inputSizeH = b.create( - loc, b.getI64IntegerAttr(inputType.getShape()[2])); - Value inputSizeW = b.create( - loc, b.getI64IntegerAttr(inputType.getShape()[3])); Value retVal; if (mode == "nearest") { retVal = diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 671df14b3d34..b6160f54c39b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2165,10 +2165,6 @@ "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", - # Failure - onnx_lowering: onnx.Resize - "UpSampleNearest2dDynamicSize_basic", - "UpSampleNearest2dStaticSize_basic", - # Failure - onnx_lowering: onnx.ScatterElements "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 480454b3f1fc..9850a5fdabd6 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -4,15 +4,13 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] ,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[generic:.*]] = linalg.generic - // CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64 - // CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64 // CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32 // CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32 // CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[x13:.*]] = linalg.index 2 : index // CHECK: %[[x14:.*]] = linalg.index 3 : index - // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32 + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 // CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 // CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32 // CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64 @@ -23,7 +21,7 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: // CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32 // CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32 // CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32 - // CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32 + // CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 // CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32 // CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64 @@ -96,12 +94,10 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[GENERIC:.*]] = linalg.generic - // CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64 - // CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64 // CHECK: %[[x13:.*]] = linalg.index 2 : index // CHECK: %[[x14:.*]] = linalg.index 3 : index - // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32 - // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32 + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 From f1e7ed2db32efc5ff56e3708d9f9cb4a98bcae8c Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Mon, 20 May 2024 15:35:27 -0500 Subject: [PATCH 03/11] onnx.Resize and aten._interpolate : allow n spatial dims. (#3368) The old lowering only had logic for 2d (i.e. images). this patch allows interpolation for n spatial dims, which is required for some 3d vision models such as - onnx/models/pytorch-3dunet_vaiq_int8 which successfully compiles and runs with this patch. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 +- .../TorchToLinalg/Uncategorized.cpp | 151 ++++++++---------- test/Conversion/TorchToLinalg/resize.mlir | 94 +++++++++-- 3 files changed, 151 insertions(+), 96 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 2a55378bc4a9..35a3204b7e36 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2180,7 +2180,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::ArrayRef(selectSizes), operandType.getOptionalDtype()); MLIRContext *context = binder.op->getContext(); - for (int i = sizes[0] - 2; i < sizes[0]; i++) { + for (int i = 2; i < sizes[0]; i++) { Value selectIndex = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e73fb1e88dc4..0648508f75bb 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2589,68 +2589,58 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { }; } // namespace -static Value NearestInterpolate(OpBuilder &b, Location loc, Value outputSizeH, - Value outputSizeW, Value input, - Value inputSizeH, Value inputSizeW) { +static Value NearestInterpolate(OpBuilder &b, Location loc, + SmallVector outputSizes, Value input, + SmallVector inputSizes) { auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); - Value yOut = b.create(loc, 2); - Value xOut = b.create(loc, 3); - - Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); - Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } - Value outputSizeHFP = - b.create(loc, b.getF32Type(), outputSizeH); - Value outputSizeWFP = - b.create(loc, b.getF32Type(), outputSizeW); + for (unsigned i = 2; i < inputRank; i++) { + Value outIndex = indices[i]; - // scale = length_resized / length_original - // x_original = x_resized / scale - Value hScale = b.create(loc, outputSizeHFP, inputHFP); - Value wScale = b.create(loc, outputSizeWFP, inputWFP); + Value inputSizeFP = + b.create(loc, b.getF32Type(), inputSizes[i - 2]); - Value yOutInt = b.create(loc, b.getI64Type(), yOut); - Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); - Value yProj = b.create(loc, yOutFP, hScale); + Value outputSizeFP = + b.create(loc, b.getF32Type(), outputSizes[i - 2]); - Value xOutInt = b.create(loc, b.getI64Type(), xOut); - Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); - Value xProj = b.create(loc, xOutFP, wScale); + // scale = length_resized / length_original + // x_original = x_resized / scale + Value scale = b.create(loc, outputSizeFP, inputSizeFP); - // get nearest pixel using floor - Value yNearestFP = b.create(loc, yProj); - Value xNearestFP = b.create(loc, xProj); + Value outInt = b.create(loc, b.getI64Type(), outIndex); + Value outFP = b.create(loc, b.getF32Type(), outInt); + Value proj = b.create(loc, outFP, scale); - Value yNearestInt = - b.create(loc, b.getI64Type(), yNearestFP); - Value yNearest = - b.create(loc, b.getIndexType(), yNearestInt); + // get nearest pixel using floor + Value nearestFP = b.create(loc, proj); - Value xNearestInt = - b.create(loc, b.getI64Type(), xNearestFP); - Value xNearest = - b.create(loc, b.getIndexType(), xNearestInt); + Value nearestInt = + b.create(loc, b.getI64Type(), nearestFP); + Value nearest = + b.create(loc, b.getIndexType(), nearestInt); - SmallVector indices; - for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); + indices[i] = nearest; } - - int hDimOffset = 2; - indices[hDimOffset] = yNearest; - indices[hDimOffset + 1] = xNearest; Value retVal = b.create(loc, input, indices); return retVal; } static Value BilinearInterpolate(OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, - Location loc, Value outputSizeH, - Value outputSizeW, Value input, - Value inputSizeH, Value inputSizeW) { + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes) { + Value inputSizeH = inputSizes[0]; + Value inputSizeW = inputSizes[1]; + Value outputSizeH = outputSizes[0]; + Value outputSizeW = outputSizes[1]; + int hDimOffset = 2; auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); @@ -2805,7 +2795,6 @@ static Value BilinearInterpolate(OpBuilder &b, rhs = b.create(loc, w1, xInter1); Value retVal = b.create(loc, lhs, rhs); - return retVal; } @@ -2828,46 +2817,43 @@ class ConvertInterpolateOp Value input = adaptor.getInput(); auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); + if (mode == "bilinear" && inputRank != 4) + return rewriter.notifyMatchFailure( + op, + "cannot perform bilinear interpolation when input spatial dims != 2"); - SmallVector outputSizeIntValues; - Value inputSizeH = getDimOp(rewriter, loc, input, 2); - inputSizeH = rewriter.create( - loc, rewriter.getIntegerType(64), inputSizeH); - Value inputSizeW = getDimOp(rewriter, loc, input, 3); - inputSizeW = rewriter.create( - loc, rewriter.getIntegerType(64), inputSizeW); + SmallVector outputSizeIntValues; + SmallVector inputSizes; + for (unsigned i = 2; i < inputRank; i++) { + Value inputSize = getDimOp(rewriter, loc, input, 2); + inputSizes.push_back(rewriter.create( + loc, rewriter.getIntegerType(64), inputSize)); + } if (!op.getScaleFactor().getType().isa()) { - SmallVector ScaleFactorTorchFloat; + SmallVector ScaleFactorTorchFloat; if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat)) return rewriter.notifyMatchFailure( op, "unimplemented: the output_size is not constructed from " "ListConstruct"); - SmallVector ScaleFactorFloatValues; + SmallVector ScaleFactorFloatValues; ScaleFactorFloatValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat); - Value inputHFP = rewriter.create( - loc, rewriter.getF32Type(), inputSizeH); - Value scale = rewriter.create(loc, inputHFP.getType(), - ScaleFactorFloatValues[0]); - Value outputSizeH = rewriter.create(loc, inputHFP, scale); - Value outputH = rewriter.create(loc, outputSizeH); - outputH = - rewriter.create(loc, rewriter.getI64Type(), outputH); - - Value inputWFP = rewriter.create( - loc, rewriter.getF32Type(), inputSizeW); - scale = rewriter.create(loc, inputWFP.getType(), - ScaleFactorFloatValues[1]); - Value outputSizeW = rewriter.create(loc, inputWFP, scale); - Value outputW = rewriter.create(loc, outputSizeW); - outputW = - rewriter.create(loc, rewriter.getI64Type(), outputW); - - outputSizeIntValues.push_back(outputH); - outputSizeIntValues.push_back(outputW); + for (unsigned i = 0; i < inputRank - 2; i++) { + Value inputSizeFP = rewriter.create( + loc, rewriter.getF32Type(), inputSizes[i]); + Value scale = rewriter.create( + loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]); + Value outputSize = + rewriter.create(loc, inputSizeFP, scale); + outputSize = rewriter.create(loc, outputSize); + outputSize = rewriter.create( + loc, rewriter.getI64Type(), outputSize); + + outputSizeIntValues.push_back(outputSize); + } } else { - SmallVector outputSizeTorchInt; + SmallVector outputSizeTorchInt; if (!getListConstructElements(op.getSize(), outputSizeTorchInt)) return rewriter.notifyMatchFailure( op, "unimplemented: the output_size is not constructed from " @@ -2876,8 +2862,9 @@ class ConvertInterpolateOp rewriter, loc, getTypeConverter(), outputSizeTorchInt); } SmallVector dims = getTensorSizesUntilDim(rewriter, loc, input, 1); - dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[0])); - dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[1])); + for (unsigned i = 2; i < inputRank; i++) { + dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[i - 2])); + } Value outTensor = rewriter.create( loc, getAsOpFoldResult(dims), inputType.getElementType()); @@ -2894,17 +2881,13 @@ class ConvertInterpolateOp /*indexingMaps=*/idMap, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value outputSizeH = outputSizeIntValues[0]; - Value outputSizeW = outputSizeIntValues[1]; Value retVal; if (mode == "nearest") { - retVal = - NearestInterpolate(b, loc, outputSizeH, outputSizeW, - input, inputSizeH, inputSizeW); + retVal = NearestInterpolate(b, loc, outputSizeIntValues, + input, inputSizes); } else if (mode == "bilinear") { - retVal = BilinearInterpolate(b, op, loc, outputSizeH, - outputSizeW, input, inputSizeH, - inputSizeW); + retVal = BilinearInterpolate( + b, op, loc, outputSizeIntValues, input, inputSizes); } b.create(loc, retVal); }) diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 9850a5fdabd6..1f6b69a50af0 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -94,31 +94,29 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index // CHECK: %[[x13:.*]] = linalg.index 2 : index // CHECK: %[[x14:.*]] = linalg.index 3 : index // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 - // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 - // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 - // CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32 // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 + // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 + // CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32 // CHECK: %[[x26:.*]] = arith.index_cast %[[x14]] : index to i64 // CHECK: %[[x27:.*]] = arith.sitofp %[[x26]] : i64 to f32 // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x22]] : f32 - // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 // CHECK: %[[x30:.*]] = math.floor %[[x28]] : f32 - // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 - // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index // CHECK: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 // CHECK: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index - // CHECK: %[[x35:.*]] = linalg.index 0 : index - // CHECK: %[[x36:.*]] = linalg.index 1 : index - // CHECK: %[[x37:.*]] = linalg.index 2 : index - // CHECK: %[[x38:.*]] = linalg.index 3 : index - // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x35]], %[[x36]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32> // CHECK: linalg.yield %[[extracted]] : f32 %none = torch.constant.none %none_0 = torch.constant.none @@ -136,3 +134,77 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> return %5 : !torch.vtensor<[?,?,?,?],f32> } + +// ----- + +func.func @test_resize_nearest_1d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// ----- + +func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[5],si64>) -> !torch.vtensor<[?,?,?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x14:.*]] = linalg.index 3 : index + // CHECK: %[[index4:.*]] = linalg.index 4 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[x34:.*]] = arith.index_cast %[[Wfptosi:.*]] : i64 to index + // CHECK: %[[x35:.*]] = arith.index_cast %[[Dfptosi:.*]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]], %[[x35]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %int4 = torch.constant.int 4 + %4 = torch.aten.select.int %arg1, %int0, %int4 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %5 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %6 = torch.prim.ListConstruct %1, %3, %5: (!torch.int, !torch.int, !torch.int) -> !torch.list + %7 = torch.aten.__interpolate.size_list_scale_list %arg0, %6, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32> + return %7 : !torch.vtensor<[?,?,?,?,?],f32> + } From 8d6a5ffcbc0cf8ee55564f5566aa8f868c1ff297 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 29 Apr 2024 10:51:17 +0800 Subject: [PATCH 04/11] [Torch] emit aten.__contains__.str_list and add folder (#3249) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++++++++ .../torch-mlir/Dialect/Torch/IR/TorchOps.h | 31 +++++++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 24 ++++++++++++++ .../build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/canonicalize.mlir | 28 +++++++++++++++-- 5 files changed, 107 insertions(+), 2 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c7ce2f39eb6d..ca7a28b156b2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13278,6 +13278,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [ }]; } +def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`"; + let arguments = (ins + AnyTorchListOfTorchStringType:$l, + Torch_StringType:$item + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index e6a9e1622cc1..dbeb2f522b33 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl &bind_values) { return detail::torch_list_of_constant_bools_op_binder(bind_values); } +namespace detail { +/// Matches the constant strs stored in a `torch.ListConstruct`. +struct torch_list_of_constant_strs_op_binder { + SmallVectorImpl &bind_values; + + /// Creates a matcher instance that binds the value to bvs if match succeeds. + torch_list_of_constant_strs_op_binder(SmallVectorImpl &bvs) + : bind_values(bvs) {} + + bool match(Operation *op) { + auto listConstruct = dyn_cast(op); + if (!listConstruct) + return false; + for (Value value : listConstruct.getElements()) { + std::string str; + if (matchPattern(value, m_TorchConstantStr(str))) + bind_values.push_back(str); + else + return false; + } + return true; + } +}; +} // namespace detail + +/// Matches the constant strs stored in a `torch.prim.ListConstruct`. +inline detail::torch_list_of_constant_strs_op_binder +m_TorchListOfConstantStrs(SmallVectorImpl &bind_values) { + return detail::torch_list_of_constant_strs_op_binder(bind_values); +} + namespace detail { /// Matches the expected tensor and dim from `torch.aten.size.int`. struct torch_tensor_size_int_op_binder { diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index fff872b32198..6da620cd61d3 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2103,6 +2103,30 @@ OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// Aten__Contains__StrListOp +//===----------------------------------------------------------------------===// + +OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) { + StringAttr item = dyn_cast(adaptor.getItem()); + if (!item) + return nullptr; + + if (auto listConstruct = getL().getDefiningOp()) { + if (isListPotentiallyMutated(listConstruct)) + return nullptr; + } + llvm::SmallVector strs; + if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) { + for (const auto &str : strs) { + if (item.getValue().str() == str) + return getI1IntegerAttr(getContext(), true); + } + return getI1IntegerAttr(getContext(), false); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenLtIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6096afcfc195..58afa0c4747d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -757,6 +757,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::format : (...) -> (str)") emit("aten::join : (str, str[]) -> (str)") emit("aten::warn : (str, int) -> ()") + emit("aten::__contains__.str_list : (str[], str) -> (bool)", has_folder=True) # Type conversion ops. emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a607365f4918..a1db60e43c40 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -504,8 +504,8 @@ func.func @torch.aten.eq.str$different_value() -> !torch.bool { // CHECK-LABEL: func.func @torch.aten.eq.str$same_operand( // CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { -// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true -// CHECK-NEXT: return %[[F]] : !torch.bool +// CHECK-NEXT: %[[TRUE:.*]] = torch.constant.bool true +// CHECK-NEXT: return %[[TRUE]] : !torch.bool func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool { %0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool return %0 : !torch.bool @@ -539,6 +539,30 @@ func.func @torch.aten.len.str$empty() -> !torch.int { return %2 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$false() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func.func @torch.aten.__contains__.str_list$false() -> !torch.bool { + %str = torch.constant.str "c" + %str_0 = torch.constant.str "b" + %str_1 = torch.constant.str "a" + %1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list + %2 = torch.aten.__contains__.str_list %1, %str : !torch.list, !torch.str -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$true() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func.func @torch.aten.__contains__.str_list$true() -> !torch.bool { + %str = torch.constant.str "aa" + %str_0 = torch.constant.str "aa" + %str_1 = torch.constant.str "ccc" + %1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list + %2 = torch.aten.__contains__.str_list %1, %str : !torch.list, !torch.str -> !torch.bool + return %2 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.__not__ // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool From c90ce0d920a8d18875bc470bdbc1b49fcbb4931c Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Thu, 30 May 2024 19:34:37 -0500 Subject: [PATCH 05/11] Modifies onnx resize lowering to fix numerical issues (#3381) Updates: - some unsupported modes are now going to report a match failure for unsupported coordinate transformation modes. - fixes a bug that was introduced in the last patch for resize (my bad...) - uses actual x and y coordinates for computing weights in bilinear interpolation (rather than eps modified values) - slightly simplifies the bilinear interpolation payload for readability and performance - passes coordinate transformation mode information from an onnx.Resize op to the mode string for the aten._interpolate op. This allows us to perform custom logic in the torch->linalg lowering to support onnx.Resize options without losing the default behaviors of the interpolate op. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 68 ++-- .../TorchToLinalg/Uncategorized.cpp | 298 +++++++++--------- lib/Dialect/Torch/IR/TorchOps.cpp | 2 +- projects/pt1/e2e_testing/xfail_sets.py | 10 + .../test_suite/reshape_like.py | 96 +++++- test/Conversion/TorchToLinalg/resize.mlir | 82 +---- 6 files changed, 307 insertions(+), 249 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 35a3204b7e36..670638711ca9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2140,12 +2140,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( coordTfMode, "coordinate_transformation_mode", "half_pixel") || binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "")) return failure(); - + if (coordTfMode == "tf_crop_and_resize") + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: coordinate transformation mode: " + "tf_crop_and_resize"); if (mode == "nearest" && nearest_mode != "floor") { return rewriter.notifyMatchFailure( binder.op, "unimplemented: support not present for nearest_mode " "except floor"); } + unsigned rank = dyn_cast(operands[0].getType()) + .getSizes() + .size(); Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -2207,36 +2213,54 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value sizesValueList = noneVal; Value alignCorners = coordTfMode == "align_corners" ? cstTrue : cstFalse; - if (mode == "cubic") { return rewriter.notifyMatchFailure(binder.op, "unimplemented: bicubic mode"); } + // supported modes: + // bilinear (half_pixel), bilinear with align_corners, + // bilinear_pytorch_half_pixel, bilinear_asymmetric nearest + // (asymmetric), nearest with align_corners, nearest_half_pixel, + // nearest_pytorch_half_pixel if (mode == "linear") { - modeStrValue = rewriter.create(binder.getLoc(), - "bilinear"); - if (operands.size() < 4) { - Value scaleOperand = operands[2]; - scalesValueList = getValueList(scaleOperand); - sizesValueList = noneVal; - } else { - Value sizeOperand = operands[3]; - scalesValueList = noneVal; - sizesValueList = getValueList(sizeOperand); + std::string modeStr; + switch (rank) { + case 3: + modeStr = "linear"; + break; + case 4: + modeStr = "bilinear"; + break; + case 5: + modeStr = "trilinear"; + break; + default: + return failure(); } + // Confusingly enough, the default coordTfMode for pytorch bilinear + // mode is apparently half_pixel, NOT pytorch_half_pixel + if (coordTfMode != "half_pixel" && coordTfMode != "align_corners") + modeStr = (modeStr + "_") + coordTfMode; + modeStrValue = + rewriter.create(binder.getLoc(), modeStr); } if (mode == "nearest") { + std::string modeStr = "nearest"; + // The default coordTfMode for pytorch with mode = nearest is + // apparently asymmetric + if (coordTfMode != "asymmetric" && coordTfMode != "align_corners") + modeStr = (modeStr + "_") + coordTfMode; modeStrValue = - rewriter.create(binder.getLoc(), "nearest"); - if (operands.size() < 4) { - Value scaleOperand = operands[2]; - scalesValueList = getValueList(scaleOperand); - sizesValueList = noneVal; - } else { - Value sizesOperand = operands[3]; - scalesValueList = noneVal; - sizesValueList = getValueList(sizesOperand); - } + rewriter.create(binder.getLoc(), modeStr); + } + if (operands.size() < 4) { + Value scaleOperand = operands[2]; + scalesValueList = getValueList(scaleOperand); + sizesValueList = noneVal; + } else { + Value sizeOperand = operands[3]; + scalesValueList = noneVal; + sizesValueList = getValueList(sizeOperand); } if (scalesValueList.getType().isa() && sizesValueList.getType().isa()) { diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0648508f75bb..9a4e9c7ffd02 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2591,7 +2591,9 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { static Value NearestInterpolate(OpBuilder &b, Location loc, SmallVector outputSizes, Value input, - SmallVector inputSizes) { + SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); @@ -2612,7 +2614,11 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, // scale = length_resized / length_original // x_original = x_resized / scale - Value scale = b.create(loc, outputSizeFP, inputSizeFP); + Value scale; + if (scaleValues.empty()) + scale = b.create(loc, outputSizeFP, inputSizeFP); + else + scale = scaleValues[i - 2]; Value outInt = b.create(loc, b.getI64Type(), outIndex); Value outFP = b.create(loc, b.getF32Type(), outInt); @@ -2635,167 +2641,139 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, static Value BilinearInterpolate(OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc, SmallVector outputSizes, - Value input, SmallVector inputSizes) { - Value inputSizeH = inputSizes[0]; - Value inputSizeW = inputSizes[1]; - Value outputSizeH = outputSizes[0]; - Value outputSizeW = outputSizes[1]; - - int hDimOffset = 2; + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); - Value cstOneEps = b.create(loc, b.getF32FloatAttr(1.001)); + Value cstOneEps = + b.create(loc, b.getF32FloatAttr(1.000001)); Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); Value zero = b.create(loc, b.getF32FloatAttr(0.0)); - Value yOut = b.create(loc, 2); - Value xOut = b.create(loc, 3); - bool alignCornersBool; matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); - Value yProj, xProj; - if (alignCornersBool) { - // x_original = x_resized * (length_original - 1) / (length_resized - 1) - Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); - Value outputSizeHFP = - b.create(loc, b.getF32Type(), outputSizeH); - Value yOutInt = b.create(loc, b.getI64Type(), yOut); - Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); - Value inputHSubOne = b.create(loc, inputHFP, cstOneFloat); - Value outputSizeHSubOne = - b.create(loc, outputSizeHFP, cstOneFloat); - Value hScale = - b.create(loc, inputHSubOne, outputSizeHSubOne); - Value yProjBeforeClamp = b.create(loc, yOutFP, hScale); - Value yMax = b.create(loc, yProjBeforeClamp, zero); - Value outputSizeHSubOneEps = - b.create(loc, outputSizeHFP, cstOneEps); - yProj = b.create(loc, outputSizeHSubOneEps, yMax); - - Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); - Value outputSizeWFP = - b.create(loc, b.getF32Type(), outputSizeW); - Value xOutInt = b.create(loc, b.getI64Type(), xOut); - Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); - Value inputWSubOne = b.create(loc, inputWFP, cstOneFloat); - Value outputSizeWSubOne = - b.create(loc, outputSizeWFP, cstOneFloat); - Value wScale = - b.create(loc, inputWSubOne, outputSizeWSubOne); - Value xProjBeforeClamp = b.create(loc, xOutFP, wScale); - Value xMax = b.create(loc, xProjBeforeClamp, zero); - Value outputSizeWSubOneEps = - b.create(loc, outputSizeWFP, cstOneEps); - xProj = b.create(loc, outputSizeWSubOneEps, xMax); - } else { - // y_original = (y_resized + 0.5) / scale - 0.5 - Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); - Value outputSizeHFP = - b.create(loc, b.getF32Type(), outputSizeH); - Value hScale = b.create(loc, outputSizeHFP, inputHFP); - Value yOutInt = b.create(loc, b.getI64Type(), yOut); - Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); - Value yPlusHalf = b.create(loc, yOutFP, cstHalf); - Value yDivScale = b.create(loc, yPlusHalf, hScale); - Value ySubHalf = b.create(loc, yDivScale, cstHalf); - Value yMax = b.create(loc, ySubHalf, zero); - Value inputHSubOne = b.create(loc, inputHFP, cstOneEps); - yProj = b.create(loc, yMax, inputHSubOne); - - Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); - Value outputSizeWFP = - b.create(loc, b.getF32Type(), outputSizeW); - Value wScale = b.create(loc, outputSizeWFP, inputWFP); - Value xOutInt = b.create(loc, b.getI64Type(), xOut); - Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); - Value xPlusHalf = b.create(loc, xOutFP, cstHalf); - Value xDivScale = b.create(loc, xPlusHalf, wScale); - Value xSubHalf = b.create(loc, xDivScale, cstHalf); - // clamp - Value xMax = b.create(loc, xSubHalf, zero); - Value inputWSubOne = b.create(loc, inputWFP, cstOneEps); - xProj = b.create(loc, xMax, inputWSubOne); - } - Value yLow = b.create(loc, yProj); - Value yProjPlusOne = b.create(loc, cstOneFloat, yProj); - Value yHigh = b.create(loc, yProjPlusOne); - - Value xLow = b.create(loc, xProj); - Value xProjPlusOne = b.create(loc, cstOneFloat, xProj); - Value xHigh = b.create(loc, xProjPlusOne); - SmallVector indices; for (unsigned i = 0; i < inputRank; i++) { indices.push_back(b.create(loc, i)); } - Value yLowInt = b.create(loc, b.getI64Type(), yLow); - Value yLowIdx = b.create(loc, b.getIndexType(), yLowInt); - - Value xLowInt = b.create(loc, b.getI64Type(), xLow); - Value xLowIdx = b.create(loc, b.getIndexType(), xLowInt); - - Value yHighInt = b.create(loc, b.getI64Type(), yHigh); - Value yHighIdx = - b.create(loc, b.getIndexType(), yHighInt); - Value xHighInt = b.create(loc, b.getI64Type(), xHigh); - Value xHighIdx = - b.create(loc, b.getIndexType(), xHighInt); - - indices[hDimOffset] = yLowIdx; - indices[hDimOffset + 1] = xLowIdx; + SmallVector proj, projEps, high, low, highFP, lowFP; + for (unsigned i = 0; i < inputRank - dimOffset; i++) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); + // length_resized + Value outputSizeFP = + b.create(loc, b.getF32Type(), outputSizes[i]); + // scale = length_resized/length_original + Value scale; + if (alignCornersBool) { + // x_original = x_resized * (length_original - 1) / (length_resized - 1) + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + Value outputSizeSubOne = + b.create(loc, outputSizeFP, cstOneFloat); + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, + outputSizeSubOne, zero); + scale = b.create(loc, inputSubOne, outputSizeSubOne); + scale = b.create(loc, cmp, zero, scale); + coordStr = "_align_corners"; + } else if (scaleValues.empty()) + scale = b.create(loc, outputSizeFP, inputFP); + else + scale = scaleValues[i]; + // y_resized + Value outInt = b.create(loc, b.getI64Type(), + indices[i + dimOffset]); + Value outFP = b.create(loc, b.getF32Type(), outInt); + Value preClip; + if (coordStr == "_align_corners") { + preClip = b.create(loc, outFP, scale); + } + if (coordStr == "_asymmetric") { + preClip = b.create(loc, outFP, scale); + } + if (coordStr == "_pytorch_half_pixel" || coordStr == "") { + // half-pixel modes + // y_resized + 0.5 + Value outPlusHalf = b.create(loc, outFP, cstHalf); + // (y_resized + 0.5) / scale + Value outDivScale = b.create(loc, outPlusHalf, scale); + // _ - 0.5 + preClip = b.create(loc, outDivScale, cstHalf); + } + // for pytorch half pixel , special case for length_resized == 1: + if (coordStr == "_pytorch_half_pixel") { + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, + outputSizeFP, cstOneFloat); + preClip = b.create(loc, cmp, zero, preClip); + } + // clip to 0,inf + Value max = b.create(loc, preClip, zero); + // length_original - 1.001 + Value inputSubOneEps = b.create(loc, inputFP, cstOneEps); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + // clip to [0,length_original - 1.001] + projEps.push_back(b.create(loc, max, inputSubOneEps)); + proj.push_back(b.create(loc, max, inputSubOne)); + + lowFP.push_back(b.create(loc, projEps[i])); + Value projPlusOne = b.create(loc, cstOneFloat, projEps[i]); + highFP.push_back(b.create(loc, projPlusOne)); + + Value lowInt = b.create(loc, b.getI64Type(), lowFP[i]); + low.push_back(b.create(loc, b.getIndexType(), lowInt)); + + Value highInt = b.create(loc, b.getI64Type(), highFP[i]); + high.push_back( + b.create(loc, b.getIndexType(), highInt)); + } + + SmallVector cornerValues; + indices[dimOffset] = low[0]; + indices[dimOffset + 1] = low[1]; Value p00 = b.create(loc, input, indices); - indices[hDimOffset] = yLowIdx; - indices[hDimOffset + 1] = xHighIdx; + indices[dimOffset] = low[0]; + indices[dimOffset + 1] = high[1]; Value p01 = b.create(loc, input, indices); - indices[hDimOffset] = yHighIdx; - indices[hDimOffset + 1] = xLowIdx; + indices[dimOffset] = high[0]; + indices[dimOffset + 1] = low[1]; Value p10 = b.create(loc, input, indices); - indices[hDimOffset] = yHighIdx; - indices[hDimOffset + 1] = xHighIdx; + indices[dimOffset] = high[0]; + indices[dimOffset + 1] = high[1]; Value p11 = b.create(loc, input, indices); - // p00 p01 - // p10 p11 - // (xhigh - xproj) / (xhigh - xlow) * p00 + (xproj - xlow) / - // (xhigh - xlow) * p01 - Value xHighMinusxProj = b.create(loc, xHigh, xProj); - Value xHighMinusxLow = b.create(loc, xHigh, xLow); - Value w0 = b.create(loc, xHighMinusxProj, xHighMinusxLow); - Value lhs = b.create(loc, w0, p00); - - Value xProjMinusxLow = b.create(loc, xProj, xLow); - Value w1 = b.create(loc, xProjMinusxLow, xHighMinusxLow); - Value rhs = b.create(loc, w1, p01); - - Value xInter = b.create(loc, lhs, rhs); - - // (xhigh - xproj) / (xhigh - xlow) * p10 + (xproj - xlow) / - // (xhigh - xlow) * p11 - lhs = b.create(loc, w0, p10); - rhs = b.create(loc, w1, p11); - - Value xInter1 = b.create(loc, lhs, rhs); - - // (yhigh - yproj) / (yhigh - ylow) * xInter + (yproj - ylow) - // / (yhigh - ylow) * xInter1 - Value yHighMinusyProj = b.create(loc, yHigh, yProj); - Value yHighMinusyLow = b.create(loc, yHigh, yLow); - w0 = b.create(loc, yHighMinusyProj, yHighMinusyLow); - lhs = b.create(loc, w0, xInter); - - Value yProjMinusyLow = b.create(loc, yProj, yLow); - w1 = b.create(loc, yProjMinusyLow, yHighMinusyLow); - rhs = b.create(loc, w1, xInter1); - - Value retVal = b.create(loc, lhs, rhs); - return retVal; + // Let Aij := area rect((yProj,xProj) <-> (y_i*,x_j*)), + // where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc. + // We interpolate via the weighted average of pij by weights Aij + // the formula is retval = Sum(pij*Aij for i and j in range(2)) + // Note: we do not need to divide by total rect area == 1 + + // lengths : Aij == dyi*dxj + Value dy0 = b.create(loc, highFP[0], proj[0]); + Value dy1 = b.create(loc, proj[0], lowFP[0]); + Value dx0 = b.create(loc, highFP[1], proj[1]); + Value dx1 = b.create(loc, proj[1], lowFP[1]); + + // left = A00*p00 + A01*p01 = dy0(dx0p00 + dx1p01) + Value dx0p00 = b.create(loc, dx0, p00); + Value dx1p01 = b.create(loc, dx1, p01); + Value sum = b.create(loc, dx0p00, dx1p01); + Value left = b.create(loc, dy0, sum); + // right = A10*p10 + A11*p11 = dy1(dx0p10 + dx1p11) + Value dx0p10 = b.create(loc, dx0, p10); + Value dx1p11 = b.create(loc, dx1, p11); + sum = b.create(loc, dx0p10, dx1p11); + Value right = b.create(loc, dy1, sum); + + return b.create(loc, left, right); } namespace { @@ -2808,8 +2786,12 @@ class ConvertInterpolateOp ConversionPatternRewriter &rewriter) const override { std::string mode; + // note: to support onnx.Resize, we are passing some extra options through + // the mode attribute. For example, onnx.Resize with mode="linear" and + // coordinate_transformation_mode="asymmetric" will lower to an interpolate + // op with the non-standard mode="bilinear_asymmetric". matchPattern(op.getMode(), m_TorchConstantStr(mode)); - if (mode != "bilinear" && mode != "nearest") { + if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") { return failure(); } @@ -2817,41 +2799,46 @@ class ConvertInterpolateOp Value input = adaptor.getInput(); auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); - if (mode == "bilinear" && inputRank != 4) + if (mode.substr(0, 8) == "bilinear" && inputRank != 4) return rewriter.notifyMatchFailure( op, "cannot perform bilinear interpolation when input spatial dims != 2"); SmallVector outputSizeIntValues; SmallVector inputSizes; + SmallVector ScaleFactorFloatValues; for (unsigned i = 2; i < inputRank; i++) { - Value inputSize = getDimOp(rewriter, loc, input, 2); + Value inputSize = getDimOp(rewriter, loc, input, i); inputSizes.push_back(rewriter.create( loc, rewriter.getIntegerType(64), inputSize)); } if (!op.getScaleFactor().getType().isa()) { + bool recompScale; + if (!matchPattern(op.getRecomputeScaleFactor(), + m_TorchConstantBool(&recompScale))) + recompScale = false; SmallVector ScaleFactorTorchFloat; if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat)) return rewriter.notifyMatchFailure( op, "unimplemented: the output_size is not constructed from " "ListConstruct"); - SmallVector ScaleFactorFloatValues; ScaleFactorFloatValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat); for (unsigned i = 0; i < inputRank - 2; i++) { Value inputSizeFP = rewriter.create( loc, rewriter.getF32Type(), inputSizes[i]); - Value scale = rewriter.create( + ScaleFactorFloatValues[i] = rewriter.create( loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]); - Value outputSize = - rewriter.create(loc, inputSizeFP, scale); + Value outputSize = rewriter.create( + loc, inputSizeFP, ScaleFactorFloatValues[i]); outputSize = rewriter.create(loc, outputSize); outputSize = rewriter.create( loc, rewriter.getI64Type(), outputSize); - outputSizeIntValues.push_back(outputSize); } + if (recompScale) + ScaleFactorFloatValues.clear(); } else { SmallVector outputSizeTorchInt; if (!getListConstructElements(op.getSize(), outputSizeTorchInt)) @@ -2868,12 +2855,9 @@ class ConvertInterpolateOp Value outTensor = rewriter.create( loc, getAsOpFoldResult(dims), inputType.getElementType()); - AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank); - SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); - Value finalRes = rewriter .create( @@ -2882,12 +2866,14 @@ class ConvertInterpolateOp /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value retVal; - if (mode == "nearest") { - retVal = NearestInterpolate(b, loc, outputSizeIntValues, - input, inputSizes); - } else if (mode == "bilinear") { + if (mode.substr(0, 7) == "nearest") { + retVal = NearestInterpolate( + b, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, mode.substr(7)); + } else if (mode.substr(0, 8) == "bilinear") { retVal = BilinearInterpolate( - b, op, loc, outputSizeIntValues, input, inputSizes); + b, op, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, mode.substr(8)); } b.create(loc, retVal); }) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 6da620cd61d3..a70e8368720b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2108,7 +2108,7 @@ OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) { - StringAttr item = dyn_cast(adaptor.getItem()); + StringAttr item = dyn_cast_or_null(adaptor.getItem()); if (!item) return nullptr; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index b6160f54c39b..6ec35e9576c4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -29,6 +29,12 @@ "IscloseStaticModule_basic", "IscloseStaticModuleTrue_basic", "SplitWithSizes_Module_basic", + # lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec + # these interpolate tests are added specifically to test onnx.Resize. + "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateStaticModule_scales_bilinear_align_corners", + "InterpolateDynamicModule_scales_recompute_bilinear", } @@ -1814,6 +1820,10 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", + "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateStaticModule_scales_bilinear_align_corners", + "InterpolateDynamicModule_scales_recompute_bilinear", "IntFloatModule_basic", "IntImplicitModule_basic", "IouOfModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 73371058cf46..a5dabd018cc5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1099,4 +1099,98 @@ def forward(self, tensor1, tensor2): @register_test_case(module_factory=lambda: EinsumStaticContractRhsModule()) def EinsumStaticContractRhsModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5), tu.rand(4, 5)) \ No newline at end of file + module.forward(tu.rand(3, 4, 5), tu.rand(4, 5)) + + +class InterpolateModule(torch.nn.Module): + def __init__( + self, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + antialias=False, + ): + self.size = size + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + self.antialias = antialias + super().__init__() + + def _forward(self, input): + return torch.nn.functional.interpolate( + input, + size=self.size, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + antialias=self.antialias, + ) + + +class InterpolateStaticModule(InterpolateModule): + @export + @annotate_args( + [ + None, + ([1, 1, 4, 5], torch.float32, True), + ] + ) + def forward(self, input): + return self._forward(input) + + +class InterpolateDynamicModule(InterpolateModule): + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, input): + return self._forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateStaticModule( + scale_factor=0.41, mode="bilinear", align_corners=True + ) +) +def InterpolateStaticModule_scales_bilinear_align_corners(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateDynamicModule(size=(2, 7), mode="nearest") +) +def InterpolateDynamicModule_sizes_nearest(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateDynamicModule(size=(2, 7), mode="bilinear") +) +def InterpolateDynamicModule_sizes_bilinear(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateDynamicModule( + scale_factor=(1.9, 2.4), mode="bilinear", recompute_scale_factor=True + ) +) +def InterpolateDynamicModule_scales_recompute_bilinear(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 1f6b69a50af0..542f251c6024 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -4,75 +4,19 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] ,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[generic:.*]] = linalg.generic - // CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32 - // CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32 - // CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32 - // CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32 - // CHECK: %[[x13:.*]] = linalg.index 2 : index - // CHECK: %[[x14:.*]] = linalg.index 3 : index - // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 - // CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 - // CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32 - // CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64 - // CHECK: %[[x19:.*]] = arith.sitofp %[[x18]] : i64 to f32 - // CHECK: %[[x20:.*]] = arith.addf %[[x19]], %[[cst_5]] : f32 - // CHECK: %[[x21:.*]] = arith.divf %[[x20]], %[[x17]] : f32 - // CHECK: %[[x22:.*]] = arith.subf %[[x21]], %[[cst_5]] : f32 - // CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32 - // CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32 - // CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32 - // CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 - // CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 - // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32 - // CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64 - // CHECK: %[[x30:.*]] = arith.sitofp %[[x29]] : i64 to f32 - // CHECK: %[[x31:.*]] = arith.addf %[[x30]], %[[cst_5]] : f32 - // CHECK: %[[x32:.*]] = arith.divf %[[x31]], %[[x28]] : f32 - // CHECK: %[[x33:.*]] = arith.subf %[[x32]], %[[cst_5]] : f32 - // CHECK: %[[x34:.*]] = arith.maximumf %[[x33]], %[[cst_6]] : f32 - // CHECK: %[[x35:.*]] = arith.subf %[[x26]], %[[cst]] : f32 - // CHECK: %[[x36:.*]] = arith.minimumf %[[x34]], %[[x35]] : f32 - // CHECK: %[[x37:.*]] = math.floor %[[x25]] : f32 - // CHECK: %[[x38:.*]] = arith.addf %[[cst_4]], %[[x25]] : f32 - // CHECK: %[[x39:.*]] = math.floor %[[x38]] : f32 - // CHECK: %[[x40:.*]] = math.floor %[[x36]] : f32 - // CHECK: %[[x41:.*]] = arith.addf %[[cst_4]], %[[x36]] : f32 - // CHECK: %[[x42:.*]] = math.floor %[[x41]] : f32 - // CHECK: %[[x43:.*]] = linalg.index 0 : index - // CHECK: %[[x44:.*]] = linalg.index 1 : index - // CHECK: %[[x45:.*]] = linalg.index 2 : index - // CHECK: %[[x46:.*]] = linalg.index 3 : index - // CHECK: %[[x47:.*]] = arith.fptosi %[[x37]] : f32 to i64 - // CHECK: %[[x48:.*]] = arith.index_cast %[[x47]] : i64 to index - // CHECK: %[[x49:.*]] = arith.fptosi %[[x40]] : f32 to i64 - // CHECK: %[[x50:.*]] = arith.index_cast %[[x49]] : i64 to index - // CHECK: %[[x51:.*]] = arith.fptosi %[[x39]] : f32 to i64 - // CHECK: %[[x52:.*]] = arith.index_cast %[[x51]] : i64 to index - // CHECK: %[[x53:.*]] = arith.fptosi %[[x42]] : f32 to i64 - // CHECK: %[[x54:.*]] = arith.index_cast %[[x53]] : i64 to index - // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x43]], %[[x44]], %[[x48]], %[[x50]]] : tensor<1x1x2x4xf32> - // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x48]], %[[x54]]] : tensor<1x1x2x4xf32> - // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x50]]] : tensor<1x1x2x4xf32> - // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x54]]] : tensor<1x1x2x4xf32> - // CHECK: %[[x55:.*]] = arith.subf %[[x42]], %[[x36]] : f32 - // CHECK: %[[x56:.*]] = arith.subf %[[x42]], %[[x40]] : f32 - // CHECK: %[[x57:.*]] = arith.divf %[[x55]], %[[x56]] : f32 - // CHECK: %[[x58:.*]] = arith.mulf %[[x57]], %extracted : f32 - // CHECK: %[[x59:.*]] = arith.subf %[[x36]], %[[x40]] : f32 - // CHECK: %[[x60:.*]] = arith.divf %[[x59]], %[[x56]] : f32 - // CHECK: %[[x61:.*]] = arith.mulf %[[x60]], %[[extracted_7]] : f32 - // CHECK: %[[x62:.*]] = arith.addf %[[x58]], %[[x61]] : f32 - // CHECK: %[[x63:.*]] = arith.mulf %[[x57]], %[[extracted_8]] : f32 - // CHECK: %[[x64:.*]] = arith.mulf %[[x60]], %[[extracted_9]] : f32 - // CHECK: %[[x65:.*]] = arith.addf %[[x63]], %[[x64]] : f32 - // CHECK: %[[x66:.*]] = arith.subf %[[x39]], %[[x25]] : f32 - // CHECK: %[[x67:.*]] = arith.subf %[[x39]], %[[x37]] : f32 - // CHECK: %[[x68:.*]] = arith.divf %[[x66]], %[[x67]] : f32 - // CHECK: %[[x69:.*]] = arith.mulf %[[x68]], %[[x62]] : f32 - // CHECK: %[[x70:.*]] = arith.subf %[[x25]], %[[x37]] : f32 - // CHECK: %[[x71:.*]] = arith.divf %[[x70]], %[[x67]] : f32 - // CHECK: %[[x72:.*]] = arith.mulf %[[x71]], %[[x65]] : f32 - // CHECK: %[[x73:.*]] = arith.addf %[[x69]], %[[x72]] : f32 + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]] + // CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]] + // CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]] + // CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]] + // CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]] + // CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]] + // CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]] + // CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]] + // CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]] %none = torch.constant.none %none_0 = torch.constant.none %int0 = torch.constant.int 0 From 23160c77bcf541064862815bebff8b309bece51e Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 5 Jun 2024 19:45:36 +0000 Subject: [PATCH 06/11] add resize nearest mode round_prefer_floor, round_prefer_ceil, ceil --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 10 ++++-- .../TorchToLinalg/Uncategorized.cpp | 34 ++++++++++++++++--- test/Conversion/TorchToLinalg/resize.mlir | 33 ++++++++++-------- 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 670638711ca9..89f6d9c180b3 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2144,11 +2144,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return rewriter.notifyMatchFailure( binder.op, "unimplemented: coordinate transformation mode: " "tf_crop_and_resize"); - if (mode == "nearest" && nearest_mode != "floor") { + + if (mode == "nearest" && coordTfMode != "asymmetric") { return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for nearest_mode " - "except floor"); + binder.op, "unimplemented: support not present for coord tf mode " + "except asymmetric"); } + unsigned rank = dyn_cast(operands[0].getType()) .getSizes() .size(); @@ -2250,6 +2252,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // apparently asymmetric if (coordTfMode != "asymmetric" && coordTfMode != "align_corners") modeStr = (modeStr + "_") + coordTfMode; + if (nearest_mode != "floor" && nearest_mode != "") + modeStr = modeStr + "," + nearest_mode; modeStrValue = rewriter.create(binder.getLoc(), modeStr); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 9a4e9c7ffd02..d6c5d521f871 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2593,7 +2593,7 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, SmallVector outputSizes, Value input, SmallVector inputSizes, SmallVector scaleValues, - std::string coordStr) { + std::string coordStr, std::string nearestMode) { auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); @@ -2624,9 +2624,29 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, Value outFP = b.create(loc, b.getF32Type(), outInt); Value proj = b.create(loc, outFP, scale); + Value nearestFP; // get nearest pixel using floor - Value nearestFP = b.create(loc, proj); - + if (nearestMode == "floor" || nearestMode == "") { + nearestFP = b.create(loc, proj); + } else if (nearestMode == "round_prefer_floor") { + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value floor = b.create(loc, proj); + Value ceil = b.create(loc, proj); + Value decimal = b.create(loc, proj, floor); + Value cmp = b.create(loc, arith::CmpFPredicate::ULE, + decimal, cstHalf); + nearestFP = b.create(loc, cmp, floor, ceil); + } else if (nearestMode == "round_prefer_ceil") { + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value floor = b.create(loc, proj); + Value ceil = b.create(loc, proj); + Value decimal = b.create(loc, proj, floor); + Value cmp = b.create(loc, arith::CmpFPredicate::UGE, + decimal, cstHalf); + nearestFP = b.create(loc, cmp, ceil, floor); + } else if (nearestMode == "ceil") { + nearestFP = b.create(loc, proj); + } Value nearestInt = b.create(loc, b.getI64Type(), nearestFP); Value nearest = @@ -2867,9 +2887,15 @@ class ConvertInterpolateOp [&](OpBuilder &b, Location loc, ValueRange args) { Value retVal; if (mode.substr(0, 7) == "nearest") { + std::string coordTfMode = + mode.substr(7, mode.find(",") - 7); + std::string nearestMode = + (mode.find(",") == std::string::npos) + ? "" + : mode.substr(mode.find(",") + 1); retVal = NearestInterpolate( b, loc, outputSizeIntValues, input, inputSizes, - ScaleFactorFloatValues, mode.substr(7)); + ScaleFactorFloatValues, coordTfMode, nearestMode); } else if (mode.substr(0, 8) == "bilinear") { retVal = BilinearInterpolate( b, op, loc, outputSizeIntValues, input, inputSizes, diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 542f251c6024..a2babe7a09c2 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -3,20 +3,20 @@ // CHECK-LABEL: func.func @test_resize_sizes_linear func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] ,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[generic:.*]] = linalg.generic - // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32> - // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] - // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] - // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] - // CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]] - // CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]] - // CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]] - // CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]] - // CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]] - // CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]] - // CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]] - // CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]] - // CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]] + // CHECK: %[[generic:.*]] = linalg.generic + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]] + // CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]] + // CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]] + // CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]] + // CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]] + // CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]] + // CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]] + // CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]] + // CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]] %none = torch.constant.none %none_0 = torch.constant.none %int0 = torch.constant.int 0 @@ -36,6 +36,7 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: // ----- +// CHECK-LABEL: func.func @test_resize_sizes_nearest func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[GENERIC:.*]] = linalg.generic // CHECK: %[[x11:.*]] = linalg.index 0 : index @@ -81,6 +82,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 // ----- +// CHECK-LABEL: func.func @test_resize_nearest_1d func.func @test_resize_nearest_1d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[GENERIC:.*]] = linalg.generic // CHECK: %[[x11:.*]] = linalg.index 0 : index @@ -102,7 +104,7 @@ func.func @test_resize_nearest_1d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !to %int0 = torch.constant.int 0 %false = torch.constant.bool false %true = torch.constant.bool true - %str = torch.constant.str "nearest" + %str = torch.constant.str "nearest,floor" %int2 = torch.constant.int 2 %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int @@ -113,6 +115,7 @@ func.func @test_resize_nearest_1d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !to // ----- +// CHECK-LABEL: func.func @test_resize_nearest_3d func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[5],si64>) -> !torch.vtensor<[?,?,?,?,?],f32> { // CHECK: %[[GENERIC:.*]] = linalg.generic // CHECK: %[[x11:.*]] = linalg.index 0 : index From 070e9cdf29fbb9fa329ebf230e3fb39b81c61e24 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 6 Jun 2024 18:11:37 +0200 Subject: [PATCH 07/11] fixup xfail --- projects/pt1/e2e_testing/xfail_sets.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6ec35e9576c4..1942039f7757 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1822,7 +1822,6 @@ "IndexPutImplIndexWithNoneModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", - "InterpolateStaticModule_scales_bilinear_align_corners", "InterpolateDynamicModule_scales_recompute_bilinear", "IntFloatModule_basic", "IntImplicitModule_basic", @@ -2009,6 +2008,8 @@ "UpSampleNearest2dDynamicFactor_basic", "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2d_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticSize_basic", "VarCorrectionEmptyDimModule_basic", "VarDimEmptyDimModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", From bbae91b2629e1342432e02683ba7080c77355ebf Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 7 Jun 2024 08:27:40 +0200 Subject: [PATCH 08/11] onnx.Resize: Default nearest_mode is round_prefer_floor --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 89f6d9c180b3..92f0da13e064 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2138,7 +2138,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.customOpNameStringAttr(mode, "mode", "nearest") || binder.customOpNameStringAttr( coordTfMode, "coordinate_transformation_mode", "half_pixel") || - binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "")) + binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "round_prefer_floor")) return failure(); if (coordTfMode == "tf_crop_and_resize") return rewriter.notifyMatchFailure( @@ -2252,7 +2252,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // apparently asymmetric if (coordTfMode != "asymmetric" && coordTfMode != "align_corners") modeStr = (modeStr + "_") + coordTfMode; - if (nearest_mode != "floor" && nearest_mode != "") + if (nearest_mode != "floor") modeStr = modeStr + "," + nearest_mode; modeStrValue = rewriter.create(binder.getLoc(), modeStr); From 96addd13ce8bad2bfb286aace0eb116265a62095 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 7 Jun 2024 08:35:15 +0200 Subject: [PATCH 09/11] onnx.resize: Add support for coordTfMode half_pixel --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 +- .../TorchToLinalg/Uncategorized.cpp | 14 ++++++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 13 ++++++ test/Conversion/TorchToLinalg/resize.mlir | 41 +++++++++++++++++++ 4 files changed, 69 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 92f0da13e064..abe2eff05600 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2145,10 +2145,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented: coordinate transformation mode: " "tf_crop_and_resize"); - if (mode == "nearest" && coordTfMode != "asymmetric") { + if (mode == "nearest" && coordTfMode != "asymmetric" && coordTfMode != "half_pixel") { return rewriter.notifyMatchFailure( binder.op, "unimplemented: support not present for coord tf mode " - "except asymmetric"); + "except asymmetric and half_pixel"); } unsigned rank = dyn_cast(operands[0].getType()) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index d6c5d521f871..25a4f807f7c8 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2622,7 +2622,17 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, Value outInt = b.create(loc, b.getI64Type(), outIndex); Value outFP = b.create(loc, b.getF32Type(), outInt); - Value proj = b.create(loc, outFP, scale); + Value proj; + if (coordStr.empty() || coordStr == "_asymmetric") { + proj = b.create(loc, outFP, scale); + } else if (coordStr == "_half_pixel"){ + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value add = b.create(loc, outFP, cstHalf); + Value div = b.create(loc, add, scale); + proj = b.create(loc, div, cstHalf); + } else { + llvm_unreachable("Unsupported coordination transformation mode"); + } Value nearestFP; // get nearest pixel using floor @@ -2646,6 +2656,8 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, nearestFP = b.create(loc, cmp, ceil, floor); } else if (nearestMode == "ceil") { nearestFP = b.create(loc, proj); + } else { + llvm_unreachable("Unsupported nearest mode"); } Value nearestInt = b.create(loc, b.getI64Type(), nearestFP); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 13b25e2b16ca..afc85bccf6de 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1676,6 +1676,19 @@ func.func @test_size(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si // ----- +// CHECK-LABEL: func.func @test_resize_sizes_nearest +func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.coordinate_transformation_mode = "half_pixel", + torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_resize_sizes_linear func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index a2babe7a09c2..4815a4a9211a 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -155,3 +155,44 @@ func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: %7 = torch.aten.__interpolate.size_list_scale_list %arg0, %6, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32> return %7 : !torch.vtensor<[?,?,?,?,?],f32> } + +// CHECK-LABEL: func.func @test_resize_nearest_half_pixel +func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[cst:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[add:.*]] = arith.addf %[[x24]], %[[cst]] : f32 + // CHECK: %[[x25:.*]] = arith.divf %[[add]], %[[x21]] : f32 + // CHECK: %[[sub:.*]] = arith.subf %[[x25]], %[[cst]] : f32 + // CHECK: %[[cst3:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[floor:.*]] = math.floor %[[sub]] : f32 + // CHECK: %[[ceil:.*]] = math.ceil %[[sub]] : f32 + // CHECK: %[[sub2:.*]] = arith.subf %[[sub]], %[[floor]] : f32 + // CHECK: %[[cmpf:.*]] = arith.cmpf ule, %[[sub2]], %[[cst3]] : f32 + // CHECK: %[[select:.*]] = arith.select %[[cmpf]], %[[floor]], %[[ceil]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[select]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest_half_pixel,round_prefer_floor" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// ----- From 457908a5799073257df8e8d2d98ad231d4a22ab3 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 7 Jun 2024 16:07:39 +0200 Subject: [PATCH 10/11] Update xfail --- projects/pt1/e2e_testing/xfail_sets.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1942039f7757..9c362df4a928 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -342,6 +342,11 @@ # Others "GridSamplerBasic1_basic", "GridSamplerBasic2_basic", + + "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateStaticModule_scales_bilinear_align_corners", + "InterpolateDynamicModule_scales_recompute_bilinear", } if torch_version_for_comparison() <= version.parse("2.2.0"): From 3eab72478dabf8abf3c3e0414b972c40c664debf Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 7 Jun 2024 16:48:10 +0200 Subject: [PATCH 11/11] Update GeneratedTorchOps.td --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 58 +++++++++---------- .../build_tools/torch_ods_gen.py | 3 + 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ca7a28b156b2..f9b5cada1049 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6833,6 +6833,35 @@ def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [ }]; } +def Torch_Aten__InterpolateSizeListScaleListOp : Torch_Op<"aten.__interpolate.size_list_scale_list", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalListOfTorchIntType:$size, + AnyTorchOptionalListOfTorchFloatType:$scale_factor, + Torch_StringType:$mode, + AnyTorchOptionalBoolType:$align_corners, + AnyTorchOptionalBoolType:$recompute_scale_factor, + Torch_BoolType:$antialias + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__InterpolateSizeListScaleListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void Aten__InterpolateSizeListScaleListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [ AllowsTypeRefinement, HasValueSemantics, @@ -6984,35 +7013,6 @@ def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [ }]; } -def Torch_Aten__InterpolateSizeListScaleListOp : Torch_Op<"aten.__interpolate.size_list_scale_list", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchOptionalListOfTorchIntType:$size, - AnyTorchOptionalListOfTorchFloatType:$scale_factor, - Torch_StringType:$mode, - AnyTorchOptionalBoolType:$align_corners, - AnyTorchOptionalBoolType:$recompute_scale_factor, - Torch_BoolType:$antialias - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult Aten__InterpolateSizeListScaleListOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); - } - void Aten__InterpolateSizeListScaleListOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); - } - }]; -} - def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 58afa0c4747d..7db3ea511164 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -504,6 +504,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)" ) + emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") + emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") + emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")