-
Notifications
You must be signed in to change notification settings - Fork 320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle sequence_lens for GRU on CPU #2479
Conversation
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Another test case for the initialH.
The result:
|
Value cond = createMath.sge( | ||
createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB); | ||
nextHt = createMath.select(cond, /*padding*/ initial, nextHt); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we create a common function for this to avoid boilerplate? and we can call it in other ops like LSTM and RNN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
This reverts commit bcc617c.
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Now both the first and second output of GRU are the same as the torch GRU example. |
Signed-off-by: chentong319 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Jenkins Linux s390x Build #12569 [push] Handle sequence_lens for... started at 20:04 |
Jenkins Linux ppc64le Build #11562 [push] Handle sequence_lens for... started at 20:13 |
Jenkins Linux amd64 Build #12557 [push] Handle sequence_lens for... started at 19:04 |
Jenkins Linux amd64 Build #12557 [push] Handle sequence_lens for... passed after 1 hr 5 min |
Jenkins Linux s390x Build #12569 [push] Handle sequence_lens for... passed after 1 hr 24 min |
Jenkins Linux ppc64le Build #11562 [push] Handle sequence_lens for... passed after 1 hr 44 min |
This PR is a quick fix for sequence_lens. According to the definition from PyTorch, padding value is added after a sequence reaches its sequence lens. This PR does not try to save the computation. I will try another PR to use scf.if so that all the RNN op can be handled and computation will be saved.
The output of my test case of GRU seems to conform with the PyTorch example.
The output is
Limitations: This PR does not save computation with the sequence_lens info. To do that, I can add a scf.if within the loop for sequence and batch. However, the existing implementation defines the loop nest for batch and hidden state together. Need some efforts to break the loop nest. It is doable. But priority?
Question: should the final result be modified according to the sequence_lens? For example, should the 2nd output be
[[[-0.001489] [-0.00399583]]]
? I did not find any specification for that.