Skip to content

Commit

Permalink
Merge feature/backport_ea1_ops into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
ljfitz committed Jun 25, 2024
2 parents bd6e22e + 0b089c8 commit 5310597
Show file tree
Hide file tree
Showing 14 changed files with 58 additions and 52 deletions.
21 changes: 19 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/Patterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

Expand All @@ -24,12 +25,28 @@ LogicalResult OnnxCustomOpConversionPattern::matchAndRewrite(
auto foundIt = namedHandlers.find(op.getNameAttr());
if (foundIt == namedHandlers.end())
return failure();
// The domainVersion comes from the function attribute
// torch.onnx_meta.opset_version and defines the opset for all ONNX ops the
// function contains. Absent this attribute, domainVersion is 0.
int64_t opDomainVersion = domainVersion;
// If the op has an individual version (torch.onnx_meta.version attribute), it
// overrides the function's domainVersion and will be used for matching later
// here.
if (auto attr = op->getAttrOfType<IntegerAttr>("torch.onnx_meta.version")) {
if (auto type = dyn_cast<IntegerType>(attr.getType())) {
if (type.isSigned()) {
opDomainVersion =
op->getAttrOfType<IntegerAttr>("torch.onnx_meta.version").getSInt();
}
}
}
auto &reggies = foundIt->second;
for (const HandlerReg &reg : reggies) {
if (domainVersion < reg.sinceVersion) {
if (opDomainVersion < reg.sinceVersion) {
LLVM_DEBUG(dbgs() << ": skipping conversion " << foundIt->first
<< ", sinceVersion=" << reg.sinceVersion
<< ", for domainVersion=" << domainVersion << "\n");
<< ", for domainVersion=" << domainVersion
<< ", opDomainVersion=" << opDomainVersion << "\n");
continue;
}
if (succeeded(reg.callback(OpBinder(op), rewriter))) {
Expand Down
6 changes: 0 additions & 6 deletions lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,6 @@ class ConvertTorchOnnxToTorch

// Populate our patterns for each handled domain.
int64_t defaultOpsetVersion = getDefaultOpsetVersion(getOperation());
if (defaultOpsetVersion == 0) {
emitError(getOperation().getLoc())
<< "function is missing onnx opset version attribute "
"(torch.onnx_meta.opset_version)";
return signalPassFailure();
}

auto defaultDomainPatterns =
std::make_unique<OnnxCustomOpConversionPattern>(
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3066,6 +3066,7 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
Operation *op, Value x, Type dtype) {
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();

auto loc = op->getLoc();

// buildNormalCdf, mean = zero, sigma = one
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,

auto const_op =
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);

if (dtype) {
return rewriter.createOrFold<tosa::CastOp>(
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7538,8 +7538,8 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArcSinCosOp<AtenAsinOp>>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArcSinCosOp<AtenAcosOp>>(
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
newEnd =
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
}
newEnd = rewriter.create<PrimMinIntOp>(op.getLoc(), newEnd, dimSize);

newStart = rewriter.create<PrimMinIntOp>(op.getLoc(), newStart, dimSize);
newEnd = rewriter.create<PrimMinIntOp>(op.getLoc(), newEnd, dimSize);
Expand Down
1 change: 0 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2328,7 +2328,6 @@
"ElementwiseAcosTensorIntModule_basic",
"ElementwiseAsinTensorIntModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"Im2ColModule_basic",
"IndexPutImpl2DNoneIndexBroadcastStaticModule_basic",
"PrimsSumFloatModule_basic",
"RepeatInterleaveFillModule_basic",
Expand Down
3 changes: 0 additions & 3 deletions projects/pt1/python/torch_mlir/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ def _get_decomposition_table():
aten._native_batch_norm_legit,
aten.squeeze,
aten.cumsum,
aten.im2col,
aten.index_select,
aten.linalg_vector_norm,
aten.eye,
]
# TODO: enable test once 2.1.0 is stable
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"QuantizedMLP_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"RepeatInterleaveModule_basic",
"Im2ColModule_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
}
Expand Down
20 changes: 0 additions & 20 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5141,26 +5141,6 @@ def forward(self, x):
def Add_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))

# ==============================================================================

class Im2Col_Module(torch.nn.Module):

def __init__(self):
super().__init__()
self.tensor = torch.ones(2, 3)

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.im2col(x, [9, 1], [1, 1], [4, 0], [1, 1]);

@register_test_case(module_factory=lambda: Im2Col_Module())
def Im2ColModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3,4,5,2))


# ==============================================================================

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1850,22 +1850,6 @@ def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils):
# ==============================================================================


class EyeStaticModule(torch.nn.Module):
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.eye(3, 5)


@register_test_case(module_factory=lambda: EyeStaticModule())
def EyeStaticModule_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================


class EmptyStridedModule(torch.nn.Module):

def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,4 +419,4 @@ def forward(self, a, b):

@register_test_case(module_factory=lambda: AtenLinalgCrossDynamic())
def AtenLinalgCrossDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1))
module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1))
17 changes: 17 additions & 0 deletions test/Conversion/TorchOnnxToTorch/op_wise_version.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s

// CHECK-LABEL: @test_quantizelinear_opset_16_op_19
func.func @test_quantizelinear_opset_16_op_19(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 16 : si64} {
// CHECK-NOT: torch.operator
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx_meta.version = 19 : si64} : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8>
return %0 : !torch.vtensor<[6],si8>
}

// -----

// CHECK-LABEL: @test_quantizelinear_no_opset_op_19
func.func @test_quantizelinear_no_opset_op_19(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64} {
// CHECK-NOT: torch.operator
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx_meta.version = 19 : si64} : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8>
return %0 : !torch.vtensor<[6],si8>
}
18 changes: 18 additions & 0 deletions test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,21 @@ func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtens
%0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64>
return %0 : !torch.vtensor<[2],si64>
}

// -----

// Less is supported starting from v13, so although this Less is legal, it will not be accepted.

func.func @test_earlier_version_than_supported(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// expected-error @+1 {{failed to legalize operation 'torch.operator'}}
%0 = torch.operator "onnx.Less"(%arg0, %arg1) { torch.onnx_meta.version = 7 : si64 } : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
return %0 : !torch.vtensor<[3,4,5],i1>
}

// -----

func.func @test_no_version(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// expected-error @+1 {{failed to legalize operation 'torch.operator'}}
%0 = torch.operator "onnx.Less"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
return %0 : !torch.vtensor<[3,4,5],i1>
}

0 comments on commit 5310597

Please sign in to comment.