Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle sequence_lens for GRU on CPU #2479

Merged
merged 18 commits into from
Sep 8, 2023
Merged
31 changes: 30 additions & 1 deletion src/Conversion/ONNXToKrnl/RNN/GRU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ template <>
void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
ConversionPatternRewriter &rewriter, Location loc, Value Xt, GruState state,
GruActivationPack activationPack, GruWeightPack weightPack,
GruBiasPack biasPack, Value sequenceIV, Value directionIV, bool isForward) {
GruBiasPack biasPack, Value sequenceIV, Value directionIV,
Value sequenceLens, Value initialH, bool isForward) {
// Equations (Default: f=Sigmoid, g=Tanh):"
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)"
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)"
Expand Down Expand Up @@ -498,6 +499,20 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
Value nextHt = createMath.add(ztht, ztHt);

// Store the intermediate Ht.
// Handle sequence_lens
if (!isNoneValue(sequenceLens)) {
Value sequenceUB = createKrnl.load(sequenceLens, {bs});
Value initial;
if (isNoneValue(initialH)) {
initial = createMath.constant(nextHt.getType(), 0.);
} else {
initial = createKrnl.load(initialH, {directionIV, bs, hs});
}
Value cond = createMath.sge(
createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB);
nextHt = createMath.select(cond, /*padding*/ initial, nextHt);
}

createKrnl.store(nextHt, Ht, indices);
if (!isNoneValue(state.allH))
createKrnl.store(
Expand Down Expand Up @@ -602,6 +617,20 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
Value nextHt = createMath.add(ztht, ztHt);

// Store the intermediate Ht.
// Handle sequence_lens
if (!isNoneValue(sequenceLens)) {
Value sequenceUB = createKrnl.load(sequenceLens, {bs});
Value initial;
if (isNoneValue(initialH)) {
initial = createMath.constant(nextHt.getType(), 0.);
} else {
initial = createKrnl.load(initialH, {directionIV, bs, hs});
}
Value cond = createMath.sge(
createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB);
nextHt = createMath.select(cond, /*padding*/ initial, nextHt);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we create a common function for this to avoid boilerplate? and we can call it in other ops like LSTM and RNN.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.


createKrnl.store(nextHt, Ht, indices);
if (!isNoneValue(state.allH))
createKrnl.store(
Expand Down
9 changes: 8 additions & 1 deletion src/Conversion/ONNXToKrnl/RNN/LSTM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ void calculateState<LstmState, LstmActivationPack, LstmWeightPack,
LstmBiasPack>(ConversionPatternRewriter &rewriter, Location loc, Value Xt,
LstmState state, LstmActivationPack activationPack,
LstmWeightPack weightPack, LstmBiasPack biasPack, Value sequenceIV,
Value directionIV, bool isForward) {
Value directionIV, Value sequenceLens, Value initialH, bool isForward) {
// Equations for LSTM.
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
Expand All @@ -452,6 +452,13 @@ void calculateState<LstmState, LstmActivationPack, LstmWeightPack,
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
// Ht = ot (.) h(Ct)

// ToFix: add support of sequence lens for LSTM
// The assert will fail the test_lstm_with_peephole.
// In that test case, the length of the input is used as sequence_lens.
// Therefore, onnx-mlir can pass the test by ignoring the sequence_lens
// paramenter.
// assert(isNoneValue(sequenceLens) && "not implemented yet");

// TODO remove scope
MultiDialectBuilder<KrnlBuilder, MathBuilder, MemRefBuilder, OnnxBuilder>
create(rewriter, loc);
Expand Down
6 changes: 5 additions & 1 deletion src/Conversion/ONNXToKrnl/RNN/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ template <>
void calculateState<RnnState, RnnActivationPack, RnnWeightPack, RnnBiasPack>(
ConversionPatternRewriter &rewriter, Location loc, Value Xt, RnnState state,
RnnActivationPack activationPack, RnnWeightPack weightPack,
RnnBiasPack biasPack, Value sequenceIV, Value directionIV, bool isForward) {
RnnBiasPack biasPack, Value sequenceIV, Value directionIV,
Value sequenceLens, Value initialH, bool isForward) {
// Equations for RNN.
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// Shape information:
Expand All @@ -311,6 +312,9 @@ void calculateState<RnnState, RnnActivationPack, RnnWeightPack, RnnBiasPack>(
// Wbi: [hidden_size]
// Rbi: [hidden_size]

// ToFix: add support of sequenceLens for RNN
assert(isNoneValue(sequenceLens) && "not implemented yet");

MultiDialectBuilder<KrnlBuilder, MathBuilder, MemRefBuilder, OnnxBuilder>
create(rewriter, loc);

Expand Down
9 changes: 6 additions & 3 deletions src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ S allocAndInitializeStates(mlir::ConversionPatternRewriter &rewriter,
template <typename S, typename A, typename W, typename B>
void calculateState(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc, mlir::Value Xt, S state, A activationSet, W weight,
B bias, mlir::Value sequenceIV, mlir::Value directionIV, bool isForward);
B bias, mlir::Value sequenceIV, mlir::Value directionIV,
mlir::Value sequenceLens, mlir::Value initialH, bool isForward);

// Write states to the RNN's outputs.
template <typename RNNOp, typename S>
Expand All @@ -136,6 +137,8 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
mlir::Operation *op = rnnOp.getOperation();
mlir::Location loc = ONNXLoc<RNNOp>(op);
mlir::Value X = adaptor.getX();
mlir::Value sequenceLens = adaptor.getSequenceLens();
mlir::Value initialH = adaptor.getInitialH();

if (hasAllNoneOutput<RNNOp>(&rnnOp)) {
rewriter.eraseOp(op);
Expand Down Expand Up @@ -188,7 +191,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
// Emit calculation for one RNN step.
calculateState<S, A, W, B>(rewriter, loc, Xt, state,
activationForward, weightForward, biasForward, sequenceIV,
directionIV,
directionIV, sequenceLens, initialH,
/*isForward=*/true);
});
}
Expand Down Expand Up @@ -226,7 +229,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
// Emit calculation for one RNN step.
calculateState<S, A, W, B>(rewriter, loc, Xt, state,
activationReverse, weightReverse, biasReverse,
reverseSequenceIV, directionIV,
reverseSequenceIV, directionIV, sequenceLens, initialH,
/*isForward=*/false);
});
}
Expand Down
Loading