diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5b985a80b301..f9b5cada1049 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6833,6 +6833,35 @@ def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [ }]; } +def Torch_Aten__InterpolateSizeListScaleListOp : Torch_Op<"aten.__interpolate.size_list_scale_list", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalListOfTorchIntType:$size, + AnyTorchOptionalListOfTorchFloatType:$scale_factor, + Torch_StringType:$mode, + AnyTorchOptionalBoolType:$align_corners, + AnyTorchOptionalBoolType:$recompute_scale_factor, + Torch_BoolType:$antialias + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__InterpolateSizeListScaleListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void Aten__InterpolateSizeListScaleListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [ AllowsTypeRefinement, HasValueSemantics, @@ -13249,6 +13278,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [ }]; } +def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`"; + let arguments = (ins + AnyTorchListOfTorchStringType:$l, + Torch_StringType:$item + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index e6a9e1622cc1..dbeb2f522b33 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl &bind_values) { return detail::torch_list_of_constant_bools_op_binder(bind_values); } +namespace detail { +/// Matches the constant strs stored in a `torch.ListConstruct`. +struct torch_list_of_constant_strs_op_binder { + SmallVectorImpl &bind_values; + + /// Creates a matcher instance that binds the value to bvs if match succeeds. + torch_list_of_constant_strs_op_binder(SmallVectorImpl &bvs) + : bind_values(bvs) {} + + bool match(Operation *op) { + auto listConstruct = dyn_cast(op); + if (!listConstruct) + return false; + for (Value value : listConstruct.getElements()) { + std::string str; + if (matchPattern(value, m_TorchConstantStr(str))) + bind_values.push_back(str); + else + return false; + } + return true; + } +}; +} // namespace detail + +/// Matches the constant strs stored in a `torch.prim.ListConstruct`. +inline detail::torch_list_of_constant_strs_op_binder +m_TorchListOfConstantStrs(SmallVectorImpl &bind_values) { + return detail::torch_list_of_constant_strs_op_binder(bind_values); +} + namespace detail { /// Matches the expected tensor and dim from `torch.aten.size.int`. struct torch_tensor_size_int_op_binder { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index b5e9162bc2bf..abe2eff05600 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2099,4 +2099,184 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + std::string mode, nearest_mode, coordTfMode; + Value noneVal = rewriter.create(binder.getLoc()); + + if (auto attr = binder.op->getAttr("torch.onnx.antialias")) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for antialias attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.axes")) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for axes attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "exclude_outside attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "extrapolation_value attribute"); + } + if (auto attr = + binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "keep_aspect_ratio_policy attribute"); + } + + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + binder.customOpNameStringAttr(mode, "mode", "nearest") || + binder.customOpNameStringAttr( + coordTfMode, "coordinate_transformation_mode", "half_pixel") || + binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "round_prefer_floor")) + return failure(); + if (coordTfMode == "tf_crop_and_resize") + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: coordinate transformation mode: " + "tf_crop_and_resize"); + + if (mode == "nearest" && coordTfMode != "asymmetric" && coordTfMode != "half_pixel") { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for coord tf mode " + "except asymmetric and half_pixel"); + } + + unsigned rank = dyn_cast(operands[0].getType()) + .getSizes() + .size(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value cstTrue = + rewriter.create(binder.getLoc(), true); + Value modeStrValue; + + auto extract = [&rewriter, &binder](Value x, Value v) { + auto xTy = x.getType().cast(); + Type extractTy = rewriter.getType(); + if (isa(xTy.getDtype())) + extractTy = rewriter.getType(); + + return rewriter.create(binder.getLoc(), extractTy, + v); + }; + + auto getValueList = [&](Value operand) { + SmallVector itemList; + auto sizes = + dyn_cast(operand.getType()).getSizes(); + Torch::BaseTensorType operandType = + operand.getType().cast(); + + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = operandType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), operandType.getOptionalDtype()); + + MLIRContext *context = binder.op->getContext(); + for (int i = 2; i < sizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value ext = rewriter.create( + binder.getLoc(), selectResultType, operand, zero, selectIndex); + Value item = extract(operand, ext); + itemList.push_back(item); + } + auto xTy = operand.getType().cast(); + Value ValueList; + if (isa(xTy.getDtype())) { + ValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(context)), itemList); + } else { + ValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::FloatType::get(context)), itemList); + } + return ValueList; + }; + + Value scalesValueList = noneVal; + Value sizesValueList = noneVal; + Value alignCorners = + coordTfMode == "align_corners" ? cstTrue : cstFalse; + if (mode == "cubic") { + return rewriter.notifyMatchFailure(binder.op, + "unimplemented: bicubic mode"); + } + // supported modes: + // bilinear (half_pixel), bilinear with align_corners, + // bilinear_pytorch_half_pixel, bilinear_asymmetric nearest + // (asymmetric), nearest with align_corners, nearest_half_pixel, + // nearest_pytorch_half_pixel + if (mode == "linear") { + std::string modeStr; + switch (rank) { + case 3: + modeStr = "linear"; + break; + case 4: + modeStr = "bilinear"; + break; + case 5: + modeStr = "trilinear"; + break; + default: + return failure(); + } + // Confusingly enough, the default coordTfMode for pytorch bilinear + // mode is apparently half_pixel, NOT pytorch_half_pixel + if (coordTfMode != "half_pixel" && coordTfMode != "align_corners") + modeStr = (modeStr + "_") + coordTfMode; + modeStrValue = + rewriter.create(binder.getLoc(), modeStr); + } + if (mode == "nearest") { + std::string modeStr = "nearest"; + // The default coordTfMode for pytorch with mode = nearest is + // apparently asymmetric + if (coordTfMode != "asymmetric" && coordTfMode != "align_corners") + modeStr = (modeStr + "_") + coordTfMode; + if (nearest_mode != "floor") + modeStr = modeStr + "," + nearest_mode; + modeStrValue = + rewriter.create(binder.getLoc(), modeStr); + } + if (operands.size() < 4) { + Value scaleOperand = operands[2]; + scalesValueList = getValueList(scaleOperand); + sizesValueList = noneVal; + } else { + Value sizeOperand = operands[3]; + scalesValueList = noneVal; + sizesValueList = getValueList(sizeOperand); + } + if (scalesValueList.getType().isa() && + sizesValueList.getType().isa()) { + return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); + } + rewriter + .replaceOpWithNewOp( + binder.op, resultType, operands[0], sizesValueList, + scalesValueList, modeStrValue, + /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, + /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, + /*Torch_BoolType:$antialias*/ cstFalse); + return success(); + }); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 86bc4578178f..25a4f807f7c8 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2589,6 +2589,340 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { }; } // namespace +static Value NearestInterpolate(OpBuilder &b, Location loc, + SmallVector outputSizes, Value input, + SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr, std::string nearestMode) { + + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + for (unsigned i = 2; i < inputRank; i++) { + Value outIndex = indices[i]; + + Value inputSizeFP = + b.create(loc, b.getF32Type(), inputSizes[i - 2]); + + Value outputSizeFP = + b.create(loc, b.getF32Type(), outputSizes[i - 2]); + + // scale = length_resized / length_original + // x_original = x_resized / scale + Value scale; + if (scaleValues.empty()) + scale = b.create(loc, outputSizeFP, inputSizeFP); + else + scale = scaleValues[i - 2]; + + Value outInt = b.create(loc, b.getI64Type(), outIndex); + Value outFP = b.create(loc, b.getF32Type(), outInt); + Value proj; + if (coordStr.empty() || coordStr == "_asymmetric") { + proj = b.create(loc, outFP, scale); + } else if (coordStr == "_half_pixel"){ + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value add = b.create(loc, outFP, cstHalf); + Value div = b.create(loc, add, scale); + proj = b.create(loc, div, cstHalf); + } else { + llvm_unreachable("Unsupported coordination transformation mode"); + } + + Value nearestFP; + // get nearest pixel using floor + if (nearestMode == "floor" || nearestMode == "") { + nearestFP = b.create(loc, proj); + } else if (nearestMode == "round_prefer_floor") { + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value floor = b.create(loc, proj); + Value ceil = b.create(loc, proj); + Value decimal = b.create(loc, proj, floor); + Value cmp = b.create(loc, arith::CmpFPredicate::ULE, + decimal, cstHalf); + nearestFP = b.create(loc, cmp, floor, ceil); + } else if (nearestMode == "round_prefer_ceil") { + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value floor = b.create(loc, proj); + Value ceil = b.create(loc, proj); + Value decimal = b.create(loc, proj, floor); + Value cmp = b.create(loc, arith::CmpFPredicate::UGE, + decimal, cstHalf); + nearestFP = b.create(loc, cmp, ceil, floor); + } else if (nearestMode == "ceil") { + nearestFP = b.create(loc, proj); + } else { + llvm_unreachable("Unsupported nearest mode"); + } + Value nearestInt = + b.create(loc, b.getI64Type(), nearestFP); + Value nearest = + b.create(loc, b.getIndexType(), nearestInt); + + indices[i] = nearest; + } + Value retVal = b.create(loc, input, indices); + return retVal; +} + +static Value BilinearInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + + Value cstOneEps = + b.create(loc, b.getF32FloatAttr(1.000001)); + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + + bool alignCornersBool; + matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + SmallVector proj, projEps, high, low, highFP, lowFP; + for (unsigned i = 0; i < inputRank - dimOffset; i++) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); + // length_resized + Value outputSizeFP = + b.create(loc, b.getF32Type(), outputSizes[i]); + // scale = length_resized/length_original + Value scale; + if (alignCornersBool) { + // x_original = x_resized * (length_original - 1) / (length_resized - 1) + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + Value outputSizeSubOne = + b.create(loc, outputSizeFP, cstOneFloat); + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, + outputSizeSubOne, zero); + scale = b.create(loc, inputSubOne, outputSizeSubOne); + scale = b.create(loc, cmp, zero, scale); + coordStr = "_align_corners"; + } else if (scaleValues.empty()) + scale = b.create(loc, outputSizeFP, inputFP); + else + scale = scaleValues[i]; + // y_resized + Value outInt = b.create(loc, b.getI64Type(), + indices[i + dimOffset]); + Value outFP = b.create(loc, b.getF32Type(), outInt); + Value preClip; + if (coordStr == "_align_corners") { + preClip = b.create(loc, outFP, scale); + } + if (coordStr == "_asymmetric") { + preClip = b.create(loc, outFP, scale); + } + if (coordStr == "_pytorch_half_pixel" || coordStr == "") { + // half-pixel modes + // y_resized + 0.5 + Value outPlusHalf = b.create(loc, outFP, cstHalf); + // (y_resized + 0.5) / scale + Value outDivScale = b.create(loc, outPlusHalf, scale); + // _ - 0.5 + preClip = b.create(loc, outDivScale, cstHalf); + } + // for pytorch half pixel , special case for length_resized == 1: + if (coordStr == "_pytorch_half_pixel") { + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, + outputSizeFP, cstOneFloat); + preClip = b.create(loc, cmp, zero, preClip); + } + // clip to 0,inf + Value max = b.create(loc, preClip, zero); + // length_original - 1.001 + Value inputSubOneEps = b.create(loc, inputFP, cstOneEps); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + // clip to [0,length_original - 1.001] + projEps.push_back(b.create(loc, max, inputSubOneEps)); + proj.push_back(b.create(loc, max, inputSubOne)); + + lowFP.push_back(b.create(loc, projEps[i])); + Value projPlusOne = b.create(loc, cstOneFloat, projEps[i]); + highFP.push_back(b.create(loc, projPlusOne)); + + Value lowInt = b.create(loc, b.getI64Type(), lowFP[i]); + low.push_back(b.create(loc, b.getIndexType(), lowInt)); + + Value highInt = b.create(loc, b.getI64Type(), highFP[i]); + high.push_back( + b.create(loc, b.getIndexType(), highInt)); + } + + SmallVector cornerValues; + indices[dimOffset] = low[0]; + indices[dimOffset + 1] = low[1]; + Value p00 = b.create(loc, input, indices); + + indices[dimOffset] = low[0]; + indices[dimOffset + 1] = high[1]; + Value p01 = b.create(loc, input, indices); + + indices[dimOffset] = high[0]; + indices[dimOffset + 1] = low[1]; + Value p10 = b.create(loc, input, indices); + + indices[dimOffset] = high[0]; + indices[dimOffset + 1] = high[1]; + Value p11 = b.create(loc, input, indices); + + // Let Aij := area rect((yProj,xProj) <-> (y_i*,x_j*)), + // where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc. + // We interpolate via the weighted average of pij by weights Aij + // the formula is retval = Sum(pij*Aij for i and j in range(2)) + // Note: we do not need to divide by total rect area == 1 + + // lengths : Aij == dyi*dxj + Value dy0 = b.create(loc, highFP[0], proj[0]); + Value dy1 = b.create(loc, proj[0], lowFP[0]); + Value dx0 = b.create(loc, highFP[1], proj[1]); + Value dx1 = b.create(loc, proj[1], lowFP[1]); + + // left = A00*p00 + A01*p01 = dy0(dx0p00 + dx1p01) + Value dx0p00 = b.create(loc, dx0, p00); + Value dx1p01 = b.create(loc, dx1, p01); + Value sum = b.create(loc, dx0p00, dx1p01); + Value left = b.create(loc, dy0, sum); + // right = A10*p10 + A11*p11 = dy1(dx0p10 + dx1p11) + Value dx0p10 = b.create(loc, dx0, p10); + Value dx1p11 = b.create(loc, dx1, p11); + sum = b.create(loc, dx0p10, dx1p11); + Value right = b.create(loc, dy1, sum); + + return b.create(loc, left, right); +} + +namespace { +class ConvertInterpolateOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(Aten__InterpolateSizeListScaleListOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + std::string mode; + // note: to support onnx.Resize, we are passing some extra options through + // the mode attribute. For example, onnx.Resize with mode="linear" and + // coordinate_transformation_mode="asymmetric" will lower to an interpolate + // op with the non-standard mode="bilinear_asymmetric". + matchPattern(op.getMode(), m_TorchConstantStr(mode)); + if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") { + return failure(); + } + + Location loc = op->getLoc(); + Value input = adaptor.getInput(); + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + if (mode.substr(0, 8) == "bilinear" && inputRank != 4) + return rewriter.notifyMatchFailure( + op, + "cannot perform bilinear interpolation when input spatial dims != 2"); + + SmallVector outputSizeIntValues; + SmallVector inputSizes; + SmallVector ScaleFactorFloatValues; + for (unsigned i = 2; i < inputRank; i++) { + Value inputSize = getDimOp(rewriter, loc, input, i); + inputSizes.push_back(rewriter.create( + loc, rewriter.getIntegerType(64), inputSize)); + } + + if (!op.getScaleFactor().getType().isa()) { + bool recompScale; + if (!matchPattern(op.getRecomputeScaleFactor(), + m_TorchConstantBool(&recompScale))) + recompScale = false; + SmallVector ScaleFactorTorchFloat; + if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the output_size is not constructed from " + "ListConstruct"); + ScaleFactorFloatValues = getTypeConvertedValues( + rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat); + for (unsigned i = 0; i < inputRank - 2; i++) { + Value inputSizeFP = rewriter.create( + loc, rewriter.getF32Type(), inputSizes[i]); + ScaleFactorFloatValues[i] = rewriter.create( + loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]); + Value outputSize = rewriter.create( + loc, inputSizeFP, ScaleFactorFloatValues[i]); + outputSize = rewriter.create(loc, outputSize); + outputSize = rewriter.create( + loc, rewriter.getI64Type(), outputSize); + outputSizeIntValues.push_back(outputSize); + } + if (recompScale) + ScaleFactorFloatValues.clear(); + } else { + SmallVector outputSizeTorchInt; + if (!getListConstructElements(op.getSize(), outputSizeTorchInt)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the output_size is not constructed from " + "ListConstruct"); + outputSizeIntValues = getTypeConvertedValues( + rewriter, loc, getTypeConverter(), outputSizeTorchInt); + } + SmallVector dims = getTensorSizesUntilDim(rewriter, loc, input, 1); + for (unsigned i = 2; i < inputRank; i++) { + dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[i - 2])); + } + + Value outTensor = rewriter.create( + loc, getAsOpFoldResult(dims), inputType.getElementType()); + AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank); + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + Value finalRes = + rewriter + .create( + loc, outTensor.getType(), ValueRange{}, outTensor, + /*indexingMaps=*/idMap, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value retVal; + if (mode.substr(0, 7) == "nearest") { + std::string coordTfMode = + mode.substr(7, mode.find(",") - 7); + std::string nearestMode = + (mode.find(",") == std::string::npos) + ? "" + : mode.substr(mode.find(",") + 1); + retVal = NearestInterpolate( + b, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, coordTfMode, nearestMode); + } else if (mode.substr(0, 8) == "bilinear") { + retVal = BilinearInterpolate( + b, op, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, mode.substr(8)); + } + b.create(loc, retVal); + }) + .getResult(0); + Type newResultType = + getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOpWithNewOp(op, newResultType, finalRes); + return success(); + } +}; +} // namespace void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -2644,4 +2978,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index fff872b32198..a70e8368720b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2103,6 +2103,30 @@ OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// Aten__Contains__StrListOp +//===----------------------------------------------------------------------===// + +OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) { + StringAttr item = dyn_cast_or_null(adaptor.getItem()); + if (!item) + return nullptr; + + if (auto listConstruct = getL().getDefiningOp()) { + if (isListPotentiallyMutated(listConstruct)) + return nullptr; + } + llvm::SmallVector strs; + if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) { + for (const auto &str : strs) { + if (item.getValue().str() == str) + return getI1IntegerAttr(getContext(), true); + } + return getI1IntegerAttr(getContext(), false); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenLtIntOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 06d36f58d1c8..65aeb6ddad4f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6608,6 +6608,70 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.__interpolate.size_list_scale_list\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.str, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.bool) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: Either size or scale_factor must be presented\"\n" +" %str_0 = torch.constant.str \"AssertionError: Must specify exactly one of size and scale_factor\"\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.list\n" +" %1 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__isnot__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.list) {\n" +" %7 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" +" %8 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.__getitem__.t %7, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.append.t %3, %9 : !torch.list, !torch.int -> !torch.list\n" +" %11 = torch.aten.__getitem__.t %7, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %3, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %true, %3 : !torch.bool, !torch.list\n" +" } else {\n" +" %7 = torch.aten.__isnot__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.list) {\n" +" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %10 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.__getitem__.t %9, %int0 : !torch.list, !torch.int -> !torch.float\n" +" %12 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.operator \"aten.mul.float_int\"(%11, %12) : (!torch.float, !torch.int) -> !torch.float \n" +" %14 = torch.aten.Int.float %13 : !torch.float -> !torch.int\n" +" %15 = torch.aten.append.t %3, %14 : !torch.list, !torch.int -> !torch.list\n" +" %16 = torch.aten.__getitem__.t %9, %int1 : !torch.list, !torch.int -> !torch.float\n" +" %17 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %18 = torch.operator \"aten.mul.float_int\"(%16, %17) : (!torch.float, !torch.int) -> !torch.float \n" +" %19 = torch.aten.Int.float %18 : !torch.float -> !torch.int\n" +" %20 = torch.aten.append.t %3, %19 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %true, %3 : !torch.bool, !torch.list\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.list\n" +" }\n" +" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.list\n" +" }\n" +" %6 = torch.prim.If %5#0 -> (!torch.list) {\n" +" torch.prim.If.yield %5#1 : !torch.list\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %3 : !torch.list\n" +" }\n" +" return %6 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n" @@ -9938,6 +10002,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.__interpolate.size_list_scale_list\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.str, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 671df14b3d34..9c362df4a928 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -29,6 +29,12 @@ "IscloseStaticModule_basic", "IscloseStaticModuleTrue_basic", "SplitWithSizes_Module_basic", + # lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec + # these interpolate tests are added specifically to test onnx.Resize. + "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateStaticModule_scales_bilinear_align_corners", + "InterpolateDynamicModule_scales_recompute_bilinear", } @@ -336,6 +342,11 @@ # Others "GridSamplerBasic1_basic", "GridSamplerBasic2_basic", + + "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateStaticModule_scales_bilinear_align_corners", + "InterpolateDynamicModule_scales_recompute_bilinear", } if torch_version_for_comparison() <= version.parse("2.2.0"): @@ -1814,6 +1825,9 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", + "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateDynamicModule_scales_recompute_bilinear", "IntFloatModule_basic", "IntImplicitModule_basic", "IouOfModule_basic", @@ -1999,6 +2013,8 @@ "UpSampleNearest2dDynamicFactor_basic", "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2d_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticSize_basic", "VarCorrectionEmptyDimModule_basic", "VarDimEmptyDimModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", @@ -2165,10 +2181,6 @@ "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", - # Failure - onnx_lowering: onnx.Resize - "UpSampleNearest2dDynamicSize_basic", - "UpSampleNearest2dStaticSize_basic", - # Failure - onnx_lowering: onnx.ScatterElements "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 31ce183bb7a0..f21d2d57fcb5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -291,6 +291,26 @@ def aten〇grid_sampler〡shape(input: List[int], grid: List[int], interpolation output = [input[0],input[1],grid[1],grid[2]] return output +def aten〇__interpolate〇size_list_scale_list〡shape(input: List[int], size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> List[int]: + output = [input[0], input[1]] + if size is not None: + assert ( + scale_factor is None + ), "Must specify exactly one of size and scale_factor" + output.append(size[0]) + output.append(size[1]) + return output + elif scale_factor is not None: + assert ( + size is None + ), "Must specify exactly one of size and scale_factor" + output.append(int(scale_factor[0] * input[2])) + output.append(int(scale_factor[1] * input[3])) + return output + assert 0, "Either size or scale_factor must be presented" + return output + + def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]: # Obtained through trial and error on a few examples in PyTorch: assert start < len(a), "start out of bounds" @@ -2217,6 +2237,10 @@ def aten〇grid_sampler〡dtype(input_rank_dtype: Tuple[int, int], grid_rank_dty grid_rank, grid_dtype = input_rank_dtype return input_dtype +def aten〇__interpolate〇size_list_scale_list〡dtype(input_rank_dtype: Tuple[int, int], size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1), ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]), ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index a16279c9df78..7db3ea511164 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -501,6 +501,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::_log_softmax : (Tensor, int, bool) -> (Tensor)" ) + emit( + "aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)" + ) emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)") @@ -757,6 +760,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::format : (...) -> (str)") emit("aten::join : (str, str[]) -> (str)") emit("aten::warn : (str, int) -> ()") + emit("aten::__contains__.str_list : (str[], str) -> (bool)", has_folder=True) # Type conversion ops. emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 73371058cf46..a5dabd018cc5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1099,4 +1099,98 @@ def forward(self, tensor1, tensor2): @register_test_case(module_factory=lambda: EinsumStaticContractRhsModule()) def EinsumStaticContractRhsModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5), tu.rand(4, 5)) \ No newline at end of file + module.forward(tu.rand(3, 4, 5), tu.rand(4, 5)) + + +class InterpolateModule(torch.nn.Module): + def __init__( + self, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + antialias=False, + ): + self.size = size + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + self.antialias = antialias + super().__init__() + + def _forward(self, input): + return torch.nn.functional.interpolate( + input, + size=self.size, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + antialias=self.antialias, + ) + + +class InterpolateStaticModule(InterpolateModule): + @export + @annotate_args( + [ + None, + ([1, 1, 4, 5], torch.float32, True), + ] + ) + def forward(self, input): + return self._forward(input) + + +class InterpolateDynamicModule(InterpolateModule): + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, input): + return self._forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateStaticModule( + scale_factor=0.41, mode="bilinear", align_corners=True + ) +) +def InterpolateStaticModule_scales_bilinear_align_corners(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateDynamicModule(size=(2, 7), mode="nearest") +) +def InterpolateDynamicModule_sizes_nearest(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateDynamicModule(size=(2, 7), mode="bilinear") +) +def InterpolateDynamicModule_sizes_bilinear(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateDynamicModule( + scale_factor=(1.9, 2.4), mode="bilinear", recompute_scale_factor=True + ) +) +def InterpolateDynamicModule_scales_recompute_bilinear(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 508ed55d3337..afc85bccf6de 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1664,3 +1664,36 @@ func.func @test_size(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si return %0 : !torch.vtensor<[],si32> } +// ----- + +// CHECK-LABEL: func.func @test_resize_sizes_nearest + func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_sizes_nearest +func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.coordinate_transformation_mode = "half_pixel", + torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_resize_sizes_linear + func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], +f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> + } diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir new file mode 100644 index 000000000000..4815a4a9211a --- /dev/null +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -0,0 +1,198 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @test_resize_sizes_linear +func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] +,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[generic:.*]] = linalg.generic + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]] + // CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]] + // CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]] + // CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]] + // CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]] + // CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]] + // CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]] + // CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]] + // CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]] + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "bilinear" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_sizes_nearest +func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x14:.*]] = linalg.index 3 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 + // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 + // CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32 + // CHECK: %[[x26:.*]] = arith.index_cast %[[x14]] : index to i64 + // CHECK: %[[x27:.*]] = arith.sitofp %[[x26]] : i64 to f32 + // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x22]] : f32 + // CHECK: %[[x30:.*]] = math.floor %[[x28]] : f32 + // CHECK: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 + // CHECK: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32> + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_nearest_1d +func.func @test_resize_nearest_1d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest,floor" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_resize_nearest_3d +func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[5],si64>) -> !torch.vtensor<[?,?,?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x14:.*]] = linalg.index 3 : index + // CHECK: %[[index4:.*]] = linalg.index 4 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[x34:.*]] = arith.index_cast %[[Wfptosi:.*]] : i64 to index + // CHECK: %[[x35:.*]] = arith.index_cast %[[Dfptosi:.*]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]], %[[x35]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %int4 = torch.constant.int 4 + %4 = torch.aten.select.int %arg1, %int0, %int4 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %5 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %6 = torch.prim.ListConstruct %1, %3, %5: (!torch.int, !torch.int, !torch.int) -> !torch.list + %7 = torch.aten.__interpolate.size_list_scale_list %arg0, %6, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32> + return %7 : !torch.vtensor<[?,?,?,?,?],f32> + } + +// CHECK-LABEL: func.func @test_resize_nearest_half_pixel +func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[cst:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[add:.*]] = arith.addf %[[x24]], %[[cst]] : f32 + // CHECK: %[[x25:.*]] = arith.divf %[[add]], %[[x21]] : f32 + // CHECK: %[[sub:.*]] = arith.subf %[[x25]], %[[cst]] : f32 + // CHECK: %[[cst3:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[floor:.*]] = math.floor %[[sub]] : f32 + // CHECK: %[[ceil:.*]] = math.ceil %[[sub]] : f32 + // CHECK: %[[sub2:.*]] = arith.subf %[[sub]], %[[floor]] : f32 + // CHECK: %[[cmpf:.*]] = arith.cmpf ule, %[[sub2]], %[[cst3]] : f32 + // CHECK: %[[select:.*]] = arith.select %[[cmpf]], %[[floor]], %[[ceil]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[select]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest_half_pixel,round_prefer_floor" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// ----- diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a607365f4918..a1db60e43c40 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -504,8 +504,8 @@ func.func @torch.aten.eq.str$different_value() -> !torch.bool { // CHECK-LABEL: func.func @torch.aten.eq.str$same_operand( // CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { -// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true -// CHECK-NEXT: return %[[F]] : !torch.bool +// CHECK-NEXT: %[[TRUE:.*]] = torch.constant.bool true +// CHECK-NEXT: return %[[TRUE]] : !torch.bool func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool { %0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool return %0 : !torch.bool @@ -539,6 +539,30 @@ func.func @torch.aten.len.str$empty() -> !torch.int { return %2 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$false() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func.func @torch.aten.__contains__.str_list$false() -> !torch.bool { + %str = torch.constant.str "c" + %str_0 = torch.constant.str "b" + %str_1 = torch.constant.str "a" + %1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list + %2 = torch.aten.__contains__.str_list %1, %str : !torch.list, !torch.str -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$true() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func.func @torch.aten.__contains__.str_list$true() -> !torch.bool { + %str = torch.constant.str "aa" + %str_0 = torch.constant.str "aa" + %str_1 = torch.constant.str "ccc" + %1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list + %2 = torch.aten.__contains__.str_list %1, %str : !torch.list, !torch.str -> !torch.bool + return %2 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.__not__ // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool