-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add custom triton kernels #84
[WIP] Add custom triton kernels #84
Conversation
|
Oh nevermind, it is working on GPU tests but then failing on CPU tests because triton doesn't handle those. |
Yeah, I am getting some weird results here:
Basically this means that sometimes the result is coming back as NaN when it should be a finite value. This isn't happening every time i test so I guess some more digging is needed. |
I added a basic short-circuit to default back to pytorch on CPU. |
@lubbersnick Interesting, I am not sure how anything leads to NaN given that the kernels only do add and multiply. Is this error triggered by "test_env_triton"? |
Yes, this is where it happened, during the large system correctness checks. I wonder if some sparsity in the input could be responsible. Like if there is a test particle with no neighbors or something, or a pair of particles with no sensitivities. This would be infrequent but possible. I upped the number of large system tests to 1000 and it is happening again... |
It is sensesum that fails, by the way. |
I confirmed that it still fails with new pytorch (2.3.1), it wasn't due to being on an old version. however I will try your fixes and see. |
I am having trouble triggering the issue, I fixed the seed for repro. and did some some fixes in the code. Not sure if that resolves it as I cannot reproduce it. |
I am running the tests with your fixes now. |
I think it's working now. By the way you can directly see the improvement in the compiler over time: pytorch2.3.1:
vs pytorch 2.0.0
I think I can take things from here to integrate into the settings structure, run a long-running test of some training to make sure we get similar results and the NaNs don't appear at all, stuff like that. |
Sounds great! It is nice to see Triton works well for sparse operations as well not just dense ones. |
My tests were with CUDA 11.8, so that does still work fine |
I think there is still some more cleaning up to do, but I can train a model using this version, and now it defaults to triton kernels when they are available. |
Good news, we are getting 30% faster training of HIP-NN-TS to ANI-1x. (GPU: Titan V). Still wanting to test a longer run. |
That is great to hear! Also, what is the typical the number of pairs per batch, sensitivity and feature size for these runs? We can definitely increase the amount of parallelism revealed in the kernels if the typical work load is not large enough. |
I think the main problem the default PTX Triton produces uses functionality that requires a GPU driver compatible with CUDA 12. So, based on the CUDA compatibility matrix, the GPU driver needs to be >=525.60.13. |
ok, re: the versions, you're talking GPU driver version vs I was saying cuda library versions. So I think the story makes sense. |
re: typical workloads, every time I try to address this the kinds of datasets and simulations are not so easily classified. For now it would actually make more sense to go back and do whole-program profiling to see if this operation is even the bottleneck right now. |
I committed a new version of the code based on our discussion. Now, the feature and sensitivity vectors are split into chunks. Also, I added the autotune option. The current options I put there may not be the ideal so I need to do some benchmarking to find the reasonable config options. Performance wise, everything looks great now, even for the ultra case. |
Here are the benchmark numbers on V100:
And on A100:
|
This is great and working for me too. I am doing a small amount of cleaning and doc updates and stuff like that, but I think we can essentially merge this after those housekeeping things. Is there any additional algorithmic strategy you would want to try? It seems like we have the obvious ideas out of the way. |
I think we have all the low hanging fruits implemented. There might be some other tricks to increase the data reuse as I expect memory bandwidth to be the main bottleneck (profiling is needed to verify that). I think we can focus on merging what we have right now. |
Add custom triton kernels.
For large test cases, it is faster than both numba and pure torch implementation.