Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add custom triton kernels #84

Merged
merged 24 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1ed39dd
add custom triton kernels
Jul 22, 2024
b628310
add CPU compatibility
lubbersnick Jul 22, 2024
727c502
move call site for CPU version
lubbersnick Jul 22, 2024
934413f
call the right functions
lubbersnick Jul 22, 2024
28619e0
fix GPU memory limits for tests
lubbersnick Jul 22, 2024
6931aa2
set the seed for repro. and change triton loading logic
Jul 22, 2024
e2b8901
update compare_against
Jul 23, 2024
c88086b
update tester
lubbersnick Jul 23, 2024
298b258
small tweaks to testing arguments
Jul 24, 2024
5df6edd
integrate triton kernels to options, apply formatter
lubbersnick Jul 24, 2024
ecf706b
adjust configparser
Jul 24, 2024
4f9b389
try refactor triton code (cosmetic)
lubbersnick Jul 25, 2024
ab9d47a
formatter
lubbersnick Jul 25, 2024
dec9fdd
add ultra-size test
Jul 25, 2024
1ecd758
remove explicit numba dependency from custom kernel tests
lubbersnick Jul 25, 2024
4aab176
update triton to use numba on CPU
lubbersnick Jul 25, 2024
fc7cb0a
more formatting and name changes
lubbersnick Jul 25, 2024
d1c4e37
fix lack of forward correctness checks!
lubbersnick Jul 25, 2024
bd5c7df
actually do what the last commit says
lubbersnick Jul 25, 2024
286b379
update for todos
lubbersnick Jul 25, 2024
4265297
split feat. and sense vectors to chunks to lower register pressure, a…
Jul 26, 2024
2a7fa6d
raise better errors in old custom kernels. partially update documenta…
lubbersnick Jul 27, 2024
fa5845b
update documentation for custom kernels
lubbersnick Jul 29, 2024
aa2c833
update requirements/docs
lubbersnick Jul 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Sakib Matin (LANL)
Emily Shinkle (LANL)
Michael G. Taylor (LANL)
Jan Janssen (LANL)
Cagri Kaymak (LANL)

Also thanks to testing and feedback from:

Expand Down
17 changes: 17 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@


Breaking changes:
-----------------

New Features:
-------------

- Added a new custom cuda kernel implementation using triton. These are highly performant and now the default implementation.

Improvements:
-------------


Bug Fixes:
----------

0.0.3
=======

Expand Down
3 changes: 2 additions & 1 deletion conda_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
numpy
pytorch >= 1.6
pytorch >= 1.9
torchtriton
matplotlib
numba
cupy
Expand Down
10 changes: 6 additions & 4 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@ Requirements:
* pytorch_ >= 1.9
* numpy_
Optional Dependencies:
* numba_ (recommended, for accelerating performance)
* cupy_ (also for accelerating performance)
* triton_ (recommended, for improved GPU performance)
* numba_ (recommended for improved CPU performance)
* cupy_ (Alternative for accelerating GPU performance)
* ASE_ (for usage with ase)
* matplotlib_ (for plotting)
* tqdm_ (for progress bars)
* graphviz_ (for viewing model graphs as figures)
* h5py_ (for ani-h5 datasets)
* pyanitools_ (for ani-h5 datasets)
* h5py_ (for loading ani-h5 datasets)
* pyanitools_ (for loading ani-h5 datasets)

Interfacing codes:
* ASE_
* PYSEQM_
* LAMMPS_

.. _triton: https://triton-lang.org/
.. _numpy: https://numpy.org/
.. _Python: http://www.python.org
.. _pytorch: http://www.pytorch.org
Expand Down
47 changes: 39 additions & 8 deletions docs/source/user_guide/ckernels.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
Custom Kernels
==============


Bottom line up front
--------------------

We use custom kernels in `hippynn` to accelerate the HIP-NN neural network message passing.
On the GPU, the best implementation to select is ``triton``, followed by ``cupy``,
followed by ``numba``. On the CPU, only ``numba`` is available. In general, these
custom kernels are very useful, and the only reasons for them to be off is if are
if the packages are not available for installation in your environment or if diagnosing
whether or not a bug could be related to potential misconfiguration of these additional packages.
``triton`` comes with recent versions of ``pytorch``, so optimistically you may already be
configured to use the custom kernels.

Detailed Explanation
--------------------
Analogs of convolutional layers that apply to continously variable points in space, such as the
`HIP-NN` interaction layer, can be awkward to write in pure-pytorch.
HIP-NN interaction layer, can be awkward to write in pure-pytorch.

The :mod:`~hippynn.custom_kernels` subpackage implements some more efficient kernels for both the forward
and backward pass of the sum over neighbors. This is implemented, more or less, as a CSR-type
Expand All @@ -12,8 +25,8 @@ mixture of inner products and outer products on the remaining "feature" and "sen
This behavior can be switched off (and is off by default if the dependencies are not installed)
to revert to a pure pytorch implementation.

The custom kernels provide `much` better memory footprint than the pure pytorch implementation,
and a decent amount of speedup on those core operations. The memory footprint of the pytorch
The custom kernels provide *much* better memory footprint than the pure pytorch implementation,
and a very good amount of speedup on those core operations. The memory footprint of the pytorch
implementation is approximately:

.. math::
Expand All @@ -27,19 +40,33 @@ whereas the memory footprint of the custom kernels is approximately
O(N_\mathrm{pairs}N_\mathrm{sensitivities} +
N_\mathrm{atoms}N_\mathrm{features}N_\mathrm{sensitivities}).

The custom kernels are implemented using ``numba`` and/or ``cupy``, depending
The custom kernels are implemented using ``triton``, ``cupy`` and/or ``numba``, depending
on what is installed in your python environment.
However, there are certain overheads in using them.
In particular, if you are using a GPU and your batch size is small,
the pytorch implementations may actually be faster, because they launch more quickly.
This is especially true if you use a shallower model (one interaction layer) with
This is especially true if you use a shallow HIP-NN type model (one interaction layer) with
with a small number of elements, because the memory waste in a pure pytorch
implementation is proportional to the number of input features.
If you are using a CPU, the custom kernels are recommended at all times.
Nonetheless for most practical purposes, keeping custom kernels
on at all times is computationally recommended.
If you are using a CPU, the custom kernels are provided only using ``numba``, but they
do not come with any large overheads, and so provide computatonal benefits at all times.
The only reason to turn custom kernels off, in general, is to diagnose whether there are
issues with how they are being deployed; if ``numba`` or ``cupy`` is not correctly installed,
then we have found that sometimes the kernels may silently fail.

The three custom kernels correspond to the interaction sum in hip-nn:

For envsum, sensum, featsum:
.. math::

a'_{i,a} = = \sum_{\nu,b} V^\nu_{a,b} e^{\nu}_{i,b}

e^{\nu}_{i,a} = \sum_p s^\nu_{p} z_{p_j,a}

Where :math:`a` is the pre-activation for an interaction layer using input features :math:`z`.

For envsum, sensesum, featsum:

.. math::

Expand All @@ -49,6 +76,10 @@ For envsum, sensum, featsum:

f_{j,a} = \sum_{\nu,i} e_{p_i,\nu,a} s_{p_i,a}

These three functions form a closed system under automatic differentiation, and are linked to each
other in pytorch's autograd, thereby supporting custom kernels in backwards passes and in
double-backwards passes associated with Force training or similar features.

Custom kernels can be set ahead of time using :doc:`/user_guide/settings` and dynamically
using :func:`~hippynn.custom_kernels.set_custom_kernels`.

4 changes: 2 additions & 2 deletions docs/source/user_guide/settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ The following settings are available:
- false
- Yes
* - USE_CUSTOM_KERNELS
- Use custom kernels with numba or cupy. Auto tries to detect the installation of numba or cupy. For more info see :doc:`/user_guide/ckernels`.
- auto, true, false, pytorch, numba, cupy
- Use custom kernels with triton, numba or cupy. Auto tries to detect the installation. For more info see :doc:`/user_guide/ckernels`.
- auto, true, false, pytorch, numba, cupy, triton
- auto
- Not directly, use :func:`~hippynn.custom_kernels.set_custom_kernels`
* - WARN_LOW_DISTANCES
Expand Down
4 changes: 2 additions & 2 deletions hippynn/_settings_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def kernel_handler(kernel_string):
"true": True,
}.get(kernel_string, kernel_string)

if kernel not in [True, False, "auto", "cupy", "numba"]:
if kernel not in [True, False, "auto", "triton", "cupy", "numba"]:
warnings.warn(f"Unrecognized custom kernel option: {kernel_string}. Setting custom kernels to 'auto'")
kernel = "auto"

Expand Down Expand Up @@ -86,7 +86,7 @@ def kernel_handler(kernel_string):

rc_name = os.path.expanduser("~/.hippynnrc")
if os.path.exists(rc_name) and os.path.isfile(rc_name):
config = configparser.ConfigParser()
config = configparser.ConfigParser(inline_comment_prefixes="#")
config.read(rc_name)
config_sources["~/.hippynnrc"] = config["GLOBALS"]

Expand Down
30 changes: 24 additions & 6 deletions hippynn/custom_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,16 @@
except ImportError:
pass

try:
import triton

CUSTOM_KERNELS_AVAILABLE.append("triton")
except ImportError:
pass

if not CUSTOM_KERNELS_AVAILABLE:
warnings.warn("Numba or cupy not available: Custom Kernels will be disabled.")
warnings.warn(
"Triton, cupy and numba are not available: Custom kernels will be disabled and performance maybe be degraded.")

CUSTOM_KERNELS_ACTIVE = False

Expand Down Expand Up @@ -75,46 +83,56 @@ def set_custom_kernels(active: Union[bool, str] = True):
Activate or deactivate custom kernels for interaction.

:param active: If true, set custom kernels to the best available. If False, turn them off and default to pytorch.
If "numba" or "cupy", use those implementations explicitly. If "auto", use best available.
If "triton", "numba" or "cupy", use those implementations explicitly. If "auto", use best available.
:return: None
"""
global envsum, sensesum, featsum, CUSTOM_KERNELS_ACTIVE

if isinstance(active, str):
active = active.lower()

if active not in [True, False, "numba", "cupy", "pytorch", "auto"]:
if active not in [True, False, "triton", "numba", "cupy", "pytorch", "auto"]:
raise ValueError(f"Unrecognized custom kernel implementation: {active}")

active_map = {"auto": True, "pytorch": False}
if not CUSTOM_KERNELS_AVAILABLE:
if active == "auto" or active == "pytorch":
active = False
elif active:
raise RuntimeError("Numba or cupy was not found. Custom kernels are not available.")
raise RuntimeError(
"Triton, numba and cupy were not found. Custom kernels are not available, but they were required by library settings.")
else:
active = active_map.get(active, active)

# Handle fallback to pytorch kernels.
if not active:
envsum = env_pytorch.envsum
sensesum = env_pytorch.sensesum
featsum = env_pytorch.featsum
CUSTOM_KERNELS_ACTIVE = False
return

# Select custom kernel implementation

if not CUSTOM_KERNELS_AVAILABLE:
raise RuntimeError("Numba was not found. Custom kernels are not available.")

if active is True:
if "cupy" in CUSTOM_KERNELS_AVAILABLE:
if "triton" in CUSTOM_KERNELS_AVAILABLE:
active = "triton"
elif "cupy" in CUSTOM_KERNELS_AVAILABLE:
active = "cupy"
else:
active = "numba"

if active not in CUSTOM_KERNELS_AVAILABLE:
raise RuntimeError(f"Unavailable custom kernel implementation: {active}")

if active == "cupy":
if active == "triton":
from .env_triton import envsum as triton_envsum, sensesum as triton_sensesum, featsum as triton_featsum

envsum, sensesum, featsum = autograd_wrapper.wrap_envops(triton_envsum, triton_sensesum, triton_featsum)
elif active == "cupy":
_check_numba()
_check_cupy()
from .env_cupy import cupy_envsum, cupy_featsum, cupy_sensesum
Expand Down
12 changes: 6 additions & 6 deletions hippynn/custom_kernels/autograd_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def forward(ctx, sense, feat, pfirst, psecond):
if pfirst.shape[0] == 0:
n_pair, n_nu = sense.shape
n_atom, n_feat = feat.shape
if n_pair!=0 or psecond.shape[0]!=0:
if n_pair != 0 or psecond.shape[0] != 0:
raise ValueError("Inconsistent shapes for envsum.")
return torch.zeros((n_atom,n_nu,n_feat),dtype=feat.dtype,device=feat.device)
return torch.zeros((n_atom, n_nu, n_feat), dtype=feat.dtype, device=feat.device)
env = envsum_impl(sense, feat, pfirst, psecond)
return env

Expand Down Expand Up @@ -49,9 +49,9 @@ def forward(ctx, env, feat, pfirst, psecond):
if pfirst.shape[0] == 0:
n_atom0, n_nu, n_feat0 = env.shape
n_atom1, n_feat1 = feat.shape
if psecond.shape[0] !=0 or n_atom0!=n_atom1 or n_feat0 != n_feat1:
if psecond.shape[0] != 0 or n_atom0 != n_atom1 or n_feat0 != n_feat1:
raise ValueError("Inconsistent shapes for sensesum")
return torch.zeros((0,n_nu),dtype=feat.dtype,device=feat.device)
return torch.zeros((0, n_nu), dtype=feat.dtype, device=feat.device)
sense = sensesum_impl(env, feat, pfirst, psecond)
return sense

Expand All @@ -72,9 +72,9 @@ def forward(ctx, env, sense, pfirst, psecond):
if pfirst.shape[0] == 0:
n_atom, n_nu0, n_feat = env.shape
n_pair, n_nu1 = sense.shape
if psecond.shape[0] !=0 or n_nu0!=n_nu1:
if psecond.shape[0] != 0 or n_nu0 != n_nu1:
raise ValueError("Inconsistent shapes for featsum")
return torch.zeros((n_atom,n_feat),dtype=env.dtype,device=env.device)
return torch.zeros((n_atom, n_feat), dtype=env.dtype, device=env.device)
feat = featsum_impl(env, sense, pfirst, psecond)
return feat

Expand Down
Loading
Loading