diff --git a/docs/add_ops.md b/docs/add_ops.md index 37dee90817db..b8e5ce37ec45 100644 --- a/docs/add_ops.md +++ b/docs/add_ops.md @@ -2,7 +2,6 @@ Collected links and contacts for how to add ops to torch-mlir. -
Turbine Camp: Start Here This document was previously known as `turbine-camp.md` to Nod.ai. "Turbine Camp" is part of Nod.ai's onboarding process. Welcome to turbine camp. This document originated at Nod.ai as a part of onboardding process, where new nod-ai folks learn about the architecture of our work by adding support for 2 ops to torch-mlir. I decided to put this into torch mlir because a lot of this is about torch-mlir. @@ -27,6 +26,7 @@ The details of how we do it and helpful commands to help you set up each repo is PS: IREE is pronounced Eerie, and hence the ghost icon. ## How to begin +0. Set up torch-mlir according to the instructions here: https://github.com/llvm/torch-mlir/blob/main/docs/development.md 1. You will start by adding support for 2 ops in torch-mlir, to get you familiar with the center of our pipeline. Begin by reading [torch-mlir's documentation on how to implement a new torch op](https://github.com/llvm/torch-mlir/blob/main/docs/Torch-ops-E2E-implementation.md), and set up `llvm/torch_mlir` using https://github.com/llvm/torch-mlir/blob/main/docs/development.md 2. Pick 1 of the yet-unimplemented from the following. You should choose something that looks easy to you. **Make sure you create an issue by clicking the little "target" icon to the right of the op, thereby marking the op as yours** - [TorchToLinalg ops tracking issue](https://github.com/nod-ai/SHARK-Turbine/issues/347) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8f949d9ba195..c38d0dbbd389 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4223,6 +4223,52 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [ }]; } +def Torch_AtenTruncOp : Torch_Op<"aten.trunc", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::trunc : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTruncOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTruncOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::trunc_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTrunc_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTrunc_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSignOp : Torch_Op<"aten.sign", [ AllowsTypeRefinement, HasValueSemantics, @@ -10713,6 +10759,30 @@ def Torch_AtenProdDimIntOp : Torch_Op<"aten.prod.dim_int", [ }]; } +def Torch_AtenProdOp : Torch_Op<"aten.prod", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::prod : (Tensor, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenProdOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenProdOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenMaxOp : Torch_Op<"aten.max", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 898c768ae1c2..e7fc4bc976bb 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -129,6 +129,7 @@ class AnyTorchTensorType | torch.bool | i1 | | torch.qint8 | !torch.qint8 | | torch.quint8 | !torch.quint8 | + | torch.qint32 | !torch.qint32 | | torch.complex64 | complex | | torch.complex128 | complex | |-------------------|--------------------| diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index b16e76e3afe5..14aa41bef349 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -101,17 +101,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) return failure(); - if (selectLastIndex) { - // TODO: Figure out how to support this case. Need to add a reverse - // or something. - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: select_last_index=true"); - } - // ONNX allows negative axis. + auto operandSizes = + cast(operand.getType()).getSizes(); if (axis < 0) - axis += - cast(operand.getType()).getSizes().size(); + axis += operandSizes.size(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -119,6 +113,26 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value constKeepDims = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(keepDims)); + + if (selectLastIndex) { + Value dims = createConstantIntList(binder, rewriter, {axis}); + auto operandTy = dyn_cast(operand.getType()); + operand = rewriter.create( + binder.getLoc(), operandTy, operand, dims); + Value argmax = rewriter.create( + binder.getLoc(), resultType, operand, constAxis, constKeepDims); + Value offset = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(operandSizes[axis] - 1)); + Value alpha = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value sub = rewriter.create( + binder.getLoc(), resultType, argmax, offset, alpha); + rewriter.replaceOpWithNewOp(binder.op, resultType, + sub); + return success(); + } + rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAxis, constKeepDims); return success(); @@ -137,17 +151,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) return failure(); - if (selectLastIndex) { - // TODO: Figure out how to support this case. Need to add a reverse - // or something. - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: select_last_index=true"); - } - // ONNX allows negative axis. + auto operandSizes = + cast(operand.getType()).getSizes(); if (axis < 0) - axis += - cast(operand.getType()).getSizes().size(); + axis += operandSizes.size(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -155,6 +163,26 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value constKeepDims = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(keepDims)); + + if (selectLastIndex) { + Value dims = createConstantIntList(binder, rewriter, {axis}); + auto operandTy = dyn_cast(operand.getType()); + operand = rewriter.create( + binder.getLoc(), operandTy, operand, dims); + Value argmin = rewriter.create( + binder.getLoc(), resultType, operand, constAxis, constKeepDims); + Value offset = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(operandSizes[axis] - 1)); + Value alpha = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value sub = rewriter.create( + binder.getLoc(), resultType, argmin, offset, alpha); + rewriter.replaceOpWithNewOp(binder.op, resultType, + sub); + return success(); + } + rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAxis, constKeepDims); return success(); @@ -981,6 +1009,157 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( + "ConvInteger", 10, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + + Torch::ValueTensorType resultType; + Value input, weight, inputZp, weightZp; + int64_t group; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(weight, 1) || + binder.s64IntegerAttr(group, "group", 1) || + binder.tensorResultType(resultType)) + return failure(); + + auto inputTy = dyn_cast(input.getType()); + auto weightTy = dyn_cast(weight.getType()); + if (!weightTy || !weightTy.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, "Expected weight type having sizes"); + ArrayRef weightShape = weightTy.getSizes(); + SmallVector kernelShape; + if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) + return failure(); + if (kernelShape.size()) { + if (kernelShape.size() != weightShape.size() - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: kernel_shape list size should have " + "number of values equal to weight_rank - 2"); + } else { + for (unsigned i = 0; i < kernelShape.size(); i++) { + if (weightShape[i + 2] != kernelShape[i]) + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: kernel_shape value " + "should be equal to the weight tensor shape"); + } + } + } + + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(input); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector padding, strides, dilations; + SmallVector defaultPadding(rank - 2, 0), + defaultStrides(rank - 2, 1), defaultDilations(rank - 2, 1); + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) + return failure(); + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations)) + return failure(); + if (dilations.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, + "dilations list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) + return failure(); + if (strides.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + + Value scale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(1.0)); + if (binder.tensorOperandAtIndex(inputZp, 2)) { + inputZp = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + } else { + inputZp = rewriter.create( + binder.getLoc(), rewriter.getType(), inputZp); + } + if (binder.tensorOperandAtIndex(weightZp, 3)) + weightZp = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + // TODO: support per channel quantization if weightZp is a 1-D tensor + if (auto zpTy = dyn_cast(weightZp.getType())) { + for (auto dim : zpTy.getSizes()) + if (dim != 1) + return failure(); + weightZp = rewriter.create( + binder.getLoc(), rewriter.getType(), weightZp); + } + + SmallVector cstPadding; + if (padding.size() != 2 * (rank - 2)) { + for (int64_t i : padding) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else { + for (unsigned i = 0; i < padding.size() / 2; i++) { + if (padding[i] != padding[i + (padding.size() / 2)]) + // TODO: Add support for different padding values for the + // beginning and ending along each spatial axis + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: padding values for the beginning " + "and ending along each spatial axis must be equal"); + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + } + + Value paddingList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + cstPadding); + Value dilationsList = + createConstantIntList(binder, rewriter, dilations); + Value stridesList = createConstantIntList(binder, rewriter, strides); + Value outputPaddingList = + createConstantIntList(binder, rewriter, {0, 0}); + Value transposed = + rewriter.create(binder.getLoc(), false); + Value bias = rewriter.create(binder.getLoc()); + Value cstGroup = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(group)); + + Type inputQTy = getQTorchTypeFromTorchIntType(inputTy); + Type weightQTy = getQTorchTypeFromTorchIntType(weightTy); + input = rewriter.create( + binder.getLoc(), inputQTy, input, scale, inputZp); + weight = rewriter.create( + binder.getLoc(), weightQTy, weight, scale, weightZp); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, stridesList, + paddingList, dilationsList, transposed, outputPaddingList, + cstGroup); + return success(); + }); + patterns.onOp( "ConvTranspose", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c7d071079119..ad6c91e405c6 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1319,9 +1319,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( "expect 1-d pad tensor"); int64_t padsSize = padsShape[0]; - if (padsSize == Torch::kUnknownSize) - return rewriter.notifyMatchFailure(binder.op, - "pad length is unknown"); + if (padsSize == Torch::kUnknownSize) { + // As per onnx.Pad documentation, padSize = 2*num_data_axes + // (if axes param not passed). Need to be updated when adding + // support for `axes` param. + auto dataOpTy = data.getType().cast(); + TensorType dataTensor = dataOpTy.toBuiltinTensor(); + if (!dataTensor || !dataTensor.hasRank()) + return rewriter.notifyMatchFailure( + binder.op, "pad length unknown and data operand unranked"); + int64_t dataRank = dataTensor.getRank(); + padsSize = 2 * dataRank; + } Value constantValue; if (binder.getNumOperands() >= 3) { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index ba8f71a6d234..f50bae62fb53 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -891,6 +891,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*storeValue=*/operand, keepDims, noop_with_empty_axes, false); }); + patterns.onOp( + "ReduceL2", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(operand, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + // A ReduceL2 op is equivalent to the following sequence of operations: + // Mul(x, x) -> ReduceSum -> CastF32 -> Sqrt -> CastLike(resultType) + Value squareOfOperand = rewriter.create( + binder.getLoc(), operand.getType(), operand, operand); + + auto reducedSum = + reducedSumImpl(binder, rewriter, squareOfOperand, resultType, + operand, keepDims, noop_with_empty_axes, true); + if (failed(reducedSum)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); + + Value castDType = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 6)); + + Value noneVal = rewriter.create(binder.getLoc()); + Value constFalse = + rewriter.create(binder.getLoc(), false); + + // Perform an AtenToDtype op on the squared sum of the operand, stored + // now in operand itself. + auto size = operand.getType() + .dyn_cast() + .getOptionalSizes(); + auto f32ResultType = rewriter.getType( + size, rewriter.getF32Type()); + Value operandCast = rewriter.create( + binder.getLoc(), f32ResultType, operand, castDType, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + + Value operandSqrt = rewriter.create( + binder.getLoc(), f32ResultType, operandCast); + + Value resultDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), resultType.getDtype()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operandSqrt, resultDtype, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + return success(); + }); patterns.onOp("ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 3b18844df516..6519a272330e 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -129,6 +129,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { Value generator = adaptor.getGenerator(); RankedTensorType resultType = self.getType().cast(); Type elemTy = resultType.getElementType(); + Type f64Ty = rewriter.getF64Type(); if (!isa(elemTy)) return rewriter.notifyMatchFailure(op, "This op only support float type"); @@ -139,8 +140,8 @@ class ConvertAtenUniformOp : public OpConversionPattern { "generator is supported"); // Get key, min and max used by `linalg.generic` compute payload. Value key = rewriter.create(loc); - Value min = convertScalarToDtype(rewriter, loc, from, elemTy); - Value max = convertScalarToDtype(rewriter, loc, to, elemTy); + Value min = convertScalarToDtype(rewriter, loc, from, f64Ty); + Value max = convertScalarToDtype(rewriter, loc, to, f64Ty); // Construct the `linalg.generic` op. auto resultRank = resultType.getRank(); @@ -179,11 +180,14 @@ class ConvertAtenUniformOp : public OpConversionPattern { // res = cast(F64, tempN) * scale + min Value updateFloat = - b.create(loc, elemTy, randomVal); + b.create(loc, f64Ty, randomVal); Value updateScaled = b.create(loc, updateFloat, scale); Value res = b.create(loc, updateScaled, min); - b.create(loc, res); + Value truncRes = res; + if (elemTy.isa()) + truncRes = b.create(loc, elemTy, res); + b.create(loc, truncRes); }) .getResult(0); diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index bd8b1fc6bfb1..a5238c9b1211 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -298,7 +298,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, if (isa(op)) return b.create(loc, b.getZeroAttr(elementType)); - if (isa(op)) { + if (isa(op)) { if (isa(elementType)) return b.create(loc, b.getFloatAttr(elementType, 1.0)); else if (isa(elementType)) @@ -341,10 +341,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, isa(op)) return b.create(loc, b.getZeroAttr(elementType)); - if (isa(op)) { + if (isa(op)) { return b.create(loc, b.getBoolAttr(true)); } + if (isa(op)) { + return b.create(loc, b.getBoolAttr(false)); + } + op->emitError("unimplemented lowering in createInitElementForReduceOp"); return nullptr; } @@ -362,7 +366,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, return b.create(loc, self, result); else if (isa(resultElementType)) return b.create(loc, self, result); - } else if (isa(op)) { + } else if (isa(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; @@ -439,11 +443,16 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); - } else if (isa(op)) { + } else if (isa(op)) { + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + return b.create(loc, self, result); + } else if (isa(op)) { Value elem = payloadArgs[0]; Value result = payloadArgs[1]; Value self = convertScalarToDtype(b, loc, elem, resultElementType); - return b.create(loc, self, result); + return b.create(loc, self, result); } op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp"); return nullptr; @@ -510,12 +519,13 @@ class ConvertReductionOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; - if (isa(op)) { + if (isa(op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); - // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the - // dimensions of the input tensor. + // `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and + // `AtenMinOp` each reduce along all the dimensions of the input tensor. for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); @@ -714,7 +724,10 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 441c76ce7ea4..3c5d6cfaee07 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -56,6 +56,13 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, llvm_unreachable("Unhandled element type for comparison"); } +static Value getZeroPoint(Value value) { + if (auto make = value.getDefiningOp()) { + return make.getZeroPoint(); + } + return nullptr; +} + static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate(op)) { - if (!relu.getType() - .cast() - .getDtype() - .isa()) { - relu.emitError("unimplemented: non-floating point dtype"); + Value zeroPoint = getZeroPoint(relu.getSelf()); + Value arg = payloadArgs[0]; + auto intType = arg.getType().dyn_cast(); + if (zeroPoint && !intType) { + relu.emitError("unimplemented: non-integer quantized Relu."); return nullptr; } - Type elementType = payloadArgs[0].getType(); - Value constZero = - b.create(loc, b.getZeroAttr(elementType)); - Value pred = b.create(loc, arith::CmpFPredicate::UGT, - payloadArgs[0], constZero); - return b.create(loc, pred, payloadArgs[0], constZero); + auto reluTorchType = cast(relu.getType()); + bool isUnsigned = + torch_to_linalg::isUnsignedTorchType(reluTorchType.getDtype()); + if (zeroPoint) { + int64_t zeroPointInt; + int64_t width = intType.getWidth(); + assert(width < 64); + int64_t minForIntType = isUnsigned ? 0 : -(1 << (width - 1)); + int64_t maxForIntType = + isUnsigned ? (1 << (width + 1)) - 1 : (1 << (width - 1)) - 1; + // check for constant zero point edge-cases: + if (matchPattern(zeroPoint, m_TorchConstantInt(&zeroPointInt))) { + if (zeroPointInt > maxForIntType) { + // TODO: figure out how to handle this case: + // current impl. quantizes output like input. + // If zero point > maxForIntType, ordinary relu should return 0. + // However, 0 isn't represented in such a quantization scheme. + relu.emitError( + "unimplemented: quantized relu for zero-point > max qint"); + return nullptr; + } + if (zeroPointInt < minForIntType) + return arg; + } + zeroPoint = converter->materializeTargetConversion( + b, loc, converter->convertType(zeroPoint.getType()), zeroPoint); + auto minForIntTypeValue = b.create( + loc, b.getIntegerAttr(zeroPoint.getType(), minForIntType)); + auto maxForIntTypeValue = b.create( + loc, b.getIntegerAttr(zeroPoint.getType(), maxForIntType)); + auto zpLtMax = b.create(loc, arith::CmpIPredicate::slt, + zeroPoint, maxForIntTypeValue); + b.create( + loc, zpLtMax, + b.getStringAttr("Invalid Quantization: quantized relu with " + "zero-point > max qint")); + auto zpLtMin = b.create(loc, arith::CmpIPredicate::slt, + zeroPoint, minForIntTypeValue); + zeroPoint = b.create(loc, zpLtMin, minForIntTypeValue, + zeroPoint); + zeroPoint = b.create(loc, arg.getType(), zeroPoint); + } else { + zeroPoint = + b.create(loc, b.getZeroAttr(arg.getType())); + } + Value cmp; + if (intType) { + auto pred = + isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt; + cmp = b.create(loc, pred, arg, zeroPoint); + } else { + cmp = b.create(loc, arith::CmpFPredicate::UGT, arg, + zeroPoint); + } + return b.create(loc, cmp, arg, zeroPoint); } if (auto round = dyn_cast(op)) { if (!round.getType() diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index eca7c30259de..1858b1a6d7ca 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -217,6 +217,37 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { }; } // namespace +// These legalizations are for unary ops with promoting to floating point +// datatypes. +namespace { +template +class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.getSelf(); + auto selfTy = self.getType().cast(); + if (!selfTy) + return op.emitError("only Tensor types supported in StableHLO"); + auto resultTy = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + if (resultTy.getElementType().template isa()) { + Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy); + rewriter.replaceOpWithNewOp(op, resultTy, src); + return success(); + } else { + return op.emitError( + "only result to be floating-point datatype legalization supported"); + } + } +}; +} // namespace + // aten.ones & aten.zeros // Ref: Error checking based on the Torch to TOSA lowering namespace { @@ -1029,6 +1060,49 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenLog2Op +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLog2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().template dyn_cast(); + if (!inputTy) { + return op.emitError("only ranked tensor type is supported."); + } + auto outTy = getTypeConverter()->convertType(op.getType()).cast(); + input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + + auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input); + auto log2Op = rewriter.create(op.getLoc(), two); + auto logInputOp = rewriter.create(op.getLoc(), input); + + rewriter.replaceOpWithNewOp(op, outTy, logInputOp, log2Op); + return success(); +} + +// AtenLog10Op +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLog10Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().template dyn_cast(); + if (!inputTy) { + return op.emitError("only ranked tensor type is supported."); + } + + auto outTy = getTypeConverter()->convertType(op.getType()).cast(); + input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + + auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input); + auto log10Op = rewriter.create(op.getLoc(), ten); + auto logInputOp = rewriter.create(op.getLoc(), input); + + rewriter.replaceOpWithNewOp(op, outTy, logInputOp, log10Op); + return success(); +} + // AtenErfOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1558,6 +1632,46 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenConstantPadNdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.getSelf(); + auto selfTy = self.getType().cast(); + auto selfElemTy = selfTy.getElementType(); + int64_t rank = selfTy.getRank(); + + SmallVector padInts; + if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure(op, + "only support constant int pad ranges"); + uint64_t padRank = padInts.size() / 2; + if (padRank * 2 != padInts.size()) + return rewriter.notifyMatchFailure(op, "pad range size is not even"); + if (rank < 0 || padRank > (uint64_t)rank) + return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); + + // Initialize low/high paddings with 0 for all the dims. + SmallVector lowPadding(/*Size=*/rank, /*Value=*/0); + SmallVector highPadding(/*Size=*/rank, /*Value=*/0); + // Add the requested padding - note op.pad() is highest dim first ordered + // pairs of low,high. + // Add the requested padding - note op.pad() is highest dim first ordered + // pairs of low,high. + for (uint64_t i = 0; i < padRank; ++i) { + lowPadding[rank - i - 1] = padInts[i * 2]; + highPadding[rank - i - 1] = padInts[i * 2 + 1]; + } + + Value constantValue = hlo::scalarToStablehloTensor( + rewriter, op, adaptor.getValue(), selfElemTy); + + SmallVector interiorPadding(rank, 0); + rewriter.replaceOpWithNewOp( + op, self, constantValue, lowPadding, highPadding, interiorPadding); + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluBackwardOp op, OpAdaptor adaptor, @@ -1888,20 +2002,31 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) - INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp); - INSERT_UNARY_FPONLY_PATTERN(AtenLog1pOp, stablehlo::Log1pOp); - INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); - INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp); - INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp); INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp); INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp); INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp); #undef INSERT_UNARY_FPONLY_PATTERN +#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, StablehloOp) \ + target.addIllegalOp(); \ + patterns.add>( \ + typeConverter, context) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, stablehlo::LogOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLog1pOp, stablehlo::Log1pOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, stablehlo::ExpOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinhOp, chlo::SinhOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCoshOp, chlo::CoshOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp); +#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN + #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ @@ -1985,9 +2110,12 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenScalarImplicitOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); + INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenLog2Op); + INSERT_ATENOP_PATTERN(AtenLog10Op); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index c525c8b40de5..fee5cc01e4ae 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -89,6 +89,33 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } + if (isa(op)) { + if (isa(elementTy)) { + APFloat one(cast(elementTy).getFloatSemantics(), 1); + auto constAttr = DenseElementsAttr::get(constType, one); + return rewriter.create(op->getLoc(), constType, + constAttr); + } else if (isa(elementTy) && + elementTy.getIntOrFloatBitWidth() != 8) { + APInt one(elementTy.getIntOrFloatBitWidth(), 1); + auto constAttr = DenseElementsAttr::get(constType, one); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + } + + if (isa(op)) { + auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + + if (isa(op)) { + auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + op->emitError("unimplemented lowering in " "createInitialValueForReduceOp"); return nullptr; @@ -448,6 +475,223 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace +// AtenAllOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenAllOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenAllOp to StableHLO"); + } + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + if (inputElemTy != outTy.getElementType()) { + // Use output bool type as computation type. + auto dstElemTy = outTy.getElementType(); + input = + rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = input.getType().dyn_cast(); + inputElemTy = inputTy.getElementType(); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return failure(); + llvm::sort(dims.begin(), dims.end()); + auto stablehloReduceOp = rewriter.create( + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); + + Block &block = stablehloReduceOp.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value allResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), allResult); + } + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + stablehloReduceOp.getResults()); + return success(); +} +} // namespace + +// AtenAnyOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenAnyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenAllOp to StableHLO"); + } + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + if (inputElemTy != outTy.getElementType()) { + // Use output bool type as computation type. + auto dstElemTy = outTy.getElementType(); + input = + rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = input.getType().dyn_cast(); + inputElemTy = inputTy.getElementType(); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return failure(); + llvm::sort(dims.begin(), dims.end()); + auto stablehloReduceOp = rewriter.create( + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); + + Block &block = stablehloReduceOp.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value anyResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), anyResult); + } + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + stablehloReduceOp.getResults()); + return success(); +} +} // namespace + +// AtenProdOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().dyn_cast(); + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + if (inputTy.getElementType() != outTy.getElementType()) { + // Use output element type as computation type. + auto dstElemTy = outTy.getElementType(); + input = + rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = input.getType().dyn_cast(); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenProdOp to StableHLO"); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return failure(); + + llvm::sort(dims.begin(), dims.end()); + auto stablehloReduceOp = rewriter.create( + op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input, + initValue, rewriter.getDenseI64ArrayAttr(dims)); + + Block &block = stablehloReduceOp.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value mulResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), mulResult); + } + + rewriter.replaceOpWithNewOp(op, outTy, + stablehloReduceOp.getResults()); + + return success(); +} +} // namespace + // AtenMaxOp namespace { template <> @@ -612,11 +856,17 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( SmallVector inputDims; SmallVector dims; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { - return rewriter.notifyMatchFailure(op, "non-int dim list unsupported"); - } - if (inputDims.size() == 0) { + + if (failed(checkNotNone(rewriter, op, op.getDim()))) { inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } else { + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); + } + if (inputDims.size() == 0) { + inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } } for (auto d : inputDims) { @@ -957,6 +1207,9 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e768033ac87f..a8769def6585 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1834,6 +1834,19 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { return {}; } +//===----------------------------------------------------------------------===// +// AtenTruncOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) { + auto resultType = getType().dyn_cast(); + if (resultType && resultType.hasDtype() && + resultType.getDtype().isa()) { + return getSelf(); + } + return {}; +} + //===----------------------------------------------------------------------===// // AtenSignOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index a1cc7ddf6ea6..f4415a480a7c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6332,6 +6332,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sinh\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.asin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6498,6 +6502,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.trunc\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.log\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7030,6 +7038,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.prod\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mean\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" @@ -9699,6 +9711,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sinh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.asin\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -9990,6 +10007,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.trunc\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" @@ -12231,6 +12252,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.prod\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sum.dim_IntList\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" " return %0 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 87f93ba9c555..49dd5319514b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5886,6 +5886,32 @@ class DecomposeAtenCosineSimilarityOp }; } // namespace +namespace { +// decompose `trunc(x)` to `sign(x) * floor(abs(x))` +class DecomposeAtenTruncOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTruncOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result must have dtype"); + } + + if (isa(resultTy.getDtype())) { + Value sign = rewriter.create(loc, resultTy, self); + Value abs = rewriter.create(loc, resultTy, self); + Value floor = rewriter.create(loc, resultTy, abs); + rewriter.replaceOpWithNewOp(op, resultTy, sign, floor); + return success(); + } + return failure(); + } +}; +} // namespace + namespace { // Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and // `aten.add.Tensor` op. @@ -7700,6 +7726,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index bff463c4cee6..3b30e9424f44 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -20,6 +20,13 @@ using namespace mlir::torch::Torch; namespace { +template struct QuantInfo { + static constexpr unsigned operandsToQuantize[2] = {0, 1}; +}; + +template <> struct QuantInfo { + static constexpr unsigned operandsToQuantize[1] = {0}; +}; template class QuantizeOperands : public OpRewritePattern { public: @@ -42,8 +49,9 @@ class QuantizeOperands : public OpRewritePattern { return operand; }; - operands[0] = f(operands[0]); - operands[1] = f(operands[1]); + for (unsigned i : QuantInfo::operandsToQuantize) { + operands[i] = f(operands[i]); + } if (!dequanted) { return rewriter.notifyMatchFailure(op, "no dequantizations found"); @@ -259,6 +267,70 @@ class QuantizeAccumulator : public OpRewritePattern { } }; +// Use for ops which do not manipulate scale/zero point of an input. +template +class QuantizeResultLikeOperand : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector operands(op->getOperands()); + Value input = operands[0]; + + auto inputType = dyn_cast_or_null(input.getType()); + if (!inputType || !inputType.hasDtype()) + return failure(); + auto qDtype = inputType.getDtype(); + + auto resultTy = dyn_cast_or_null(op.getType()); + if (!resultTy || !resultTy.hasDtype()) + return failure(); + + Type resultETy = resultTy.getDtype(); + if (!isa(resultETy)) + return failure(); + + Value inputScale, inputZeroPoint; + Type definingOpInputType; + if (auto defining = input.template getDefiningOp< + Aten_MakePerTensorQuantizedTensorOp>()) { + inputScale = defining.getScale(); + inputZeroPoint = defining.getZeroPoint(); + definingOpInputType = defining.getSelf().getType(); + } + + auto inputIntReprType = + dyn_cast_or_null(definingOpInputType); + if (!inputScale || !inputZeroPoint || !inputIntReprType || + !inputIntReprType.hasDtype()) + return failure(); + auto intReprDtype = inputIntReprType.getDtype(); + + // set SrcOp type to use quantized dtype from input + auto newResultTy = + rewriter.getType(resultTy.getOptionalSizes(), qDtype); + auto newResult = rewriter.create(op.getLoc(), newResultTy, operands); + + // int repr to get non quantized int type result + auto intReprTy = rewriter.getType( + resultTy.getOptionalSizes(), intReprDtype); + auto intRepr = + rewriter.create(op.getLoc(), intReprTy, newResult); + + // requantize so the scale and zero-point info can be attached + auto quantTy = + rewriter.getType(resultTy.getOptionalSizes(), qDtype); + auto quant = rewriter.create( + op.getLoc(), quantTy, intRepr, inputScale, inputZeroPoint); + + // dequant back to original dtype + auto dequant = + rewriter.create(op.getLoc(), resultTy, quant); + rewriter.replaceOp(op, dequant); + return success(); + } +}; + template class RemoveUnused : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -285,11 +357,12 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { RemoveUnused, RemoveUnused, RemoveUnused, QuantizeOperands, - QuantizeOperands, + QuantizeOperands, QuantizeOperands, QuantizeTransposedOperands, QuantizeAccumulator, QuantizeOperands, QuantizeTransposedOperands, QuantizeAccumulator, - QuantizeBias>(context); + QuantizeResultLikeOperand, QuantizeBias>( + context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 701300fefe43..e1377afce373 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -512,6 +512,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 323a39bf33cb..c3d9d0dfeb09 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -331,6 +331,9 @@ "AtenMatmulQint8VV_basic", "AtenMatmulQint8VM_basic", "AtenMatmulQint8_basic", + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", # Dynamo not supporting conv_tbc @@ -413,6 +416,9 @@ 'AtenMmQMixedSigni8_basic', 'AtenMmQint8_basic', 'AtenMmQuint8_basic', + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", 'AtenSubFloatModule_basic', 'BincountMinlengthModule_basic', 'BincountModule_basic', @@ -599,10 +605,6 @@ "BroadcastDynamicDimModule_basic", "CeilFloatModule_basic", "ConstantBoolParameterModule_basic", - "ConstantPad2dStaticModule_basic", - "ConstantPadNdModule_basic", - "ConstantPadNdPartialStaticModule_basic", - "ConstantPadNdStaticModule_basic", "ContainsIntList_False", "ContainsIntList_True", "Conv2dQInt8Module_basic", @@ -620,20 +622,14 @@ "DiagonalModule_with_offset", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAcosIntModule_basic", - "ElementwiseAcosModule_basic", "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAsinIntModule_basic", - "ElementwiseAsinModule_basic", "ElementwiseAsinhIntModule_basic", "ElementwiseAsinhModule_basic", "ElementwiseAtan2FloatIntModule_basic", "ElementwiseAtan2TensorFloatModule_basic", "ElementwiseAtan2TensorIntModule_basic", - "ElementwiseAtanTensorFloatModule_basic", - "ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", "ElementwiseBitwiseLeftShiftInt32Module_basic", @@ -642,35 +638,21 @@ "ElementwiseBitwiseRightShiftInt32Module_basic", "ElementwiseBitwiseRightShiftInt64Module_basic", "ElementwiseBitwiseRightShiftInt8Module_basic", - "ElementwiseCosIntModule_basic", - "ElementwiseCoshIntModule_basic", - "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseLog10IntModule_basic", - "ElementwiseLog10Module_basic", - "ElementwiseLog2IntModule_basic", - "ElementwiseLog2Module_basic", - "ElementwiseLogIntModule_basic", "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwisePowScalarModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", - "ElementwiseRsqrtIntModule_basic", - "ElementwiseSigmoidIntModule_basic", - "ElementwiseSinIntModule_basic", - "ElementwiseSqrtIntModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", "ElementwiseTernaryModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", - "ElementwiseUnaryIntModule_basic", "EmptyModule_uint8", "EqIntModule_basic", - "ExponentialModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic", "FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineRoundToEvenModule_basic", @@ -742,7 +724,6 @@ "MaxPool3dModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic", "MaxPool3dStaticModule_basic", - "MeanDimNoneDimModule_basic", "MseLossMeanReductionModule_basic", "MseLossSumReductionWithDifferentElemTypeModule_basic", "MulFloatModule_basic", @@ -769,8 +750,6 @@ "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "PadModule_basic", - "PadWithNoneValModule_basic", "PixelShuffleModuleFullDynamic_basic", "PixelShuffleModuleSpatiallyDynamic_basic", "PixelShuffleModuleSpatiallyStatic_basic", @@ -997,6 +976,10 @@ "Convolution2DStaticModule_basic", "ConvolutionBackwardModule2DStatic_basic", "ConvolutionModule2DTransposeStridedStatic_basic", + "ConstantPad2dStaticModule_basic", + "ConstantPadNdModule_basic", + "ConstantPadNdPartialStaticModule_basic", + "ConstantPadNdStaticModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CosineSimilarityStaticModule_basic", "CumsumInputDtypeInt32Module_basic", @@ -1060,6 +1043,10 @@ "ElementwiseGeluModule_basic", "ElementwiseLeakyReluStaticModule_basic", "ElementwiseLogModule_basic", + "ElementwiseLog10Module_basic", + "ElementwiseLog2Module_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog2IntModule_basic", "ElementwiseNanToNumModule_Basic", "ElementwiseNeFloatTensorStaticModule_basic", "ElementwiseNeIntTensorStaticModule_basic", @@ -1150,6 +1137,7 @@ "MaxPool2dStaticModule_basic", "MeanDimAllReduceModule_basic", "MeanDimEmptyDimModule_basic", + "MeanDimNoneDimModule_basic", "MeanDtypeModule_basic", "MeanDynamicSizesModule_basic", "MeanModule_basic", @@ -1219,6 +1207,8 @@ "OnesModuleFalsePinMemory_basic", "OnesModuleFloat_basic", "OnesModuleInt_basic", + "PadModule_basic", + "PadWithNoneValModule_basic", "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", @@ -1237,6 +1227,12 @@ "RandIntLowModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", + "ReduceAllFloatModule_basic", + "ReduceAllIntModule_basic", + "ReduceAllBoolModule_basic", + "ReduceAnyFloatModule_basic", + "ReduceAnyIntModule_basic", + "ReduceAnyBoolModule_basic", "ReduceAmaxMultiDim_basic", "ReduceAmaxOutOfOrderDim_basic", "ReduceAmaxSingleDim_basic", @@ -1263,6 +1259,12 @@ "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", + "ReduceProdFloatModule_basic", + "ReduceProdDtypeFloatModule_basic", + "ReduceProdElementTypeBoolModule_basic", + "ReduceProdUnsignedIntModule_basic", + "ReduceProdSignedIntModule_basic", + "ReduceProdDtypeIntModule_basic", "RepeatInterleaveSelfIntModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", "ReturnThreeTensorFloat32_basic", @@ -1465,6 +1467,26 @@ "ElementwiseLog1pModule_basic", "ElementwiseSgnModule_basic", "ElementwiseSignIntModule_basic", + "ElementwiseAcosModule_basic", + "ElementwiseAsinModule_basic", + "ElementwiseAtanTensorFloatModule_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseLogIntModule_basic", + "ElementwiseRsqrtIntModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseSqrtIntModule_basic", + "ElementwiseUnaryIntModule_basic", + "ElementwiseCoshIntModule_basic", + "ElementwiseCoshModule_basic", + "ElementwiseSinhIntModule_basic", + "ElementwiseSinhModule_basic", + "ElementwiseTruncIntModule_basic", + "ElementwiseTruncModule_basic", } STABLEHLO_CRASHING_SET = { @@ -1474,6 +1496,8 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ElementwiseTruncModule_basic", + "ElementwiseTruncIntModule_basic", "ElementwiseSgnModule_basic", "ElementwiseSignIntModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", @@ -1795,6 +1819,8 @@ "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", + "ReduceAllBoolModule_basic", + "ReduceAnyBoolModule_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", "ReduceSumDimIntListKeepDimFloatModule_basic", @@ -2326,8 +2352,12 @@ "ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseSinhIntModule_basic", + "ElementwiseSinhModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", + "ElementwiseTruncIntModule_basic", + "ElementwiseTruncModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", @@ -2462,6 +2492,9 @@ "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", + "QuantizedReluInt8_basic", + "QuantizedReluInt32_basic", + "QuantizedReluUint8_basic", "RandIntDtypeModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", @@ -2604,6 +2637,14 @@ # Failure - onnx_lowering: onnx.OneHot "OneHotModule_basic", + + # Failure - onnx_lowering: onnx.ReduceProd + "ReduceProdFloatModule_basic", + "ReduceProdDtypeFloatModule_basic", + "ReduceProdElementTypeBoolModule_basic", + "ReduceProdUnsignedIntModule_basic", + "ReduceProdSignedIntModule_basic", + "ReduceProdDtypeIntModule_basic", # ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64) "RandnDtypeDeviceModule_basic", @@ -2615,12 +2656,6 @@ "BernoulliPModule_basic", "BernoulliTensorModule_basic", - # Failure - onnx_lowering: onnx.ReduceL2 - "LinalgNormKeepDimModule_basic", - "LinalgNormModule_basic", - "NormalizeModule_basic", - "ReduceL2NormModule_basic", - # Failure - onnx_lowering: onnx.ReduceProd "ReduceProdDimIntFloatModule_basic", @@ -2694,6 +2729,7 @@ "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", + "ReduceAnyFloatModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 8b32ff602697..06962010fea8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -121,6 +121,9 @@ def aten〇fake_quantize_per_tensor_affine〡shape(self: List[int], scale: float def aten〇sin〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇sinh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇asin〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -242,6 +245,9 @@ def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], mi def aten〇ceil〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇trunc〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇log〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -550,6 +556,9 @@ def aten〇max〇other〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇sum〡shape(self: List[int], dtype: Optional[int] = None) -> List[int]: return [] +def aten〇prod〡shape(self: List[int], dtype: Optional[int] = None) -> List[int]: + return [] + def aten〇mean〡shape(self: List[int], dtype: Optional[int] = None) -> List[int]: return [] @@ -2014,6 +2023,11 @@ def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇sinh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇asin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2216,6 +2230,11 @@ def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇trunc〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0)) def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3973,6 +3992,18 @@ def aten〇sum〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = return torch.int64 return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇prod〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.float32) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.int32) + diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 5c4e4d214932..e5b219e55e9c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -359,6 +359,7 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) @@ -660,6 +661,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::sum : (Tensor, int?) -> (Tensor)") emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)") emit("aten::prod.dim_int : (Tensor, int, bool, int?) -> (Tensor)") + emit("aten::prod : (Tensor, int?) -> (Tensor)") emit("aten::max : (Tensor) -> (Tensor)") emit("aten::max.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)") diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index 48b9066d2391..acb487319ae9 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -9,6 +9,7 @@ import sys from io import StringIO import tempfile +import os from torch._functorch.compile_utils import strip_overloads import torch @@ -253,19 +254,20 @@ def _get_for_tracing( } -def _canon_extra_library(extra_library): - extra_library_file_name = "" +def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra_library.mlir"): if len(extra_library) != 0: extra_library_dict = {} for library_func in extra_library: extra_library_dict[library_func.__name__] = library_func mlir_library = generate_library(extra_library_dict) - extra_library_file_name = \ - tempfile.gettempdir() + "/custom_op_extra_library.mlir" - with open(extra_library_file_name, "w") as f: + extra_library_file = \ + os.path.join(tempfile.gettempdir(), extra_library_file_name) + with open(extra_library_file, "w") as f: f.write(mlir_library) - return extra_library_file_name + return extra_library_file + else: + return "" def _lower_mlir_module(verbose, output_type, module): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index e4a185189354..3aa8f10ff9dd 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -63,6 +63,50 @@ def ElementwiseUnaryIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSinhModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.sinh(a) + + +@register_test_case(module_factory=lambda: ElementwiseSinhModule()) +def ElementwiseSinhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseSinhIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.sinh(a) + + +@register_test_case(module_factory=lambda: ElementwiseSinhIntModule()) +def ElementwiseSinhIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseCoshModule(torch.nn.Module): def __init__(self): @@ -661,6 +705,69 @@ def ElementwiseReluModule_basic(module, tu: TestUtils): # ============================================================================== +class QuantizedReluInt8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) + qx = torch.dequantize(qx) + return torch.relu(qx) + +@register_test_case(module_factory=lambda: QuantizedReluInt8()) +def QuantizedReluInt8_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 4, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + +class QuantizedReluUint8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.uint8, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190) + qx = torch.dequantize(qx) + return torch.relu(qx) + +@register_test_case(module_factory=lambda: QuantizedReluUint8()) +def QuantizedReluUint8_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 4, low=0, high=255).to(torch.uint8)) + +# ============================================================================== + +class QuantizedReluInt32(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190) + qx = torch.dequantize(qx) + return torch.relu(qx) + +@register_test_case(module_factory=lambda: QuantizedReluInt32()) +def QuantizedReluInt32_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 4, low=(-2**31), high=(2**31 - 1)).to(torch.int32)) + +# ============================================================================== + class ElementwiseRelu6Module(torch.nn.Module): @@ -1970,6 +2077,50 @@ def ElementwiseCeilModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseTruncModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 6], torch.float32, True), + ]) + def forward(self, a): + return torch.trunc(a) + + +@register_test_case(module_factory=lambda: ElementwiseTruncModule()) +def ElementwiseTruncModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[-torch.inf, torch.inf, torch.nan, -2.3, 0.0, 1.5]])) + + +# ============================================================================== + + +class ElementwiseTruncIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.trunc(a) + + +@register_test_case(module_factory=lambda: ElementwiseTruncIntModule()) +def ElementwiseTruncIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseSignModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index b6178221eb48..076dd4e458a4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -68,6 +68,176 @@ def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceProdFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.prod(a) + + +@register_test_case(module_factory=lambda: ReduceProdFloatModule()) +def ReduceProdFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceProdDtypeFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, a): + return torch.prod(a, dtype=torch.float32) + +@register_test_case(module_factory=lambda: ReduceProdDtypeFloatModule()) +def ReduceProdDtypeFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float64)) + +# ============================================================================== + +class ReduceProdElementTypeBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.prod(a) + + +@register_test_case(module_factory=lambda: ReduceProdElementTypeBoolModule()) +def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool)) + +# ============================================================================== + +class ReduceAllFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllFloatModule()) +def ReduceAllFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceAllIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllIntModule()) +def ReduceAllIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) + +# ============================================================================== + +class ReduceAllBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllBoolModule()) +def ReduceAllBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=2).to(torch.bool)) + +# ============================================================================== + +class ReduceAnyFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.any(a) + + +@register_test_case(module_factory=lambda: ReduceAnyFloatModule()) +def ReduceAnyFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceAnyIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.any(a) + + +@register_test_case(module_factory=lambda: ReduceAnyIntModule()) +def ReduceAnyIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) + +# ============================================================================== + +class ReduceAnyBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.ops.aten.any(a) + + +@register_test_case(module_factory=lambda: ReduceAnyBoolModule()) +def ReduceAnyBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=2).to(torch.bool)) + +# ============================================================================== + class ReduceSumDimIntListFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -239,6 +409,63 @@ def ReduceSumDtypeIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceProdUnsignedIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.prod(a) + + +@register_test_case(module_factory=lambda: ReduceProdUnsignedIntModule()) +def ReduceProdUnsignedIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=0, high=100)) + +# ============================================================================== + +class ReduceProdSignedIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.prod(a) + + +@register_test_case(module_factory=lambda: ReduceProdSignedIntModule()) +def ReduceProdSignedIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-100, high=100)) + +# ============================================================================== + +class ReduceProdDtypeIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.prod(a, dtype=torch.int64) + + +@register_test_case(module_factory=lambda: ReduceProdDtypeIntModule()) +def ReduceProdDtypeIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32)) + +# ============================================================================== + class ReduceSumDimIntListIntModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 85e2832ac392..33d8d8f658b2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -74,6 +74,24 @@ func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2 // ----- +// CHECK-LABEL: @test_argmax_negative_axis_keepdims_random_select_last_index +func.func @test_argmax_negative_axis_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C2_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ARGMAX:.*]] = torch.aten.argmax %[[FLIP]], %[[C2]], %[[TRUE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,3,1],si64> + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMAX]], %[[C3]], %[[C1]] : !torch.vtensor<[2,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,3,1],si64> + // CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,3,1],si64> -> !torch.vtensor<[2,3,1],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> + return %0 : !torch.vtensor<[2,3,1],si64> +} + +// ----- + // CHECK-LABEL: @test_argmax_no_keepdims_example func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -85,6 +103,24 @@ func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> // ----- +// CHECK-LABEL: @test_argmax_no_keepdims_random_select_last_index +func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ARGMAX:.*]] = torch.aten.argmax %[[FLIP]], %[[C1]], %[[FALSE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,4],si64> + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMAX]], %[[C2]], %[[C1_1]] : !torch.vtensor<[2,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + // CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,4],si64> -> !torch.vtensor<[2,4],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> + return %0 : !torch.vtensor<[2,4],si64> +} + +// ----- + // CHECK-LABEL: @test_argmin_default_axis_example func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 0 @@ -107,6 +143,24 @@ func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2 // ----- +// CHECK-LABEL: @test_argmin_negative_axis_keepdims_random_select_last_index +func.func @test_argmin_negative_axis_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C2_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ARGMIN:.*]] = torch.aten.argmin %[[FLIP]], %[[C2]], %[[TRUE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,3,1],si64> + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMIN]], %[[C3]], %[[C1]] : !torch.vtensor<[2,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,3,1],si64> + // CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,3,1],si64> -> !torch.vtensor<[2,3,1],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> + return %0 : !torch.vtensor<[2,3,1],si64> +} + +// ----- + // CHECK-LABEL: @test_argmin_no_keepdims_example func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -118,6 +172,24 @@ func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> // ----- +// CHECK-LABEL: @test_argmin_no_keepdims_example_select_last_index +func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> + // CHECK: %[[ARGMIN:.*]] = torch.aten.argmin %[[FLIP]], %[[C1]], %[[FALSE]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2],si64> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMIN]], %[[C1_1]], %[[C1_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2],si64> + // CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2],si64> -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + +// ----- + // CHECK-LABEL: @test_atan func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.atan %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -938,6 +1010,68 @@ func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,22 // ----- +// CHECK-LABEL: @test_convinteger_without_padding +func.func @test_convinteger_without_padding(%arg0: !torch.vtensor<[1,1,3,3],ui8>, %arg1: !torch.vtensor<[1,1,2,2],ui8>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[1,1,2,2],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[INPUT_ZP:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[WEIGHT_ZP:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_0]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C1_3:.*]] = torch.constant.int 1 + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1_2]], %[[C1_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[INPUT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[INPUT_ZP]] : !torch.vtensor<[1,1,3,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,3,3],!torch.quint8> + // CHECK: %[[WEIGHT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[WEIGHT_ZP]] : !torch.vtensor<[1,1,2,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,2,2],!torch.quint8> + // CHECK: torch.aten.convolution %[[INPUT]], %[[WEIGHT]], %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],!torch.quint8>, !torch.vtensor<[1,1,2,2],!torch.quint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,2,2],si32> + %none = torch.constant.none + %0 = torch.operator "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[1,1,3,3],ui8>, !torch.vtensor<[1,1,2,2],ui8>, !torch.vtensor<[],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[1,1,2,2],si32> + return %0 : !torch.vtensor<[1,1,2,2],si32> +} + +// ----- + +// CHECK-LABEL: @test_convinteger_with_padding +func.func @test_convinteger_with_padding(%arg0: !torch.vtensor<[1,1,3,3],ui8>, %arg1: !torch.vtensor<[1,1,2,2],ui8>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,4,4],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[INPUT_ZP:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[WEIGHT_ZP:.*]] = torch.constant.int 0 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1_0]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C1_3:.*]] = torch.constant.int 1 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_2]], %[[C1_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C1_4:.*]] = torch.constant.int 1 + // CHECK: %[[C1_5:.*]] = torch.constant.int 1 + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1_4]], %[[C1_5]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[INPUT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[INPUT_ZP]] : !torch.vtensor<[1,1,3,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,3,3],!torch.quint8> + // CHECK: %[[WEIGHT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[WEIGHT_ZP]] : !torch.vtensor<[1,1,2,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,2,2],!torch.quint8> + // CHECK: torch.aten.convolution %[[INPUT]], %[[WEIGHT]], %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],!torch.quint8>, !torch.vtensor<[1,1,2,2],!torch.quint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,4],si32> + %none = torch.constant.none + %0 = torch.operator "onnx.ConvInteger"(%arg0, %arg1, %arg2) {torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,3,3],ui8>, !torch.vtensor<[1,1,2,2],ui8>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,4,4],si32> + return %0 : !torch.vtensor<[1,1,4,4],si32> +} + +// ----- + // CHECK-LABEL: @test_convtranspose_dilations func.func @test_convtranspose_dilations(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 6a8757134da1..1262fd7fb983 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -761,6 +761,105 @@ func.func @test_reduce_l1_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f // ----- +// CHECK-LABEL: func.func @test_reduce_l2_default_axes_keepdims_example +func.func @test_reduce_l2_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE_0:.+]] = torch.constant.bool true + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_l2_do_not_keepdims_example_expanded +func.func @test_reduce_l2_do_not_keepdims_example_expanded(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT0_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE_0:.+]] = torch.constant.bool false + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[FALSE_1:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> + // CHECK: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_l2_keep_dims_example +func.func @test_reduce_l2_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT0_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> + + %0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_l2_keep_dims_int_input_example +func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],si64>, !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT0_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> + + %0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir deleted file mode 100644 index 22d5e2d35183..000000000000 --- a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir +++ /dev/null @@ -1,18 +0,0 @@ -// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch - -module { - func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // TODO: Unsupported torch.onnx.select_last_index - // expected-error @+1 {{failed to legalize operation 'torch.operator'}} - %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> - return %0 : !torch.vtensor<[2,4],si64> - } -} - -// ----- -func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // TODO: Unsupported torch.onnx.select_last_index - // expected-error @+1 {{failed to legalize operation 'torch.operator'}} - %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> - return %0 : !torch.vtensor<[2],si64> -} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 2f7d5a11a216..4d2a595da43a 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2308,6 +2308,14 @@ func.func @torch.aten.floor$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> ! return %0 : !torch.vtensor<[?,?],si64> } +// CHECK-LABEL: func.func @torch.aten.trunc$canonicalize +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],si64> +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?,?],si64> +func.func @torch.aten.trunc$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { + %0 = torch.aten.trunc %arg0 : !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64> + return %0 : !torch.vtensor<[?,?],si64> +} + // CHECK-LABEL: func.func @torch.aten.numel$canonicalize // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4],f32> // CHECK-NEXT: %int12 = torch.constant.int 12