diff --git a/lib/Conversion/TorchOnnxToTorch/Patterns.cpp b/lib/Conversion/TorchOnnxToTorch/Patterns.cpp index 6ca7824165d3..a3958d92ead5 100644 --- a/lib/Conversion/TorchOnnxToTorch/Patterns.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Patterns.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "mlir/IR/BuiltinAttributes.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -24,12 +25,28 @@ LogicalResult OnnxCustomOpConversionPattern::matchAndRewrite( auto foundIt = namedHandlers.find(op.getNameAttr()); if (foundIt == namedHandlers.end()) return failure(); + // The domainVersion comes from the function attribute + // torch.onnx_meta.opset_version and defines the opset for all ONNX ops the + // function contains. Absent this attribute, domainVersion is 0. + int64_t opDomainVersion = domainVersion; + // If the op has an individual version (torch.onnx_meta.version attribute), it + // overrides the function's domainVersion and will be used for matching later + // here. + if (auto attr = op->getAttrOfType("torch.onnx_meta.version")) { + if (auto type = dyn_cast(attr.getType())) { + if (type.isSigned()) { + opDomainVersion = + op->getAttrOfType("torch.onnx_meta.version").getSInt(); + } + } + } auto ®gies = foundIt->second; for (const HandlerReg ® : reggies) { - if (domainVersion < reg.sinceVersion) { + if (opDomainVersion < reg.sinceVersion) { LLVM_DEBUG(dbgs() << ": skipping conversion " << foundIt->first << ", sinceVersion=" << reg.sinceVersion - << ", for domainVersion=" << domainVersion << "\n"); + << ", for domainVersion=" << domainVersion + << ", opDomainVersion=" << opDomainVersion << "\n"); continue; } if (succeeded(reg.callback(OpBinder(op), rewriter))) { diff --git a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp index ea890bf0f4b6..fa2b95c0c29f 100644 --- a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp +++ b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp @@ -45,12 +45,6 @@ class ConvertTorchOnnxToTorch // Populate our patterns for each handled domain. int64_t defaultOpsetVersion = getDefaultOpsetVersion(getOperation()); - if (defaultOpsetVersion == 0) { - emitError(getOperation().getLoc()) - << "function is missing onnx opset version attribute " - "(torch.onnx_meta.opset_version)"; - return signalPassFailure(); - } auto defaultDomainPatterns = std::make_unique( diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index ee9fe6e26d44..a25bbe402a73 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3066,6 +3066,7 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x, Type dtype) { auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); + auto loc = op->getLoc(); // buildNormalCdf, mean = zero, sigma = one diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index d2fe75390e68..1fcc91991f37 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -240,6 +240,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); + if (dtype) { return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 66ca5e12c9d4..e6e5200677ff 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7538,8 +7538,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal>( patterns); addPatternIfTargetOpIsIllegal>( diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index 69c8715442a7..f3a589ba4d70 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -70,7 +70,6 @@ class RecomposeSliceCopy_ : public OpRewritePattern { newEnd = rewriter.create(op.getLoc(), dimSize, sliceOp.getEnd()); } - newEnd = rewriter.create(op.getLoc(), newEnd, dimSize); newStart = rewriter.create(op.getLoc(), newStart, dimSize); newEnd = rewriter.create(op.getLoc(), newEnd, dimSize); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 14b30bcd5519..a8e4649a96b8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2328,7 +2328,6 @@ "ElementwiseAcosTensorIntModule_basic", "ElementwiseAsinTensorIntModule_basic", "FakeQuantizePerTensorAffineCachemaskModule_basic", - "Im2ColModule_basic", "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", "PrimsSumFloatModule_basic", "RepeatInterleaveFillModule_basic", diff --git a/projects/pt1/python/torch_mlir/dynamo.py b/projects/pt1/python/torch_mlir/dynamo.py index b9420f1f8d34..7fc887d56bc4 100644 --- a/projects/pt1/python/torch_mlir/dynamo.py +++ b/projects/pt1/python/torch_mlir/dynamo.py @@ -65,10 +65,7 @@ def _get_decomposition_table(): aten._native_batch_norm_legit, aten.squeeze, aten.cumsum, - aten.im2col, aten.index_select, - aten.linalg_vector_norm, - aten.eye, ] # TODO: enable test once 2.1.0 is stable if torch_version_for_comparison() >= version.parse("2.1.0.dev"): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 0d16158af887..6f492a1eff5c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -14,7 +14,6 @@ "QuantizedMLP_basic", "ReduceMaxAlongDimUnsignedInt_basic", "RepeatInterleaveModule_basic", - "Im2ColModule_basic", "ReduceMinAlongDimUnsignedInt_basic", "ElementwiseToDtypeI64ToUI8Module_basic", } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index e5fab589258f..f1e3700e0a4a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5141,26 +5141,6 @@ def forward(self, x): def Add_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) -# ============================================================================== - -class Im2Col_Module(torch.nn.Module): - - def __init__(self): - super().__init__() - self.tensor = torch.ones(2, 3) - - @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) - def forward(self, x): - return torch.ops.aten.im2col(x, [9, 1], [1, 1], [4, 0], [1, 1]); - -@register_test_case(module_factory=lambda: Im2Col_Module()) -def Im2ColModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3,4,5,2)) - # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index a3cf7d525251..38138d742dc5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1850,22 +1850,6 @@ def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils): # ============================================================================== -class EyeStaticModule(torch.nn.Module): - @export - @annotate_args([ - None, - ]) - def forward(self): - return torch.ops.aten.eye(3, 5) - - -@register_test_case(module_factory=lambda: EyeStaticModule()) -def EyeStaticModule_basic(module, tu: TestUtils): - module.forward() - -# ============================================================================== - - class EmptyStridedModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 2ccd9d9d39c8..9b94ac42c605 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -419,4 +419,4 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: AtenLinalgCrossDynamic()) def AtenLinalgCrossDynamic_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1)) + module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1)) \ No newline at end of file diff --git a/test/Conversion/TorchOnnxToTorch/op_wise_version.mlir b/test/Conversion/TorchOnnxToTorch/op_wise_version.mlir new file mode 100644 index 000000000000..f35ecf3aeca5 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/op_wise_version.mlir @@ -0,0 +1,17 @@ +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s + +// CHECK-LABEL: @test_quantizelinear_opset_16_op_19 +func.func @test_quantizelinear_opset_16_op_19(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 16 : si64} { + // CHECK-NOT: torch.operator + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx_meta.version = 19 : si64} : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> + return %0 : !torch.vtensor<[6],si8> +} + +// ----- + +// CHECK-LABEL: @test_quantizelinear_no_opset_op_19 +func.func @test_quantizelinear_no_opset_op_19(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64} { + // CHECK-NOT: torch.operator + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx_meta.version = 19 : si64} : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> + return %0 : !torch.vtensor<[6],si8> +} diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir index 22d5e2d35183..b55b87912aec 100644 --- a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir +++ b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir @@ -16,3 +16,21 @@ func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtens %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> return %0 : !torch.vtensor<[2],si64> } + +// ----- + +// Less is supported starting from v13, so although this Less is legal, it will not be accepted. + +func.func @test_earlier_version_than_supported(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // expected-error @+1 {{failed to legalize operation 'torch.operator'}} + %0 = torch.operator "onnx.Less"(%arg0, %arg1) { torch.onnx_meta.version = 7 : si64 } : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// ----- + +func.func @test_no_version(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // expected-error @+1 {{failed to legalize operation 'torch.operator'}} + %0 = torch.operator "onnx.Less"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +}