diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 1f516c798b..ddee42c015 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -466,18 +466,6 @@ bool checkOpResultIsUsedByGetRef(mlir::memref::AllocOp *allocOp); /// 1, 2 and 4 in the MemRef shape respectively int64_t getAllocArgIndex(mlir::memref::AllocOp allocOp, int64_t index); -/// This function returns a location with the corresponding ONNX operator name -/// inside. This is useful when tracing what expanded MLIR instructions -/// correspond to what ONNX operator. -/// -/// -template -mlir::Location ONNXLoc(mlir::Operation *op) { - return mlir::NameLoc::get( - mlir::StringAttr::get(op->getContext(), OP_TYPE::getOperationName()), - op->getLoc()); -} - /// This function returns a scalar of type 'dtype' from an optional value. /// Optional value must be: NoneType, memref<1xdtype> or memref. /// Default value is used in case of NoneType. diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp index 12575d5afd..06a3479981 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp @@ -47,6 +47,18 @@ namespace onnx_mlir { +/// This function returns a location with the corresponding ONNX operator name +/// inside. This is useful when tracing what expanded MLIR instructions +/// correspond to what ONNX operator. +/// +/// +template +mlir::Location ONNXLoc(mlir::Operation *op) { + return mlir::NameLoc::get( + mlir::StringAttr::get(op->getContext(), OP_TYPE::getOperationName()), + op->getLoc()); +} + //===----------------------------------------------------------------------===// // ONNX Tensor support. diff --git a/src/Dialect/ONNX/Rewrite.cpp b/src/Dialect/ONNX/Rewrite.cpp index 0ee519a7fd..4d5a679492 100644 --- a/src/Dialect/ONNX/Rewrite.cpp +++ b/src/Dialect/ONNX/Rewrite.cpp @@ -687,7 +687,7 @@ class RNNOpRewriteLayoutPattern : public OpRewritePattern { LogicalResult matchAndRewrite( ONNXOp onnxOp, PatternRewriter &rewriter) const override { if (onnxOp.getLayout() == 0) { - return success(); + return failure(); } InputOutputTransposer transposer(rewriter, onnxOp.getLoc()); @@ -735,6 +735,61 @@ class RNNOpRewriteLayoutPattern : public OpRewritePattern { } }; +// Rewrites sequence_lens from tensor to none when bs = 1. It works +// because by definition all batches (meaning one) has the same sequence length. +// This rewrite helps the compiler not need to handle sequence_lens. +template +class RNNOpRewriteSeqLenPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXOp onnxOp, PatternRewriter &rewriter) const override { + Operation *op = onnxOp.getOperation(); + Location loc = ONNXLoc(op); + Value X = onnxOp.getX(); + Value initialH = onnxOp.getInitialH(); + Value seqLen = onnxOp.getSequenceLens(); + + // sequence_lens is already none. Pattern does not match. + if (isNoneValue(seqLen)) + return failure(); + + // Check if batchsize is 1. Batchsize can be in: + // - X: [seq_length, batch_size, input_size], + // - intial_h: [num_directions, batch_size, hidden_size] + // - sequence_lens: [batch_size], or + bool oneInX = false, oneInSeqLen = false, oneInInitalH = false; + if (isRankedShapedType(X.getType())) { + ArrayRef shape = getShape(X.getType()); + oneInX = shape[1] == 1; + } + if (isRankedShapedType(seqLen.getType())) { + ArrayRef shape = getShape(seqLen.getType()); + oneInSeqLen = (shape.size() == 1) && (shape[0] == 1); + } + if (!isNoneValue(initialH) && isRankedShapedType(initialH.getType())) { + ArrayRef shape = getShape(initialH.getType()); + oneInInitalH = shape[1] == 1; + } + if (!oneInX && !oneInInitalH && !oneInSeqLen) + return failure(); + + // We know batchsize is 1. Rewrite now. + MultiDialectBuilder create(rewriter, loc); + // Find the operand index of sequence_lens and update it with none. + bool updated = false; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + if (op->getOperand(i) != seqLen) + continue; + op->setOperand(i, create.onnx.none()); + updated = true; + break; + } + return updated ? success() : failure(); + } +}; + // ============================================================================= // Rewrite pattern for Power // ============================================================================= @@ -912,6 +967,7 @@ void ONNXGreaterOp::getCanonicalizationPatterns( void ONNXGRUOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.insert>(context); + results.insert>(context); } /// on the ONNXIdentityOp. @@ -943,6 +999,7 @@ void ONNXLoopOp::getCanonicalizationPatterns( void ONNXLSTMOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.insert>(context); + results.insert>(context); } /// on the ONNXMulOp. @@ -977,6 +1034,7 @@ void ONNXResizeOp::getCanonicalizationPatterns( void ONNXRNNOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.insert>(context); + results.insert>(context); } /// on the ONNXShapeOp. diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 1f213b1119..5e6e0f82e5 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -938,6 +938,151 @@ func.func @test_lstm_layout1(%arg0: tensor<5x4x2xf32>, %arg1: tensor<1x12x2xf32> // ----- +// Check rewriting sequence_lens from tensor to none when bs = 1. +// Check with LSTM, GRU and RNN. + +func.func @test_lstm_seq_lens_bs1_in_X(%X: tensor<7x1x3xf32>, %W: tensor<1x16x3xf32>, %R: tensor<1x16x4xf32>, %B: tensor<1x32xf32>, %seq_lens: tensor, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> { + %none = "onnx.NoValue"() {value} : () -> none + %Y, %Y_h, %Y_c = "onnx.LSTM"(%X, %W, %R, %B, %seq_lens, %initial_h, %none, %none) {hidden_size = 4 : si64} : (tensor<7x1x3xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, tensor, tensor<1x?x4xf32>, none, none) -> (none, tensor<*xf32>, none) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_lstm_seq_lens_bs1_in_X +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x16x3xf32>, [[PARAM_2_:%.+]]: tensor<1x16x4xf32>, [[PARAM_3_:%.+]]: tensor<1x32xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_:%.+]], [[Y_h_:%.+]], [[VAR_Y_c_:%.+]] = "onnx.LSTM"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]], [[VAR_0_]], [[VAR_0_]]) {direction = "forward", hidden_size = 4 : si64, input_forget = 0 : si64, layout = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, none, tensor<1x?x4xf32>, none, none) -> (none, tensor<*xf32>, none) +// CHECK: return [[Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_lstm_seq_lens_bs1_in_seq_lens(%X: tensor<*xf32>, %W: tensor<1x16x3xf32>, %R: tensor<1x16x4xf32>, %B: tensor<1x32xf32>, %seq_lens: tensor<1xi32>, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> { + %none = "onnx.NoValue"() {value} : () -> none + %Y, %Y_h, %Y_c = "onnx.LSTM"(%X, %W, %R, %B, %seq_lens, %initial_h, %none, %none) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, tensor<1xi32>, tensor<1x?x4xf32>, none, none) -> (none, tensor<*xf32>, none) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_lstm_seq_lens_bs1_in_seq_lens +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x16x3xf32>, [[PARAM_2_:%.+]]: tensor<1x16x4xf32>, [[PARAM_3_:%.+]]: tensor<1x32xf32>, [[PARAM_4_:%.+]]: tensor<1xi32>, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_:%.+]], [[Y_h_:%.+]], [[VAR_Y_c_:%.+]] = "onnx.LSTM"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]], [[VAR_0_]], [[VAR_0_]]) {direction = "forward", hidden_size = 4 : si64, input_forget = 0 : si64, layout = 0 : si64} : (tensor<*xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, none, tensor<1x?x4xf32>, none, none) -> (none, tensor<*xf32>, none) +// CHECK: return [[Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_lstm_seq_lens_bs1_in_initial_h(%X: tensor<*xf32>, %W: tensor<1x16x3xf32>, %R: tensor<1x16x4xf32>, %B: tensor<1x32xf32>, %seq_lens: tensor, %initial_h: tensor<1x1x4xf32>) -> tensor<*xf32> { + %none = "onnx.NoValue"() {value} : () -> none + %Y, %Y_h, %Y_c = "onnx.LSTM"(%X, %W, %R, %B, %seq_lens, %initial_h, %none, %none) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, tensor, tensor<1x1x4xf32>, none, none) -> (none, tensor<*xf32>, none) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_lstm_seq_lens_bs1_in_initial_h +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x16x3xf32>, [[PARAM_2_:%.+]]: tensor<1x16x4xf32>, [[PARAM_3_:%.+]]: tensor<1x32xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x1x4xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_]], [[Y_h_:%.+]], [[VAR_Y_c_:%.+]] = "onnx.LSTM"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]], [[VAR_0_]], [[VAR_0_]]) {direction = "forward", hidden_size = 4 : si64, input_forget = 0 : si64, layout = 0 : si64} : (tensor<*xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, none, tensor<1x1x4xf32>, none, none) -> (none, tensor<*xf32>, none) +// CHECK: return [[Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_gru_seq_lens_bs1_in_X(%X: tensor<7x1x3xf32>, %W: tensor<1x12x3xf32>, %R: tensor<1x12x4xf32>, %B: tensor<1x24xf32>, %seq_lens: tensor, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> { + %Y, %Y_h = "onnx.GRU"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<7x1x3xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, tensor, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_gru_seq_lens_bs1_in_X +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x12x3xf32>, [[PARAM_2_:%.+]]: tensor<1x12x4xf32>, [[PARAM_3_:%.+]]: tensor<1x24xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.GRU"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {direction = "forward", hidden_size = 4 : si64, layout = 0 : si64, linear_before_reset = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_gru_seq_lens_bs1_in_seq_lens(%X: tensor<*xf32>, %W: tensor<1x12x3xf32>, %R: tensor<1x12x4xf32>, %B: tensor<1x24xf32>, %seq_lens: tensor<1xi32>, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> { + %Y, %Y_h = "onnx.GRU"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, tensor<1xi32>, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_gru_seq_lens_bs1_in_seq_lens +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x12x3xf32>, [[PARAM_2_:%.+]]: tensor<1x12x4xf32>, [[PARAM_3_:%.+]]: tensor<1x24xf32>, [[PARAM_4_:%.+]]: tensor<1xi32>, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.GRU"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {direction = "forward", hidden_size = 4 : si64, layout = 0 : si64, linear_before_reset = 0 : si64} : (tensor<*xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_gru_seq_lens_bs1_in_initial_h(%X: tensor<*xf32>, %W: tensor<1x12x3xf32>, %R: tensor<1x12x4xf32>, %B: tensor<1x24xf32>, %seq_lens: tensor, %initial_h: tensor<1x1x4xf32>) -> tensor<*xf32> { + %Y, %Y_h = "onnx.GRU"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, tensor, tensor<1x1x4xf32>) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_gru_seq_lens_bs1_in_initial_h +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x12x3xf32>, [[PARAM_2_:%.+]]: tensor<1x12x4xf32>, [[PARAM_3_:%.+]]: tensor<1x24xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x1x4xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.GRU"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {direction = "forward", hidden_size = 4 : si64, layout = 0 : si64, linear_before_reset = 0 : si64} : (tensor<*xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, none, tensor<1x1x4xf32>) -> (none, tensor<*xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_rnn_seq_lens_bs1_in_X(%X: tensor<7x1x3xf32>, %W: tensor<1x4x3xf32>, %R: tensor<1x4x4xf32>, %B: tensor<1x8xf32>, %seq_lens: tensor, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> { + %Y, %Y_h = "onnx.RNN"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<7x1x3xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_rnn_seq_lens_bs1_in_X +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x4x3xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_rnn_seq_lens_bs1_in_seq_lens(%X: tensor<*xf32>, %W: tensor<1x4x3xf32>, %R: tensor<1x4x4xf32>, %B: tensor<1x8xf32>, %seq_lens: tensor<1xi32>, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> { + %Y, %Y_h = "onnx.RNN"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor<1xi32>, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_rnn_seq_lens_bs1_in_seq_lens +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x4x3xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor<1xi32>, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_rnn_seq_lens_bs1_in_initial_h(%X: tensor<*xf32>, %W: tensor<1x4x3xf32>, %R: tensor<1x4x4xf32>, %B: tensor<1x8xf32>, %seq_lens: tensor, %initial_h: tensor<1x1x4xf32>) -> tensor<*xf32> { + %Y, %Y_h = "onnx.RNN"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor, tensor<1x1x4xf32>) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_rnn_seq_lens_bs1_in_initial_h +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x4x3xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x1x4xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, none, tensor<1x1x4xf32>) -> (none, tensor<*xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_rnn_seq_lens_not_rewrite(%X: tensor<*xf32>, %W: tensor<1x4x3xf32>, %R: tensor<1x4x4xf32>, %B: tensor<1x8xf32>, %seq_lens: tensor, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> { + %Y, %Y_h = "onnx.RNN"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + +// CHECK-LABEL: func.func @test_rnn_seq_lens_not_rewrite +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x4x3xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> { +// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + func.func @test_dim_to_constant(%arg0: tensor) -> (tensor<1xi64>) { %0 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor) -> tensor<1xi64> onnx.Return %0 : tensor<1xi64> @@ -1110,3 +1255,4 @@ func.func @mul_broadcast_axis_unsqueeze(%279: tensor<1x64x112x112xf32>, %138: te // CHECK: onnx.Return [[VAR_2_]] : tensor<*xf32> // CHECK: } } +