Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle sequence_lens for GRU on CPU #2479

Merged
merged 18 commits into from
Sep 8, 2023
Merged
13 changes: 10 additions & 3 deletions src/Conversion/ONNXToKrnl/RNN/GRU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ template <>
void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
ConversionPatternRewriter &rewriter, Location loc, Value Xt, GruState state,
GruActivationPack activationPack, GruWeightPack weightPack,
GruBiasPack biasPack, Value sequenceIV, Value directionIV, bool isForward) {
GruBiasPack biasPack, Value sequenceIV, Value directionIV,
Value sequenceLens, Value initialH, bool isForward) {
// Equations (Default: f=Sigmoid, g=Tanh):"
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)"
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)"
Expand Down Expand Up @@ -498,7 +499,10 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
Value nextHt = createMath.add(ztht, ztHt);

// Store the intermediate Ht.
createKrnl.store(nextHt, Ht, indices);
// Handle sequence_lens
nextHt = handleSequenceLens(createKrnl, createMath, sequenceLens,
initialH, nextHt, sequenceIV, directionIV, bs, hs, Ht);

if (!isNoneValue(state.allH))
createKrnl.store(
nextHt, state.allH, {sequenceIV, directionIV, bs, hs});
Expand Down Expand Up @@ -602,7 +606,10 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
Value nextHt = createMath.add(ztht, ztHt);

// Store the intermediate Ht.
createKrnl.store(nextHt, Ht, indices);
// Handle sequence_lens
nextHt = handleSequenceLens(createKrnl, createMath, sequenceLens,
initialH, nextHt, sequenceIV, directionIV, bs, hs, Ht);

if (!isNoneValue(state.allH))
createKrnl.store(
nextHt, state.allH, {sequenceIV, directionIV, bs, hs});
Expand Down
9 changes: 8 additions & 1 deletion src/Conversion/ONNXToKrnl/RNN/LSTM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ void calculateState<LstmState, LstmActivationPack, LstmWeightPack,
LstmBiasPack>(ConversionPatternRewriter &rewriter, Location loc, Value Xt,
LstmState state, LstmActivationPack activationPack,
LstmWeightPack weightPack, LstmBiasPack biasPack, Value sequenceIV,
Value directionIV, bool isForward) {
Value directionIV, Value sequenceLens, Value initialH, bool isForward) {
// Equations for LSTM.
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
Expand All @@ -452,6 +452,13 @@ void calculateState<LstmState, LstmActivationPack, LstmWeightPack,
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
// Ht = ot (.) h(Ct)

// ToFix: add support of sequence lens for LSTM
// The assert will fail the test_lstm_with_peephole.
// In that test case, the length of the input is used as sequence_lens.
// Therefore, onnx-mlir can pass the test by ignoring the sequence_lens
// paramenter.
// assert(isNoneValue(sequenceLens) && "not implemented yet");

// TODO remove scope
MultiDialectBuilder<KrnlBuilder, MathBuilder, MemRefBuilder, OnnxBuilder>
create(rewriter, loc);
Expand Down
6 changes: 5 additions & 1 deletion src/Conversion/ONNXToKrnl/RNN/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ template <>
void calculateState<RnnState, RnnActivationPack, RnnWeightPack, RnnBiasPack>(
ConversionPatternRewriter &rewriter, Location loc, Value Xt, RnnState state,
RnnActivationPack activationPack, RnnWeightPack weightPack,
RnnBiasPack biasPack, Value sequenceIV, Value directionIV, bool isForward) {
RnnBiasPack biasPack, Value sequenceIV, Value directionIV,
Value sequenceLens, Value initialH, bool isForward) {
// Equations for RNN.
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// Shape information:
Expand All @@ -311,6 +312,9 @@ void calculateState<RnnState, RnnActivationPack, RnnWeightPack, RnnBiasPack>(
// Wbi: [hidden_size]
// Rbi: [hidden_size]

// ToFix: add support of sequenceLens for RNN
assert(isNoneValue(sequenceLens) && "not implemented yet");

MultiDialectBuilder<KrnlBuilder, MathBuilder, MemRefBuilder, OnnxBuilder>
create(rewriter, loc);

Expand Down
29 changes: 29 additions & 0 deletions src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,33 @@ Value emitXSliceAt(ConversionPatternRewriter &rewriter, Location loc, Value X,
return sliceX;
}

// Change the nextHt and Ht value if sequenceLens is defined.
// When a sample reachs the limit of its sequence len, nextHt will be padded
// with 0 (or initialH), and Ht will keep the last value at the sequence end
// so that the final value Ht is the last value at their sequence len.
Value handleSequenceLens(KrnlBuilder &createKrnl, MathBuilder &createMath,
Value sequenceLens, Value initialH, Value nextHt, Value sequenceIV,
Value directionIV, Value bs, Value hs, Value Ht) {
if (!isNoneValue(sequenceLens)) {
Value sequenceUB = createKrnl.load(sequenceLens, {bs});
Value initial;
if (isNoneValue(initialH)) {
initial = createMath.constant(nextHt.getType(), 0.);
} else {
initial = createKrnl.load(initialH, {directionIV, bs, hs});
}
Value cond = createMath.sge(
createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB);
nextHt = createMath.select(cond, /*padding*/ initial, nextHt);

// Last HT should be the last in sequenceLens or the current result
Value lastHt =
createMath.select(cond, createKrnl.load(Ht, {bs, hs}), nextHt);
createKrnl.store(lastHt, Ht, {bs, hs});
} else {
createKrnl.store(nextHt, Ht, {bs, hs});
}
return nextHt;
}

} // namespace onnx_mlir
18 changes: 15 additions & 3 deletions src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ mlir::Value applyActivation(mlir::OpBuilder &rewriter, mlir::Location loc,
mlir::Value emitXSliceAt(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc, mlir::Value X, mlir::Value timestep);

// Change the nextHt and Ht value if sequenceLens is defined.
// When a sample reachs the limit of its sequence len, nextHt will be padded
// with 0 (or initialH), and Ht will keep the last value at the sequence end
// so that the final value Ht is the last value at their sequence len.
mlir::Value handleSequenceLens(KrnlBuilder &createKrnl, MathBuilder &createMath,
mlir::Value sequenceLens, mlir::Value initialH, mlir::Value nextHt,
mlir::Value sequenceIV, mlir::Value directionIV, mlir::Value bs,
mlir::Value hs, mlir::Value Ht);

// Override the following methods when lowering an RNN operation:
// - hasAllNoneOutput
// - getActivationPack
Expand Down Expand Up @@ -116,7 +125,8 @@ S allocAndInitializeStates(mlir::ConversionPatternRewriter &rewriter,
template <typename S, typename A, typename W, typename B>
void calculateState(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc, mlir::Value Xt, S state, A activationSet, W weight,
B bias, mlir::Value sequenceIV, mlir::Value directionIV, bool isForward);
B bias, mlir::Value sequenceIV, mlir::Value directionIV,
mlir::Value sequenceLens, mlir::Value initialH, bool isForward);

// Write states to the RNN's outputs.
template <typename RNNOp, typename S>
Expand All @@ -136,6 +146,8 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
mlir::Operation *op = rnnOp.getOperation();
mlir::Location loc = ONNXLoc<RNNOp>(op);
mlir::Value X = adaptor.getX();
mlir::Value sequenceLens = adaptor.getSequenceLens();
mlir::Value initialH = adaptor.getInitialH();

if (hasAllNoneOutput<RNNOp>(&rnnOp)) {
rewriter.eraseOp(op);
Expand Down Expand Up @@ -188,7 +200,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
// Emit calculation for one RNN step.
calculateState<S, A, W, B>(rewriter, loc, Xt, state,
activationForward, weightForward, biasForward, sequenceIV,
directionIV,
directionIV, sequenceLens, initialH,
/*isForward=*/true);
});
}
Expand Down Expand Up @@ -226,7 +238,7 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
// Emit calculation for one RNN step.
calculateState<S, A, W, B>(rewriter, loc, Xt, state,
activationReverse, weightReverse, biasReverse,
reverseSequenceIV, directionIV,
reverseSequenceIV, directionIV, sequenceLens, initialH,
/*isForward=*/false);
});
}
Expand Down
Loading