Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add custom triton kernels #84

Merged
merged 24 commits into from
Jul 29, 2024

Conversation

cagrikymk
Copy link
Contributor

Add custom triton kernels.
For large test cases, it is faster than both numba and pure torch implementation.

@lubbersnick
Copy link
Collaborator

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
ValueError: envsum failed grad check.
I got this set of tracebacks stochastically? That is, I ran once and got this error, then ran again and got no error. Also I may be using too old of a version of triton right now.

@lubbersnick
Copy link
Collaborator

Oh nevermind, it is working on GPU tests but then failing on CPU tests because triton doesn't handle those.

@lubbersnick
Copy link
Collaborator

lubbersnick commented Jul 22, 2024

Yeah, I am getting some weird results here:

Checking forward methods on large data 3 times...
1 bad tolerances (of 8572360)
Allowed tolerance:
[0.00010032]
Observed deviations:
[nan]
Ratio:
[nan]
Desired result: [-9.032157]
Observed result: [nan]
36311.402 nan
[  87.8035       0.9533726  -14.046047  ... -149.72322   -124.96163
   55.055904 ] [  87.80349      0.9533751  -14.046047  ... -149.72324   -124.961624
   55.055904 ]
[ 2984.9521   2590.1042   5114.351    1407.1326   4234.9707   2854.0742
  4787.0854          nan   631.5906    165.11407  -922.35425 -1296.7333
   618.398    3709.275    4717.3447    401.191    2153.7969   2135.7646
  1789.038   -1044.5712 ] [ 2984.9746   2590.0986   5114.355    1407.1372   4234.969    2854.0762
  4787.084    -719.533     631.59283   165.11786  -922.3617  -1296.7377
   618.4122   3709.2612   4717.354     401.19986  2153.8032   2135.77
  1789.0298  -1044.5677 ]
Locations: (array([308]), array([7]))
Violation stats: median: nan, max: nan

Basically this means that sometimes the result is coming back as NaN when it should be a finite value. This isn't happening every time i test so I guess some more digging is needed.

@lubbersnick
Copy link
Collaborator

I added a basic short-circuit to default back to pytorch on CPU.

@cagrikymk
Copy link
Contributor Author

@lubbersnick Interesting, I am not sure how anything leads to NaN given that the kernels only do add and multiply. Is this error triggered by "test_env_triton"?

@lubbersnick
Copy link
Collaborator

Yes, this is where it happened, during the large system correctness checks. I wonder if some sparsity in the input could be responsible. Like if there is a test particle with no neighbors or something, or a pair of particles with no sensitivities. This would be infrequent but possible. I upped the number of large system tests to 1000 and it is happening again...

@lubbersnick
Copy link
Collaborator

It is sensesum that fails, by the way.

@lubbersnick
Copy link
Collaborator

I confirmed that it still fails with new pytorch (2.3.1), it wasn't due to being on an old version. however I will try your fixes and see.

@cagrikymk
Copy link
Contributor Author

I am having trouble triggering the issue, I fixed the seed for repro. and did some some fixes in the code. Not sure if that resolves it as I cannot reproduce it.

@lubbersnick
Copy link
Collaborator

I am running the tests with your fixes now.

@lubbersnick
Copy link
Collaborator

I think it's working now. By the way you can directly see the improvement in the compiler over time:

pytorch2.3.1:

--------------------------------------------------------------------------------
Repetitions: 20
....................
Mean Pytorch_Envsum time: 0.015448415279388427 Median: 0.015448212623596191
Mean Pytorch_Sensesum time: 0.01885749101638794 Median: 0.01886439323425293
Mean Pytorch_Featsum time: 0.01822214126586914 Median: 0.018230915069580078
Mean Envsum time: 0.002795732021331787 Median: 0.0027532577514648438
Mean Sensesum time: 0.0023035645484924317 Median: 0.0023040771484375
Mean Featsum time: 0.003322196006774902 Median: 0.003315567970275879
Envsum Speedup: 5.610885001731901
Sensesum Speedup: 8.187396523178808
Featsum Speedup: 5.4985798008125695
Overall Pytorch time: 0.0525435209274292
Overall time now: 0.008372902870178223
Overall speedup: 6.275424633740052

vs pytorch 2.0.0

--------------------------------------------------------------------------------
Repetitions: 20
....................
Mean Pytorch_Envsum time: 0.015489602088928222 Median: 0.015532255172729492
Mean Pytorch_Sensesum time: 0.018367350101470947 Median: 0.018375158309936523
Mean Pytorch_Featsum time: 0.017868351936340333 Median: 0.017873525619506836
Mean Envsum time: 0.0030530691146850586 Median: 0.0030149221420288086
Mean Sensesum time: 0.0047687172889709474 Median: 0.004762530326843262
Mean Featsum time: 0.010993754863739014 Median: 0.011021256446838379
Envsum Speedup: 5.151793127990194
Sensesum Speedup: 3.8582763885760056
Featsum Speedup: 1.6217321233491613
Overall Pytorch time: 0.05178093910217285
Overall time now: 0.01879870891571045
Overall speedup: 2.754494435460858

I think I can take things from here to integrate into the settings structure, run a long-running test of some training to make sure we get similar results and the NaNs don't appear at all, stuff like that.

@cagrikymk
Copy link
Contributor Author

Sounds great!
I am not sure how Pytorch handles the installation of Triton but the up-to-date version of Triton requires CUDA 12. We might need to guard the triton code in case triton is not available and fall back to another option.

It is nice to see Triton works well for sparse operations as well not just dense ones.

@lubbersnick
Copy link
Collaborator

My tests were with CUDA 11.8, so that does still work fine

@lubbersnick
Copy link
Collaborator

I think there is still some more cleaning up to do, but I can train a model using this version, and now it defaults to triton kernels when they are available.

@lubbersnick
Copy link
Collaborator

Good news, we are getting 30% faster training of HIP-NN-TS to ANI-1x. (GPU: Titan V). Still wanting to test a longer run.

@cagrikymk
Copy link
Contributor Author

That is great to hear!

Also, what is the typical the number of pairs per batch, sensitivity and feature size for these runs? We can definitely increase the amount of parallelism revealed in the kernels if the typical work load is not large enough.

@cagrikymk
Copy link
Contributor Author

My tests were with CUDA 11.8, so that does still work fine

I think the main problem the default PTX Triton produces uses functionality that requires a GPU driver compatible with CUDA 12. So, based on the CUDA compatibility matrix, the GPU driver needs to be >=525.60.13.
I wasn't able to use Triton on machines with older GPU drivers. I can dig deeper but this is my current understanding.

image

@lubbersnick
Copy link
Collaborator

ok, re: the versions, you're talking GPU driver version vs I was saying cuda library versions. So I think the story makes sense.

@lubbersnick
Copy link
Collaborator

re: typical workloads, every time I try to address this the kinds of datasets and simulations are not so easily classified. For now it would actually make more sense to go back and do whole-program profiling to see if this operation is even the bottleneck right now.

@cagrikymk
Copy link
Contributor Author

cagrikymk commented Jul 26, 2024

I committed a new version of the code based on our discussion. Now, the feature and sensitivity vectors are split into chunks.

Also, I added the autotune option. The current options I put there may not be the ideal so I need to do some benchmarking to find the reasonable config options.

Performance wise, everything looks great now, even for the ultra case.

@cagrikymk
Copy link
Contributor Author

Here are the benchmark numbers on V100:

Ultra systems: {'n_molecules': 500, 'n_atoms': 30, 'atom_prob': 0.7, 'n_features': 128, 'n_nu': 320}
Repetitions: 20
....................
Mean cupy_Envsum time: 0.0488245964050293 Median: 0.01654231548309326
Mean cupy_Sensesum time: 0.07521920204162598 Median: 0.07400453090667725
Mean cupy_Featsum time: 0.025140464305877686 Median: 0.025070905685424805
Mean Envsum time: 0.00803889036178589 Median: 0.00780332088470459
Mean Sensesum time: 0.031079304218292237 Median: 0.010969281196594238
Mean Featsum time: 0.012728917598724365 Median: 0.01262056827545166
Envsum Speedup: 2.1199071174322857
Sensesum Speedup: 6.7465250986230805
Featsum Speedup: 1.9865116323002956
Overall cupy time: 0.11561775207519531
Overall time now: 0.03139317035675049
Overall speedup: 3.6828950616111946
--------------------------------------------------------------------------------
Mega systems: {'n_molecules': 500, 'n_atoms': 30, 'atom_prob': 0.7, 'n_features': 128, 'n_nu': 100}
Repetitions: 20
....................
Mean cupy_Envsum time: 0.0069476008415222164 Median: 0.006906867027282715
Mean cupy_Sensesum time: 0.011604201793670655 Median: 0.011866092681884766
Mean cupy_Featsum time: 0.008078169822692872 Median: 0.008179187774658203
Mean Envsum time: 0.028499603271484375 Median: 0.0035256147384643555
Mean Sensesum time: 0.024009811878204345 Median: 0.004373788833618164
Mean Featsum time: 0.014723169803619384 Median: 0.005219936370849609
Envsum Speedup: 1.9590532544378698
Sensesum Speedup: 2.713000817661488
Featsum Speedup: 1.566913309582534
Overall cupy time: 0.026952147483825684
Overall time now: 0.013119339942932129
Overall speedup: 2.054382888244755
--------------------------------------------------------------------------------
Large systems: {'n_molecules': 1000, 'n_atoms': 30, 'atom_prob': 0.7, 'n_features': 80, 'n_nu': 20}
Repetitions: 20
....................
Mean cupy_Envsum time: 0.002829885482788086 Median: 0.0027687549591064453
Mean cupy_Sensesum time: 0.003365147113800049 Median: 0.003361821174621582
Mean cupy_Featsum time: 0.0022456526756286623 Median: 0.0022395849227905273
Mean Envsum time: 0.021705222129821778 Median: 0.002555251121520996
Mean Sensesum time: 0.01661339998245239 Median: 0.0015361309051513672
Mean Featsum time: 0.02126336097717285 Median: 0.002826094627380371
Envsum Speedup: 1.083554933519944
Sensesum Speedup: 2.188499146360391
Featsum Speedup: 0.7924663601467921
Overall cupy time: 0.008370161056518555
Overall time now: 0.006917476654052734
Overall speedup: 1.2100020679671883

And on A100:

--------------------------------------------------------------------------------
Ultra systems: {'n_molecules': 500, 'n_atoms': 30, 'atom_prob': 0.7, 'n_features': 128, 'n_nu': 320}
Repetitions: 20
....................
Mean cupy_Envsum time: 0.046938538551330566 Median: 0.014365196228027344
Mean cupy_Sensesum time: 0.09621376991271972 Median: 0.10970163345336914
Mean cupy_Featsum time: 0.032917892932891844 Median: 0.03567230701446533
Mean Envsum time: 0.008868896961212158 Median: 0.009229421615600586
Mean Sensesum time: 0.029375433921813965 Median: 0.011646270751953125
Mean Featsum time: 0.012893128395080566 Median: 0.012567400932312012
Envsum Speedup: 1.5564568210586138
Sensesum Speedup: 9.419464461185719
Featsum Speedup: 2.8384792692296745
Overall cupy time: 0.15973913669586182
Overall time now: 0.03344309329986572
Overall speedup: 4.776446223546647
--------------------------------------------------------------------------------
Mega systems: {'n_molecules': 500, 'n_atoms': 30, 'atom_prob': 0.7, 'n_features': 128, 'n_nu': 100}
Repetitions: 20
....................
Mean cupy_Envsum time: 0.00874859094619751 Median: 0.007819175720214844
Mean cupy_Sensesum time: 0.013895213603973389 Median: 0.014153361320495605
Mean cupy_Featsum time: 0.010239267349243164 Median: 0.010371923446655273
Mean Envsum time: 0.025943386554718017 Median: 0.003808259963989258
Mean Sensesum time: 0.020059311389923097 Median: 0.004188418388366699
Mean Featsum time: 0.014934170246124267 Median: 0.004977107048034668
Envsum Speedup: 2.053214799974958
Sensesum Speedup: 3.3791660737156684
Featsum Speedup: 2.083926133505784
Overall cupy time: 0.03234446048736572
Overall time now: 0.012973785400390625
Overall speedup: 2.4930627021464273
--------------------------------------------------------------------------------
Large systems: {'n_molecules': 1000, 'n_atoms': 30, 'atom_prob': 0.7, 'n_features': 80, 'n_nu': 20}
Repetitions: 20
....................
Mean cupy_Envsum time: 0.0033119559288024903 Median: 0.0033206939697265625
Mean cupy_Sensesum time: 0.0037001729011535644 Median: 0.003951311111450195
Mean cupy_Featsum time: 0.0030521392822265626 Median: 0.003204345703125
Mean Envsum time: 0.019997811317443846 Median: 0.0030335187911987305
Mean Sensesum time: 0.014393055438995361 Median: 0.0018552541732788086
Mean Featsum time: 0.019430005550384523 Median: 0.00400388240814209
Envsum Speedup: 1.0946673478209612
Sensesum Speedup: 2.12979502666581
Featsum Speedup: 0.8003096436121119
Overall cupy time: 0.010476350784301758
Overall time now: 0.008892655372619629
Overall speedup: 1.178090271726745

@lubbersnick
Copy link
Collaborator

This is great and working for me too. I am doing a small amount of cleaning and doc updates and stuff like that, but I think we can essentially merge this after those housekeeping things. Is there any additional algorithmic strategy you would want to try? It seems like we have the obvious ideas out of the way.

@cagrikymk
Copy link
Contributor Author

I think we have all the low hanging fruits implemented. There might be some other tricks to increase the data reuse as I expect memory bandwidth to be the main bottleneck (profiling is needed to verify that).

I think we can focus on merging what we have right now.

@lubbersnick lubbersnick marked this pull request as ready for review July 29, 2024 19:19
@lubbersnick lubbersnick merged commit 14c8205 into lanl:development Jul 29, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants