Skip to content

Commit

Permalink
Replace sequence_lens with none in RNN ops when bs is 1 (#2404)
Browse files Browse the repository at this point in the history
* Expand's shape inference returns success when shape input is unknown

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Jul 28, 2023
1 parent 8268d43 commit 0d32512
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 13 deletions.
12 changes: 0 additions & 12 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,18 +466,6 @@ bool checkOpResultIsUsedByGetRef(mlir::memref::AllocOp *allocOp);
/// 1, 2 and 4 in the MemRef shape respectively
int64_t getAllocArgIndex(mlir::memref::AllocOp allocOp, int64_t index);

/// This function returns a location with the corresponding ONNX operator name
/// inside. This is useful when tracing what expanded MLIR instructions
/// correspond to what ONNX operator.
///
///
template <typename OP_TYPE>
mlir::Location ONNXLoc(mlir::Operation *op) {
return mlir::NameLoc::get(
mlir::StringAttr::get(op->getContext(), OP_TYPE::getOperationName()),
op->getLoc());
}

/// This function returns a scalar of type 'dtype' from an optional value.
/// Optional value must be: NoneType, memref<1xdtype> or memref<dtype>.
/// Default value is used in case of NoneType.
Expand Down
12 changes: 12 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@

namespace onnx_mlir {

/// This function returns a location with the corresponding ONNX operator name
/// inside. This is useful when tracing what expanded MLIR instructions
/// correspond to what ONNX operator.
///
///
template <typename OP_TYPE>
mlir::Location ONNXLoc(mlir::Operation *op) {
return mlir::NameLoc::get(
mlir::StringAttr::get(op->getContext(), OP_TYPE::getOperationName()),
op->getLoc());
}

//===----------------------------------------------------------------------===//
// ONNX Tensor support.

Expand Down
60 changes: 59 additions & 1 deletion src/Dialect/ONNX/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ class RNNOpRewriteLayoutPattern : public OpRewritePattern<ONNXOp> {
LogicalResult matchAndRewrite(
ONNXOp onnxOp, PatternRewriter &rewriter) const override {
if (onnxOp.getLayout() == 0) {
return success();
return failure();
}

InputOutputTransposer transposer(rewriter, onnxOp.getLoc());
Expand Down Expand Up @@ -735,6 +735,61 @@ class RNNOpRewriteLayoutPattern : public OpRewritePattern<ONNXOp> {
}
};

// Rewrites sequence_lens from tensor<bsxi32> to none when bs = 1. It works
// because by definition all batches (meaning one) has the same sequence length.
// This rewrite helps the compiler not need to handle sequence_lens.
template <typename ONNXOp>
class RNNOpRewriteSeqLenPattern : public OpRewritePattern<ONNXOp> {
public:
using OpRewritePattern<ONNXOp>::OpRewritePattern;

LogicalResult matchAndRewrite(
ONNXOp onnxOp, PatternRewriter &rewriter) const override {
Operation *op = onnxOp.getOperation();
Location loc = ONNXLoc<ONNXOp>(op);
Value X = onnxOp.getX();
Value initialH = onnxOp.getInitialH();
Value seqLen = onnxOp.getSequenceLens();

// sequence_lens is already none. Pattern does not match.
if (isNoneValue(seqLen))
return failure();

// Check if batchsize is 1. Batchsize can be in:
// - X: [seq_length, batch_size, input_size],
// - intial_h: [num_directions, batch_size, hidden_size]
// - sequence_lens: [batch_size], or
bool oneInX = false, oneInSeqLen = false, oneInInitalH = false;
if (isRankedShapedType(X.getType())) {
ArrayRef<int64_t> shape = getShape(X.getType());
oneInX = shape[1] == 1;
}
if (isRankedShapedType(seqLen.getType())) {
ArrayRef<int64_t> shape = getShape(seqLen.getType());
oneInSeqLen = (shape.size() == 1) && (shape[0] == 1);
}
if (!isNoneValue(initialH) && isRankedShapedType(initialH.getType())) {
ArrayRef<int64_t> shape = getShape(initialH.getType());
oneInInitalH = shape[1] == 1;
}
if (!oneInX && !oneInInitalH && !oneInSeqLen)
return failure();

// We know batchsize is 1. Rewrite now.
MultiDialectBuilder<OnnxBuilder> create(rewriter, loc);
// Find the operand index of sequence_lens and update it with none.
bool updated = false;
for (unsigned i = 0; i < op->getNumOperands(); ++i) {
if (op->getOperand(i) != seqLen)
continue;
op->setOperand(i, create.onnx.none());
updated = true;
break;
}
return updated ? success() : failure();
}
};

// =============================================================================
// Rewrite pattern for Power
// =============================================================================
Expand Down Expand Up @@ -912,6 +967,7 @@ void ONNXGreaterOp::getCanonicalizationPatterns(
void ONNXGRUOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.insert<RNNOpRewriteLayoutPattern<ONNXGRUOp>>(context);
results.insert<RNNOpRewriteSeqLenPattern<ONNXGRUOp>>(context);
}

/// on the ONNXIdentityOp.
Expand Down Expand Up @@ -943,6 +999,7 @@ void ONNXLoopOp::getCanonicalizationPatterns(
void ONNXLSTMOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.insert<RNNOpRewriteLayoutPattern<ONNXLSTMOp>>(context);
results.insert<RNNOpRewriteSeqLenPattern<ONNXLSTMOp>>(context);
}

/// on the ONNXMulOp.
Expand Down Expand Up @@ -977,6 +1034,7 @@ void ONNXResizeOp::getCanonicalizationPatterns(
void ONNXRNNOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.insert<RNNOpRewriteLayoutPattern<ONNXRNNOp>>(context);
results.insert<RNNOpRewriteSeqLenPattern<ONNXRNNOp>>(context);
}

/// on the ONNXShapeOp.
Expand Down
146 changes: 146 additions & 0 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,151 @@ func.func @test_lstm_layout1(%arg0: tensor<5x4x2xf32>, %arg1: tensor<1x12x2xf32>

// -----

// Check rewriting sequence_lens from tensor<bsxi32> to none when bs = 1.
// Check with LSTM, GRU and RNN.

func.func @test_lstm_seq_lens_bs1_in_X(%X: tensor<7x1x3xf32>, %W: tensor<1x16x3xf32>, %R: tensor<1x16x4xf32>, %B: tensor<1x32xf32>, %seq_lens: tensor<?xi32>, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> {
%none = "onnx.NoValue"() {value} : () -> none
%Y, %Y_h, %Y_c = "onnx.LSTM"(%X, %W, %R, %B, %seq_lens, %initial_h, %none, %none) {hidden_size = 4 : si64} : (tensor<7x1x3xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, tensor<?xi32>, tensor<1x?x4xf32>, none, none) -> (none, tensor<*xf32>, none)
return %Y_h : tensor<*xf32>

// CHECK-LABEL: func.func @test_lstm_seq_lens_bs1_in_X
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x16x3xf32>, [[PARAM_2_:%.+]]: tensor<1x16x4xf32>, [[PARAM_3_:%.+]]: tensor<1x32xf32>, [[PARAM_4_:%.+]]: tensor<?xi32>, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_:%.+]], [[Y_h_:%.+]], [[VAR_Y_c_:%.+]] = "onnx.LSTM"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]], [[VAR_0_]], [[VAR_0_]]) {direction = "forward", hidden_size = 4 : si64, input_forget = 0 : si64, layout = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, none, tensor<1x?x4xf32>, none, none) -> (none, tensor<*xf32>, none)
// CHECK: return [[Y_h_]] : tensor<*xf32>
// CHECK: }
}

// -----

func.func @test_lstm_seq_lens_bs1_in_seq_lens(%X: tensor<*xf32>, %W: tensor<1x16x3xf32>, %R: tensor<1x16x4xf32>, %B: tensor<1x32xf32>, %seq_lens: tensor<1xi32>, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> {
%none = "onnx.NoValue"() {value} : () -> none
%Y, %Y_h, %Y_c = "onnx.LSTM"(%X, %W, %R, %B, %seq_lens, %initial_h, %none, %none) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, tensor<1xi32>, tensor<1x?x4xf32>, none, none) -> (none, tensor<*xf32>, none)
return %Y_h : tensor<*xf32>

// CHECK-LABEL: func.func @test_lstm_seq_lens_bs1_in_seq_lens
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x16x3xf32>, [[PARAM_2_:%.+]]: tensor<1x16x4xf32>, [[PARAM_3_:%.+]]: tensor<1x32xf32>, [[PARAM_4_:%.+]]: tensor<1xi32>, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_:%.+]], [[Y_h_:%.+]], [[VAR_Y_c_:%.+]] = "onnx.LSTM"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]], [[VAR_0_]], [[VAR_0_]]) {direction = "forward", hidden_size = 4 : si64, input_forget = 0 : si64, layout = 0 : si64} : (tensor<*xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, none, tensor<1x?x4xf32>, none, none) -> (none, tensor<*xf32>, none)
// CHECK: return [[Y_h_]] : tensor<*xf32>
// CHECK: }
}

// -----

func.func @test_lstm_seq_lens_bs1_in_initial_h(%X: tensor<*xf32>, %W: tensor<1x16x3xf32>, %R: tensor<1x16x4xf32>, %B: tensor<1x32xf32>, %seq_lens: tensor<?xi32>, %initial_h: tensor<1x1x4xf32>) -> tensor<*xf32> {
%none = "onnx.NoValue"() {value} : () -> none
%Y, %Y_h, %Y_c = "onnx.LSTM"(%X, %W, %R, %B, %seq_lens, %initial_h, %none, %none) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, tensor<?xi32>, tensor<1x1x4xf32>, none, none) -> (none, tensor<*xf32>, none)
return %Y_h : tensor<*xf32>

// CHECK-LABEL: func.func @test_lstm_seq_lens_bs1_in_initial_h
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x16x3xf32>, [[PARAM_2_:%.+]]: tensor<1x16x4xf32>, [[PARAM_3_:%.+]]: tensor<1x32xf32>, [[PARAM_4_:%.+]]: tensor<?xi32>, [[PARAM_5_:%.+]]: tensor<1x1x4xf32>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_]], [[Y_h_:%.+]], [[VAR_Y_c_:%.+]] = "onnx.LSTM"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]], [[VAR_0_]], [[VAR_0_]]) {direction = "forward", hidden_size = 4 : si64, input_forget = 0 : si64, layout = 0 : si64} : (tensor<*xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, none, tensor<1x1x4xf32>, none, none) -> (none, tensor<*xf32>, none)
// CHECK: return [[Y_h_]] : tensor<*xf32>
// CHECK: }
}

// -----

func.func @test_gru_seq_lens_bs1_in_X(%X: tensor<7x1x3xf32>, %W: tensor<1x12x3xf32>, %R: tensor<1x12x4xf32>, %B: tensor<1x24xf32>, %seq_lens: tensor<?xi32>, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> {
%Y, %Y_h = "onnx.GRU"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<7x1x3xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, tensor<?xi32>, tensor<1x?x4xf32>) -> (none, tensor<*xf32>)
return %Y_h : tensor<*xf32>

// CHECK-LABEL: func.func @test_gru_seq_lens_bs1_in_X
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x12x3xf32>, [[PARAM_2_:%.+]]: tensor<1x12x4xf32>, [[PARAM_3_:%.+]]: tensor<1x24xf32>, [[PARAM_4_:%.+]]: tensor<?xi32>, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.GRU"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {direction = "forward", hidden_size = 4 : si64, layout = 0 : si64, linear_before_reset = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<*xf32>)
// CHECK: return [[VAR_Y_h_]] : tensor<*xf32>
// CHECK: }
}

// -----

func.func @test_gru_seq_lens_bs1_in_seq_lens(%X: tensor<*xf32>, %W: tensor<1x12x3xf32>, %R: tensor<1x12x4xf32>, %B: tensor<1x24xf32>, %seq_lens: tensor<1xi32>, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> {
%Y, %Y_h = "onnx.GRU"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, tensor<1xi32>, tensor<1x?x4xf32>) -> (none, tensor<*xf32>)
return %Y_h : tensor<*xf32>

// CHECK-LABEL: func.func @test_gru_seq_lens_bs1_in_seq_lens
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x12x3xf32>, [[PARAM_2_:%.+]]: tensor<1x12x4xf32>, [[PARAM_3_:%.+]]: tensor<1x24xf32>, [[PARAM_4_:%.+]]: tensor<1xi32>, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.GRU"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {direction = "forward", hidden_size = 4 : si64, layout = 0 : si64, linear_before_reset = 0 : si64} : (tensor<*xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<*xf32>)
// CHECK: return [[VAR_Y_h_]] : tensor<*xf32>
// CHECK: }
}

// -----

func.func @test_gru_seq_lens_bs1_in_initial_h(%X: tensor<*xf32>, %W: tensor<1x12x3xf32>, %R: tensor<1x12x4xf32>, %B: tensor<1x24xf32>, %seq_lens: tensor<?xi32>, %initial_h: tensor<1x1x4xf32>) -> tensor<*xf32> {
%Y, %Y_h = "onnx.GRU"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, tensor<?xi32>, tensor<1x1x4xf32>) -> (none, tensor<*xf32>)
return %Y_h : tensor<*xf32>

// CHECK-LABEL: func.func @test_gru_seq_lens_bs1_in_initial_h
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x12x3xf32>, [[PARAM_2_:%.+]]: tensor<1x12x4xf32>, [[PARAM_3_:%.+]]: tensor<1x24xf32>, [[PARAM_4_:%.+]]: tensor<?xi32>, [[PARAM_5_:%.+]]: tensor<1x1x4xf32>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.GRU"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {direction = "forward", hidden_size = 4 : si64, layout = 0 : si64, linear_before_reset = 0 : si64} : (tensor<*xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, none, tensor<1x1x4xf32>) -> (none, tensor<*xf32>)
// CHECK: return [[VAR_Y_h_]] : tensor<*xf32>
// CHECK: }
}

// -----

func.func @test_rnn_seq_lens_bs1_in_X(%X: tensor<7x1x3xf32>, %W: tensor<1x4x3xf32>, %R: tensor<1x4x4xf32>, %B: tensor<1x8xf32>, %seq_lens: tensor<?xi32>, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> {
%Y, %Y_h = "onnx.RNN"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<7x1x3xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor<?xi32>, tensor<1x?x4xf32>) -> (none, tensor<*xf32>)
return %Y_h : tensor<*xf32>

// CHECK-LABEL: func.func @test_rnn_seq_lens_bs1_in_X
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x4x3xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor<?xi32>, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<*xf32>)
// CHECK: return [[VAR_Y_h_]] : tensor<*xf32>
// CHECK: }
}

// -----

func.func @test_rnn_seq_lens_bs1_in_seq_lens(%X: tensor<*xf32>, %W: tensor<1x4x3xf32>, %R: tensor<1x4x4xf32>, %B: tensor<1x8xf32>, %seq_lens: tensor<1xi32>, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> {
%Y, %Y_h = "onnx.RNN"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor<1xi32>, tensor<1x?x4xf32>) -> (none, tensor<*xf32>)
return %Y_h : tensor<*xf32>

// CHECK-LABEL: func.func @test_rnn_seq_lens_bs1_in_seq_lens
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x4x3xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor<1xi32>, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<*xf32>)
// CHECK: return [[VAR_Y_h_]] : tensor<*xf32>
// CHECK: }
}

// -----

func.func @test_rnn_seq_lens_bs1_in_initial_h(%X: tensor<*xf32>, %W: tensor<1x4x3xf32>, %R: tensor<1x4x4xf32>, %B: tensor<1x8xf32>, %seq_lens: tensor<?xi32>, %initial_h: tensor<1x1x4xf32>) -> tensor<*xf32> {
%Y, %Y_h = "onnx.RNN"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor<?xi32>, tensor<1x1x4xf32>) -> (none, tensor<*xf32>)
return %Y_h : tensor<*xf32>

// CHECK-LABEL: func.func @test_rnn_seq_lens_bs1_in_initial_h
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x4x3xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor<?xi32>, [[PARAM_5_:%.+]]: tensor<1x1x4xf32>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, none, tensor<1x1x4xf32>) -> (none, tensor<*xf32>)
// CHECK: return [[VAR_Y_h_]] : tensor<*xf32>
// CHECK: }
}

// -----

func.func @test_rnn_seq_lens_not_rewrite(%X: tensor<*xf32>, %W: tensor<1x4x3xf32>, %R: tensor<1x4x4xf32>, %B: tensor<1x8xf32>, %seq_lens: tensor<?xi32>, %initial_h: tensor<1x?x4xf32>) -> tensor<*xf32> {
%Y, %Y_h = "onnx.RNN"(%X, %W, %R, %B, %seq_lens, %initial_h) {hidden_size = 4 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor<?xi32>, tensor<1x?x4xf32>) -> (none, tensor<*xf32>)
return %Y_h : tensor<*xf32>

// CHECK-LABEL: func.func @test_rnn_seq_lens_not_rewrite
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x4x3xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor<?xi32>, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> {
// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor<*xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, tensor<?xi32>, tensor<1x?x4xf32>) -> (none, tensor<*xf32>)
// CHECK: return [[VAR_Y_h_]] : tensor<*xf32>
// CHECK: }
}

// -----

func.func @test_dim_to_constant(%arg0: tensor<?x256xi64>) -> (tensor<1xi64>) {
%0 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x256xi64>) -> tensor<1xi64>
onnx.Return %0 : tensor<1xi64>
Expand Down Expand Up @@ -1110,3 +1255,4 @@ func.func @mul_broadcast_axis_unsqueeze(%279: tensor<1x64x112x112xf32>, %138: te
// CHECK: onnx.Return [[VAR_2_]] : tensor<*xf32>
// CHECK: }
}

0 comments on commit 0d32512

Please sign in to comment.