diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index c00522a763fc..0de85f4eebe5 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -97,6 +97,19 @@ struct OpBinder { return success(); } + // Operand matches of different arities. + ParseResult tensorListOperand(Value &value0) { + if (op->getNumOperands() != 1) + return failure(); + value0 = op->getOperand(0); + auto tt = dyn_cast(value0.getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + return success(); + } + ParseResult tensorListResultType(Torch::ListType &type0) { if (op->getNumResults() != 1) return failure(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 037633490d93..0f445b5944df 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -532,6 +532,30 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operands); return success(); }); + patterns.onOp( + "SequenceLength", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // onnx.SequenceLength takes a sequence(list) of tensors, and returns + // a zero rank tensor with the length. + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorListOperand(x) || binder.tensorResultType(resultType)) + return failure(); + + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value none = rewriter.create(binder.getLoc()); + + Value len = rewriter.create( + binder.getLoc(), rewriter.getType(), x); + + // AtenLenTOp returns a torch.int, so we have to + // put that in a tensor. + rewriter.replaceOpWithNewOp( + binder.op, resultType, len, none, none, cstFalse); + + return success(); + }); patterns.onOp( "Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 9432702b6b12..0fc82da74f46 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2099,6 +2099,21 @@ module { // ----- +// CHECK-LABEL: func.func @test_sequence_length +module { + func.func @test_sequence_length(%arg0: !torch.list>) -> !torch.vtensor<[],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[FALSE:.+]] = torch.constant.bool false +// CHECK: %[[NONE:.+]] = torch.constant.none +// CHECK: %[[LEN:.+]] = torch.aten.len.t %arg0 : !torch.list> -> !torch.int +// CHECK: %[[LEN_AS_TEN:.+]] = torch.aten.tensor.int %[[LEN]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si64> +// CHECK: return %[[LEN_AS_TEN]] : !torch.vtensor<[],si64> + %0 = torch.operator "onnx.SequenceLength"(%arg0) : (!torch.list>) -> !torch.vtensor<[],si64> + return %0 : !torch.vtensor<[],si64> + } +} + +// ----- + // CHECK-LABEL: func.func @test_sce_mean_3d func.func @test_sce_mean_3d(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none