-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm] Implement RNN support #25755
base: main
Are you sure you want to change the base?
[ROCm] Implement RNN support #25755
Conversation
@dfm and @superbobry could you please take a look? |
0b07837
to
36d037e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dfm want to have a look as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good overall - thanks! My main high level comment is that it would be useful to move as much of the #ifdef JAX_GPU_HIP
logic into vendor.h
rather than in rnn_kernels.cc
directly. It's ok to have some, but the more we can move, the better. Can you look into redefining some of the macros in vendor.h
to consolidate the logic there?
jax/experimental/rnn.py
Outdated
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_fwd_lowering, platform='cuda') | ||
mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_fwd_lowering, platform='rocm') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since gpu_rnn
is in jaxlib, these changes will cause problems with version skew. JAX always needs to work with the most recent stable release of jaxlib. Perhaps you could protect this using hasattr(gpu_rnn, "miopen_rnn_fwd_lowering")
?
jax/experimental/rnn.py
Outdated
mlir.register_lowering( | ||
rnn_bwd_p, gpu_rnn.miopen_rnn_bwd_lowering, platform='rocm') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, this needs to be protected against old version of jaxlib.
36d037e
to
4f1f486
Compare
4f1f486
to
1b871e2
Compare
Created from: ROCm#171