-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problems Running Jax-Triton with an Nvidia 4090 #114
Comments
We don't yet support the latest Triton dev version but we soon will! |
I am correct in thinking that Nvidia 4090s aren't supported by any earlier versions of Triton, thus for the moment we are in a situation where I can't use JAX-ML until further updates? |
There's an open PR that rebases Jax triton on top of triton at head |
Link here: #50 |
After seeing that the PR has been accepted into the main branch, I have revisited this issue. The original error has gone, however I now receive the following error when I run the provided examples e.g. https://github.com/jax-ml/jax-triton/blob/main/examples/add.py
which obviously suggests that that function doesn't exist / not installed, but it is definitely there and that I can run the example Triton / Pytorch examples from their tutorials without error. |
You'll need to use triton installed from HEAD or nightly. |
I have, I am using the triton nightly. |
I total reinstall from scratch seems to have fixed the issue. |
Running the quick start example using an Nvidia 4090, if you use the suggested triton version (2.0.0.dev20221202), you receive the following error
Upgrading the latest development version of triton, it is possible to run Pytorch based examples, but JAX-ML results in the following error
The text was updated successfully, but these errors were encountered: