From 4c21e20caa69a02cd604c69160a9ec03c77a11ea Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 18 Apr 2024 11:32:31 -0700 Subject: [PATCH 01/34] [torch] Support rank-0 index for torch index select (#3182) Need to perform an expand in the case where the indices is rank-0. --- .../TorchToLinalg/IndirectDataMovement.cpp | 8 ++++++++ .../test_suite/index_select.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index b8754a306711..011978a68a66 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -470,6 +470,7 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { Location loc = op.getLoc(); Value input = adaptor.getSelf(); Value indices = adaptor.getIndex(); + auto indicesTy = cast(indices.getType()); RankedTensorType inputType = input.getType().cast(); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) @@ -484,6 +485,13 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + if (indicesTy.getRank() == 0) { + llvm::SmallVector reassociations; + indicesTy = RankedTensorType::get({1}, indicesTy.getElementType()); + indices = rewriter.create(loc, indicesTy, indices, + reassociations); + } + SmallVector resultShape = getTensorSizes(rewriter, loc, input); resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0]; Value initTensor = rewriter.create( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/index_select.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/index_select.py index 0fdda62a13a0..c25b563aa9ca 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/index_select.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/index_select.py @@ -31,6 +31,24 @@ def IndexSelectSingleIdxModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), torch.tensor([2])) +class IndexSelectRank0IdxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.float32, True), + ([], torch.int64, True), + ]) + + def forward(self, input, indices): + return torch.index_select(input, 1, indices) + +@register_test_case(module_factory=lambda: IndexSelectRank0IdxModule()) +def IndexSelectRank0IdxModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6), torch.tensor(2)) + class IndexSelectNegativeDimModule(torch.nn.Module): def __init__(self): super().__init__() From 0e77de996aa715fb75aff4c3dd3d10c8c9c01853 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 18 Apr 2024 11:47:19 -0700 Subject: [PATCH 02/34] [torch] Add support for `torch.view` with dynamic shapes (#3164) We can map to `tensor.reshape` for handling multiple output dynamic shapes. Later we can perform a more complex analysis for indentifying expand/collapse cases from the tensor.reshape. Initially we planned to handle this identification at the `torch` level however it will be easier to handle once converted to core mlir-dialects. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 95 ++++++++++++++++++- projects/pt1/e2e_testing/main.py | 3 +- projects/pt1/e2e_testing/xfail_sets.py | 15 +-- .../test_suite/reshape_like.py | 24 ++++- test/Conversion/TorchToLinalg/view.mlir | 12 +-- 5 files changed, 127 insertions(+), 22 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 5a47a247abbe..a94f8882edef 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1003,8 +1003,14 @@ class ConvertAtenViewOp : public OpConversionPattern { // collapsed. Note this may technically not always be true. // TODO: think of a way better way to at least detect when this assumption // is violated for the cases of dynamic dimensions. - bool inputHasOneDynDim = llvm::count(inputShape, kUnknownSize) == 1; - bool outputHasOneDynDim = llvm::count(outputShape, kUnknownSize) == 1; + int64_t inputDynDim = llvm::count(inputShape, kUnknownSize); + int64_t outputDynDim = llvm::count(outputShape, kUnknownSize); + if (outputDynDim > 1) + return rewriter.notifyMatchFailure( + op, "Cannot support more than one output dynamic dimension"); + + bool inputHasOneDynDim = inputDynDim == 1; + bool outputHasOneDynDim = outputDynDim == 1; bool singleDynDimsAreEqual = inputHasOneDynDim && outputHasOneDynDim && productReduce(inputShape) == productReduce(outputShape); @@ -1271,6 +1277,85 @@ class ConvertAtenViewOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenViewOpToReshape : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector sizes; + if (!getListConstructElements(op.getSize(), sizes)) + return op.emitError( + "unimplemented: the tensor size list is not from list construct"); + + auto loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + auto self = adaptor.getSelf(); + const TypeConverter *typeConverter = getTypeConverter(); + + // Convert to the `linalg` types, count the number of negative values, + // and determine the product of non-negative values. This lets us compute + // the inferred dimensions sizes. + auto sizeTy = + cast(typeConverter->convertType(sizes.front().getType())); + Value one = + b.create(sizeTy, rewriter.getIntegerAttr(sizeTy, 1)); + Value zero = + b.create(sizeTy, rewriter.getIntegerAttr(sizeTy, 0)); + Value count = zero; + Value knownSize = one; + for (auto &size : sizes) { + Value convert = typeConverter->materializeTargetConversion(rewriter, loc, + sizeTy, size); + + Value mul = b.create(knownSize, convert); + Value add = b.create(count, one); + Value isNeg = + b.create(arith::CmpIPredicate::slt, convert, zero); + + knownSize = b.create(isNeg, knownSize, mul); + count = b.create(isNeg, add, count); + size = convert; + } + + // Check we are only inferring one dimension: + Value countPred = + b.create(arith::CmpIPredicate::sle, count, one); + b.create( + loc, countPred, + b.getStringAttr("must have at most one inferred (negative) dimension")); + + // Determine the total size of the inferred dimension and update the + // inferred dimension: + auto selfTy = cast(self.getType()); + Value totalSize = one; + for (int i = 0, s = selfTy.getRank(); i < s; ++i) { + Value index = b.create(i); + Value dim = b.create(self, index); + dim = b.create(sizeTy, dim); + totalSize = b.create(totalSize, dim); + } + + Value inferredSize = b.create(totalSize, knownSize); + for (auto &size : sizes) { + Value isNeg = + b.create(arith::CmpIPredicate::slt, size, zero); + size = b.create(isNeg, inferredSize, size); + } + + auto ty = RankedTensorType::get(sizes.size(), sizes.front().getType()); + auto outputDims = b.create(ty, sizes); + + auto resultType = + typeConverter->convertType(op.getType()).cast(); + rewriter.replaceOpWithNewOp(op, resultType, self, + outputDims); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenSqueezeOp : public OpConversionPattern { public: @@ -2348,10 +2433,12 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context, /*benefit=*/200); + patterns.add(typeConverter, context, + /*benefit=*/100); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index 9f2323793579..d2c381d654bc 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -32,6 +32,7 @@ from .xfail_sets import ( LINALG_XFAIL_SET, + LINALG_CRASHING_SET, MAKE_FX_TOSA_PASS_SET, STABLEHLO_PASS_SET, STABLEHLO_CRASHING_SET, @@ -99,7 +100,7 @@ def main(): if args.config == "linalg": config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = LINALG_XFAIL_SET - crashing_set = set() + crashing_set = LINALG_CRASHING_SET elif args.config == "stablehlo": config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) xfail_set = all_test_unique_names - STABLEHLO_PASS_SET diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f6949abcfa8d..72b6678024e8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -24,6 +24,11 @@ "SplitWithSizes_Module_basic", } +LINALG_CRASHING_SET = { + # Crashes due to copy to a smaller destination buffer than the source buffer. + "SliceCopyStartGreaterThanDimSize_Module_basic", +} + TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors @@ -2280,15 +2285,6 @@ "ElementwiseToDtypeI64ToUI8Module_basic", # Failure - torch.aten.view lower - "IndexTensorDyanmicInputContiguousWithNoneModule_basic", - "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", - "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", - "IndexTensorMultiInputContiguousCenter_basic", - "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", - "IndexTensorMultiInputNonContiguous_basic", - "IndexTensorMultiInputOneDim_basic", - "IndexTensorMultiInputThreeIndexers_basic", - "IndexTensorMultiInput_basic", "IndexTensorMultiInputContiguousOneDimDynamic_basic", "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", @@ -2327,7 +2323,6 @@ "EmbeddingModuleF16_basic", "EmbeddingModuleI32_basic", "EmbeddingModuleI64_basic", - "FlattenDynamicModule_basic", "GluStaticModule_basic", "GroupNormModule_basic", "IndexTensorHackedTwinModule3dInput_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 73b15afe93b7..8aa3e2c1ff58 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -992,6 +992,28 @@ def forward(self, a): def ReshapeAliasExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(384)) + +# ============================================================================== + +class ReshapeDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(1), a.size(0)) + +@register_test_case(module_factory=lambda: ReshapeDynamicModule()) +def ReshapeDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3,4)) + + + # ============================================================================== class ReshapeAliasCollapseModule(torch.nn.Module): @@ -1153,4 +1175,4 @@ def forward(self, tensor1, tensor2): @register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule()) def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5)) \ No newline at end of file + module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5)) diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index 4f9c1f867ee4..7cad9ffe33f6 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -23,7 +23,8 @@ func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.view$dynamictest( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[BUILTIN_TENSOR]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]] +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -31,7 +32,7 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor %int0 = torch.constant.int 0 %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int - %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %1, %0 : (!torch.int, !torch.int) -> !torch.list %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> return %3 : !torch.vtensor<[?,?],f32> } @@ -41,7 +42,7 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK-LABEL: func.func @torch.aten.view$dynamictest2( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,6,?],f32> -> tensor -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2], [3]] : tensor into tensor +// CHECK: %[[EXPAND:.*]] = tensor.reshape %[[BUILTIN_TENSOR]] // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor -> !torch.vtensor<[?,2,3,?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,2,3,?],f32> @@ -174,9 +175,8 @@ func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2 // CHECK: func.func @torch.aten.view$combineConcepts( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[8,?,?,?,2,1,3],f32>) -> !torch.vtensor<[2,2,2,?,?,?,6],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[8,?,?,?,2,1,3],f32> -> tensor<8x?x?x?x2x1x3xf32> -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2], [3], [4, 5, 6]] : tensor<8x?x?x?x2x1x3xf32> into tensor<8x?x?x?x6xf32> -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1, 2], [3], [4], [5], [6]] : tensor<8x?x?x?x6xf32> into tensor<2x2x2x?x?x?x6xf32> -// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32> +// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]] +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,2,2,?,?,?,6],f32> func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>) -> !torch.vtensor<[2,2,2,?,?,?,6], f32> { From be742a937d18a71c0ddeddeb8a2fb5338e41c431 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 18 Apr 2024 14:58:13 -0700 Subject: [PATCH 03/34] [onnx] Update the failure triage for onnx (#3186) Reclassifying what the source of failures are for various bugs so we can reprioritize what failures are common. --- projects/pt1/e2e_testing/xfail_sets.py | 240 ++++++++++++------------- 1 file changed, 110 insertions(+), 130 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 72b6678024e8..a65dab6e7273 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1724,12 +1724,27 @@ } ONNX_XFAIL_SET = { - # Failure - cast error "PermuteNegativeIndexModule_basic", - + + # Failure - expand multiple dynamic dims + "EmbeddingModuleF16_basic", + "EmbeddingModuleI32_basic", + "EmbeddingModuleI64_basic", + "IndexTensorHackedTwinModule3dInput_basic", + "IndexTensorHackedTwinModule_basic", + "IndexTensorModule3dInput_basic", + "IndexTensorModule_basic", + "IndexTensorMultiInputContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", + "IndexTensorSelectDimModule_basic", + # Failure - incorrect numerics + "AvgPool2dDivisorOverrideModule_basic", + "BroadcastDynamicDimModule_basic", "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseAtenFloorDivideScalarNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", "ElementwiseSeluModule_basic", @@ -1738,43 +1753,58 @@ "HardsigmoidModule_basic", "HardsigmoidRandomModule_basic", "PixelShuffleModuleStaticRank4Float32_basic", + "ReflectionPad1dModule2dInput_Right", + "ReflectionPad1dModule2dInput_basic", + "ReflectionPad1dModule3dInput_Left", + "ReflectionPad1dModule3dInput_basic", + "ReflectionPad2dModule_Bottom", + "ReflectionPad2dModule_Left", + "ReflectionPad2dModule_Right", + "ReflectionPad2dModule_Top", + "ReflectionPad2dModule_basic", + "ReplicationPad2dModule_basic", + "ReplicationPad2dModule_bottom0", + "ReplicationPad2dModule_left0", + "ReplicationPad2dModule_right0", + "ReplicationPad2dModule_top0", "SliceCopyEndGreaterThanDimSize_Module_basic", "SliceCopyNegative_Module_basic", "SliceCopyNonZeroDim_Module_basic", "SliceCopy_Module_basic", + "StdCorrectionLargeInputModule_basic", "TupleModule_basic", - + "VarCorrectionLargeInputModule_basic", + # Failure - incorrect shape "ArangeStartOutDtypeModule_basic", "ArangeStartOutViewModule_basic", - "BroadcastDynamicDimModule_basic", "MoveDimIntNegativeIndexModule_basic", + "ReduceL3NormKeepDimModule_basic", "ViewSizeFromOtherTensor_basic", - + # Failure - onnx_export - "ElementwiseSgnModule_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", + "AdaptiveAvgPool3dDynamicNoBatch_basic", + "AdaptiveAvgPool3dDynamic_basic", + "AdaptiveMaxPool1dDynamicNoBatch_basic", + "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dStatic_basic", + "AdaptiveMaxPool2dDynamicNoBatch_basic", "AdaptiveMaxPool2dDynamicWithIndices_basic", "AdaptiveMaxPool2dDynamic_basic", "AdaptiveMaxPool2dStaticWithIndices_basic", "AdaptiveMaxPool2dStatic_basic", - "AdaptiveMaxPool3dStatic_basic", - "AdaptiveMaxPool3dStaticWithIndices_basic", - "AdaptiveMaxPool3dDynamic_basic", - "AdaptiveMaxPool3dDynamicWithIndices_basic", "AdaptiveMaxPool3dDynamicNoBatch_basic", - "AdaptiveMaxPool2dDynamicNoBatch_basic", - "AdaptiveMaxPool1dStatic_basic", - "AdaptiveMaxPool1dDynamic_basic", - "AdaptiveMaxPool1dDynamicNoBatch_basic", - "AdaptiveAvgPool3dDynamic_basic", - "AdaptiveAvgPool3dDynamicNoBatch_basic", - "AdaptiveAvgPool2dDynamic_basic", - "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveMaxPool3dDynamicWithIndices_basic", + "AdaptiveMaxPool3dDynamic_basic", + "AdaptiveMaxPool3dStaticWithIndices_basic", + "AdaptiveMaxPool3dStatic_basic", "AddCDivModule_basic", "AddIntModule_basic", "Add_Module_basic", @@ -1786,6 +1816,12 @@ "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", "AtenComplexViewModule_basic", + "AtenDiagEmbedDefaultDiag_basic", + "AtenDiagEmbedDimDiag_basic", + "AtenDiagEmbedNegOffsetDiag_basic", + "AtenDiagEmbedNonDefault4DDiag_basic", + "AtenDiagEmbedOffsetDiag_basic", + "AtenDiagEmbedRevDimDiag_basic", "AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagSumExample_basic", "AtenFloatScalarModule_basic", @@ -1796,15 +1832,16 @@ "AtenIntTensorCharDtypeModule_basic", "AtenItemFpOpModule_basic", "AtenItemIntOpModule_basic", - "AtenMmQint8_basic", - "AtenMmQuint8_basic", - "AtenMmQMixedSigni8_basic", + "AtenLinalgCrossDynamic_basic", "AtenMatmulQMixedSigni8Transpose_basic", "AtenMatmulQMixedSigni8_basic", "AtenMatmulQint8MV_basic", - "AtenMatmulQint8VV_basic", "AtenMatmulQint8VM_basic", + "AtenMatmulQint8VV_basic", "AtenMatmulQint8_basic", + "AtenMmQMixedSigni8_basic", + "AtenMmQint8_basic", + "AtenMmQuint8_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", "AtenSubFloatModule_basic", @@ -1886,16 +1923,19 @@ "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic", "ElementwiseEluNonDefaultModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", + "ElementwiseFmodTensor_Int_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwiseOrTensorModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseRemainderTensorModule_Int_basic", - "ElementwiseFmodTensor_Int_basic", + "ElementwiseSgnModule_basic", "EmptyStridedModule_basic", "EmptyStridedSizeIntStrideModule_basic", "EqIntModule_basic", @@ -1910,11 +1950,11 @@ "HardtanhBackward_basic", "IndexPutImpl1DFloatAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic", - "IndexPutImpl2DImplicitModule_basic", "IndexPutImpl1DIntAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic", "IndexPutImpl2DFloatAccumulateModule_basic", "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DImplicitModule_basic", "IndexPutImpl2DIndexModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic", @@ -1931,6 +1971,8 @@ "LeakyReluBackwardStaticModule_basic", "LenStrModule_basic", "LiftFreshCopyModule_basic", + "LinalgNormKeepDimComplexModule_basic", + "LinalgVectorNormComplexModule_basic", "LogSoftmaxBackwardModule_basic", "MaxPool2dCeilModeTrueModule_basic", "MaxPool2dModule_basic", @@ -1990,14 +2032,11 @@ "NllLossModule_ignore_index_out_of_bounds_basic", "NllLossModule_mean_basic", "NllLossModule_sum_basic", - "NormScalarModule_basic", "NormScalarComplexModule_basic", + "NormScalarModule_basic", + "NormScalarOptDimKeepDimComplexModule_basic", "NormScalarOptDimKeepDimModule_basic", "NormScalarOptDimModule_basic", - "NormScalarOptDimKeepDimComplexModule_basic", - "LinalgNormKeepDimComplexModule_basic", - "LinalgVectorNormComplexModule_basic", - "ReduceFrobeniusNormComplexModule_basic", "NormalFunctionalModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", @@ -2019,6 +2058,10 @@ "RandIntDtypeModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", + "ReduceFrobeniusNormComplexModule_basic", + "ReduceL1NormComplexModule_basic", + "ReduceL2NormComplexModule_basic", + "ReduceL3NormKeepDimComplexModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", @@ -2132,8 +2175,11 @@ "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", "_SoftmaxModule_basic", - - # Failure - onnx_import + + # Failure - onnx_lowering: onnx.AveragePool + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + + # Failure - onnx_lowering: onnx.If "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", @@ -2141,70 +2187,31 @@ "DiagonalModule_with_dims_and_offset", "DiagonalModule_with_negative_dims", "DiagonalModule_with_offset", - "AtenDiagEmbedDefaultDiag_basic", - "AtenDiagEmbedDimDiag_basic", - "AtenDiagEmbedOffsetDiag_basic", - "AtenDiagEmbedRevDimDiag_basic", - "AtenDiagEmbedNegOffsetDiag_basic", - "AtenDiagEmbedNonDefault4DDiag_basic", - "ScatterReduceFloatMaxModuleIncludeSelf", - "ScatterReduceFloatMinModuleIncludeSelf", - "ScatterReduceFloatProdModuleIncludeSelf", - "ScatterReduceFloatSumModuleIncludeSelf", - "ScatterReduceIntMaxModuleIncludeSelf", - "ScatterReduceIntMinModuleIncludeSelf", - "ScatterReduceIntProdModuleIncludeSelf", - "ScatterReduceIntSumModuleIncludeSelf", "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", - "LinalgNormKeepDimModule_basic", - "LinalgNormModule_basic", - - # Failure - onnx_lowering: onnx.AveragePool - "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", - "AvgPool2dDivisorOverrideModule_basic", - - # Failure - onnx_lowering: onnx.Clip - "NormalizeModule_basic", - + # Failure - onnx_lowering: onnx.MaxPool "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", "MaxPool2dWithIndicesStaticModule_basic", - + # Failure - onnx_lowering: onnx.OneHot "OneHotModule_basic", - - # Failure - onnx_lowering: onnx.Pad - "ReflectionPad1dModule2dInput_Right", - "ReflectionPad1dModule2dInput_basic", - "ReflectionPad1dModule3dInput_Left", - "ReflectionPad1dModule3dInput_basic", - "ReflectionPad2dModule_Bottom", - "ReflectionPad2dModule_Left", - "ReflectionPad2dModule_Right", - "ReflectionPad2dModule_Top", - "ReflectionPad2dModule_basic", - "ReplicationPad2dModule_basic", - "ReplicationPad2dModule_bottom0", - "ReplicationPad2dModule_left0", - "ReplicationPad2dModule_right0", - "ReplicationPad2dModule_top0", - + # Failure - onnx_lowering: onnx.RandomNormal "RandnDtypeDeviceModule_basic", "RandnGeneratorF64Module_basic", "RandnGeneratorModule_basic", "RandnModule_basic", - + # Failure - onnx_lowering: onnx.RandomNormalLike "RandnLikeDtypeModule_basic", "RandnLikeModule_basic", - + # Failure - onnx_lowering: onnx.RandomUniform "RandIntLowDtypeModule_basic", "RandIntLowModule_basic", - + # Failure - onnx_lowering: onnx.RandomUniformLike "BernoulliFloatModule_basic", "BernoulliPModule_basic", @@ -2212,39 +2219,34 @@ "RandLikeDtypeModule_basic", "RandLikeModule_basic", "RandModule_basic", - - # Failure - onnx_lowering: onnx.ReduceL1 - "ReduceL1NormComplexModule_basic", - + # Failure - onnx_lowering: onnx.ReduceL2 + "LinalgNormKeepDimModule_basic", + "LinalgNormModule_basic", + "NormalizeModule_basic", "ReduceL2NormModule_basic", - "ReduceL2NormComplexModule_basic", - - # Failure - onnx_lowering: onnx.ReduceL3 - "ReduceL3NormKeepDimModule_basic", - "ReduceL3NormKeepDimComplexModule_basic", - - + # Failure - onnx_lowering: onnx.ReduceProd - "BernoulliModule_basic", - "DropoutTrainModule_basic", - "DropoutTrainStaticShapeModule_basic", - "NativeDropoutTrainModule_basic", - "NativeDropoutTrainStaticShapeModule_basic", "ReduceProdDimIntFloatModule_basic", - "StdCorrectionLargeInputModule_basic", - "VarCorrectionLargeInputModule_basic", - + # Failure - onnx_lowering: onnx.Resize "UpSampleNearest2dDynamicSize_basic", "UpSampleNearest2dStaticSize_basic", - + # Failure - onnx_lowering: onnx.ScatterElements + "ScatterReduceFloatMaxModuleIncludeSelf", + "ScatterReduceFloatMinModuleIncludeSelf", + "ScatterReduceFloatProdModuleIncludeSelf", + "ScatterReduceFloatSumModuleIncludeSelf", + "ScatterReduceIntMaxModuleIncludeSelf", + "ScatterReduceIntMinModuleIncludeSelf", + "ScatterReduceIntProdModuleIncludeSelf", + "ScatterReduceIntSumModuleIncludeSelf", "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "ScatterValueFloatModule_basic", "ScatterValueIntModule_basic", - + # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic", @@ -2270,30 +2272,21 @@ "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", "IndexPutHackedTwin3DIntAccumulateModule_basic", "IndexPutHackedTwin3DIntNonAccumulateModule_basic", - + # Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - + # Failure - onnx_lowering: onnx.Squeeze "SqueezeModule_allUnitDim", "SqueezeModule_broadcast", "SqueezeModule_static", - - # Failure - incorrect dtype - "ReduceMaxAlongDimUnsignedInt_basic", - "ElementwiseToDtypeI64ToUI8Module_basic", - - # Failure - torch.aten.view lower - "IndexTensorMultiInputContiguousOneDimDynamic_basic", - "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", - - # Failure - torch.aten.squeeze lower - "BucketizeTensorOutInt32RightModule_basic", # unsupported by backend contract: tensor with unknown rank - + # Failure - unknown + "BernoulliModule_basic", "BucketizeTensorFloatModule_basic", "BucketizeTensorModule_basic", + "BucketizeTensorOutInt32RightModule_basic", "BucketizeTensorStaticFloatModule_basic", "BucketizeTensorStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", @@ -2301,16 +2294,15 @@ "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", + "DropoutTrainModule_basic", + "DropoutTrainStaticShapeModule_basic", "ElementwiseAcosIntModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAtanTensorIntModule_basic", "ElementwiseCosIntModule_basic", - "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", - "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncModule_basic", "ElementwiseDivTensorRoundingModeTruncStaticModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseErfIntModule_basic", "ElementwiseExpIntModule_basic", "ElementwiseLogIntModule_basic", "ElementwisePreluModule_basic", @@ -2318,33 +2310,21 @@ "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", "ElementwiseTanIntModule_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", - "EmbeddingModuleF16_basic", - "EmbeddingModuleI32_basic", - "EmbeddingModuleI64_basic", "GluStaticModule_basic", "GroupNormModule_basic", - "IndexTensorHackedTwinModule3dInput_basic", - "IndexTensorHackedTwinModule_basic", - "IndexTensorModule3dInput_basic", - "IndexTensorModule_basic", - "IndexTensorSelectDimModule_basic", "MaskedFillTensorFloatValueModule_basic", + "NativeDropoutTrainModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", "ReduceAllDimEmpty_basic", "ReduceAllDimFloat_basic", "ReduceAllDimInt_basic", + "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", - - # Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1" - "AtenLinalgCrossDynamic_basic", - - # Failure - value not close to golden value (op is incorrectly truncating) - "ElementwiseAtenFloorDivideTensorNegativeModule_basic", - "ElementwiseAtenFloorDivideScalarNegativeModule_basic", - } if torch_version_for_comparison() >= version.parse("2.4.0.dev"): From 6c4f7deebb89308b8916648a9d3bd1d0ce4edea6 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Fri, 19 Apr 2024 10:55:27 +0800 Subject: [PATCH 04/34] [stablehlo] add aten.clamp.Tensor op conversion support (#3185) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 36 +++++++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 4 +++ 2 files changed, 40 insertions(+) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 136a3352c4c2..0b9e1291efa6 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1470,6 +1470,41 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenClampTensorOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenClampTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputType = cast(input.getType()); + auto inputElemType = inputType.getElementType(); + Value minValue = adaptor.getMin(); + Value maxValue = adaptor.getMax(); + auto minIsNotNone = checkNotNone(rewriter, op, minValue); + auto maxIsNotNone = checkNotNone(rewriter, op, maxValue); + if (failed(minIsNotNone) && failed(maxIsNotNone)) { + return rewriter.notifyMatchFailure( + op, "this op should be folded as its `min` and `max` both are none"); + } else if (failed(minIsNotNone)) { + auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); + if (failed(minInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to generate min value of dtype"); + } + minValue = *minInfo; + } else if (failed(maxIsNotNone)) { + auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); + if (failed(maxInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to generate max value of dtype"); + } + maxValue = *maxInfo; + } + rewriter.replaceOpWithNewOp(op, minValue, input, + maxValue); + return success(); +} + // AtenArangeStartStepOp // aten.arange.start_step = range(ceil((end-start)/step)) * step + start. template <> @@ -1906,6 +1941,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenClampOp); + INSERT_ATENOP_PATTERN(AtenClampTensorOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(AtenBatchNormOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a65dab6e7273..50c909936450 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -639,7 +639,11 @@ "ElementwiseCeilModule_basic", "ElementwiseClampMaxModule_basic", "ElementwiseClampMinModule_basic", + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", "ElementwiseClampModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorIntModule_basic", "ElementwiseClampTensorInt8Module_basic", "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneContiguousModule_basic", From 0a6073414db2eb0860e5256175ee2d8bdeea457a Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Fri, 19 Apr 2024 12:29:17 +0800 Subject: [PATCH 05/34] [FxImporter] Add fx importer to stablehlo e2e test config (#3183) --- projects/pt1/e2e_testing/main.py | 15 +- projects/pt1/e2e_testing/xfail_sets.py | 404 +++++++++++++++++- .../configs/fx_importer_backend.py | 11 +- 3 files changed, 421 insertions(+), 9 deletions(-) diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index d2c381d654bc..5a61b50db730 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -43,8 +43,10 @@ TORCHDYNAMO_CRASHING_SET, ONNX_CRASHING_SET, ONNX_XFAIL_SET, - FX_IMPORT_XFAIL_SET, + FX_IMPORTER_XFAIL_SET, FX_IMPORTER_CRASHING_SET, + FX_IMPORTER_STABLEHLO_XFAIL_SET, + FX_IMPORTER_STABLEHLO_CRASHING_SET, ) # Import tests to register them in the global registry. @@ -52,7 +54,8 @@ register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo", "onnx", "fx_importer"] + config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", + "torchdynamo", "onnx", "fx_importer", "fx_importer_stablehlo"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -67,6 +70,8 @@ def _get_argparse(): "lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. "torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors. "onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path. +"fx_importer": run the model through the fx importer frontend and execute the graph using Linalg-on-Tensors. +"fx_importer_stablehlo": run the model through the fx importer frontend and execute the graph using Stablehlo backend. """) parser.add_argument("-f", "--filter", default=".*", help=""" Regular expression specifying which tests to include in this run. @@ -127,8 +132,12 @@ def main(): crashing_set = LTC_CRASHING_SET elif args.config == "fx_importer": config = FxImporterTestConfig(RefBackendLinalgOnTensorsBackend()) - xfail_set = FX_IMPORT_XFAIL_SET + xfail_set = FX_IMPORTER_XFAIL_SET crashing_set = FX_IMPORTER_CRASHING_SET + elif args.config == "fx_importer_stablehlo": + config = FxImporterTestConfig(LinalgOnTensorsStablehloBackend(), "stablehlo") + xfail_set = FX_IMPORTER_STABLEHLO_XFAIL_SET + crashing_set = FX_IMPORTER_STABLEHLO_CRASHING_SET elif args.config == "torchdynamo": config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = TORCHDYNAMO_XFAIL_SET diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 50c909936450..ed1f5eb8dd7a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -391,7 +391,7 @@ "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", } -FX_IMPORT_XFAIL_SET = { +FX_IMPORTER_XFAIL_SET = { 'AllBoolFalseModule_basic', 'AllBoolTrueModule_basic', 'AnyBoolFalseModule_basic', @@ -525,6 +525,408 @@ "HBC_basic", } +FX_IMPORTER_STABLEHLO_XFAIL_SET = { + "AdaptiveAvgPool3dDynamicNoBatch_basic", + "AdaptiveAvgPool3dDynamic_basic", + "AdaptiveMaxPool1dDynamicNoBatch_basic", + "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dStatic_basic", + "AdaptiveMaxPool2dDynamicNoBatch_basic", + "AdaptiveMaxPool2dDynamicWithIndices_basic", + "AdaptiveMaxPool2dDynamic_basic", + "AdaptiveMaxPool2dStaticWithIndices_basic", + "AdaptiveMaxPool2dStatic_basic", + "AdaptiveMaxPool3dDynamicNoBatch_basic", + "AdaptiveMaxPool3dDynamicWithIndices_basic", + "AdaptiveMaxPool3dDynamic_basic", + "AdaptiveMaxPool3dStaticWithIndices_basic", + "AdaptiveMaxPool3dStatic_basic", + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", + "ArangeStartOutViewModule_basic", + "ArgminIntModule_basic", + "ArgminIntModule_multiple_mins", + "ArgminModule_basic", + "ArgminModule_keepDim", + "ArgminModule_with_dim", + "AtenComplexImagModule_basic", + "AtenComplexRealModule_basic", + "AtenComplexViewModule_basic", + "AtenDiagEmbedDefaultDiag_basic", + "AtenDiagEmbedDimDiag_basic", + "AtenDiagEmbedNegOffsetDiag_basic", + "AtenDiagEmbedNonDefault4DDiag_basic", + "AtenDiagEmbedOffsetDiag_basic", + "AtenDiagEmbedRevDimDiag_basic", + "AtenEmbeddingBagStaticModule_basic", + "AtenEmbeddingBagSumExample_basic", + "AtenFloatScalarModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenItemFpOpModule_basic", + "AtenMatmulQMixedSigni8Transpose_basic", + "AtenMatmulQMixedSigni8_basic", + "AtenMatmulQint8MV_basic", + "AtenMatmulQint8VM_basic", + "AtenMatmulQint8VV_basic", + "AtenMatmulQint8_basic", + "AtenMmQMixedSigni8_basic", + "AtenMmQint8_basic", + "AtenMmQuint8_basic", + "AtenRealView128Module_basic", + "AtenRealView64Module_basic", + "AtenSubFloatModule_basic", + "AtenTopKModule_basic", + "AtenTopKSmallestModule_basic", + "AtenTrilModule_basic", + "AtenTrilWithNegDiagonalModule_basic", + "AtenTrilWithPosDiagonalModule_basic", + "Aten_EmbeddingBagExample_basic", + "AvgPool2dDivisorOverrideModule_basic", + "BernoulliTensorModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "BroadcastDynamicDimModule_basic", + "CeilFloatModule_basic", + "ConstantBoolParameterModule_basic", + "ConstantPad2dStaticModule_basic", + "ConstantPadNdModule_basic", + "ConstantPadNdPartialStaticModule_basic", + "ConstantPadNdStaticModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "Conv2dQInt8Module_basic", + "ConvTbcModule_basic", + "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionBackwardModule2D_basic", + "CumsumModule_basic", + "DiagonalModule_basic", + "DiagonalModule_nonsquare", + "DiagonalModule_transposed", + "DiagonalModule_with_dims", + "DiagonalModule_with_dims_and_offset", + "DiagonalModule_with_negative_dims", + "DiagonalModule_with_offset", + "DivFloatModule_basic", + "DivIntModule_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAcosModule_basic", + "ElementwiseAcoshIntModule_basic", + "ElementwiseAcoshModule_basic", + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAsinModule_basic", + "ElementwiseAsinhIntModule_basic", + "ElementwiseAsinhModule_basic", + "ElementwiseAtan2FloatIntModule_basic", + "ElementwiseAtan2TensorFloatModule_basic", + "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseAtanTensorFloatModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseAtanhIntModule_basic", + "ElementwiseAtanhModule_basic", + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + "ElementwiseBitwiseRightShiftInt32Module_basic", + "ElementwiseBitwiseRightShiftInt64Module_basic", + "ElementwiseBitwiseRightShiftInt8Module_basic", + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorIntModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseCoshIntModule_basic", + "ElementwiseCoshModule_basic", + "ElementwiseDequantizePerChannelModule_basic", + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseErfIntModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "ElementwiseFmodTensor_Float_basic", + "ElementwiseFmodTensor_Int_Float_basic", + "ElementwiseFmodTensor_Int_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog10Module_basic", + "ElementwiseLog2IntModule_basic", + "ElementwiseLog2Module_basic", + "ElementwiseLogIntModule_basic", + "ElementwiseLogitModule_basic", + "ElementwiseMulTensorComplexModule_basic", + "ElementwisePowScalarModule_basic", + "ElementwiseQuantizePerTensorModule_basic", + "ElementwiseQuantizePerTensorUIntModule_basic", + "ElementwiseReciprocalIntModule_basic", + "ElementwiseRemainderTensorModule_Float_basic", + "ElementwiseRemainderTensorModule_Int_Float_basic", + "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseRsqrtIntModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseSqrtIntModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", + "ElementwiseTernaryModule_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", + "ElementwiseUnaryIntModule_basic", + "EmptyModule_uint8", + "EqIntModule_basic", + "ExponentialModule_basic", + "FakeQuantizePerTensorAffineDynamicShapeModule_basic", + "FakeQuantizePerTensorAffineModule_basic", + "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "Fill_TensorFloat32WithFloat32_basic", + "Fill_TensorFloat32WithFloat64_basic", + "Fill_TensorFloat32WithInt64_basic", + "FloatImplicitModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "HBC_basic", + "HardtanhBackward_basic", + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DImplicitModule_basic", + "IndexPutImpl2DIndexModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl3DFloatAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexPutImplIndexWithNoneModule_basic", + "IndexTensorNegativeIndexModule_basic", + "IntFloatModule_basic", + "IntImplicitModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", + "LenStrModule_basic", + "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dEmptyStrideStaticModule_basic", + "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", + "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", + "MaxPool2dWithIndicesBackwardStatic3DModule_basic", + "MaxPool2dWithIndicesBackwardStatic4DModule_basic", + "MaxPool3dCeilModeTrueModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MaxPool3dStaticCeilModeTrueModule_basic", + "MaxPool3dStaticModule_basic", + "MeanDimNoneDimModule_basic", + "MseLossMeanReductionModule_basic", + "MseLossSumReductionWithDifferentElemTypeModule_basic", + "MulFloatModule_basic", + "NativeGroupNormBackwardModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NllLossModuleBackward1DMeanWeight_basic", + "NllLossModuleBackward1DMean_basic", + "NllLossModuleBackward1DSumWeight_basic", + "NllLossModuleBackward1DSum_basic", + "NllLossModuleBackward1DWeight_basic", + "NllLossModuleBackward1D_basic", + "NllLossModuleBackwardMeanWeight_basic", + "NllLossModuleBackwardMean_basic", + "NllLossModuleBackwardSumWeight_basic", + "NllLossModuleBackwardSum_basic", + "NllLossModuleBackwardWeight_basic", + "NllLossModuleBackward_basic", + "NllLossModuleBackward_ignore_index", + "NormScalarComplexModule_basic", + "NormScalarModule_basic", + "NormalFunctionalModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "PadModule_basic", + "PadWithNoneValModule_basic", + "PixelShuffleModuleFullDynamic_basic", + "PixelShuffleModuleSpatiallyDynamic_basic", + "PixelShuffleModuleSpatiallyStatic_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "PowIntFloatModule_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "QuantizedBatchedInputSingleLayer_basic", + "QuantizedMLP_basic", + "QuantizedNoLayer_basic", + "QuantizedSingleLayer_basic", + "RandnDtypeDeviceModule_basic", + "RandnGeneratorF64Module_basic", + "RandnGeneratorModule_basic", + "RandnLikeDtypeModule_basic", + "RandnLikeModule_basic", + "RandnModule_basic", + "ReduceAllDimBool_basic", + "ReduceAllDimEmpty_basic", + "ReduceAllDimFloat_basic", + "ReduceAllDimInt_basic", + "ReduceMaxAlongDimUnsignedInt_basic", + "ReduceMinAlongDimNegative_basic", + "ReduceMinAlongDimSignedInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", + "ReduceMinAlongDim_basic", + "ReduceMinKeepDimReturnBoth_basic", + "ReduceMinKeepDim_basic", + "ReduceProdDimIntFloatModule_basic", + "ReflectionPad1dModule2dInput_Right", + "ReflectionPad1dModule2dInput_basic", + "ReflectionPad1dModule3dInput_Left", + "ReflectionPad1dModule3dInput_basic", + "ReflectionPad2dModule_Bottom", + "ReflectionPad2dModule_Left", + "ReflectionPad2dModule_Right", + "ReflectionPad2dModule_Top", + "ReflectionPad2dModule_basic", + "ReplicationPad2dModule_basic", + "ReplicationPad2dModule_bottom0", + "ReplicationPad2dModule_left0", + "ReplicationPad2dModule_right0", + "ReplicationPad2dModule_top0", + "RsubInt0d_NumToTensor_Module_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScatterReduceFloatMaxModule", + "ScatterReduceFloatMaxModuleIncludeSelf", + "ScatterReduceFloatMeanModule", + "ScatterReduceFloatMeanModuleIncludeSelf", + "ScatterReduceFloatMinModule", + "ScatterReduceFloatMinModuleIncludeSelf", + "ScatterReduceFloatProdModule", + "ScatterReduceFloatProdModuleIncludeSelf", + "ScatterReduceFloatSumModule", + "ScatterReduceFloatSumModuleIncludeSelf", + "ScatterReduceIntMaxModule", + "ScatterReduceIntMaxModuleIncludeSelf", + "ScatterReduceIntMeanModule", + "ScatterReduceIntMeanModuleIncludeSelf", + "ScatterReduceIntMinModule", + "ScatterReduceIntMinModuleIncludeSelf", + "ScatterReduceIntProdModule", + "ScatterReduceIntProdModuleIncludeSelf", + "ScatterReduceIntSumModule", + "ScatterReduceIntSumModuleIncludeSelf", + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "ScatterValueFloatModule_basic", + "ScatterValueIntModule_basic", + "SliceOutOfLowerBoundEndIndexModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SortTensorDescending_basic", + "SortTensorInteger_basic", + "SortTensorNegativeDimension_basic", + "SortTensorSpecificDimension_basic", + "SortTensor_basic", + "SplitDimDynamicModule_basic", + "SplitDimStaticModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "SubFloatModule_basic", + "TModuleRank0_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TensorToInt_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "Threshold1dFloatModule_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dIntModule_basic", + "Threshold2dFloatModule_basic", + "Threshold2dIntModule_basic", + "Threshold3dFloatModule_basic", + "Threshold3dIntModule_basic", + "ThresholdBackward1dFloatModule_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward1dMixedModule_basic", + "ThresholdBackward2dFloatModule_basic", + "ThresholdBackward2dIntModule_basic", + "ThresholdBackward2dMixedModule_basic", + "ThresholdBackward3dFloatModule_basic", + "ThresholdBackward3dIntModule_basic", + "ThresholdBackward3dMixedModule_basic", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", + "TraceModule_basic", + "TraceModule_empty", + "TraceModule_nonsquare", + "TraceSignedIntModule_basic", + "TraceUnsignedIntModule_basic", + "TraceUnsignedIntModule_empty", + "UnbindIntGetItem_Module_basic", + "UnbindIntListUnpack_Module_basic", + "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "UpSampleNearest2dBackwardScalesNone_basic", + "UpSampleNearest2dBackward_basic", + "VarMeanBiasedModule_basic", + "VarMeanCorrectionNoneModule_basic", + "VarMeanUnbiasedModule_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "ViewSizeFromOtherTensor_basic", +} + +FX_IMPORTER_STABLEHLO_CRASHING_SET = { + "BatchNorm1DModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", + "MobilenetV3Module_basic", + "Conv2dBiasNoPaddingModule_basic", +} + STABLEHLO_PASS_SET = { "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 4d8ca8dc4968..0d75fe2ad3f0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -82,9 +82,10 @@ def jit( class FxImporterTestConfig(TestConfig): """TestConfig that runs the torch.nn.Module with Fx Importer""" - def __init__(self, backend): + def __init__(self, backend, output_type="linalg-on-tensors"): super().__init__() - self.backend = backend + self._backend = backend + self._output_type = output_type def compile(self, program: torch.nn.Module) -> torch.nn.Module: return program @@ -95,9 +96,9 @@ def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: prog = torch.export.export(artifact, tuple(item.inputs)) module = jit(prog, func_name=artifact.__class__.__name__, - output_type="linalg-on-tensors") - module = self.backend.compile(module) - backend_module = self.backend.load(module) + output_type=self._output_type) + module = self._backend.compile(module) + backend_module = self._backend.load(module) params = { # **dict(artifact.named_parameters(remove_duplicate=False)), **dict(artifact.named_buffers(remove_duplicate=False)), From 5a98c72c7f33a4bcc88071aa146c2e012d70fec9 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Fri, 19 Apr 2024 17:08:29 +0800 Subject: [PATCH 06/34] [StableHLO] Fix aten.clamp.Tensor in FxImporter2StableHLO (#3190) The FX importer will pass static shapes to the Torch dialect, so it needs to generate a StableHLO that satisfies shape inference. --- lib/Conversion/TorchToStablehlo/Basic.cpp | 4 ++++ projects/pt1/e2e_testing/xfail_sets.py | 6 +----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 0b9e1291efa6..9bef7de7aafa 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1500,6 +1500,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } maxValue = *maxInfo; } + if (inputType.hasStaticShape()) { + minValue = hlo::promoteAndBroadcast(rewriter, minValue, inputType); + maxValue = hlo::promoteAndBroadcast(rewriter, maxValue, inputType); + } rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ed1f5eb8dd7a..f49be4d5a943 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -642,17 +642,12 @@ "ElementwiseBitwiseRightShiftInt32Module_basic", "ElementwiseBitwiseRightShiftInt64Module_basic", "ElementwiseBitwiseRightShiftInt8Module_basic", - "ElementwiseClampMinTensorFloatModule_basic", - "ElementwiseClampMinTensorIntModule_basic", - "ElementwiseClampTensorFloatModule_basic", - "ElementwiseClampTensorIntModule_basic", "ElementwiseCosIntModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseExpIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", "ElementwiseFmodTensor_Float_basic", @@ -734,6 +729,7 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", + "IndexSelectRank0IdxModule_basic", "IndexTensorNegativeIndexModule_basic", "IntFloatModule_basic", "IntImplicitModule_basic", From 790a697245881e6306aeda7cca2218dc65d64447 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Fri, 19 Apr 2024 22:17:06 +0800 Subject: [PATCH 07/34] [Torch] Add folder for AtenIntOp, AtenFloatOp (#3189) See unit test below: ``` // CHECK-LABEL: func.func @torch.aten.tensor.float( // CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor) : !torch.vtensor<[],f32> func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> { %none = torch.constant.none %false = torch.constant.bool false %float1.000000e01 = torch.constant.float 1.000000e+01 %67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32> return %67 : !torch.vtensor<[],f32> } // CHECK-LABEL: func.func @torch.aten.tensor.int( // CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor) : !torch.vtensor<[],si32> func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> { %none = torch.constant.none %false = torch.constant.bool false %int45 = torch.constant.int 45 %67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32> return %67 : !torch.vtensor<[],si32> } ``` --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 2 + lib/Dialect/Torch/IR/TorchOps.cpp | 44 ++++++++++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 3 ++ .../build_tools/torch_ods_gen.py | 4 +- test/Dialect/Torch/canonicalize.mlir | 20 +++++++++ 5 files changed, 70 insertions(+), 3 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8d00e0f96899..432a3ad6ebc4 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9092,6 +9092,7 @@ def Torch_AtenTensorIntOp : Torch_Op<"aten.tensor.int", [ printDefaultTorchOp(printer, *this, 4, 1); } }]; + let hasFolder = 1; } def Torch_AtenScalarTensorOp : Torch_Op<"aten.scalar_tensor", [ @@ -11577,6 +11578,7 @@ def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [ printDefaultTorchOp(printer, *this, 4, 1); } }]; + let hasFolder = 1; } def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 878e6e7e44a5..e768033ac87f 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3747,6 +3747,8 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { // If a torch.aten.tensor op is initialized by a list with a constant, single // element, fold it into a torch.vtensor.literal auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) + return nullptr; Type eTy = resultTy.getDtype(); ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); @@ -3761,7 +3763,47 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { } //===----------------------------------------------------------------------===// -// AtenTensorOp +// AtenTensorIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) { + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) + return nullptr; + Type eTy = resultTy.getDtype(); + ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + + int64_t data; + if (matchPattern(getT(), m_TorchConstantInt(&data))) { + Attribute attribute = IntegerAttr::get(eTy, data); + return DenseElementsAttr::get(shapedTy, attribute); + } + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AtenTensorFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) { + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) + return nullptr; + Type eTy = resultTy.getDtype(); + ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + + double data; + if (matchPattern(getT(), m_TorchConstantFloat(&data))) { + Attribute attribute = FloatAttr::get(eTy, data); + return DenseElementsAttr::get(shapedTy, attribute); + } + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// Aten_ShapeAsTensorOp //===----------------------------------------------------------------------===// OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f49be4d5a943..93ceb1a7998f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1307,6 +1307,7 @@ "TModuleRank0_basic", "TModuleRank1_basic", "TModuleRank2_basic", + "TensorFloatModule_basic", "TensorIntModule_basic", "TensorLiteralModule_basic", "TensorOpaqueLiteralModule_basic", @@ -1838,6 +1839,8 @@ "TModuleRank1_basic", "TModuleRank2_basic", "TanhBackward_basic", + "TensorFloatModule_basic", + "TensorIntModule_basic", "TensorLiteralModule_basic", "TensorOpaqueLiteralModule_basic", "TensorsConcatNegativeDimStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 78861202caba..e6258ece8fb5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -597,7 +597,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)", has_folder=True) emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") - emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") + emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)", has_folder=True) emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)", has_folder=True) emit("aten::isnan : (Tensor) -> (Tensor)") @@ -691,7 +691,7 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)") emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True) - emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") + emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True) emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 9558e897ae0f..2f7d5a11a216 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1481,6 +1481,26 @@ func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) { return %67 : !torch.vtensor<[1],si64> } +// CHECK-LABEL: func.func @torch.aten.tensor.float( +// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor) : !torch.vtensor<[],f32> +func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> { + %none = torch.constant.none + %false = torch.constant.bool false + %float1.000000e01 = torch.constant.float 1.000000e+01 + %67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32> + return %67 : !torch.vtensor<[],f32> +} + +// CHECK-LABEL: func.func @torch.aten.tensor.int( +// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor) : !torch.vtensor<[],si32> +func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> { + %none = torch.constant.none + %false = torch.constant.bool false + %int45 = torch.constant.int 45 + %67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32> + return %67 : !torch.vtensor<[],si32> +} + // CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> { // CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32> From b01245c0e821323b06802dfac110f36a7a9fa960 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 19 Apr 2024 11:32:24 -0700 Subject: [PATCH 08/34] [onnx] Fix `onnx.Not` for non-bool inputs (#3187) Need to perform a bool cast to support `onnx.Not` on non-bool inputs. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 24 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 3 --- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 92ef81390979..c7d071079119 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -636,6 +636,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) { return failure(); } + + auto loc = binder.getLoc(); + auto operandTy = + cast(operand.getType()); + auto eTy = operandTy.getDtype(); + + if (!eTy.isInteger(1)) { + auto i1ty = rewriter.getI1Type(); + auto ty = rewriter.getType( + operandTy.getSizes(), i1ty); + auto torchqTy = Torch::getScalarTypeForType(i1ty); + Value tyConst = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(64), + static_cast(torchqTy))); + Value none = rewriter.create(loc); + Value cstFalse = + rewriter.create(loc, false); + operand = rewriter.create( + loc, ty, operand, tyConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 93ceb1a7998f..037771b7494b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2723,9 +2723,6 @@ "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", - "ReduceAllDimEmpty_basic", - "ReduceAllDimFloat_basic", - "ReduceAllDimInt_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", From ea0ecb67be3187d82d0860da117d73284c16efb1 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Sun, 21 Apr 2024 00:03:37 +0800 Subject: [PATCH 09/34] [stablehlo] add aten.remainder.Tensor op conversion support (#3197) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 17 +++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 6 +++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 9bef7de7aafa..fa1b7cc530c0 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1817,6 +1817,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenRemainderTensorOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRemainderTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = adaptor.getSelf(); + Value rhs = adaptor.getOther(); + + auto resultType = + cast(getTypeConverter()->convertType(op.getType())); + lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType); + rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType); + rewriter.replaceOpWithNewOp(op, lhs, rhs); + return success(); +} + void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { @@ -1959,6 +1975,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRemainderTensorOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 037771b7494b..ce33b10caa7c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -664,9 +664,6 @@ "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", - "ElementwiseRemainderTensorModule_Float_basic", - "ElementwiseRemainderTensorModule_Int_Float_basic", - "ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRsqrtIntModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", @@ -1074,6 +1071,9 @@ "ElementwisePreluStaticModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseReluModule_basic", + "ElementwiseRemainderTensorModule_Float_basic", + "ElementwiseRemainderTensorModule_Int_Float_basic", + "ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", From b6b01602d3ef358621be7e597c9b2b7b21cb8040 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Sun, 21 Apr 2024 08:39:36 +0800 Subject: [PATCH 10/34] [stablehlo] add aten.fmod.Tensor op conversion support (#3198) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 32 +++++++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 6 ++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index fa1b7cc530c0..0358e79868d2 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1833,6 +1833,37 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenFmodTensorOp +// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenFmodTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + Value lhs = adaptor.getSelf(); + Value rhs = adaptor.getOther(); + + auto resultType = + cast(getTypeConverter()->convertType(op.getType())); + lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType); + rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType); + + stablehlo::MulOp mul; + auto div = rewriter.create(loc, lhs, rhs); + if (isa(resultType.getElementType())) { + // rounding mode is trunc + auto sign = rewriter.create(loc, div); + auto abs = rewriter.create(loc, div); + auto floor = rewriter.create(loc, abs); + auto trunc = rewriter.create(loc, sign, floor); + mul = rewriter.create(loc, trunc, rhs); + } else { + mul = rewriter.create(loc, div, rhs); + } + rewriter.replaceOpWithNewOp(op, lhs, mul); + return success(); +} + void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { @@ -1976,6 +2007,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFlipOp); INSERT_ATENOP_PATTERN(AtenRemainderTensorOp); + INSERT_ATENOP_PATTERN(AtenFmodTensorOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ce33b10caa7c..35327f367fe2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -650,9 +650,6 @@ "ElementwiseErfIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", - "ElementwiseFmodTensor_Float_basic", - "ElementwiseFmodTensor_Int_Float_basic", - "ElementwiseFmodTensor_Int_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog10Module_basic", "ElementwiseLog2IntModule_basic", @@ -1056,6 +1053,9 @@ "ElementwiseExpModule_basic", "ElementwiseFloorIntModule_basic", "ElementwiseFloorModule_basic", + "ElementwiseFmodTensor_Float_basic", + "ElementwiseFmodTensor_Int_Float_basic", + "ElementwiseFmodTensor_Int_basic", "ElementwiseGeluApproximateTanhModule_basic", "ElementwiseGeluModule_basic", "ElementwiseLeakyReluStaticModule_basic", From 733cace1dfd69adae99e7895c0c28955f21fb3b0 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Sun, 21 Apr 2024 09:31:56 -0700 Subject: [PATCH 11/34] [onnx] Fix `onnx.split` by directly handling slicing (#3194) Previous implementation erroneously mixed up num_outputs with slice_size. New version correctly computs the slice size and directly performs slicing rather than leveraging `aten.split.tensor`. This is due to `onnx` supporting a fixed number of splits making the size computation more easily computeable when lowering to `aten` rather than deferring to `aten.split.tensor`. --------- Co-authored-by: Robert Suderman --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 74 +++++++++++-------- projects/pt1/e2e_testing/xfail_sets.py | 1 - .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 20 +++-- 3 files changed, 59 insertions(+), 36 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index b4bd102f152f..8f6788620018 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1379,7 +1379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value self; int64_t axis; - int64_t num_outputs; + int64_t numOutputs; if (binder.tensorOperand(self)) return rewriter.notifyMatchFailure( binder.op, "Not converting to AtenSplitTensorOp due to input " @@ -1387,49 +1387,65 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.s64IntegerAttr(axis, "axis", 0)) return rewriter.notifyMatchFailure(binder.op, "Failed to get axis attribute"); - if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0)) + if (binder.s64IntegerAttr(numOutputs, "num_outputs", 2)) return rewriter.notifyMatchFailure( binder.op, "Failed to get num_outputs attribute"); + auto loc = binder.getLoc(); auto result0Ty = binder.op->getResult(0).getType().cast(); + auto resultNTy = binder.op->getResults() + .back() + .getType() + .cast(); auto selfTy = self.getType().cast(); int64_t dim = axis; if (dim < 0) dim += selfTy.getSizes().size(); - // set intermediate shape to the shape of the first result - // if the results are of different shapes - // set the splitted axis to variable shape - llvm::SmallVector intermediateShape(result0Ty.getSizes()); - for (auto result : binder.op->getResultTypes()) { - int64_t d = cast(result).getSizes()[dim]; - intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; - } - Value dimValue = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); + loc, rewriter.getType(), + rewriter.getI64IntegerAttr(dim)); - Value splitSize = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), num_outputs)); + Value vNumOutputs = rewriter.create( + loc, rewriter.getType(), + rewriter.getI64IntegerAttr(numOutputs)); - // TODO: Attempting to use the shape expected by the ONNX mlir as ground - // truth. For now just use dynamic shapes. - auto resultOuterType = - Torch::ListType::get(rewriter.getType( - /*std::optional>=*/intermediateShape, - result0Ty.getOptionalDtype())); - Torch::AtenSplitTensorOp new_op = - rewriter.create( - binder.getLoc(), resultOuterType, self, splitSize, dimValue); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + Value vDimSize = rewriter.create( + loc, rewriter.getType(), self, dimValue); + + Value addNumOutputs = + rewriter.create(loc, vDimSize, vNumOutputs); + Value subOne = + rewriter.create(loc, addNumOutputs, one); + Value splitSize = + rewriter.create(loc, subOne, vNumOutputs); + + llvm::SmallVector outputs; + Value step = one; + Value start = zero; + + for (int i = 0; i < numOutputs - 1; ++i) { + Value end = + rewriter.create(loc, start, splitSize); + Value slice = rewriter.create( + loc, result0Ty, self, dimValue, start, end, step); + start = end; + outputs.push_back(slice); + } - // the onnx op is variadic with multiple results, but AtenSplitWithSizes - // outputs a list so we need to unpack the list - rewriter.replaceOpWithNewOp( - binder.op, binder.op->getResults().getType(), new_op.getResult()); + Value end = vDimSize; + Value lastSlice = rewriter.create( + loc, resultNTy, self, dimValue, start, end, step); + outputs.push_back(lastSlice); + + rewriter.replaceOp(binder.op, outputs); return success(); }); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 35327f367fe2..ec4d3a8daa20 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2718,7 +2718,6 @@ "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", - "GluStaticModule_basic", "GroupNormModule_basic", "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 0fdecd68481e..47497d5ea5ba 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1288,12 +1288,20 @@ func.func @test_split_variable_parts_2d_opset18(%arg0: !torch.vtensor<[2,6],f32> // CHECK-LABEL: func.func @test_split_2d_uneven_split_opset18( // CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { -// CHECK: %[[AXIS:.*]] = torch.constant.int 1 -// CHECK: %[[SPLIT_SIZE:.*]] = torch.constant.int 3 -// CHECK: %[[SPLIT_RESULT:.*]] = torch.aten.split.Tensor %[[INPUT_TENSOR]], %[[SPLIT_SIZE]], %[[AXIS]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int -> !torch.list> -// CHECK: %[[UNPACKED_TENSORS:.*]]:3 = torch.prim.ListUnpack %[[SPLIT_RESULT]] : !torch.list> -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> -// CHECK: return %[[UNPACKED_TENSORS]]#0, %[[UNPACKED_TENSORS]]#1, %[[UNPACKED_TENSORS]]#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> -// CHECK: } +// CHECK-DAG: %[[DIM:.+]] = torch.constant.int 1 +// CHECK-DAG: %[[SPLITS:.+]] = torch.constant.int 3 +// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 +// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 +// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[DIM]] +// CHECK-DAG: %[[ADD:.+]] = torch.aten.add.int %[[SZ1]], %[[SPLITS]] +// CHECK-DAG: %[[SUB:.+]] = torch.aten.sub.int %[[ADD]], %[[ONE]] +// CHECK-DAG: %[[SLICESZ:.+]] = torch.aten.floordiv.int %[[SUB]], %[[SPLITS]] +// CHECK-DAG: %[[START1:.+]] = torch.aten.add.int %[[ZERO]], %[[SLICESZ]] : !torch.int, !torch.int -> !torch.int +// CHECK-DAG: %[[SLICE0:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[ZERO]], %[[START1]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +// CHECK-DAG: %[[START2:.+]] = torch.aten.add.int %[[START1]], %[[SLICESZ]] : !torch.int, !torch.int -> !torch.int +// CHECK-DAG: %[[SLICE1:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[START1]], %[[START2]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +// CHECK-DAG: %[[SLICE2:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[START2]], %[[SZ1]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> +// CHECK: return %[[SLICE0]], %[[SLICE1]], %[[SLICE2]] func.func @test_split_2d_uneven_split_opset18(%arg0: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0:3 = torch.operator "onnx.Split"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.num_outputs = 3 : si64} : (!torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) return %0#0, %0#1, %0#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> From 8222637159dc98f003f1e70f183dc11fc81e6b48 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Sun, 21 Apr 2024 09:32:18 -0700 Subject: [PATCH 12/34] [onnx] Extend op version number of `onnx.ScatterElements` (#3195) Version number was set too high. Lowered to support more cases allows more tests to pass. Co-authored-by: Robert Suderman --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 +- projects/pt1/e2e_testing/xfail_sets.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 8f6788620018..65bfb6257774 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -478,7 +478,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); patterns.onOp( - "ScatterElements", 18, + "ScatterElements", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; SmallVector valList; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ec4d3a8daa20..68fdbb961b26 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2647,10 +2647,7 @@ "ScatterReduceIntMinModuleIncludeSelf", "ScatterReduceIntProdModuleIncludeSelf", "ScatterReduceIntSumModuleIncludeSelf", - "ScatterSrcModule_basic", - "ScatterSrcStaticModule_basic", "ScatterValueFloatModule_basic", - "ScatterValueIntModule_basic", # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic", From a60e84e5ee73676ac4d1eef24c1a8a5a0f4ae493 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Mon, 22 Apr 2024 10:20:49 +0800 Subject: [PATCH 13/34] [stablehlo] add aten.expm1 op conversion support (#3199) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 0358e79868d2..eca7c30259de 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1881,6 +1881,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenAbsOp, stablehlo::AbsOp); + INSERT_UNARY_PATTERN(AtenExpm1Op, stablehlo::Expm1Op); #undef INSERT_UNARY_PATTERN #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 68fdbb961b26..1fd88ce6e238 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -648,8 +648,6 @@ "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseExpm1IntModule_basic", - "ElementwiseExpm1Module_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog10Module_basic", "ElementwiseLog2IntModule_basic", @@ -1051,6 +1049,8 @@ "ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic", "ElementwiseErfModule_basic", "ElementwiseExpModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", "ElementwiseFloorIntModule_basic", "ElementwiseFloorModule_basic", "ElementwiseFmodTensor_Float_basic", From e5bdd71baf1728bd76a63813c5debda40296d84e Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Mon, 22 Apr 2024 10:45:01 +0800 Subject: [PATCH 14/34] [Torch] Emit and decompose prims.iota op (#3132) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 28 +++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 7 +++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 30 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 5 ++++ .../build_tools/abstract_interp_lib_gen.py | 6 ++++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/arange.py | 17 +++++++++++ 7 files changed, 94 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 432a3ad6ebc4..8f949d9ba195 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15909,6 +15909,34 @@ def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [ let hasFolder = 1; } +def Torch_PrimsIotaOp : Torch_Op<"prims.iota", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prims::iota : (int, int, int, int, Device, bool) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$length, + Torch_IntType:$start, + Torch_IntType:$step, + Torch_IntType:$dtype, + Torch_DeviceType:$device, + Torch_BoolType:$requires_grad + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsIotaOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void PrimsIotaOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ HasValueSemantics, AllowsTypeRefinement, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5ad9b8216ca7..a1cc7ddf6ea6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8653,6 +8653,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list {\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 97bd85063e68..87f93ba9c555 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4789,6 +4789,35 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern { }; } // namespace +namespace { +// The `prims.iota` op is converted to `aten.arange.startStep` op. +class DecomposePrimsIotaOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimsIotaOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + int64_t length, start, step; + if (!matchPattern(op.getLength(), m_TorchConstantInt(&length))) + return rewriter.notifyMatchFailure( + op, "unimplemented: low must be a constant integer"); + if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + return rewriter.notifyMatchFailure( + op, "unimplemented: low must be a constant integer"); + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + return rewriter.notifyMatchFailure( + op, "unimplemented: low must be a constant integer"); + auto endVal = rewriter.create( + loc, rewriter.getI64IntegerAttr(start + length * step)); + auto none = rewriter.create(loc); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getStart(), endVal, op.getStep(), op.getDtype(), + none, op.getDevice(), none); + return success(); + } +}; +} // namespace + namespace { // Decompose constant tensor full like ops. template @@ -7605,6 +7634,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1fd88ce6e238..0607a972055d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1228,6 +1228,7 @@ "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", "PrimsConvertElementTypeModule_basic", + "PrimsIotaModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", @@ -1789,6 +1790,7 @@ "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "PrimListUnpackNumMismatchModule_basic", + "PrimsIotaModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", @@ -2683,6 +2685,9 @@ "SqueezeModule_allUnitDim", "SqueezeModule_broadcast", "SqueezeModule_static", + + # RuntimeError: unsupported input type: Device + "PrimsIotaModule_basic", # Failure - unknown "BernoulliModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 5d63d7a7db0c..8b32ff602697 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1319,6 +1319,12 @@ def prims〇view_of〡dtype(a_rank_dtype: Tuple[int, int]) -> int: _, a_dtype = a_rank_dtype return a_dtype +def prims〇iota〡shape(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> List[int]: + return [length] + +def prims〇iota〡dtype(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> int: + return dtype + def prim〇NumToTensor〇Scalar〡shape(a: float) -> List[int]: return [] diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index e6258ece8fb5..5c4e4d214932 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -897,6 +897,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("prims::split_dim : (Tensor, int, int) -> (Tensor)") emit("prims::squeeze : (Tensor, int[]) -> (Tensor)") emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True) + emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)") # ========================================================================== # `quantized::` namespace. diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py index fff3e60c4605..9489013076d7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py @@ -380,3 +380,20 @@ def forward(self): @register_test_case(module_factory=lambda: LinspaceTwoSizeModule()) def LinspaceTwoSizeModule_basic(module, tu: TestUtils): module.forward() + + +class PrimsIotaModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device='cpu', + requires_grad=False) + +@register_test_case(module_factory=lambda: PrimsIotaModule()) +def PrimsIotaModule_basic(module, tu: TestUtils): + module.forward() From 6abc7371c8279a678a757af2fbf085502bd7e696 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 22 Apr 2024 14:22:42 +0530 Subject: [PATCH 15/34] [MLIR][TORCH] Fix OnnxToLinalg lowering issue for Squeeze and Unsqueeze op (#2991) This commit also cleans up the OnnxToTorch lowering for the Squeeze and Unsqueeze op and adds the support for handling edge cases. Signed-Off By: Vivek Khandelwal --- .../Conversion/TorchOnnxToTorch/Utils.h | 59 ++++ .../torch-mlir/Dialect/Torch/IR/TorchOps.h | 2 +- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 253 +++++++++--------- projects/pt1/e2e_testing/xfail_sets.py | 28 +- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 231 +++------------- 5 files changed, 227 insertions(+), 346 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index b62f9dbaf4b5..8e9de1ff5940 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -10,11 +10,26 @@ #ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H #define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +class Endian { +private: + static constexpr uint32_t uint32_ = 0x01020304; + static constexpr uint8_t magic_ = (const uint8_t &)uint32_; + +public: + static constexpr bool little = magic_ == 0x04; + static constexpr bool big = magic_ == 0x01; + static_assert(little || big, "Cannot determine endianness!"); + +private: + Endian() = delete; +}; + namespace mlir::torch::onnx_c { Value createConstantIntList(OpBinder binder, @@ -28,6 +43,50 @@ LogicalResult OnnxLstmExpander(OpBinder binder, bool areAllElementsDistinct(SmallVector array); +namespace detail { +/// Matches the constant integers stored in a `onnx.Constant`. +struct onnx_list_of_constant_ints_op_binder { + SmallVectorImpl &bind_values; + + /// Creates a matcher instance that binds the value to bvs if match succeeds. + onnx_list_of_constant_ints_op_binder(SmallVectorImpl &bvs) + : bind_values(bvs) {} + + bool match(Operation *op) { + auto constOp = dyn_cast(op); + if (!constOp || !constOp.getName().equals("onnx.Constant")) + return false; + + if (DenseResourceElementsAttr attr = + constOp->getAttr("torch.onnx.value") + .dyn_cast_or_null()) { + // Bytes are stored in little endian order. Big endian support will + // require swizzling. + if (!Endian::little) { + op->emitError("unimplemented: importing on big endian systems"); + return false; + } + + auto ty = cast(attr.getType()); + ElementsAttr denseAttr; + auto ptr = attr.getRawHandle().getBlob()->getData(); + denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr); + for (auto axis : denseAttr.getValues()) { + bind_values.push_back(axis.getSExtValue()); + } + return true; + } + return false; + } +}; +} // namespace detail + +/// Matches the constant integers stored in a `onnx.Constant`. +inline detail::onnx_list_of_constant_ints_op_binder +m_OnnxListOfConstantInts(SmallVectorImpl &bind_values) { + return detail::onnx_list_of_constant_ints_op_binder(bind_values); +} + } // namespace mlir::torch::onnx_c #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index e6a9e1622cc1..4508518bf297 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -142,7 +142,7 @@ m_TorchConstantBool(bool *bind_value) { } namespace detail { -/// Matches the constant integers stored in a `torch.ListConstruct`. +/// Matches the constant integers stored in a `torch.prim.ListConstruct`. struct torch_list_of_constant_ints_op_binder { SmallVectorImpl &bind_values; diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 65bfb6257774..7630fcfa1108 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -661,57 +661,86 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( patterns.onOp( "Squeeze", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; - Value data; - Value axes; - if (binder.tensorOperands(data, axes) || + SmallVector inputOperands; + if (binder.tensorOperands(inputOperands, binder.op->getNumOperands()) || binder.tensorResultType(resultType)) return failure(); - Torch::BaseTensorType axesType = - axes.getType().cast(); - SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = axesType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); - auto sizes = - dyn_cast(axes.getType()).getSizes(); - if (sizes.size() == 0) { + + Value data = inputOperands[0]; + auto inputType = data.getType().dyn_cast(); + if (!inputType.hasSizes() || !resultType.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: expected input and result to have shapes"); + + int64_t inputRank = inputType.getSizes().size(); + int64_t resultRank = resultType.getSizes().size(); + int64_t rankDiff = inputRank - resultRank; + if (rankDiff == 0) { + // In this case, no dimension is squeezed. Hence just replace the op + // with input. + rewriter.replaceOp(binder.op, data); + return success(); + } + + if (inputOperands.size() == 1) { + // Case: `axes` value is not present which means squeeze all the + // dimensions with shape value 1. rewriter.replaceOpWithNewOp(binder.op, resultType, data); return success(); } - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - int64_t adjustmentInt = - cast(data.getType()).getSizes().size(); - Value adjustment = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adjustmentInt)); - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( + + SmallVector dimList; + if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { + // If the input shape and result shape is statically known then the + // list of dims to be squeezed can be derived from those shapes. As a + // result, we don't have to wait for the dim values to be known at + // runtime which is also expected by the downstream pipeline. + SmallVector inputShape(inputType.getSizes()); + SmallVector resultShape(resultType.getSizes()); + SmallVector squeezeDims; + unsigned resultShapeCounter = 0; + for (unsigned i = 0; i < inputRank; i++) { + if (resultShapeCounter < resultRank && + inputShape[i] == resultShape[resultShapeCounter]) { + resultShapeCounter++; + } else { + squeezeDims.push_back(i); + } + } + for (auto i : squeezeDims) { + dimList.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } + + if (dimList.empty()) { + Value axes = inputOperands[1]; + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector selectSizes{1}; + Type selectResultType = axesType.getWithSizesAndDtype( + selectSizes, axesType.getOptionalDtype()); + Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - // deal with neg axis: if (axis < 0) axis += rank - Value isNegative = - rewriter.create(binder.getLoc(), dim, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, adjustment); - Value finalDim = rewriter.create( - binder.getLoc(), dim, finalOffset); - dimList.push_back(finalDim); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + for (int i = 0; i < rankDiff; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + dimList.push_back(dim); + } } Value dimValueList = rewriter.create( binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + rewriter.getType( + rewriter.getType()), dimList); rewriter.replaceOpWithNewOp( binder.op, resultType, data, dimValueList); @@ -725,103 +754,67 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // discussion can be found here: // https://github.com/pytorch/pytorch/issues/9410 // So, for now, we unroll into multiple unsqueezes. + Location loc = binder.getLoc(); Torch::ValueTensorType resultType; - Value data; - Value axes; + Value data, axes; if (binder.tensorOperands(data, axes) || binder.tensorResultType(resultType)) return failure(); - Torch::BaseTensorType axesType = - axes.getType().cast(); - SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = axesType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); - auto sizes = - dyn_cast(axes.getType()).getSizes(); - if (sizes.size() == 0) { + auto inputType = data.getType().dyn_cast(); + if (!inputType.hasSizes() || !resultType.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: expected input and result to have shapes"); + + int64_t inputRank = inputType.getSizes().size(); + int64_t resultRank = resultType.getSizes().size(); + int64_t rankDiff = resultRank - inputRank; + if (rankDiff == 0) { + // In this case, no dimension is unsqueezed. Hence just replace the op + // with input. rewriter.replaceOp(binder.op, data); return success(); } - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - int64_t adjustmentInt = - cast(data.getType()).getSizes().size(); - Value adjustment = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adjustmentInt)); - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - // deal with neg axis: if (axis < 0) axis += rank - Value isNegative = - rewriter.create(binder.getLoc(), dim, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, adjustment); - Value finalDim = rewriter.create( - binder.getLoc(), dim, finalOffset); - dimList.push_back(finalDim); - } - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - dimList); - Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value noneVal = rewriter.create(binder.getLoc()); - Value updatedAxes = rewriter.create( - binder.getLoc(), - axesType.getWithSizesAndDtype(sizes, axesType.getOptionalDtype()), - dimValueList, /*dtype=*/noneVal, /*device=*/noneVal, cstFalse); - // Sort the list of dims, so we don't run into this situation: - // data.sizes = [2, 3, 4] - // dims = [4, 0] - // index 4 will be invalid to add a singleton dimension because - // data.sizes.size == 3 We have to work with sorted dims to avoid this - // situation. - auto sortIndicesType = axesType.getWithSizesAndDtype( - axesType.getOptionalSizes(), - IntegerType::get(binder.op->getContext(), 64, IntegerType::Signed)); - auto sortOpResult = rewriter.create( - binder.getLoc(), axes.getType(), sortIndicesType, updatedAxes, zero, - cstFalse); - Value result; - auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( - binder.op->getContext()); - // Go through the updated, sorted axes. Do unsqueeze for each dim. - for (int i = 0; i < sizes[0]; i++) { - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, sortOpResult->getResult(0), - zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - if (sizes[0] == 1) { - result = rewriter.create( - binder.getLoc(), resultType, data, dim); - } else if (i == 0) { - result = rewriter.create( - binder.getLoc(), baseType, data, dim); - } else if (i == sizes[0] - 1) { - result = rewriter.create( - binder.getLoc(), resultType, result, dim); - } else { - result = rewriter.create( - binder.getLoc(), baseType, result, dim); + + SmallVector unsqueezeDims; + SmallVector inputShape(inputType.getSizes()); + if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { + // If the input shape and result shape is statically known then the + // list of dims to be squeezed can be derived from those shapes. As a + // result, we don't have to wait for the dim values to be known at + // runtime which is also expected by the downstream pipeline. + SmallVector resultShape(resultType.getSizes()); + unsigned inputShapeCounter = 0; + for (unsigned i = 0; i < resultRank; i++) { + if (inputShapeCounter < inputRank && + inputShape[inputShapeCounter] == resultShape[i]) { + inputShapeCounter++; + } else { + unsqueezeDims.push_back(i); + } } + } else { + SmallVector unsqueezeDimsInts; + if (!matchPattern(axes, m_OnnxListOfConstantInts(unsqueezeDimsInts))) + return rewriter.notifyMatchFailure( + binder.op, "only support constant int axes values"); + + for (auto dim : unsqueezeDimsInts) + unsqueezeDims.push_back(dim < 0 ? dim + resultRank : dim); + // If we don't sort, unsqueezing first on 4 and then on 0 would fail + // for shape = {x,y,z}, and axes [4,0] + llvm::sort(unsqueezeDims.begin(), unsqueezeDims.end()); + } + Value result = data; + SmallVector unsqueezeShape = inputShape; + for (auto dim : unsqueezeDims) { + unsqueezeShape.insert(unsqueezeShape.begin() + dim, 1); + Type unsqueezeType = resultType.getWithSizesAndDtype( + unsqueezeShape, resultType.getOptionalDtype()); + Value cstDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(dim)); + result = rewriter.create(loc, unsqueezeType, + result, cstDim); } rewriter.replaceOp(binder.op, result); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0607a972055d..64a9d3bb6169 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2643,12 +2643,8 @@ # Failure - onnx_lowering: onnx.ScatterElements "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMinModuleIncludeSelf", - "ScatterReduceFloatProdModuleIncludeSelf", - "ScatterReduceFloatSumModuleIncludeSelf", "ScatterReduceIntMaxModuleIncludeSelf", "ScatterReduceIntMinModuleIncludeSelf", - "ScatterReduceIntProdModuleIncludeSelf", - "ScatterReduceIntSumModuleIncludeSelf", "ScatterValueFloatModule_basic", # Failure - onnx_lowering: onnx.ScatterND @@ -2680,22 +2676,12 @@ # Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - - # Failure - onnx_lowering: onnx.Squeeze - "SqueezeModule_allUnitDim", - "SqueezeModule_broadcast", - "SqueezeModule_static", # RuntimeError: unsupported input type: Device "PrimsIotaModule_basic", - + # Failure - unknown "BernoulliModule_basic", - "BucketizeTensorFloatModule_basic", - "BucketizeTensorModule_basic", - "BucketizeTensorOutInt32RightModule_basic", - "BucketizeTensorStaticFloatModule_basic", - "BucketizeTensorStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", @@ -2712,22 +2698,16 @@ "ElementwiseErfIntModule_basic", "ElementwiseExpIntModule_basic", "ElementwiseLogIntModule_basic", - "ElementwisePreluModule_basic", - "ElementwisePreluStaticModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", - "ElementwiseUnsqueezeNegDimsModule_basic", - "GroupNormModule_basic", "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "TensorsStackNegativeDimModule_basic", - "TensorsStackPromoteDTypeModule_basic", } if torch_version_for_comparison() >= version.parse("2.4.0.dev"): @@ -2746,6 +2726,10 @@ ONNX_CRASHING_SET = { "FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic", - + "ElementwisePreluModule_basic", "ViewDynamicExpandCollapseWithParallelUnknownDimModule_basic", + "ScatterReduceFloatProdModuleIncludeSelf", + "ScatterReduceFloatSumModuleIncludeSelf", + "ScatterReduceIntProdModuleIncludeSelf", + "ScatterReduceIntSumModuleIncludeSelf", } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 47497d5ea5ba..de3e796f4e5d 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -424,19 +424,34 @@ func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch // ----- +// CHECK-LABEL: func.func @test_squeeze_no_axes +func.func @test_squeeze_no_axes(%arg0: !torch.vtensor<[1,3,1,4,1,5,1,1],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.squeeze %arg0 : !torch.vtensor<[1,3,1,4,1,5,1,1],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Squeeze"(%arg0) : (!torch.vtensor<[1,3,1,4,1,5,1,1],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_squeeze_five_axes +func.func @test_squeeze_five_axes(%arg0: !torch.vtensor<[1,3,1,4,1,5,1,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[INT7:.*]] = torch.constant.int 7 + // CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT4]], %[[INT6]], %[[INT7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[1,3,1,4,1,5,1,1],f32>, !torch.list -> !torch.vtensor<[3,1,4,5],f32> + %0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,1,4,1,5,1,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,1,4,5],f32> + return %0 : !torch.vtensor<[3,1,4,5],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_squeeze func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT4:.*]] = torch.constant.int 4 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: torch.prims.squeeze %arg0, %6 : !torch.vtensor<[1,3,4,5],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list + // CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[1,3,4,5],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -445,24 +460,10 @@ func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vten // CHECK-LABEL: func.func @test_squeeze_two_axes func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT5:.*]] = torch.constant.int 5 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int5 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %9, %int5 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list - // CHECK: torch.prims.squeeze %arg0, %12 : !torch.vtensor<[3,1,4,5,1],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[3,1,4,5,1],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[3,1,4,5,1],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -472,23 +473,7 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: // CHECK-LABEL: func.func @test_unsqueeze_axis_0 func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: torch.constant.bool false - // CHECK: torch.constant.none - // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[1,3,4,5],f32> + // CHECK: torch.aten.unsqueeze %arg0, %[[INT0:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[1,3,4,5],f32> %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> return %0 : !torch.vtensor<[1,3,4,5],f32> } @@ -497,24 +482,8 @@ func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor // CHECK-LABEL: func.func @test_unsqueeze_axis_1 func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,1,4,5],f32> + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.unsqueeze %arg0, %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,1,4,5],f32> %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> return %0 : !torch.vtensor<[3,1,4,5],f32> } @@ -523,146 +492,22 @@ func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor // CHECK-LABEL: func.func @test_unsqueeze_axis_2 func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32> + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: torch.aten.unsqueeze %arg0, %[[INT2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32> %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> return %0 : !torch.vtensor<[3,4,1,5],f32> } // ----- -// CHECK-LABEL: func.func @test_unsqueeze_negative_axes -func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT4:.*]] = torch.constant.int 4 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[1,3,1,5],f32>, !torch.int -> !torch.vtensor<[1,3,1,1,5],f32> - %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,1,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> - return %0 : !torch.vtensor<[1,3,1,1,5],f32> -} - -// ----- - // CHECK-LABEL: func.func @test_unsqueeze_three_axes func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64> - // CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor - // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor - // CHECK: %[[INT2_3:.*]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32> - %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> - return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> -} - -// ----- - -// CHECK-LABEL: func.func @test_unsqueeze_unsorted_axes -func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64> - // CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor - // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor - // CHECK: %[[INT2_3:.*]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32> + // CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %arg0, %[[INT2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32> + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[UNSQUEEZE_1:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE]], %[[INT4]] : !torch.vtensor<[3,4,1,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5,1],f32> + // CHECK: %[[INT5:.*]] = torch.constant.int 5 + // CHECK: torch.aten.unsqueeze %[[UNSQUEEZE_1]], %[[INT5]] : !torch.vtensor<[3,4,1,5,1],f32>, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32> %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> } From 3c252cdd44f411ef67e3a759319be53b46396d44 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 22 Apr 2024 22:28:07 +0530 Subject: [PATCH 16/34] [onnx] Add `onnx-to-torch` lowering for random ops (#3193) This commit adds the OnnxToTorch lowering for Onnx's RandomNormal, RandomNormalLike, RandomUniform, and RandomUniformLike op. --- .../Conversion/TorchOnnxToTorch/Utils.h | 2 + .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 92 ++------ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 214 ++++++++++++++++++ lib/Conversion/TorchOnnxToTorch/Utils.cpp | 38 ++++ projects/pt1/e2e_testing/xfail_sets.py | 14 +- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 62 +++++ 6 files changed, 339 insertions(+), 83 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 8e9de1ff5940..d4ace352a9bd 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -87,6 +87,8 @@ m_OnnxListOfConstantInts(SmallVectorImpl &bind_values) { return detail::onnx_list_of_constant_ints_op_binder(bind_values); } +std::optional onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx); + } // namespace mlir::torch::onnx_c #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 7032ddcd208e..b16e76e3afe5 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/Support/FormatVariadic.h" @@ -17,56 +18,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; -class Endian { -private: - static constexpr uint32_t uint32_ = 0x01020304; - static constexpr uint8_t magic_ = (const uint8_t &)uint32_; - -public: - static constexpr bool little = magic_ == 0x04; - static constexpr bool big = magic_ == 0x01; - static_assert(little || big, "Cannot determine endianness!"); - -private: - Endian() = delete; -}; - -static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { - // TODO: Add complete mapping. - // Where are the ONNX and PyTorch dtype enums defined? - // ONNX: - // https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto - // PyTorch: - // https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88 - - int64_t dtypeIntTorch = [dtypeIntOnnx]() { - switch (dtypeIntOnnx) { - case 1: - return 6; // float - case 2: - return 0; // uint8 - case 3: - return 1; // int8 - case 6: - return 3; // int32 - case 7: - return 4; // int64 - case 9: - return 11; // bool - case 10: - return 5; // half - case 11: - return 7; // double - case 16: - return 15; // bfloat16 - default: - return -1; // No dtype - } - }(); - - return dtypeIntTorch; -} - static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t dimA, int64_t dimB, @@ -428,7 +379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input; - int64_t dtypeIntOnnx, dtypeIntTorch; + int64_t dtypeIntOnnx; if (binder.tensorOperand(input) || binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) || binder.tensorResultType(resultType)) @@ -452,16 +403,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( rewriter.replaceOp(binder.op, bernoulli); return success(); } - dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); - if (dtypeIntTorch == -1) { + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given dtype conversion"); } Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - dtypeIntTorch)); + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); Value cstFalse = rewriter.create(binder.getLoc(), false); rewriter.replaceOpWithNewOp( @@ -539,25 +489,21 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; - int64_t dtypeIntOnnx, dtypeIntTorch; + int64_t dtypeIntOnnx; if (binder.tensorOperand(operand) || binder.s64IntegerAttr(dtypeIntOnnx, "to") || binder.tensorResultType(resultType)) return failure(); - dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); - if (dtypeIntTorch == -1) { - auto message = llvm::formatv("unimplemented support for the given " - "dtype conversion (onnx 'type' = {0})", - dtypeIntOnnx); - auto y = rewriter.notifyMatchFailure(binder.op, message); - - return y; + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); } Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - dtypeIntTorch)); + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); Value none = rewriter.create(binder.getLoc()); Value cstFalse = rewriter.create(binder.getLoc(), false); @@ -1768,9 +1714,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value mVal = rewriter.create(binder.getLoc(), operand, cst1); Value noneVal = rewriter.create(binder.getLoc()); - int64_t dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } Value dtypeVal = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch)); + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); // diagonalIndex = 0 populates the main diagonal // diagonalIndex > 0 populates an upper diagonal diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 7630fcfa1108..6c86ecb92789 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2274,4 +2274,218 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, input, cstAlpha, value); return success(); }); + patterns.onOp( + "RandomNormal", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallString<64> name("torch.onnx.seed"); + auto seedAttr = binder.op->getAttr(name); + if (seedAttr) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + + Torch::ValueTensorType resultType; + int64_t dtypeIntOnnx; + float mean, scale; + SmallVector shape; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.f32FloatAttr(mean, "mean", 0.0) || + binder.f32FloatAttr(scale, "scale", 1.0) || + binder.s64IntegerArrayAttr(shape, "shape", {}) || + binder.tensorResultType(resultType)) { + return failure(); + } + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, shape); + Value cstNone = rewriter.create(binder.getLoc()); + + Value self = rewriter.create( + binder.op->getLoc(), resultType, shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + Value cstMean = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), mean)); + Value cstStd = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), scale)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, cstMean, cstStd, + /*generator=*/cstNone); + return success(); + }); + patterns.onOp( + "RandomNormalLike", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallString<64> name("torch.onnx.seed"); + auto seedAttr = binder.op->getAttr(name); + if (seedAttr) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + + Torch::ValueTensorType resultType; + int64_t dtypeIntOnnx; + float mean, scale; + SmallVector shape; + Value input; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.f32FloatAttr(mean, "mean", 0.0) || + binder.f32FloatAttr(scale, "scale", 1.0) || + binder.tensorResultType(resultType)) { + return failure(); + } + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value cstNone = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + input = rewriter.create( + binder.op->getLoc(), resultType, input, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstNone); + + Value cstMean = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), mean)); + Value cstStd = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), scale)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, cstMean, cstStd, + /*generator=*/cstNone); + return success(); + }); + patterns.onOp( + "RandomUniform", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallString<64> name("torch.onnx.seed"); + auto seedAttr = binder.op->getAttr(name); + if (seedAttr) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + + Torch::ValueTensorType resultType; + int64_t dtypeIntOnnx; + float high, low; + SmallVector shape; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.f32FloatAttr(high, "high", 1.0) || + binder.f32FloatAttr(low, "low", 0.0) || + binder.s64IntegerArrayAttr(shape, "shape", {}) || + binder.tensorResultType(resultType)) { + return failure(); + } + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, shape); + Value cstNone = rewriter.create(binder.getLoc()); + + Value self = rewriter.create( + binder.op->getLoc(), resultType, shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + Value cstHigh = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), high)); + Value cstLow = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), low)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, cstLow, cstHigh, + /*generator=*/cstNone); + return success(); + }); + patterns.onOp( + "RandomUniformLike", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallString<64> name("torch.onnx.seed"); + auto seedAttr = binder.op->getAttr(name); + if (seedAttr) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + + Torch::ValueTensorType resultType; + int64_t dtypeIntOnnx; + float high, low; + SmallVector shape; + Value input; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.f32FloatAttr(high, "high", 1.0) || + binder.f32FloatAttr(low, "low", 0.0) || + binder.tensorResultType(resultType)) { + return failure(); + } + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value cstNone = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + input = rewriter.create( + binder.op->getLoc(), resultType, input, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstNone); + + Value cstHigh = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), high)); + Value cstLow = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), low)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, cstLow, cstHigh, + /*generator=*/cstNone); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index 2d24303394dd..dec13490666e 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -59,3 +59,41 @@ bool mlir::torch::onnx_c::areAllElementsDistinct(SmallVector array) { // as array's size. return (set.size() == array.size()); } + +std::optional +mlir::torch::onnx_c::onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { + // TODO: Add complete mapping. + // Where are the ONNX and PyTorch dtype enums defined? + // ONNX: + // https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto + // PyTorch: + // https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88 + + std::optional dtypeIntTorch = + [dtypeIntOnnx]() -> std::optional { + switch (dtypeIntOnnx) { + case 1: + return 6; // float + case 2: + return 0; // uint8 + case 3: + return 1; // int8 + case 6: + return 3; // int32 + case 7: + return 4; // int64 + case 9: + return 11; // bool + case 10: + return 5; // half + case 11: + return 7; // double + case 16: + return 15; // bfloat16 + default: + return std::nullopt; // No dtype + } + }(); + + return dtypeIntTorch; +} diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 64a9d3bb6169..323a39bf33cb 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2605,27 +2605,15 @@ # Failure - onnx_lowering: onnx.OneHot "OneHotModule_basic", - # Failure - onnx_lowering: onnx.RandomNormal + # ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64) "RandnDtypeDeviceModule_basic", "RandnGeneratorF64Module_basic", "RandnGeneratorModule_basic", "RandnModule_basic", - - # Failure - onnx_lowering: onnx.RandomNormalLike - "RandnLikeDtypeModule_basic", "RandnLikeModule_basic", - - # Failure - onnx_lowering: onnx.RandomUniform - "RandIntLowDtypeModule_basic", - "RandIntLowModule_basic", - - # Failure - onnx_lowering: onnx.RandomUniformLike "BernoulliFloatModule_basic", "BernoulliPModule_basic", "BernoulliTensorModule_basic", - "RandLikeDtypeModule_basic", - "RandLikeModule_basic", - "RandModule_basic", # Failure - onnx_lowering: onnx.ReduceL2 "LinalgNormKeepDimModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index de3e796f4e5d..43849fbbd06e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1679,3 +1679,65 @@ func.func @test_triu_zero(%arg0: !torch.vtensor<[0,5],si64>, %arg1: !torch.vtens %0 = torch.operator "onnx.Trilu"(%arg0, %arg1) : (!torch.vtensor<[0,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[0,5],si64> return %0 : !torch.vtensor<[0,5],si64> } + +// ----- + +// CHECK-LABEL: func.func @test_random_normal +func.func @test_random_normal() -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[I6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[I10:.+]] = torch.constant.int 10 + // CHECK: %[[SHAPE:.+]] = torch.prim.ListConstruct %[[I10]] : (!torch.int) -> !torch.list + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.+]] = torch.aten.empty.memory_format %[[SHAPE]], %[[I6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK: torch.aten.normal_functional %[[EMPTY_TENSOR]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.RandomNormal"() {torch.onnx.dtype = 1 : si64, torch.onnx.mean = 0.000000e+00 : f32, torch.onnx.scale = 1.000000e+00 : f32, torch.onnx.shape = [10 : si64]} : () -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_random_normal_like +func.func @test_random_normal_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[I6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[I6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK: torch.aten.normal_functional %[[CAST]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.RandomNormalLike"(%arg0) {torch.onnx.dtype = 1 : si64, torch.onnx.mean = 0.000000e+00 : f32, torch.onnx.scale = 1.000000e+00 : f32} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_random_uniform +func.func @test_random_uniform() -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[I6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[I10:.+]] = torch.constant.int 10 + // CHECK: %[[SHAPE:.+]] = torch.prim.ListConstruct %[[I10]] : (!torch.int) -> !torch.list + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.+]] = torch.aten.empty.memory_format %[[SHAPE]], %[[I6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK: torch.aten.uniform %[[EMPTY_TENSOR]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.RandomUniform"() {torch.onnx.dtype = 1 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32, torch.onnx.shape = [10 : si64]} : () -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_random_uniform_like +func.func @test_random_uniform_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[I6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[I6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK: torch.aten.uniform %[[CAST]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.RandomUniformLike"(%arg0) {torch.onnx.dtype = 1 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} From cff2f084d4ea23ca59884f3b4d35b3b171fff18e Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Tue, 23 Apr 2024 11:33:05 +0530 Subject: [PATCH 17/34] [torch] Add OnnxToTorch lowering for `onnx.ReduceL2` (#3175) Adds OnnxToTorch lowering for the ReduceL2 op. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 55 +++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 6 -- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 99 +++++++++++++++++++ 3 files changed, 154 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 6c86ecb92789..d26601c0de8d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -891,6 +891,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*storeValue=*/operand, keepDims, noop_with_empty_axes, false); }); + patterns.onOp( + "ReduceL2", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(operand, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + // A ReduceL2 op is equivalent to the following sequence of operations: + // Mul(x, x) -> ReduceSum -> CastF32 -> Sqrt -> CastLike(resultType) + Value squareOfOperand = rewriter.create( + binder.getLoc(), operand.getType(), operand, operand); + + auto reducedSum = + reducedSumImpl(binder, rewriter, squareOfOperand, resultType, + operand, keepDims, noop_with_empty_axes, true); + if (failed(reducedSum)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); + + Value castDType = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 6)); + + Value noneVal = rewriter.create(binder.getLoc()); + Value constFalse = + rewriter.create(binder.getLoc(), false); + + // Perform an AtenToDtype op on the squared sum of the operand, stored + // now in operand itself. + auto size = operand.getType() + .dyn_cast() + .getOptionalSizes(); + auto f32ResultType = rewriter.getType( + size, rewriter.getF32Type()); + Value operandCast = rewriter.create( + binder.getLoc(), f32ResultType, operand, castDType, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + + Value operandSqrt = rewriter.create( + binder.getLoc(), f32ResultType, operandCast); + + Value resultDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), resultType.getDtype()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operandSqrt, resultDtype, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + return success(); + }); patterns.onOp("ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 323a39bf33cb..e426e998ebe0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2615,12 +2615,6 @@ "BernoulliPModule_basic", "BernoulliTensorModule_basic", - # Failure - onnx_lowering: onnx.ReduceL2 - "LinalgNormKeepDimModule_basic", - "LinalgNormModule_basic", - "NormalizeModule_basic", - "ReduceL2NormModule_basic", - # Failure - onnx_lowering: onnx.ReduceProd "ReduceProdDimIntFloatModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 43849fbbd06e..c8d513a31d21 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -761,6 +761,105 @@ func.func @test_reduce_l1_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f // ----- +// CHECK-LABEL: func.func @test_reduce_l2_default_axes_keepdims_example +func.func @test_reduce_l2_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE_0:.+]] = torch.constant.bool true + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_l2_do_not_keepdims_example_expanded +func.func @test_reduce_l2_do_not_keepdims_example_expanded(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT0_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE_0:.+]] = torch.constant.bool false + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[FALSE_1:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> + // CHECK: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_l2_keep_dims_example +func.func @test_reduce_l2_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT0_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> + + %0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_l2_keep_dims_int_input_example +func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],si64>, !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT0_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> + + %0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 From 797e4cd395a96d5b9a56efaa02003e7cb8c9bcc6 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 23 Apr 2024 16:24:53 +0800 Subject: [PATCH 18/34] [Stablehlo] lowering asin, acos, atan (#3207) * lowering asin, acos and atan to chlo ops. --- lib/Conversion/TorchToStablehlo/Basic.cpp | 3 +++ projects/pt1/e2e_testing/xfail_sets.py | 7 +++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index eca7c30259de..89afc80816dd 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1900,6 +1900,9 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp); INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp); INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp); + INSERT_UNARY_FPONLY_PATTERN(AtenAsinOp, chlo::AsinOp); + INSERT_UNARY_FPONLY_PATTERN(AtenAcosOp, chlo::AcosOp); + INSERT_UNARY_FPONLY_PATTERN(AtenAtanOp, chlo::AtanOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e426e998ebe0..a11e36060763 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -621,18 +621,15 @@ "DivFloatModule_basic", "DivIntModule_basic", "ElementwiseAcosIntModule_basic", - "ElementwiseAcosModule_basic", "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseAsinIntModule_basic", - "ElementwiseAsinModule_basic", "ElementwiseAsinhIntModule_basic", "ElementwiseAsinhModule_basic", "ElementwiseAtan2FloatIntModule_basic", "ElementwiseAtan2TensorFloatModule_basic", "ElementwiseAtan2TensorIntModule_basic", - "ElementwiseAtanTensorFloatModule_basic", "ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", @@ -670,7 +667,6 @@ "ElementwiseUnaryIntModule_basic", "EmptyModule_uint8", "EqIntModule_basic", - "ExponentialModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic", "FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineRoundToEvenModule_basic", @@ -1465,6 +1461,9 @@ "ElementwiseLog1pModule_basic", "ElementwiseSgnModule_basic", "ElementwiseSignIntModule_basic", + "ElementwiseAcosModule_basic", + "ElementwiseAsinModule_basic", + "ElementwiseAtanTensorFloatModule_basic", } STABLEHLO_CRASHING_SET = { From 1f8123b5f0df37216c96a80f21f8d1a38a38513b Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 23 Apr 2024 17:57:12 +0800 Subject: [PATCH 19/34] [Stablehlo] support unary ops which promote to floating point (#3209) * promote input to output element-type when lowering to stablehlo, so that it could satisfy stablehlo's type constraints. * split promote-to-fp unary ops from fp-only unary ops. --- lib/Conversion/TorchToStablehlo/Basic.cpp | 61 ++++++++++++++++++----- projects/pt1/e2e_testing/xfail_sets.py | 21 ++++---- 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 89afc80816dd..1c4bc27537c0 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -217,6 +217,37 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { }; } // namespace +// These legalizations are for unary ops with promoting to floating point +// datatypes. +namespace { +template +class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.getSelf(); + auto selfTy = self.getType().cast(); + if (!selfTy) + return op.emitError("only Tensor types supported in StableHLO"); + auto resultTy = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + if (resultTy.getElementType().template isa()) { + Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy); + rewriter.replaceOpWithNewOp(op, resultTy, src); + return success(); + } else { + return op.emitError( + "only result to be floating-point datatype legalization supported"); + } + } +}; +} // namespace + // aten.ones & aten.zeros // Ref: Error checking based on the Torch to TOSA lowering namespace { @@ -1888,23 +1919,29 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) - INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp); - INSERT_UNARY_FPONLY_PATTERN(AtenLog1pOp, stablehlo::Log1pOp); - INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); - INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp); - INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp); INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp); INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp); INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp); - INSERT_UNARY_FPONLY_PATTERN(AtenAsinOp, chlo::AsinOp); - INSERT_UNARY_FPONLY_PATTERN(AtenAcosOp, chlo::AcosOp); - INSERT_UNARY_FPONLY_PATTERN(AtenAtanOp, chlo::AtanOp); #undef INSERT_UNARY_FPONLY_PATTERN +#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, StablehloOp) \ + target.addIllegalOp(); \ + patterns.add>( \ + typeConverter, context) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, stablehlo::LogOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLog1pOp, stablehlo::Log1pOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, stablehlo::ExpOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp); +#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN + #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a11e36060763..80ab03566bf0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -620,17 +620,14 @@ "DiagonalModule_with_offset", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAcosIntModule_basic", "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAsinIntModule_basic", "ElementwiseAsinhIntModule_basic", "ElementwiseAsinhModule_basic", "ElementwiseAtan2FloatIntModule_basic", "ElementwiseAtan2TensorFloatModule_basic", "ElementwiseAtan2TensorIntModule_basic", - "ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", "ElementwiseBitwiseLeftShiftInt32Module_basic", @@ -639,7 +636,6 @@ "ElementwiseBitwiseRightShiftInt32Module_basic", "ElementwiseBitwiseRightShiftInt64Module_basic", "ElementwiseBitwiseRightShiftInt8Module_basic", - "ElementwiseCosIntModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", @@ -649,22 +645,16 @@ "ElementwiseLog10Module_basic", "ElementwiseLog2IntModule_basic", "ElementwiseLog2Module_basic", - "ElementwiseLogIntModule_basic", "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwisePowScalarModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", - "ElementwiseRsqrtIntModule_basic", - "ElementwiseSigmoidIntModule_basic", - "ElementwiseSinIntModule_basic", - "ElementwiseSqrtIntModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", "ElementwiseTernaryModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", - "ElementwiseUnaryIntModule_basic", "EmptyModule_uint8", "EqIntModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic", @@ -1464,6 +1454,17 @@ "ElementwiseAcosModule_basic", "ElementwiseAsinModule_basic", "ElementwiseAtanTensorFloatModule_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseLogIntModule_basic", + "ElementwiseRsqrtIntModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseSqrtIntModule_basic", + "ElementwiseUnaryIntModule_basic", } STABLEHLO_CRASHING_SET = { From c1967b607fa567990b2658a8b6db8ded65109613 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Tue, 23 Apr 2024 19:06:55 +0800 Subject: [PATCH 20/34] [Stablehlo] add AtenLog10Op, AtenLog2Op lowering to stablehlo (#3208) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 45 +++++++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 8 ++-- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 1c4bc27537c0..0c3cc85b7dd4 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1060,6 +1060,49 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenLog2Op +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLog2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().template dyn_cast(); + if (!inputTy) { + return op.emitError("only ranked tensor type is supported."); + } + auto outTy = getTypeConverter()->convertType(op.getType()).cast(); + input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + + auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input); + auto log2Op = rewriter.create(op.getLoc(), two); + auto logInputOp = rewriter.create(op.getLoc(), input); + + rewriter.replaceOpWithNewOp(op, outTy, logInputOp, log2Op); + return success(); +} + +// AtenLog10Op +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLog10Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().template dyn_cast(); + if (!inputTy) { + return op.emitError("only ranked tensor type is supported."); + } + + auto outTy = getTypeConverter()->convertType(op.getType()).cast(); + input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + + auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input); + auto log10Op = rewriter.create(op.getLoc(), ten); + auto logInputOp = rewriter.create(op.getLoc(), input); + + rewriter.replaceOpWithNewOp(op, outTy, logInputOp, log10Op); + return success(); +} + // AtenErfOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -2028,6 +2071,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenLog2Op); + INSERT_ATENOP_PATTERN(AtenLog10Op); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 80ab03566bf0..11f3d5a839a0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -641,10 +641,6 @@ "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseLog10IntModule_basic", - "ElementwiseLog10Module_basic", - "ElementwiseLog2IntModule_basic", - "ElementwiseLog2Module_basic", "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwisePowScalarModule_basic", @@ -1046,6 +1042,10 @@ "ElementwiseGeluModule_basic", "ElementwiseLeakyReluStaticModule_basic", "ElementwiseLogModule_basic", + "ElementwiseLog10Module_basic", + "ElementwiseLog2Module_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog2IntModule_basic", "ElementwiseNanToNumModule_Basic", "ElementwiseNeFloatTensorStaticModule_basic", "ElementwiseNeIntTensorStaticModule_basic", From db3842f2e80fd070925543065e0b10277ae9c227 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 23 Apr 2024 19:54:58 +0800 Subject: [PATCH 21/34] [Stablehlo] support lowering sinh & cosh to stablehlo (#3213) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 2 + .../Transforms/AbstractInterpLibrary.cpp | 9 ++++ projects/pt1/e2e_testing/xfail_sets.py | 8 +++- .../build_tools/abstract_interp_lib_gen.py | 8 ++++ .../test_suite/elementwise.py | 44 +++++++++++++++++++ 5 files changed, 69 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 0c3cc85b7dd4..84bccae83d81 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1981,7 +1981,9 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinhOp, chlo::SinhOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCoshOp, chlo::CoshOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp); #undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index a1cc7ddf6ea6..5df01076f846 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6332,6 +6332,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sinh\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.asin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9699,6 +9703,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sinh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.asin\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 11f3d5a839a0..ccef6e1060d8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -636,8 +636,6 @@ "ElementwiseBitwiseRightShiftInt32Module_basic", "ElementwiseBitwiseRightShiftInt64Module_basic", "ElementwiseBitwiseRightShiftInt8Module_basic", - "ElementwiseCoshIntModule_basic", - "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", @@ -1465,6 +1463,10 @@ "ElementwiseSinIntModule_basic", "ElementwiseSqrtIntModule_basic", "ElementwiseUnaryIntModule_basic", + "ElementwiseCoshIntModule_basic", + "ElementwiseCoshModule_basic", + "ElementwiseSinhIntModule_basic", + "ElementwiseSinhModule_basic", } STABLEHLO_CRASHING_SET = { @@ -2326,6 +2328,8 @@ "ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseSinhIntModule_basic", + "ElementwiseSinhModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 8b32ff602697..2e1c64d8a737 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -121,6 +121,9 @@ def aten〇fake_quantize_per_tensor_affine〡shape(self: List[int], scale: float def aten〇sin〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇sinh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇asin〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2014,6 +2017,11 @@ def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇sinh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇asin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index e4a185189354..5010c8b9936f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -63,6 +63,50 @@ def ElementwiseUnaryIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSinhModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.sinh(a) + + +@register_test_case(module_factory=lambda: ElementwiseSinhModule()) +def ElementwiseSinhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseSinhIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.sinh(a) + + +@register_test_case(module_factory=lambda: ElementwiseSinhIntModule()) +def ElementwiseSinhIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseCoshModule(torch.nn.Module): def __init__(self): From ddb29c2c026980ea9684787ca2f940d6e5147714 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 23 Apr 2024 09:42:02 -0700 Subject: [PATCH 22/34] [onnx] Add OnnxToTorch support for `onnx.ConvInteger` (#3179) All e2e iree tests compiled, but they have the run issue of mismatch of dtype like the following ``` expected: 1x1x2x2xsi32=[[[12 16][24 28]]] actual: 1x1x2x2xi32=[[[12 16][24 28]]] ``` --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 151 ++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 62 +++++++ 2 files changed, 213 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index b16e76e3afe5..e13597d0e8ee 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -981,6 +981,157 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( + "ConvInteger", 10, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + + Torch::ValueTensorType resultType; + Value input, weight, inputZp, weightZp; + int64_t group; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(weight, 1) || + binder.s64IntegerAttr(group, "group", 1) || + binder.tensorResultType(resultType)) + return failure(); + + auto inputTy = dyn_cast(input.getType()); + auto weightTy = dyn_cast(weight.getType()); + if (!weightTy || !weightTy.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, "Expected weight type having sizes"); + ArrayRef weightShape = weightTy.getSizes(); + SmallVector kernelShape; + if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) + return failure(); + if (kernelShape.size()) { + if (kernelShape.size() != weightShape.size() - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: kernel_shape list size should have " + "number of values equal to weight_rank - 2"); + } else { + for (unsigned i = 0; i < kernelShape.size(); i++) { + if (weightShape[i + 2] != kernelShape[i]) + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: kernel_shape value " + "should be equal to the weight tensor shape"); + } + } + } + + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(input); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector padding, strides, dilations; + SmallVector defaultPadding(rank - 2, 0), + defaultStrides(rank - 2, 1), defaultDilations(rank - 2, 1); + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) + return failure(); + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations)) + return failure(); + if (dilations.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, + "dilations list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) + return failure(); + if (strides.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + + Value scale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(1.0)); + if (binder.tensorOperandAtIndex(inputZp, 2)) { + inputZp = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + } else { + inputZp = rewriter.create( + binder.getLoc(), rewriter.getType(), inputZp); + } + if (binder.tensorOperandAtIndex(weightZp, 3)) + weightZp = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + // TODO: support per channel quantization if weightZp is a 1-D tensor + if (auto zpTy = dyn_cast(weightZp.getType())) { + for (auto dim : zpTy.getSizes()) + if (dim != 1) + return failure(); + weightZp = rewriter.create( + binder.getLoc(), rewriter.getType(), weightZp); + } + + SmallVector cstPadding; + if (padding.size() != 2 * (rank - 2)) { + for (int64_t i : padding) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else { + for (unsigned i = 0; i < padding.size() / 2; i++) { + if (padding[i] != padding[i + (padding.size() / 2)]) + // TODO: Add support for different padding values for the + // beginning and ending along each spatial axis + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: padding values for the beginning " + "and ending along each spatial axis must be equal"); + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + } + + Value paddingList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + cstPadding); + Value dilationsList = + createConstantIntList(binder, rewriter, dilations); + Value stridesList = createConstantIntList(binder, rewriter, strides); + Value outputPaddingList = + createConstantIntList(binder, rewriter, {0, 0}); + Value transposed = + rewriter.create(binder.getLoc(), false); + Value bias = rewriter.create(binder.getLoc()); + Value cstGroup = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(group)); + + Type inputQTy = getQTorchTypeFromTorchIntType(inputTy); + Type weightQTy = getQTorchTypeFromTorchIntType(weightTy); + input = rewriter.create( + binder.getLoc(), inputQTy, input, scale, inputZp); + weight = rewriter.create( + binder.getLoc(), weightQTy, weight, scale, weightZp); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, stridesList, + paddingList, dilationsList, transposed, outputPaddingList, + cstGroup); + return success(); + }); + patterns.onOp( "ConvTranspose", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 85e2832ac392..95a74656c5c5 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -938,6 +938,68 @@ func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,22 // ----- +// CHECK-LABEL: @test_convinteger_without_padding +func.func @test_convinteger_without_padding(%arg0: !torch.vtensor<[1,1,3,3],ui8>, %arg1: !torch.vtensor<[1,1,2,2],ui8>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[1,1,2,2],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[INPUT_ZP:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[WEIGHT_ZP:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_0]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C1_3:.*]] = torch.constant.int 1 + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1_2]], %[[C1_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[INPUT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[INPUT_ZP]] : !torch.vtensor<[1,1,3,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,3,3],!torch.quint8> + // CHECK: %[[WEIGHT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[WEIGHT_ZP]] : !torch.vtensor<[1,1,2,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,2,2],!torch.quint8> + // CHECK: torch.aten.convolution %[[INPUT]], %[[WEIGHT]], %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],!torch.quint8>, !torch.vtensor<[1,1,2,2],!torch.quint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,2,2],si32> + %none = torch.constant.none + %0 = torch.operator "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[1,1,3,3],ui8>, !torch.vtensor<[1,1,2,2],ui8>, !torch.vtensor<[],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[1,1,2,2],si32> + return %0 : !torch.vtensor<[1,1,2,2],si32> +} + +// ----- + +// CHECK-LABEL: @test_convinteger_with_padding +func.func @test_convinteger_with_padding(%arg0: !torch.vtensor<[1,1,3,3],ui8>, %arg1: !torch.vtensor<[1,1,2,2],ui8>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,4,4],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[INPUT_ZP:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[WEIGHT_ZP:.*]] = torch.constant.int 0 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1_0]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C1_3:.*]] = torch.constant.int 1 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_2]], %[[C1_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C1_4:.*]] = torch.constant.int 1 + // CHECK: %[[C1_5:.*]] = torch.constant.int 1 + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1_4]], %[[C1_5]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[INPUT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[INPUT_ZP]] : !torch.vtensor<[1,1,3,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,3,3],!torch.quint8> + // CHECK: %[[WEIGHT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[WEIGHT_ZP]] : !torch.vtensor<[1,1,2,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,2,2],!torch.quint8> + // CHECK: torch.aten.convolution %[[INPUT]], %[[WEIGHT]], %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],!torch.quint8>, !torch.vtensor<[1,1,2,2],!torch.quint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,4],si32> + %none = torch.constant.none + %0 = torch.operator "onnx.ConvInteger"(%arg0, %arg1, %arg2) {torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,3,3],ui8>, !torch.vtensor<[1,1,2,2],ui8>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,4,4],si32> + return %0 : !torch.vtensor<[1,1,4,4],si32> +} + +// ----- + // CHECK-LABEL: @test_convtranspose_dilations func.func @test_convtranspose_dilations(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 From 61e6312c87d33a4d3d333e9658606d9a733fd5f5 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 23 Apr 2024 10:16:08 -0700 Subject: [PATCH 23/34] Support select_last_index attribute of onnx argmax op (#3192) The tests listed in https://github.com/nod-ai/SHARK-Turbine/issues/635 all compiled, but having run issue of dtype mismatch of i/si. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 32 ++++++++++++----- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 36 +++++++++++++++++++ .../unsupported_simple_ops.mlir | 10 ------ 3 files changed, 59 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index e13597d0e8ee..9559e28ee2ac 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -101,17 +101,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) return failure(); - if (selectLastIndex) { - // TODO: Figure out how to support this case. Need to add a reverse - // or something. - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: select_last_index=true"); - } - // ONNX allows negative axis. + auto operandSizes = + cast(operand.getType()).getSizes(); if (axis < 0) - axis += - cast(operand.getType()).getSizes().size(); + axis += operandSizes.size(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -119,6 +113,26 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value constKeepDims = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(keepDims)); + + if (selectLastIndex) { + Value dims = createConstantIntList(binder, rewriter, {axis}); + auto operandTy = dyn_cast(operand.getType()); + operand = rewriter.create( + binder.getLoc(), operandTy, operand, dims); + Value argmax = rewriter.create( + binder.getLoc(), resultType, operand, constAxis, constKeepDims); + Value offset = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(operandSizes[axis] - 1)); + Value alpha = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value sub = rewriter.create( + binder.getLoc(), resultType, argmax, offset, alpha); + rewriter.replaceOpWithNewOp(binder.op, resultType, + sub); + return success(); + } + rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAxis, constKeepDims); return success(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 95a74656c5c5..b776145834d4 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -74,6 +74,24 @@ func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2 // ----- +// CHECK-LABEL: @test_argmax_negative_axis_keepdims_random_select_last_index +func.func @test_argmax_negative_axis_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C2_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ARGMAX:.*]] = torch.aten.argmax %[[FLIP]], %[[C2]], %[[TRUE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,3,1],si64> + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMAX]], %[[C3]], %[[C1]] : !torch.vtensor<[2,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,3,1],si64> + // CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,3,1],si64> -> !torch.vtensor<[2,3,1],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> + return %0 : !torch.vtensor<[2,3,1],si64> +} + +// ----- + // CHECK-LABEL: @test_argmax_no_keepdims_example func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -85,6 +103,24 @@ func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> // ----- +// CHECK-LABEL: @test_argmax_no_keepdims_random_select_last_index +func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ARGMAX:.*]] = torch.aten.argmax %[[FLIP]], %[[C1]], %[[FALSE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,4],si64> + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMAX]], %[[C2]], %[[C1_1]] : !torch.vtensor<[2,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + // CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,4],si64> -> !torch.vtensor<[2,4],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> + return %0 : !torch.vtensor<[2,4],si64> +} + +// ----- + // CHECK-LABEL: @test_argmin_default_axis_example func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 0 diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir index 22d5e2d35183..480a7dbb2dac 100644 --- a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir +++ b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir @@ -1,15 +1,5 @@ // RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch -module { - func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // TODO: Unsupported torch.onnx.select_last_index - // expected-error @+1 {{failed to legalize operation 'torch.operator'}} - %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> - return %0 : !torch.vtensor<[2,4],si64> - } -} - -// ----- func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // TODO: Unsupported torch.onnx.select_last_index // expected-error @+1 {{failed to legalize operation 'torch.operator'}} From 09d42044b4d9d2cd5399d9d5ff5ae97501314db5 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 23 Apr 2024 10:43:38 -0700 Subject: [PATCH 24/34] Support select_last_index attribute of onnx argmin op (#3212) The tests listed in https://github.com/nod-ai/SHARK-Turbine/issues/648 all compiled, and the values of results match, but having runtime issue of dtype mismatch of i/si. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 32 ++++++++++++----- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 36 +++++++++++++++++++ .../unsupported_simple_ops.mlir | 8 ----- 3 files changed, 59 insertions(+), 17 deletions(-) delete mode 100644 test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 9559e28ee2ac..14aa41bef349 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -151,17 +151,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) return failure(); - if (selectLastIndex) { - // TODO: Figure out how to support this case. Need to add a reverse - // or something. - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: select_last_index=true"); - } - // ONNX allows negative axis. + auto operandSizes = + cast(operand.getType()).getSizes(); if (axis < 0) - axis += - cast(operand.getType()).getSizes().size(); + axis += operandSizes.size(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -169,6 +163,26 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value constKeepDims = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(keepDims)); + + if (selectLastIndex) { + Value dims = createConstantIntList(binder, rewriter, {axis}); + auto operandTy = dyn_cast(operand.getType()); + operand = rewriter.create( + binder.getLoc(), operandTy, operand, dims); + Value argmin = rewriter.create( + binder.getLoc(), resultType, operand, constAxis, constKeepDims); + Value offset = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(operandSizes[axis] - 1)); + Value alpha = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value sub = rewriter.create( + binder.getLoc(), resultType, argmin, offset, alpha); + rewriter.replaceOpWithNewOp(binder.op, resultType, + sub); + return success(); + } + rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAxis, constKeepDims); return success(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index b776145834d4..33d8d8f658b2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -143,6 +143,24 @@ func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2 // ----- +// CHECK-LABEL: @test_argmin_negative_axis_keepdims_random_select_last_index +func.func @test_argmin_negative_axis_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C2_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ARGMIN:.*]] = torch.aten.argmin %[[FLIP]], %[[C2]], %[[TRUE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,3,1],si64> + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMIN]], %[[C3]], %[[C1]] : !torch.vtensor<[2,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,3,1],si64> + // CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,3,1],si64> -> !torch.vtensor<[2,3,1],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> + return %0 : !torch.vtensor<[2,3,1],si64> +} + +// ----- + // CHECK-LABEL: @test_argmin_no_keepdims_example func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -154,6 +172,24 @@ func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> // ----- +// CHECK-LABEL: @test_argmin_no_keepdims_example_select_last_index +func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> + // CHECK: %[[ARGMIN:.*]] = torch.aten.argmin %[[FLIP]], %[[C1]], %[[FALSE]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2],si64> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMIN]], %[[C1_1]], %[[C1_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2],si64> + // CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2],si64> -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + +// ----- + // CHECK-LABEL: @test_atan func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.atan %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir deleted file mode 100644 index 480a7dbb2dac..000000000000 --- a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch - -func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // TODO: Unsupported torch.onnx.select_last_index - // expected-error @+1 {{failed to legalize operation 'torch.operator'}} - %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> - return %0 : !torch.vtensor<[2],si64> -} From a8ba865fcab6475ff58c2beb14a1823fc25314c2 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 23 Apr 2024 13:01:36 -0500 Subject: [PATCH 25/34] [torch] Adds Quantization Support for `aten.relu` (#3177) A choice was made to quantize the return type of Relu with a scale and zero point copied from the input's quantization scheme. With this choice, the torch-to-linalg conversion of quantized Relu essentially computes max(input, zeroPoint) in the elementwise payload. --- .../TorchToLinalg/Uncategorized.cpp | 78 +++++++++++++++--- .../Torch/Transforms/FuseQuantizedOps.cpp | 81 ++++++++++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 9 +++ .../test_suite/elementwise.py | 63 +++++++++++++++ 4 files changed, 216 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 441c76ce7ea4..3c5d6cfaee07 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -56,6 +56,13 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, llvm_unreachable("Unhandled element type for comparison"); } +static Value getZeroPoint(Value value) { + if (auto make = value.getDefiningOp()) { + return make.getZeroPoint(); + } + return nullptr; +} + static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate(op)) { - if (!relu.getType() - .cast() - .getDtype() - .isa()) { - relu.emitError("unimplemented: non-floating point dtype"); + Value zeroPoint = getZeroPoint(relu.getSelf()); + Value arg = payloadArgs[0]; + auto intType = arg.getType().dyn_cast(); + if (zeroPoint && !intType) { + relu.emitError("unimplemented: non-integer quantized Relu."); return nullptr; } - Type elementType = payloadArgs[0].getType(); - Value constZero = - b.create(loc, b.getZeroAttr(elementType)); - Value pred = b.create(loc, arith::CmpFPredicate::UGT, - payloadArgs[0], constZero); - return b.create(loc, pred, payloadArgs[0], constZero); + auto reluTorchType = cast(relu.getType()); + bool isUnsigned = + torch_to_linalg::isUnsignedTorchType(reluTorchType.getDtype()); + if (zeroPoint) { + int64_t zeroPointInt; + int64_t width = intType.getWidth(); + assert(width < 64); + int64_t minForIntType = isUnsigned ? 0 : -(1 << (width - 1)); + int64_t maxForIntType = + isUnsigned ? (1 << (width + 1)) - 1 : (1 << (width - 1)) - 1; + // check for constant zero point edge-cases: + if (matchPattern(zeroPoint, m_TorchConstantInt(&zeroPointInt))) { + if (zeroPointInt > maxForIntType) { + // TODO: figure out how to handle this case: + // current impl. quantizes output like input. + // If zero point > maxForIntType, ordinary relu should return 0. + // However, 0 isn't represented in such a quantization scheme. + relu.emitError( + "unimplemented: quantized relu for zero-point > max qint"); + return nullptr; + } + if (zeroPointInt < minForIntType) + return arg; + } + zeroPoint = converter->materializeTargetConversion( + b, loc, converter->convertType(zeroPoint.getType()), zeroPoint); + auto minForIntTypeValue = b.create( + loc, b.getIntegerAttr(zeroPoint.getType(), minForIntType)); + auto maxForIntTypeValue = b.create( + loc, b.getIntegerAttr(zeroPoint.getType(), maxForIntType)); + auto zpLtMax = b.create(loc, arith::CmpIPredicate::slt, + zeroPoint, maxForIntTypeValue); + b.create( + loc, zpLtMax, + b.getStringAttr("Invalid Quantization: quantized relu with " + "zero-point > max qint")); + auto zpLtMin = b.create(loc, arith::CmpIPredicate::slt, + zeroPoint, minForIntTypeValue); + zeroPoint = b.create(loc, zpLtMin, minForIntTypeValue, + zeroPoint); + zeroPoint = b.create(loc, arg.getType(), zeroPoint); + } else { + zeroPoint = + b.create(loc, b.getZeroAttr(arg.getType())); + } + Value cmp; + if (intType) { + auto pred = + isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt; + cmp = b.create(loc, pred, arg, zeroPoint); + } else { + cmp = b.create(loc, arith::CmpFPredicate::UGT, arg, + zeroPoint); + } + return b.create(loc, cmp, arg, zeroPoint); } if (auto round = dyn_cast(op)) { if (!round.getType() diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index bff463c4cee6..3b30e9424f44 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -20,6 +20,13 @@ using namespace mlir::torch::Torch; namespace { +template struct QuantInfo { + static constexpr unsigned operandsToQuantize[2] = {0, 1}; +}; + +template <> struct QuantInfo { + static constexpr unsigned operandsToQuantize[1] = {0}; +}; template class QuantizeOperands : public OpRewritePattern { public: @@ -42,8 +49,9 @@ class QuantizeOperands : public OpRewritePattern { return operand; }; - operands[0] = f(operands[0]); - operands[1] = f(operands[1]); + for (unsigned i : QuantInfo::operandsToQuantize) { + operands[i] = f(operands[i]); + } if (!dequanted) { return rewriter.notifyMatchFailure(op, "no dequantizations found"); @@ -259,6 +267,70 @@ class QuantizeAccumulator : public OpRewritePattern { } }; +// Use for ops which do not manipulate scale/zero point of an input. +template +class QuantizeResultLikeOperand : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector operands(op->getOperands()); + Value input = operands[0]; + + auto inputType = dyn_cast_or_null(input.getType()); + if (!inputType || !inputType.hasDtype()) + return failure(); + auto qDtype = inputType.getDtype(); + + auto resultTy = dyn_cast_or_null(op.getType()); + if (!resultTy || !resultTy.hasDtype()) + return failure(); + + Type resultETy = resultTy.getDtype(); + if (!isa(resultETy)) + return failure(); + + Value inputScale, inputZeroPoint; + Type definingOpInputType; + if (auto defining = input.template getDefiningOp< + Aten_MakePerTensorQuantizedTensorOp>()) { + inputScale = defining.getScale(); + inputZeroPoint = defining.getZeroPoint(); + definingOpInputType = defining.getSelf().getType(); + } + + auto inputIntReprType = + dyn_cast_or_null(definingOpInputType); + if (!inputScale || !inputZeroPoint || !inputIntReprType || + !inputIntReprType.hasDtype()) + return failure(); + auto intReprDtype = inputIntReprType.getDtype(); + + // set SrcOp type to use quantized dtype from input + auto newResultTy = + rewriter.getType(resultTy.getOptionalSizes(), qDtype); + auto newResult = rewriter.create(op.getLoc(), newResultTy, operands); + + // int repr to get non quantized int type result + auto intReprTy = rewriter.getType( + resultTy.getOptionalSizes(), intReprDtype); + auto intRepr = + rewriter.create(op.getLoc(), intReprTy, newResult); + + // requantize so the scale and zero-point info can be attached + auto quantTy = + rewriter.getType(resultTy.getOptionalSizes(), qDtype); + auto quant = rewriter.create( + op.getLoc(), quantTy, intRepr, inputScale, inputZeroPoint); + + // dequant back to original dtype + auto dequant = + rewriter.create(op.getLoc(), resultTy, quant); + rewriter.replaceOp(op, dequant); + return success(); + } +}; + template class RemoveUnused : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -285,11 +357,12 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { RemoveUnused, RemoveUnused, RemoveUnused, QuantizeOperands, - QuantizeOperands, + QuantizeOperands, QuantizeOperands, QuantizeTransposedOperands, QuantizeAccumulator, QuantizeOperands, QuantizeTransposedOperands, QuantizeAccumulator, - QuantizeBias>(context); + QuantizeResultLikeOperand, QuantizeBias>( + context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ccef6e1060d8..93269b0653dd 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -331,6 +331,9 @@ "AtenMatmulQint8VV_basic", "AtenMatmulQint8VM_basic", "AtenMatmulQint8_basic", + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", # Dynamo not supporting conv_tbc @@ -413,6 +416,9 @@ 'AtenMmQMixedSigni8_basic', 'AtenMmQint8_basic', 'AtenMmQuint8_basic', + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", 'AtenSubFloatModule_basic', 'BincountMinlengthModule_basic', 'BincountModule_basic', @@ -2466,6 +2472,9 @@ "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", + "QuantizedReluInt8_basic", + "QuantizedReluInt32_basic", + "QuantizedReluUint8_basic", "RandIntDtypeModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 5010c8b9936f..b365ac54f3d4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -705,6 +705,69 @@ def ElementwiseReluModule_basic(module, tu: TestUtils): # ============================================================================== +class QuantizedReluInt8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) + qx = torch.dequantize(qx) + return torch.relu(qx) + +@register_test_case(module_factory=lambda: QuantizedReluInt8()) +def QuantizedReluInt8_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 4, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + +class QuantizedReluUint8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.uint8, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190) + qx = torch.dequantize(qx) + return torch.relu(qx) + +@register_test_case(module_factory=lambda: QuantizedReluUint8()) +def QuantizedReluUint8_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 4, low=0, high=255).to(torch.uint8)) + +# ============================================================================== + +class QuantizedReluInt32(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190) + qx = torch.dequantize(qx) + return torch.relu(qx) + +@register_test_case(module_factory=lambda: QuantizedReluInt32()) +def QuantizedReluInt32_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 4, low=(-2**31), high=(2**31 - 1)).to(torch.int32)) + +# ============================================================================== + class ElementwiseRelu6Module(torch.nn.Module): From 4da3d714cc0c1f5623931293513aaa3df0bdd00d Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Wed, 24 Apr 2024 11:14:04 +0800 Subject: [PATCH 26/34] [Torch] Support AtenProdOp on linalg and stablehlo (#3215) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++++ lib/Conversion/TorchToLinalg/Reduction.cpp | 12 +- lib/Conversion/TorchToStablehlo/Reduction.cpp | 89 ++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 23 ++++ projects/pt1/e2e_testing/xfail_sets.py | 14 +++ .../build_tools/abstract_interp_lib_gen.py | 15 +++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reduction.py | 113 ++++++++++++++++++ 8 files changed, 286 insertions(+), 5 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8f949d9ba195..0e9318753b33 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10713,6 +10713,30 @@ def Torch_AtenProdDimIntOp : Torch_Op<"aten.prod.dim_int", [ }]; } +def Torch_AtenProdOp : Torch_Op<"aten.prod", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::prod : (Tensor, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenProdOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenProdOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenMaxOp : Torch_Op<"aten.max", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index bd8b1fc6bfb1..60d631e82aa3 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -298,7 +298,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, if (isa(op)) return b.create(loc, b.getZeroAttr(elementType)); - if (isa(op)) { + if (isa(op)) { if (isa(elementType)) return b.create(loc, b.getFloatAttr(elementType, 1.0)); else if (isa(elementType)) @@ -362,7 +362,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, return b.create(loc, self, result); else if (isa(resultElementType)) return b.create(loc, self, result); - } else if (isa(op)) { + } else if (isa(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; @@ -510,12 +510,13 @@ class ConvertReductionOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; - if (isa(op)) { + if (isa( + op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); - // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the - // dimensions of the input tensor. + // `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and `AtenMinOp` each reduce + // along all the dimensions of the input tensor. for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); @@ -715,6 +716,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( patterns.add>(typeConverter, context); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index c525c8b40de5..367acefc9210 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -89,6 +89,21 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } + if (isa(op)) { + if (isa(elementTy)) { + APFloat one(cast(elementTy).getFloatSemantics(), 1); + auto constAttr = DenseElementsAttr::get(constType, one); + return rewriter.create(op->getLoc(), constType, + constAttr); + } else if (isa(elementTy) && + elementTy.getIntOrFloatBitWidth() != 8) { + APInt one(elementTy.getIntOrFloatBitWidth(), 1); + auto constAttr = DenseElementsAttr::get(constType, one); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + } + op->emitError("unimplemented lowering in " "createInitialValueForReduceOp"); return nullptr; @@ -448,6 +463,79 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace +// AtenProdOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().dyn_cast(); + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + if (inputTy.getElementType() != outTy.getElementType()) { + // Use output element type as computation type. + auto dstElemTy = outTy.getElementType(); + input = + rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = input.getType().dyn_cast(); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenProdOp to StableHLO"); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return failure(); + + llvm::sort(dims.begin(), dims.end()); + auto stablehloReduceOp = rewriter.create( + op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input, + initValue, rewriter.getDenseI64ArrayAttr(dims)); + + Block &block = stablehloReduceOp.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value mulResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), mulResult); + } + + rewriter.replaceOpWithNewOp(op, outTy, + stablehloReduceOp.getResults()); + + return success(); +} +} // namespace + // AtenMaxOp namespace { template <> @@ -957,6 +1045,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5df01076f846..3ac56e933363 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7034,6 +7034,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.prod\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mean\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" @@ -12240,6 +12244,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.prod\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sum.dim_IntList\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" " return %0 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 93269b0653dd..4657bc0a4b42 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1253,6 +1253,12 @@ "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", + "ReduceProdFloatModule_basic", + "ReduceProdDtypeFloatModule_basic", + "ReduceProdElementTypeBoolModule_basic", + "ReduceProdUnsignedIntModule_basic", + "ReduceProdSignedIntModule_basic", + "ReduceProdDtypeIntModule_basic", "RepeatInterleaveSelfIntModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", "ReturnThreeTensorFloat32_basic", @@ -2617,6 +2623,14 @@ # Failure - onnx_lowering: onnx.OneHot "OneHotModule_basic", + + # Failure - onnx_lowering: onnx.ReduceProd + "ReduceProdFloatModule_basic", + "ReduceProdDtypeFloatModule_basic", + "ReduceProdElementTypeBoolModule_basic", + "ReduceProdUnsignedIntModule_basic", + "ReduceProdSignedIntModule_basic", + "ReduceProdDtypeIntModule_basic", # ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64) "RandnDtypeDeviceModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2e1c64d8a737..1bdcfdbe98b8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -553,6 +553,9 @@ def aten〇max〇other〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇sum〡shape(self: List[int], dtype: Optional[int] = None) -> List[int]: return [] +def aten〇prod〡shape(self: List[int], dtype: Optional[int] = None) -> List[int]: + return [] + def aten〇mean〡shape(self: List[int], dtype: Optional[int] = None) -> List[int]: return [] @@ -3981,6 +3984,18 @@ def aten〇sum〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = return torch.int64 return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇prod〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.float32) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.int32) + diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 5c4e4d214932..638ec1dd8d4c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -660,6 +660,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::sum : (Tensor, int?) -> (Tensor)") emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)") emit("aten::prod.dim_int : (Tensor, int, bool, int?) -> (Tensor)") + emit("aten::prod : (Tensor, int?) -> (Tensor)") emit("aten::max : (Tensor) -> (Tensor)") emit("aten::max.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index b6178221eb48..5fe2db5ff441 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -68,6 +68,62 @@ def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceProdFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.prod(a) + + +@register_test_case(module_factory=lambda: ReduceProdFloatModule()) +def ReduceProdFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceProdDtypeFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, a): + return torch.prod(a, dtype=torch.float32) + +@register_test_case(module_factory=lambda: ReduceProdDtypeFloatModule()) +def ReduceProdDtypeFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float64)) + +# ============================================================================== + +class ReduceProdElementTypeBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.prod(a) + + +@register_test_case(module_factory=lambda: ReduceProdElementTypeBoolModule()) +def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool)) + +# ============================================================================== + class ReduceSumDimIntListFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -239,6 +295,63 @@ def ReduceSumDtypeIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceProdUnsignedIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.prod(a) + + +@register_test_case(module_factory=lambda: ReduceProdUnsignedIntModule()) +def ReduceProdUnsignedIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=0, high=100)) + +# ============================================================================== + +class ReduceProdSignedIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.prod(a) + + +@register_test_case(module_factory=lambda: ReduceProdSignedIntModule()) +def ReduceProdSignedIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-100, high=100)) + +# ============================================================================== + +class ReduceProdDtypeIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.prod(a, dtype=torch.int64) + + +@register_test_case(module_factory=lambda: ReduceProdDtypeIntModule()) +def ReduceProdDtypeIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32)) + +# ============================================================================== + class ReduceSumDimIntListIntModule(torch.nn.Module): def __init__(self): super().__init__() From 42b9eccdb3f51811ea9c1ec2f89379c3e64824f7 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Wed, 24 Apr 2024 11:25:46 +0800 Subject: [PATCH 27/34] [Stablehlo] Fix AtenSumDimIntListOp when dim==None (#3216) as titile --- lib/Conversion/TorchToStablehlo/Reduction.cpp | 14 ++++++++++---- projects/pt1/e2e_testing/xfail_sets.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 367acefc9210..1e494f4337c0 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -700,11 +700,17 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( SmallVector inputDims; SmallVector dims; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { - return rewriter.notifyMatchFailure(op, "non-int dim list unsupported"); - } - if (inputDims.size() == 0) { + + if (failed(checkNotNone(rewriter, op, op.getDim()))) { inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } else { + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); + } + if (inputDims.size() == 0) { + inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } } for (auto d : inputDims) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4657bc0a4b42..1f16a25a9555 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -728,7 +728,6 @@ "MaxPool3dModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic", "MaxPool3dStaticModule_basic", - "MeanDimNoneDimModule_basic", "MseLossMeanReductionModule_basic", "MseLossSumReductionWithDifferentElemTypeModule_basic", "MulFloatModule_basic", @@ -1140,6 +1139,7 @@ "MaxPool2dStaticModule_basic", "MeanDimAllReduceModule_basic", "MeanDimEmptyDimModule_basic", + "MeanDimNoneDimModule_basic", "MeanDtypeModule_basic", "MeanDynamicSizesModule_basic", "MeanModule_basic", From f77d88390a8e3c4bdc2172a0f0342d0df21c598d Mon Sep 17 00:00:00 2001 From: Phaneesh Barwaria Date: Wed, 24 Apr 2024 09:01:37 +0530 Subject: [PATCH 28/34] [onnx] handle dynamic padSize tensor in onnx.Pad (#3214) - Fix pad size to data_rank for dynamic paddingSize Tensor. - This fix is in accordance with [input specification](https://onnx.ai/onnx/operators/onnx__Pad.html#inputs) for onnx.Pad - Impl will need to be updated for dynamic padSize when support for `axes` is added. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c7d071079119..ad6c91e405c6 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1319,9 +1319,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( "expect 1-d pad tensor"); int64_t padsSize = padsShape[0]; - if (padsSize == Torch::kUnknownSize) - return rewriter.notifyMatchFailure(binder.op, - "pad length is unknown"); + if (padsSize == Torch::kUnknownSize) { + // As per onnx.Pad documentation, padSize = 2*num_data_axes + // (if axes param not passed). Need to be updated when adding + // support for `axes` param. + auto dataOpTy = data.getType().cast(); + TensorType dataTensor = dataOpTy.toBuiltinTensor(); + if (!dataTensor || !dataTensor.hasRank()) + return rewriter.notifyMatchFailure( + binder.op, "pad length unknown and data operand unranked"); + int64_t dataRank = dataTensor.getRank(); + padsSize = 2 * dataRank; + } Value constantValue; if (binder.getNumOperands() >= 3) { From 8a1dbbd597fd52199802513eb1fe0abbb2cf36ef Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 24 Apr 2024 11:34:02 +0800 Subject: [PATCH 29/34] [torchscript] export extra library file name to user (#3203) * so that it could be specified by user. --- projects/pt1/python/torch_mlir/torchscript.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index 48b9066d2391..acb487319ae9 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -9,6 +9,7 @@ import sys from io import StringIO import tempfile +import os from torch._functorch.compile_utils import strip_overloads import torch @@ -253,19 +254,20 @@ def _get_for_tracing( } -def _canon_extra_library(extra_library): - extra_library_file_name = "" +def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra_library.mlir"): if len(extra_library) != 0: extra_library_dict = {} for library_func in extra_library: extra_library_dict[library_func.__name__] = library_func mlir_library = generate_library(extra_library_dict) - extra_library_file_name = \ - tempfile.gettempdir() + "/custom_op_extra_library.mlir" - with open(extra_library_file_name, "w") as f: + extra_library_file = \ + os.path.join(tempfile.gettempdir(), extra_library_file_name) + with open(extra_library_file, "w") as f: f.write(mlir_library) - return extra_library_file_name + return extra_library_file + else: + return "" def _lower_mlir_module(verbose, output_type, module): From dc470e65c8188eefc99f9a1156cd493155986810 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 24 Apr 2024 11:49:26 +0800 Subject: [PATCH 30/34] add torch.qint32 to dtype-spec in TorchTypes.td (#3206) --- include/torch-mlir/Dialect/Torch/IR/TorchTypes.td | 1 + 1 file changed, 1 insertion(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 898c768ae1c2..e7fc4bc976bb 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -129,6 +129,7 @@ class AnyTorchTensorType | torch.bool | i1 | | torch.qint8 | !torch.qint8 | | torch.quint8 | !torch.quint8 | + | torch.qint32 | !torch.qint32 | | torch.complex64 | complex | | torch.complex128 | complex | |-------------------|--------------------| From e18bf42d0ec428c1dbd82211cfc922087c8ae994 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Wed, 24 Apr 2024 14:15:11 +0800 Subject: [PATCH 31/34] [stablehlo] Support ConstantPadNdOp in stablehlo (#3211) as title --- lib/Conversion/TorchToStablehlo/Basic.cpp | 41 +++++++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 12 +++---- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 84bccae83d81..1858b1a6d7ca 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1632,6 +1632,46 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenConstantPadNdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.getSelf(); + auto selfTy = self.getType().cast(); + auto selfElemTy = selfTy.getElementType(); + int64_t rank = selfTy.getRank(); + + SmallVector padInts; + if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure(op, + "only support constant int pad ranges"); + uint64_t padRank = padInts.size() / 2; + if (padRank * 2 != padInts.size()) + return rewriter.notifyMatchFailure(op, "pad range size is not even"); + if (rank < 0 || padRank > (uint64_t)rank) + return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); + + // Initialize low/high paddings with 0 for all the dims. + SmallVector lowPadding(/*Size=*/rank, /*Value=*/0); + SmallVector highPadding(/*Size=*/rank, /*Value=*/0); + // Add the requested padding - note op.pad() is highest dim first ordered + // pairs of low,high. + // Add the requested padding - note op.pad() is highest dim first ordered + // pairs of low,high. + for (uint64_t i = 0; i < padRank; ++i) { + lowPadding[rank - i - 1] = padInts[i * 2]; + highPadding[rank - i - 1] = padInts[i * 2 + 1]; + } + + Value constantValue = hlo::scalarToStablehloTensor( + rewriter, op, adaptor.getValue(), selfElemTy); + + SmallVector interiorPadding(rank, 0); + rewriter.replaceOpWithNewOp( + op, self, constantValue, lowPadding, highPadding, interiorPadding); + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluBackwardOp op, OpAdaptor adaptor, @@ -2070,6 +2110,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenScalarImplicitOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); + INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1f16a25a9555..ba13b73600f4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -605,10 +605,6 @@ "BroadcastDynamicDimModule_basic", "CeilFloatModule_basic", "ConstantBoolParameterModule_basic", - "ConstantPad2dStaticModule_basic", - "ConstantPadNdModule_basic", - "ConstantPadNdPartialStaticModule_basic", - "ConstantPadNdStaticModule_basic", "ContainsIntList_False", "ContainsIntList_True", "Conv2dQInt8Module_basic", @@ -754,8 +750,6 @@ "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "PadModule_basic", - "PadWithNoneValModule_basic", "PixelShuffleModuleFullDynamic_basic", "PixelShuffleModuleSpatiallyDynamic_basic", "PixelShuffleModuleSpatiallyStatic_basic", @@ -982,6 +976,10 @@ "Convolution2DStaticModule_basic", "ConvolutionBackwardModule2DStatic_basic", "ConvolutionModule2DTransposeStridedStatic_basic", + "ConstantPad2dStaticModule_basic", + "ConstantPadNdModule_basic", + "ConstantPadNdPartialStaticModule_basic", + "ConstantPadNdStaticModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CosineSimilarityStaticModule_basic", "CumsumInputDtypeInt32Module_basic", @@ -1209,6 +1207,8 @@ "OnesModuleFalsePinMemory_basic", "OnesModuleFloat_basic", "OnesModuleInt_basic", + "PadModule_basic", + "PadWithNoneValModule_basic", "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", From fab2696489d31c32418ace3f6af9f76ba17f8a0d Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 24 Apr 2024 14:32:33 +0800 Subject: [PATCH 32/34] [Torch] support aten.trunc (#3219) decompose `trunc(x)` to `sign(x) * floor(abs(x))` --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 46 +++++++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 13 ++++++ .../Transforms/AbstractInterpLibrary.cpp | 8 ++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 27 +++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 6 +++ .../build_tools/abstract_interp_lib_gen.py | 8 ++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 44 ++++++++++++++++++ test/Dialect/Torch/canonicalize.mlir | 8 ++++ 10 files changed, 162 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0e9318753b33..c38d0dbbd389 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4223,6 +4223,52 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [ }]; } +def Torch_AtenTruncOp : Torch_Op<"aten.trunc", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::trunc : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTruncOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTruncOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::trunc_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTrunc_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTrunc_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSignOp : Torch_Op<"aten.sign", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e768033ac87f..a8769def6585 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1834,6 +1834,19 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { return {}; } +//===----------------------------------------------------------------------===// +// AtenTruncOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) { + auto resultType = getType().dyn_cast(); + if (resultType && resultType.hasDtype() && + resultType.getDtype().isa()) { + return getSelf(); + } + return {}; +} + //===----------------------------------------------------------------------===// // AtenSignOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 3ac56e933363..f4415a480a7c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6502,6 +6502,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.trunc\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.log\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10003,6 +10007,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.trunc\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 87f93ba9c555..49dd5319514b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5886,6 +5886,32 @@ class DecomposeAtenCosineSimilarityOp }; } // namespace +namespace { +// decompose `trunc(x)` to `sign(x) * floor(abs(x))` +class DecomposeAtenTruncOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTruncOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result must have dtype"); + } + + if (isa(resultTy.getDtype())) { + Value sign = rewriter.create(loc, resultTy, self); + Value abs = rewriter.create(loc, resultTy, self); + Value floor = rewriter.create(loc, resultTy, abs); + rewriter.replaceOpWithNewOp(op, resultTy, sign, floor); + return success(); + } + return failure(); + } +}; +} // namespace + namespace { // Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and // `aten.add.Tensor` op. @@ -7700,6 +7726,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 701300fefe43..e1377afce373 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -512,6 +512,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ba13b73600f4..45a4b94a51ce 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1479,6 +1479,8 @@ "ElementwiseCoshModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", + "ElementwiseTruncIntModule_basic", + "ElementwiseTruncModule_basic", } STABLEHLO_CRASHING_SET = { @@ -1488,6 +1490,8 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ElementwiseTruncModule_basic", + "ElementwiseTruncIntModule_basic", "ElementwiseSgnModule_basic", "ElementwiseSignIntModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", @@ -2344,6 +2348,8 @@ "ElementwiseSinhModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", + "ElementwiseTruncIntModule_basic", + "ElementwiseTruncModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1bdcfdbe98b8..06962010fea8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -245,6 +245,9 @@ def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], mi def aten〇ceil〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇trunc〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇log〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2227,6 +2230,11 @@ def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇trunc〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0)) def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 638ec1dd8d4c..e5b219e55e9c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -359,6 +359,7 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index b365ac54f3d4..3aa8f10ff9dd 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2077,6 +2077,50 @@ def ElementwiseCeilModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseTruncModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 6], torch.float32, True), + ]) + def forward(self, a): + return torch.trunc(a) + + +@register_test_case(module_factory=lambda: ElementwiseTruncModule()) +def ElementwiseTruncModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[-torch.inf, torch.inf, torch.nan, -2.3, 0.0, 1.5]])) + + +# ============================================================================== + + +class ElementwiseTruncIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.trunc(a) + + +@register_test_case(module_factory=lambda: ElementwiseTruncIntModule()) +def ElementwiseTruncIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseSignModule(torch.nn.Module): def __init__(self): diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 2f7d5a11a216..4d2a595da43a 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2308,6 +2308,14 @@ func.func @torch.aten.floor$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> ! return %0 : !torch.vtensor<[?,?],si64> } +// CHECK-LABEL: func.func @torch.aten.trunc$canonicalize +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],si64> +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?,?],si64> +func.func @torch.aten.trunc$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { + %0 = torch.aten.trunc %arg0 : !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64> + return %0 : !torch.vtensor<[?,?],si64> +} + // CHECK-LABEL: func.func @torch.aten.numel$canonicalize // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4],f32> // CHECK-NEXT: %int12 = torch.constant.int 12 From 678c03b76240279f788b2b5d441625077022648c Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Tue, 23 Apr 2024 23:58:08 -0700 Subject: [PATCH 33/34] Fix nan issue for fp16 torch.randn/randn_like in ConvertAtenUniformOp (#3184) For ops that use ConvertAtenUniformOp (e.g. torch.randn/randn_like), fp16 datatype returns nan values. Trying to lower [this repro](https://gist.github.com/aviator19941/1c65e658241dea6906ca423f9abaee69) will result in nan's, this PR fixes the issue. --- lib/Conversion/TorchToLinalg/Random.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 3b18844df516..6519a272330e 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -129,6 +129,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { Value generator = adaptor.getGenerator(); RankedTensorType resultType = self.getType().cast(); Type elemTy = resultType.getElementType(); + Type f64Ty = rewriter.getF64Type(); if (!isa(elemTy)) return rewriter.notifyMatchFailure(op, "This op only support float type"); @@ -139,8 +140,8 @@ class ConvertAtenUniformOp : public OpConversionPattern { "generator is supported"); // Get key, min and max used by `linalg.generic` compute payload. Value key = rewriter.create(loc); - Value min = convertScalarToDtype(rewriter, loc, from, elemTy); - Value max = convertScalarToDtype(rewriter, loc, to, elemTy); + Value min = convertScalarToDtype(rewriter, loc, from, f64Ty); + Value max = convertScalarToDtype(rewriter, loc, to, f64Ty); // Construct the `linalg.generic` op. auto resultRank = resultType.getRank(); @@ -179,11 +180,14 @@ class ConvertAtenUniformOp : public OpConversionPattern { // res = cast(F64, tempN) * scale + min Value updateFloat = - b.create(loc, elemTy, randomVal); + b.create(loc, f64Ty, randomVal); Value updateScaled = b.create(loc, updateFloat, scale); Value res = b.create(loc, updateScaled, min); - b.create(loc, res); + Value truncRes = res; + if (elemTy.isa()) + truncRes = b.create(loc, elemTy, res); + b.create(loc, truncRes); }) .getResult(0); From 7be22bb26009e1f158a5735cba1d7bd914d7af94 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 24 Apr 2024 13:03:41 -0400 Subject: [PATCH 34/34] Update add_ops.md to link torch mlir get started instructions prominently (#3222) --- docs/add_ops.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/add_ops.md b/docs/add_ops.md index 37dee90817db..b8e5ce37ec45 100644 --- a/docs/add_ops.md +++ b/docs/add_ops.md @@ -2,7 +2,6 @@ Collected links and contacts for how to add ops to torch-mlir. -
Turbine Camp: Start Here This document was previously known as `turbine-camp.md` to Nod.ai. "Turbine Camp" is part of Nod.ai's onboarding process. Welcome to turbine camp. This document originated at Nod.ai as a part of onboardding process, where new nod-ai folks learn about the architecture of our work by adding support for 2 ops to torch-mlir. I decided to put this into torch mlir because a lot of this is about torch-mlir. @@ -27,6 +26,7 @@ The details of how we do it and helpful commands to help you set up each repo is PS: IREE is pronounced Eerie, and hence the ghost icon. ## How to begin +0. Set up torch-mlir according to the instructions here: https://github.com/llvm/torch-mlir/blob/main/docs/development.md 1. You will start by adding support for 2 ops in torch-mlir, to get you familiar with the center of our pipeline. Begin by reading [torch-mlir's documentation on how to implement a new torch op](https://github.com/llvm/torch-mlir/blob/main/docs/Torch-ops-E2E-implementation.md), and set up `llvm/torch_mlir` using https://github.com/llvm/torch-mlir/blob/main/docs/development.md 2. Pick 1 of the yet-unimplemented from the following. You should choose something that looks easy to you. **Make sure you create an issue by clicking the little "target" icon to the right of the op, thereby marking the op as yours** - [TorchToLinalg ops tracking issue](https://github.com/nod-ai/SHARK-Turbine/issues/347)