From 27afb6c11b0d117b79c2be767e3fdf5988d06bf7 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Tue, 5 Sep 2023 14:11:52 -0400 Subject: [PATCH 01/14] change Ht Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/GRU.cpp | 25 ++++++++++++++++++++++- src/Conversion/ONNXToKrnl/RNN/LSTM.cpp | 2 +- src/Conversion/ONNXToKrnl/RNN/RNN.cpp | 2 +- src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp | 8 +++++--- 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index a99701f3ea..45a7afb593 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -389,7 +389,8 @@ template <> void calculateState( ConversionPatternRewriter &rewriter, Location loc, Value Xt, GruState state, GruActivationPack activationPack, GruWeightPack weightPack, - GruBiasPack biasPack, Value sequenceIV, Value directionIV, bool isForward) { + GruBiasPack biasPack, Value sequenceIV, Value directionIV, + Value sequenceLens, bool isForward) { // Equations (Default: f=Sigmoid, g=Tanh):" // zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)" // rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)" @@ -498,6 +499,17 @@ void calculateState( Value nextHt = createMath.add(ztht, ztHt); // Store the intermediate Ht. + // Handle sequence_lens + if (!isNoneValue(sequenceLens)) { + Value sequenceUB = createKrnl.load(sequenceLens, {bs}); + Value cond = createMath.sge( + createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); + nextHt = createMath.select(cond, + /*should use initialH*/ + createMath.constant(nextHt.getType(), 0.), + nextHt); + } + createKrnl.store(nextHt, Ht, indices); if (!isNoneValue(state.allH)) createKrnl.store( @@ -602,6 +614,17 @@ void calculateState( Value nextHt = createMath.add(ztht, ztHt); // Store the intermediate Ht. + // Handle sequence_lens + if (!isNoneValue(sequenceLens)) { + Value sequenceUB = createKrnl.load(sequenceLens, {bs}); + Value cond = createMath.sge( + createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); + nextHt = createMath.select(cond, + /*should use initialH*/ + createMath.constant(nextHt.getType(), 0.), + nextHt); + } + createKrnl.store(nextHt, Ht, indices); if (!isNoneValue(state.allH)) createKrnl.store( diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index 5a693f988d..959381815d 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -443,7 +443,7 @@ void calculateState(ConversionPatternRewriter &rewriter, Location loc, Value Xt, LstmState state, LstmActivationPack activationPack, LstmWeightPack weightPack, LstmBiasPack biasPack, Value sequenceIV, - Value directionIV, bool isForward) { + Value directionIV, Value sequenceLens, bool isForward) { // Equations for LSTM. // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) diff --git a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp index ca7a1993f8..1ae4b0bd7b 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp @@ -300,7 +300,7 @@ template <> void calculateState( ConversionPatternRewriter &rewriter, Location loc, Value Xt, RnnState state, RnnActivationPack activationPack, RnnWeightPack weightPack, - RnnBiasPack biasPack, Value sequenceIV, Value directionIV, bool isForward) { + RnnBiasPack biasPack, Value sequenceIV, Value directionIV, Value sequenceLens, bool isForward) { // Equations for RNN. // Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) // Shape information: diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp index c2a9327c2b..418da5e87a 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp @@ -116,7 +116,8 @@ S allocAndInitializeStates(mlir::ConversionPatternRewriter &rewriter, template void calculateState(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, mlir::Value Xt, S state, A activationSet, W weight, - B bias, mlir::Value sequenceIV, mlir::Value directionIV, bool isForward); + B bias, mlir::Value sequenceIV, mlir::Value directionIV, + mlir::Value sequenceLens, bool isForward); // Write states to the RNN's outputs. template @@ -136,6 +137,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { mlir::Operation *op = rnnOp.getOperation(); mlir::Location loc = ONNXLoc(op); mlir::Value X = adaptor.getX(); + mlir::Value sequenceLens = adaptor.getSequenceLens(); if (hasAllNoneOutput(&rnnOp)) { rewriter.eraseOp(op); @@ -188,7 +190,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { // Emit calculation for one RNN step. calculateState(rewriter, loc, Xt, state, activationForward, weightForward, biasForward, sequenceIV, - directionIV, + directionIV, sequenceLens, /*isForward=*/true); }); } @@ -226,7 +228,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { // Emit calculation for one RNN step. calculateState(rewriter, loc, Xt, state, activationReverse, weightReverse, biasReverse, - reverseSequenceIV, directionIV, + reverseSequenceIV, directionIV, sequenceLens, /*isForward=*/false); }); } From 994e649f4f895869ccbe48c300cc0520528a71c2 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Tue, 5 Sep 2023 14:33:32 -0400 Subject: [PATCH 02/14] format Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/GRU.cpp | 6 ++---- src/Conversion/ONNXToKrnl/RNN/LSTM.cpp | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index 45a7afb593..07cb573cf8 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -506,8 +506,7 @@ void calculateState( createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); nextHt = createMath.select(cond, /*should use initialH*/ - createMath.constant(nextHt.getType(), 0.), - nextHt); + createMath.constant(nextHt.getType(), 0.), nextHt); } createKrnl.store(nextHt, Ht, indices); @@ -621,8 +620,7 @@ void calculateState( createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); nextHt = createMath.select(cond, /*should use initialH*/ - createMath.constant(nextHt.getType(), 0.), - nextHt); + createMath.constant(nextHt.getType(), 0.), nextHt); } createKrnl.store(nextHt, Ht, indices); diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index 959381815d..6e936fc9b0 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -443,7 +443,7 @@ void calculateState(ConversionPatternRewriter &rewriter, Location loc, Value Xt, LstmState state, LstmActivationPack activationPack, LstmWeightPack weightPack, LstmBiasPack biasPack, Value sequenceIV, - Value directionIV, Value sequenceLens, bool isForward) { + Value directionIV, Value sequenceLens, bool isForward) { // Equations for LSTM. // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) From a523dbba0757f33ca7bd424bd9e02b6bb661d979 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Tue, 5 Sep 2023 14:35:36 -0400 Subject: [PATCH 03/14] format Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/RNN.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp index 1ae4b0bd7b..539cf5821b 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp @@ -300,7 +300,8 @@ template <> void calculateState( ConversionPatternRewriter &rewriter, Location loc, Value Xt, RnnState state, RnnActivationPack activationPack, RnnWeightPack weightPack, - RnnBiasPack biasPack, Value sequenceIV, Value directionIV, Value sequenceLens, bool isForward) { + RnnBiasPack biasPack, Value sequenceIV, Value directionIV, + Value sequenceLens, bool isForward) { // Equations for RNN. // Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) // Shape information: From 9bb9b0a9fc6030282356a6ddd384be3435f70916 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Tue, 5 Sep 2023 16:23:08 -0400 Subject: [PATCH 04/14] use initial Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/GRU.cpp | 22 +++++++++++++++------- src/Conversion/ONNXToKrnl/RNN/LSTM.cpp | 4 +++- src/Conversion/ONNXToKrnl/RNN/RNN.cpp | 3 ++- src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp | 7 ++++--- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index 07cb573cf8..aa380c73fa 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -390,7 +390,7 @@ void calculateState( ConversionPatternRewriter &rewriter, Location loc, Value Xt, GruState state, GruActivationPack activationPack, GruWeightPack weightPack, GruBiasPack biasPack, Value sequenceIV, Value directionIV, - Value sequenceLens, bool isForward) { + Value sequenceLens, Value initialH, bool isForward) { // Equations (Default: f=Sigmoid, g=Tanh):" // zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)" // rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)" @@ -502,11 +502,15 @@ void calculateState( // Handle sequence_lens if (!isNoneValue(sequenceLens)) { Value sequenceUB = createKrnl.load(sequenceLens, {bs}); + Value initial; + if (isNoneValue(initialH)) { + initial = createMath.constant(nextHt.getType(), 0.); + } else { + initial = createKrnl.load(initialH, {directionIV, bs, hs}); + } Value cond = createMath.sge( createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); - nextHt = createMath.select(cond, - /*should use initialH*/ - createMath.constant(nextHt.getType(), 0.), nextHt); + nextHt = createMath.select(cond, /*padding*/initial, nextHt); } createKrnl.store(nextHt, Ht, indices); @@ -616,11 +620,15 @@ void calculateState( // Handle sequence_lens if (!isNoneValue(sequenceLens)) { Value sequenceUB = createKrnl.load(sequenceLens, {bs}); + Value initial; + if (isNoneValue(initialH)) { + initial = createMath.constant(nextHt.getType(), 0.); + } else { + initial = createKrnl.load(initialH, {directionIV, bs, hs}); + } Value cond = createMath.sge( createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); - nextHt = createMath.select(cond, - /*should use initialH*/ - createMath.constant(nextHt.getType(), 0.), nextHt); + nextHt = createMath.select(cond, /*padding*/initial, nextHt); } createKrnl.store(nextHt, Ht, indices); diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index 6e936fc9b0..3b5c7687dc 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -443,7 +443,7 @@ void calculateState(ConversionPatternRewriter &rewriter, Location loc, Value Xt, LstmState state, LstmActivationPack activationPack, LstmWeightPack weightPack, LstmBiasPack biasPack, Value sequenceIV, - Value directionIV, Value sequenceLens, bool isForward) { + Value directionIV, Value sequenceLens, Value initialH, bool isForward) { // Equations for LSTM. // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) @@ -452,6 +452,8 @@ void calculateState create(rewriter, loc); diff --git a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp index 539cf5821b..1e7b7b904c 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp @@ -301,7 +301,7 @@ void calculateState( ConversionPatternRewriter &rewriter, Location loc, Value Xt, RnnState state, RnnActivationPack activationPack, RnnWeightPack weightPack, RnnBiasPack biasPack, Value sequenceIV, Value directionIV, - Value sequenceLens, bool isForward) { + Value sequenceLens, Value initialH, bool isForward) { // Equations for RNN. // Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) // Shape information: @@ -312,6 +312,7 @@ void calculateState( // Wbi: [hidden_size] // Rbi: [hidden_size] + assert(isNoneValue(sequenceLens) && "not implemented yet"); MultiDialectBuilder create(rewriter, loc); diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp index 418da5e87a..e505179c22 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp @@ -117,7 +117,7 @@ template void calculateState(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, mlir::Value Xt, S state, A activationSet, W weight, B bias, mlir::Value sequenceIV, mlir::Value directionIV, - mlir::Value sequenceLens, bool isForward); + mlir::Value sequenceLens, mlir::Value initialH, bool isForward); // Write states to the RNN's outputs. template @@ -138,6 +138,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { mlir::Location loc = ONNXLoc(op); mlir::Value X = adaptor.getX(); mlir::Value sequenceLens = adaptor.getSequenceLens(); + mlir::Value initialH = adaptor.getInitialH(); if (hasAllNoneOutput(&rnnOp)) { rewriter.eraseOp(op); @@ -190,7 +191,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { // Emit calculation for one RNN step. calculateState(rewriter, loc, Xt, state, activationForward, weightForward, biasForward, sequenceIV, - directionIV, sequenceLens, + directionIV, sequenceLens, initialH, /*isForward=*/true); }); } @@ -228,7 +229,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { // Emit calculation for one RNN step. calculateState(rewriter, loc, Xt, state, activationReverse, weightReverse, biasReverse, - reverseSequenceIV, directionIV, sequenceLens, + reverseSequenceIV, directionIV, sequenceLens, initialH, /*isForward=*/false); }); } From e52ad92e4cb47ef9271bba1e408058a0be061920 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Tue, 5 Sep 2023 16:28:41 -0400 Subject: [PATCH 05/14] format Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/GRU.cpp | 4 ++-- src/Conversion/ONNXToKrnl/RNN/LSTM.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index aa380c73fa..c058192a72 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -510,7 +510,7 @@ void calculateState( } Value cond = createMath.sge( createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); - nextHt = createMath.select(cond, /*padding*/initial, nextHt); + nextHt = createMath.select(cond, /*padding*/ initial, nextHt); } createKrnl.store(nextHt, Ht, indices); @@ -628,7 +628,7 @@ void calculateState( } Value cond = createMath.sge( createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); - nextHt = createMath.select(cond, /*padding*/initial, nextHt); + nextHt = createMath.select(cond, /*padding*/ initial, nextHt); } createKrnl.store(nextHt, Ht, indices); diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index 3b5c7687dc..563191c27b 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -443,7 +443,7 @@ void calculateState(ConversionPatternRewriter &rewriter, Location loc, Value Xt, LstmState state, LstmActivationPack activationPack, LstmWeightPack weightPack, LstmBiasPack biasPack, Value sequenceIV, - Value directionIV, Value sequenceLens, Value initialH, bool isForward) { + Value directionIV, Value sequenceLens, Value initialH, bool isForward) { // Equations for LSTM. // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) From 2affba259997347d171125ac3c7c7fbf744d57c2 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Tue, 5 Sep 2023 17:22:08 -0400 Subject: [PATCH 06/14] for test Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/GRU.cpp | 2 +- src/Conversion/ONNXToKrnl/RNN/LSTM.cpp | 3 ++- src/Conversion/ONNXToKrnl/RNN/RNN.cpp | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index c058192a72..140a9b0b34 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -390,7 +390,7 @@ void calculateState( ConversionPatternRewriter &rewriter, Location loc, Value Xt, GruState state, GruActivationPack activationPack, GruWeightPack weightPack, GruBiasPack biasPack, Value sequenceIV, Value directionIV, - Value sequenceLens, Value initialH, bool isForward) { + Value sequenceLens, Value initialH, bool isForward) { // Equations (Default: f=Sigmoid, g=Tanh):" // zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)" // rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)" diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index 563191c27b..d4f193c85d 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -452,7 +452,8 @@ void calculateState diff --git a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp index 1e7b7b904c..7fbcbc5618 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp @@ -312,7 +312,9 @@ void calculateState( // Wbi: [hidden_size] // Rbi: [hidden_size] - assert(isNoneValue(sequenceLens) && "not implemented yet"); + // ToFix: add support of sequenceLens for RNN + //assert(isNoneValue(sequenceLens) && "not implemented yet"); + MultiDialectBuilder create(rewriter, loc); From 9c94df1ef87b730d81283ba833e8b37aedb94408 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Tue, 5 Sep 2023 19:52:20 -0400 Subject: [PATCH 07/14] comment Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/LSTM.cpp | 6 +++++- src/Conversion/ONNXToKrnl/RNN/RNN.cpp | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index d4f193c85d..a09b104d73 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -453,7 +453,11 @@ void calculateState diff --git a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp index 7fbcbc5618..bbe3940431 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp @@ -313,7 +313,7 @@ void calculateState( // Rbi: [hidden_size] // ToFix: add support of sequenceLens for RNN - //assert(isNoneValue(sequenceLens) && "not implemented yet"); + assert(isNoneValue(sequenceLens) && "not implemented yet"); MultiDialectBuilder create(rewriter, loc); From 0fa9322ebf25511d090bcea85ece41b39e81c733 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Tue, 5 Sep 2023 19:58:04 -0400 Subject: [PATCH 08/14] format Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/LSTM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index a09b104d73..72d7d685b6 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -456,7 +456,7 @@ void calculateState Date: Thu, 7 Sep 2023 11:14:27 -0400 Subject: [PATCH 09/14] fix output --- src/Conversion/ONNXToKrnl/RNN/GRU.cpp | 7 ++-- src/Conversion/ONNXToKrnl/RNN/LSTM.cpp | 12 +++--- src/Conversion/ONNXToKrnl/RNN/RNN.cpp | 7 ++-- src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp | 45 +++++++++++++++++++++-- src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp | 9 +++-- 5 files changed, 61 insertions(+), 19 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index 140a9b0b34..bb30a62611 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -641,7 +641,8 @@ void calculateState( template <> void stateToOutput(ConversionPatternRewriter &rewriter, - Location loc, ONNXGRUOp *op, GruState state, std::vector &outputs) { + Location loc, ONNXGRUOp *op, GruState state, std::vector &outputs, + Value sequenceLens) { auto direction = op->getDirection(); Value noneValue; // First output: all sequences. @@ -650,8 +651,8 @@ void stateToOutput(ConversionPatternRewriter &rewriter, if (isNoneValue(op->getYH())) outputs.emplace_back(noneValue); else { - stateToOutputForHiddenOrCell( - rewriter, loc, state.forwardHt, state.reverseHt, direction, state.ht); + stateToOutputForHiddenOrCell(rewriter, loc, state.forwardHt, + state.reverseHt, direction, state.ht, state.allH, sequenceLens); outputs.emplace_back(state.ht); } } diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index 72d7d685b6..db9a40aea6 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -596,8 +596,8 @@ void calculateState void stateToOutput(ConversionPatternRewriter &rewriter, - Location loc, ONNXLSTMOp *op, LstmState state, - std::vector &outputs) { + Location loc, ONNXLSTMOp *op, LstmState state, std::vector &outputs, + Value sequenceLens) { Value noneValue; auto direction = op->getDirection(); @@ -607,16 +607,16 @@ void stateToOutput(ConversionPatternRewriter &rewriter, if (isNoneValue(op->getYH())) outputs.emplace_back(noneValue); else { - stateToOutputForHiddenOrCell( - rewriter, loc, state.forwardHt, state.reverseHt, direction, state.ht); + stateToOutputForHiddenOrCell(rewriter, loc, state.forwardHt, + state.reverseHt, direction, state.ht, state.allH, sequenceLens); outputs.emplace_back(state.ht); } // Third output: cell. if (isNoneValue(op->getYC())) outputs.emplace_back(noneValue); else { - stateToOutputForHiddenOrCell( - rewriter, loc, state.forwardCt, state.reverseCt, direction, state.ct); + stateToOutputForHiddenOrCell(rewriter, loc, state.forwardCt, + state.reverseCt, direction, state.ct, state.allH, sequenceLens); outputs.emplace_back(state.ct); } } diff --git a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp index bbe3940431..bca07ec96e 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp @@ -365,7 +365,8 @@ void calculateState( template <> void stateToOutput(ConversionPatternRewriter &rewriter, - Location loc, ONNXRNNOp *op, RnnState state, std::vector &outputs) { + Location loc, ONNXRNNOp *op, RnnState state, std::vector &outputs, + Value sequenceLens) { Value noneValue; auto direction = op->getDirection(); @@ -375,8 +376,8 @@ void stateToOutput(ConversionPatternRewriter &rewriter, if (isNoneValue(op->getYH())) outputs.emplace_back(noneValue); else { - stateToOutputForHiddenOrCell( - rewriter, loc, state.forwardHt, state.reverseHt, direction, state.ht); + stateToOutputForHiddenOrCell(rewriter, loc, state.forwardHt, + state.reverseHt, direction, state.ht, state.allH, sequenceLens); outputs.emplace_back(state.ht); } } diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp index 3c1a39da4a..7edfc4d3d0 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp" using namespace mlir; @@ -218,15 +218,52 @@ void initializeHiddenAndCell(ConversionPatternRewriter &rewriter, Location loc, /// pretended, depending on 'direction'. void stateToOutputForHiddenOrCell(ConversionPatternRewriter &rewriter, Location loc, Value forwardVal, Value reverseVal, StringRef direction, - Value output) { + Value output, Value allH, Value sequenceLens) { + // TODO remove MultiDialectBuilder create( rewriter, loc); if (direction == FORWARD || direction == REVERSE) { Value val = (direction == FORWARD) ? forwardVal : reverseVal; - Value numOfElements = getDynamicMemRefSize(rewriter, loc, val); - create.krnl.memcpy(output, val, numOfElements); + if (isNoneValue(sequenceLens)) { + Value numOfElements = getDynamicMemRefSize(rewriter, loc, val); + create.krnl.memcpy(output, val, numOfElements); + } else { + // How to construct the last Ht according to the sequenceLens + // This behavior is observed from torch GRU, but not explicitly defined + // by onnx document. It doesn't make sense to return the padded value + // at the end of Ht if sequenceLens is defined. + + Value directionIV = create.math.constantIndex(0); + Value one = create.math.constantIndex(1); + + MemRefType matrixType = val.getType().cast(); + unsigned htRank = matrixType.getRank(); + Value iZero = create.math.constantIndex(0); + SmallVector htLbs(htRank, iZero); + SmallVector htUbs; + for (unsigned r = 0; r < htRank; ++r) { + htUbs.emplace_back(create.mem.dim(val, r)); + } + ValueRange loops1 = create.krnl.defineLoops(htRank); + create.krnl.iterate(loops1, loops1, htLbs, htUbs, + [&](KrnlBuilder &createKrnl, ValueRange indices) { + MathBuilder createMath(createKrnl); + IndexExprScope ieScope(createKrnl); + Value bs(indices[0]), hs(indices[1]); + + // The element at sequenceLens[batchIV]-1 in allH is used + Value sequenceIV = create.krnl.load(sequenceLens, bs); + sequenceIV = createMath.castToIndex(sequenceIV); + sequenceIV = createMath.sub(sequenceIV, one); + Value endH = + create.krnl.load(allH, {sequenceIV, directionIV, bs, hs}); + create.krnl.store(endH, output, {directionIV, bs, hs}); + }); + } } else { // BIDIRECTIONAL + // ToFix: sequenceLens is not supported for bidirection yet + assert(isNoneValue(sequenceLens) && "not implemented yet"); unsigned rank = forwardVal.getType().cast().getRank(); Value zero = create.math.constantIndex(0); Value one = create.math.constantIndex(1); diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp index e505179c22..71af7a9869 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp @@ -65,7 +65,8 @@ void initializeIntermediateStates(mlir::ConversionPatternRewriter &rewriter, /// pretended, depending on 'direction'. void stateToOutputForHiddenOrCell(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, mlir::Value forwardVal, mlir::Value reverseVal, - llvm::StringRef direction, mlir::Value output); + llvm::StringRef direction, mlir::Value output, mlir::Value allH, + mlir::Value seqenceLens); /// Apply an activation function on a given operand. mlir::Value applyActivation(mlir::OpBuilder &rewriter, mlir::Location loc, @@ -122,7 +123,8 @@ void calculateState(mlir::ConversionPatternRewriter &rewriter, // Write states to the RNN's outputs. template void stateToOutput(mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, RNNOp *op, S state, std::vector &outputs); + mlir::Location loc, RNNOp *op, S state, std::vector &outputs, + mlir::Value sequenceLens); // A common template for lowering an RNN operation. template @@ -235,7 +237,8 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { } std::vector outputs; - stateToOutput(rewriter, loc, &rnnOp, state, outputs); + stateToOutput( + rewriter, loc, &rnnOp, state, outputs, sequenceLens); rewriter.replaceOp(op, outputs); return mlir::success(); } From f4813fd81d6a42645036ca54c23be94b9b770583 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Thu, 7 Sep 2023 13:25:00 -0400 Subject: [PATCH 10/14] Revert "fix output" This reverts commit bcc617c0871f103956add2b90d0225cbca7509f3. --- src/Conversion/ONNXToKrnl/RNN/GRU.cpp | 7 ++-- src/Conversion/ONNXToKrnl/RNN/LSTM.cpp | 12 +++--- src/Conversion/ONNXToKrnl/RNN/RNN.cpp | 7 ++-- src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp | 45 ++--------------------- src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp | 9 ++--- 5 files changed, 19 insertions(+), 61 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index bb30a62611..140a9b0b34 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -641,8 +641,7 @@ void calculateState( template <> void stateToOutput(ConversionPatternRewriter &rewriter, - Location loc, ONNXGRUOp *op, GruState state, std::vector &outputs, - Value sequenceLens) { + Location loc, ONNXGRUOp *op, GruState state, std::vector &outputs) { auto direction = op->getDirection(); Value noneValue; // First output: all sequences. @@ -651,8 +650,8 @@ void stateToOutput(ConversionPatternRewriter &rewriter, if (isNoneValue(op->getYH())) outputs.emplace_back(noneValue); else { - stateToOutputForHiddenOrCell(rewriter, loc, state.forwardHt, - state.reverseHt, direction, state.ht, state.allH, sequenceLens); + stateToOutputForHiddenOrCell( + rewriter, loc, state.forwardHt, state.reverseHt, direction, state.ht); outputs.emplace_back(state.ht); } } diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index db9a40aea6..72d7d685b6 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -596,8 +596,8 @@ void calculateState void stateToOutput(ConversionPatternRewriter &rewriter, - Location loc, ONNXLSTMOp *op, LstmState state, std::vector &outputs, - Value sequenceLens) { + Location loc, ONNXLSTMOp *op, LstmState state, + std::vector &outputs) { Value noneValue; auto direction = op->getDirection(); @@ -607,16 +607,16 @@ void stateToOutput(ConversionPatternRewriter &rewriter, if (isNoneValue(op->getYH())) outputs.emplace_back(noneValue); else { - stateToOutputForHiddenOrCell(rewriter, loc, state.forwardHt, - state.reverseHt, direction, state.ht, state.allH, sequenceLens); + stateToOutputForHiddenOrCell( + rewriter, loc, state.forwardHt, state.reverseHt, direction, state.ht); outputs.emplace_back(state.ht); } // Third output: cell. if (isNoneValue(op->getYC())) outputs.emplace_back(noneValue); else { - stateToOutputForHiddenOrCell(rewriter, loc, state.forwardCt, - state.reverseCt, direction, state.ct, state.allH, sequenceLens); + stateToOutputForHiddenOrCell( + rewriter, loc, state.forwardCt, state.reverseCt, direction, state.ct); outputs.emplace_back(state.ct); } } diff --git a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp index bca07ec96e..bbe3940431 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp @@ -365,8 +365,7 @@ void calculateState( template <> void stateToOutput(ConversionPatternRewriter &rewriter, - Location loc, ONNXRNNOp *op, RnnState state, std::vector &outputs, - Value sequenceLens) { + Location loc, ONNXRNNOp *op, RnnState state, std::vector &outputs) { Value noneValue; auto direction = op->getDirection(); @@ -376,8 +375,8 @@ void stateToOutput(ConversionPatternRewriter &rewriter, if (isNoneValue(op->getYH())) outputs.emplace_back(noneValue); else { - stateToOutputForHiddenOrCell(rewriter, loc, state.forwardHt, - state.reverseHt, direction, state.ht, state.allH, sequenceLens); + stateToOutputForHiddenOrCell( + rewriter, loc, state.forwardHt, state.reverseHt, direction, state.ht); outputs.emplace_back(state.ht); } } diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp index 7edfc4d3d0..3c1a39da4a 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp" +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" using namespace mlir; @@ -218,52 +218,15 @@ void initializeHiddenAndCell(ConversionPatternRewriter &rewriter, Location loc, /// pretended, depending on 'direction'. void stateToOutputForHiddenOrCell(ConversionPatternRewriter &rewriter, Location loc, Value forwardVal, Value reverseVal, StringRef direction, - Value output, Value allH, Value sequenceLens) { - + Value output) { // TODO remove MultiDialectBuilder create( rewriter, loc); if (direction == FORWARD || direction == REVERSE) { Value val = (direction == FORWARD) ? forwardVal : reverseVal; - if (isNoneValue(sequenceLens)) { - Value numOfElements = getDynamicMemRefSize(rewriter, loc, val); - create.krnl.memcpy(output, val, numOfElements); - } else { - // How to construct the last Ht according to the sequenceLens - // This behavior is observed from torch GRU, but not explicitly defined - // by onnx document. It doesn't make sense to return the padded value - // at the end of Ht if sequenceLens is defined. - - Value directionIV = create.math.constantIndex(0); - Value one = create.math.constantIndex(1); - - MemRefType matrixType = val.getType().cast(); - unsigned htRank = matrixType.getRank(); - Value iZero = create.math.constantIndex(0); - SmallVector htLbs(htRank, iZero); - SmallVector htUbs; - for (unsigned r = 0; r < htRank; ++r) { - htUbs.emplace_back(create.mem.dim(val, r)); - } - ValueRange loops1 = create.krnl.defineLoops(htRank); - create.krnl.iterate(loops1, loops1, htLbs, htUbs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { - MathBuilder createMath(createKrnl); - IndexExprScope ieScope(createKrnl); - Value bs(indices[0]), hs(indices[1]); - - // The element at sequenceLens[batchIV]-1 in allH is used - Value sequenceIV = create.krnl.load(sequenceLens, bs); - sequenceIV = createMath.castToIndex(sequenceIV); - sequenceIV = createMath.sub(sequenceIV, one); - Value endH = - create.krnl.load(allH, {sequenceIV, directionIV, bs, hs}); - create.krnl.store(endH, output, {directionIV, bs, hs}); - }); - } + Value numOfElements = getDynamicMemRefSize(rewriter, loc, val); + create.krnl.memcpy(output, val, numOfElements); } else { // BIDIRECTIONAL - // ToFix: sequenceLens is not supported for bidirection yet - assert(isNoneValue(sequenceLens) && "not implemented yet"); unsigned rank = forwardVal.getType().cast().getRank(); Value zero = create.math.constantIndex(0); Value one = create.math.constantIndex(1); diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp index 71af7a9869..e505179c22 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp @@ -65,8 +65,7 @@ void initializeIntermediateStates(mlir::ConversionPatternRewriter &rewriter, /// pretended, depending on 'direction'. void stateToOutputForHiddenOrCell(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, mlir::Value forwardVal, mlir::Value reverseVal, - llvm::StringRef direction, mlir::Value output, mlir::Value allH, - mlir::Value seqenceLens); + llvm::StringRef direction, mlir::Value output); /// Apply an activation function on a given operand. mlir::Value applyActivation(mlir::OpBuilder &rewriter, mlir::Location loc, @@ -123,8 +122,7 @@ void calculateState(mlir::ConversionPatternRewriter &rewriter, // Write states to the RNN's outputs. template void stateToOutput(mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, RNNOp *op, S state, std::vector &outputs, - mlir::Value sequenceLens); + mlir::Location loc, RNNOp *op, S state, std::vector &outputs); // A common template for lowering an RNN operation. template @@ -237,8 +235,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { } std::vector outputs; - stateToOutput( - rewriter, loc, &rnnOp, state, outputs, sequenceLens); + stateToOutput(rewriter, loc, &rnnOp, state, outputs); rewriter.replaceOp(op, outputs); return mlir::success(); } From 9c91b4b3db4216ec4494457a5e2a6605122628a5 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Thu, 7 Sep 2023 14:09:53 -0400 Subject: [PATCH 11/14] new implementation of output Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/GRU.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index 140a9b0b34..e1a673d142 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -511,9 +511,14 @@ void calculateState( Value cond = createMath.sge( createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); nextHt = createMath.select(cond, /*padding*/ initial, nextHt); + + // Last HT should be the last in sequenceLens or the current result + Value lastHt = createMath.select(cond, createKrnl.load(Ht, indices), nextHt); + createKrnl.store(lastHt, Ht, indices); + } else { + createKrnl.store(nextHt, Ht, indices); } - createKrnl.store(nextHt, Ht, indices); if (!isNoneValue(state.allH)) createKrnl.store( nextHt, state.allH, {sequenceIV, directionIV, bs, hs}); @@ -629,9 +634,14 @@ void calculateState( Value cond = createMath.sge( createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); nextHt = createMath.select(cond, /*padding*/ initial, nextHt); + + // Last HT should be the last in sequenceLens or the current result + Value lastHt = createMath.select(cond, createKrnl.load(Ht, indices), nextHt); + createKrnl.store(lastHt, Ht, indices); + } else { + createKrnl.store(nextHt, Ht, indices); } - createKrnl.store(nextHt, Ht, indices); if (!isNoneValue(state.allH)) createKrnl.store( nextHt, state.allH, {sequenceIV, directionIV, bs, hs}); From 31a0885595a473565fc5090cad0d39991a57107d Mon Sep 17 00:00:00 2001 From: chentong319 Date: Thu, 7 Sep 2023 14:12:13 -0400 Subject: [PATCH 12/14] format Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/GRU.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index e1a673d142..49b6250012 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -513,7 +513,8 @@ void calculateState( nextHt = createMath.select(cond, /*padding*/ initial, nextHt); // Last HT should be the last in sequenceLens or the current result - Value lastHt = createMath.select(cond, createKrnl.load(Ht, indices), nextHt); + Value lastHt = + createMath.select(cond, createKrnl.load(Ht, indices), nextHt); createKrnl.store(lastHt, Ht, indices); } else { createKrnl.store(nextHt, Ht, indices); @@ -636,7 +637,8 @@ void calculateState( nextHt = createMath.select(cond, /*padding*/ initial, nextHt); // Last HT should be the last in sequenceLens or the current result - Value lastHt = createMath.select(cond, createKrnl.load(Ht, indices), nextHt); + Value lastHt = + createMath.select(cond, createKrnl.load(Ht, indices), nextHt); createKrnl.store(lastHt, Ht, indices); } else { createKrnl.store(nextHt, Ht, indices); From 5565ec71e9d4fe46e4215c68861be1e910f5e330 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Thu, 7 Sep 2023 17:40:56 -0400 Subject: [PATCH 13/14] function Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/GRU.cpp | 42 +++-------------------- src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp | 31 +++++++++++++++++ src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp | 9 +++++ 3 files changed, 44 insertions(+), 38 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index 49b6250012..2818bbcc5f 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -500,25 +500,8 @@ void calculateState( // Store the intermediate Ht. // Handle sequence_lens - if (!isNoneValue(sequenceLens)) { - Value sequenceUB = createKrnl.load(sequenceLens, {bs}); - Value initial; - if (isNoneValue(initialH)) { - initial = createMath.constant(nextHt.getType(), 0.); - } else { - initial = createKrnl.load(initialH, {directionIV, bs, hs}); - } - Value cond = createMath.sge( - createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); - nextHt = createMath.select(cond, /*padding*/ initial, nextHt); - - // Last HT should be the last in sequenceLens or the current result - Value lastHt = - createMath.select(cond, createKrnl.load(Ht, indices), nextHt); - createKrnl.store(lastHt, Ht, indices); - } else { - createKrnl.store(nextHt, Ht, indices); - } + nextHt = handleSequenceLens(createKrnl, createMath, sequenceLens, + initialH, nextHt, sequenceIV, directionIV, bs, hs, Ht); if (!isNoneValue(state.allH)) createKrnl.store( @@ -624,25 +607,8 @@ void calculateState( // Store the intermediate Ht. // Handle sequence_lens - if (!isNoneValue(sequenceLens)) { - Value sequenceUB = createKrnl.load(sequenceLens, {bs}); - Value initial; - if (isNoneValue(initialH)) { - initial = createMath.constant(nextHt.getType(), 0.); - } else { - initial = createKrnl.load(initialH, {directionIV, bs, hs}); - } - Value cond = createMath.sge( - createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); - nextHt = createMath.select(cond, /*padding*/ initial, nextHt); - - // Last HT should be the last in sequenceLens or the current result - Value lastHt = - createMath.select(cond, createKrnl.load(Ht, indices), nextHt); - createKrnl.store(lastHt, Ht, indices); - } else { - createKrnl.store(nextHt, Ht, indices); - } + nextHt = handleSequenceLens(createKrnl, createMath, sequenceLens, + initialH, nextHt, sequenceIV, directionIV, bs, hs, Ht); if (!isNoneValue(state.allH)) createKrnl.store( diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp index 3c1a39da4a..2dc5e125f6 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp @@ -336,4 +336,35 @@ Value emitXSliceAt(ConversionPatternRewriter &rewriter, Location loc, Value X, return sliceX; } +// Change the nextHt and Ht value if sequenceLens is defined. +// When a sample reachs the limit of its sequence len, nextHt will be padded +// with 0 (or initialH), and Ht will keep the last value at the sequence end +// so that the final value Ht is the last value at their sequence len. +Value handleSequenceLens(KrnlBuilder &createKrnl, MathBuilder &createMath, + Value sequenceLens, Value initialH, Value nextHt, Value sequenceIV, + Value directionIV, Value bs, Value hs, Value Ht) { + // Handle sequence_lens + IndexExprScope ieScope(createKrnl); + if (!isNoneValue(sequenceLens)) { + Value sequenceUB = createKrnl.load(sequenceLens, {bs}); + Value initial; + if (isNoneValue(initialH)) { + initial = createMath.constant(nextHt.getType(), 0.); + } else { + initial = createKrnl.load(initialH, {directionIV, bs, hs}); + } + Value cond = createMath.sge( + createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); + nextHt = createMath.select(cond, /*padding*/ initial, nextHt); + + // Last HT should be the last in sequenceLens or the current result + Value lastHt = + createMath.select(cond, createKrnl.load(Ht, {bs, hs}), nextHt); + createKrnl.store(lastHt, Ht, {bs, hs}); + } else { + createKrnl.store(nextHt, Ht, {bs, hs}); + } + return nextHt; +} + } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp index e505179c22..ec0b28c8ec 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp @@ -75,6 +75,15 @@ mlir::Value applyActivation(mlir::OpBuilder &rewriter, mlir::Location loc, mlir::Value emitXSliceAt(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, mlir::Value X, mlir::Value timestep); +// Change the nextHt and Ht value if sequenceLens is defined. +// When a sample reachs the limit of its sequence len, nextHt will be padded +// with 0 (or initialH), and Ht will keep the last value at the sequence end +// so that the final value Ht is the last value at their sequence len. +mlir::Value handleSequenceLens(KrnlBuilder &createKrnl, MathBuilder &createMath, + mlir::Value sequenceLens, mlir::Value initialH, mlir::Value nextHt, + mlir::Value sequenceIV, mlir::Value directionIV, mlir::Value bs, + mlir::Value hs, mlir::Value Ht); + // Override the following methods when lowering an RNN operation: // - hasAllNoneOutput // - getActivationPack From 14a46a6ea9eb51dd9ed949bb9c7063f4b4f0a4d6 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Thu, 7 Sep 2023 18:06:24 -0400 Subject: [PATCH 14/14] clean Signed-off-by: chentong319 --- src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp index 2dc5e125f6..b12747e174 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp @@ -343,8 +343,6 @@ Value emitXSliceAt(ConversionPatternRewriter &rewriter, Location loc, Value X, Value handleSequenceLens(KrnlBuilder &createKrnl, MathBuilder &createMath, Value sequenceLens, Value initialH, Value nextHt, Value sequenceIV, Value directionIV, Value bs, Value hs, Value Ht) { - // Handle sequence_lens - IndexExprScope ieScope(createKrnl); if (!isNoneValue(sequenceLens)) { Value sequenceUB = createKrnl.load(sequenceLens, {bs}); Value initial;