From 1ed39dd02708d2aa26c1a4a07c69db631d7017b9 Mon Sep 17 00:00:00 2001 From: Mehmet Cagri Kaymak Date: Sun, 21 Jul 2024 22:41:59 -0600 Subject: [PATCH 01/24] add custom triton kernels --- hippynn/custom_kernels/env_triton.py | 200 ++++++++++++++++++++++ hippynn/custom_kernels/test_env_triton.py | 12 ++ 2 files changed, 212 insertions(+) create mode 100644 hippynn/custom_kernels/env_triton.py create mode 100644 hippynn/custom_kernels/test_env_triton.py diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py new file mode 100644 index 00000000..ef9aefc6 --- /dev/null +++ b/hippynn/custom_kernels/env_triton.py @@ -0,0 +1,200 @@ +import torch +import triton +import triton.language as tl +from .utils import resort_pairs_cached + +@triton.jit +def envsum_kernel(out_env_ptr, + sens_ptr, + feat_ptr, + psecond_ptr, + atom_ids_ptr, + atom_starts_ptr, + atom_size, + sens_size: tl.constexpr, + feat_size: tl.constexpr, + p2_sens_size: tl.constexpr, + p2_feat_size: tl.constexpr, + dtype: tl.constexpr = tl.float32): + atom_id = tl.program_id(axis=0) + start = tl.load(atom_starts_ptr + atom_id, mask=atom_id < atom_size, other=0) + end = tl.load(atom_starts_ptr + atom_id + 1, mask=atom_id < atom_size, other=0) + target_id = tl.load(atom_ids_ptr + atom_id, mask=atom_id < atom_size, other=0) + sens_block_ids = tl.arange(0, p2_sens_size) + feat_block_ids = tl.arange(0, p2_feat_size) + tmp = tl.zeros((p2_sens_size, p2_feat_size), dtype=dtype) + for ind in range(start, end): + # [p2_sens_size,], coming from the pair sensitivity + s = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, + mask=sens_block_ids < sens_size, other=0.0) + pair_ind = tl.load(psecond_ptr + ind) # TODO do we need mask here + # [p2_feat_size,], coming from the neighbor feature + feat = tl.load(feat_ptr + (pair_ind * feat_size) + feat_block_ids, + mask=feat_block_ids < feat_size, other=0.0) + # temp_mat and tmp is [p2_sens_size, p2_feat_size] + temp_mat = s[:, None] * feat[None, :] + tmp = tmp + temp_mat + mask = (sens_block_ids[:, None] < sens_size) & (feat_block_ids[None, :] < feat_size) + block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + tl.store(out_env_ptr + (target_id * sens_size * feat_size) + block_ids, tmp, mask=mask) + +def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env_fetures=None): + n_pairs, n_nu = sensitivities.shape + n_atom, n_feat = features.shape + (n_atom_with_pairs,) = atom_ids.shape + if out_env_fetures == None: + out_env_fetures = torch.zeros((n_atom, n_nu, n_feat), dtype=features.dtype, device=features.device) + dtype = tl.float32 + if features.dtype == torch.float64: + dtype = tl.float64 + p2_sens_size = triton.next_power_of_2(n_nu) + p2_feat_size = triton.next_power_of_2(n_feat) + envsum_kernel[(n_atom_with_pairs,)]( + out_env_fetures, + sensitivities, + features, + pair_second, + atom_ids, + atom_starts, + n_atom_with_pairs, + n_nu, + n_feat, + p2_sens_size, + p2_feat_size, + dtype=dtype) + return out_env_fetures + +def envsum(sense, features, pfirst, psecond): + psecond_hold = psecond + argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) + resort_pairs_cached(psecond_hold, []) + return envsum_triton(sense, features, pfirst, psecond, atom1_ids, atom1_starts, out_env_fetures=None) + +@triton.jit +def sensesum_kernel(out_sense_ptr, + env_ptr, + feat_ptr, + pfirst_ptr, + psecond_ptr, + pair_size, + sens_size: tl.constexpr, + feat_size: tl.constexpr, + p2_sens_size: tl.constexpr, + p2_feat_size: tl.constexpr, + dtype: tl.constexpr = tl.float32): + pair_id = tl.program_id(axis=0) + first = tl.load(pfirst_ptr + pair_id, mask=pair_id < pair_size, other=0) + second = tl.load(psecond_ptr + pair_id, mask=pair_id < pair_size, other=0) + sens_block_ids = tl.arange(0, p2_sens_size) + feat_block_ids = tl.arange(0, p2_feat_size) + mask = (sens_block_ids[:, None] < sens_size) & (feat_block_ids[None, :] < feat_size) + block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + # [p2_sens_size, p2_feat_size] + env = tl.load(env_ptr + (first * sens_size * feat_size) + block_ids, mask=mask) + # [p2_feat_size, ] + feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, + mask=feat_block_ids < feat_size, other=0.0) + ''' + type_f32: tl.constexpr = tl.float32 + type_check: tl.constexpr = (dtype == type_f32) + if type_check: + res = tl.dot(env, feat[:, None]) + else: + res = tl.sum(env * feat[None, :], axis=1) + ''' + res = tl.sum(env * feat[None, :], axis=1) + tl.store(out_sense_ptr + (pair_id * sens_size) + sens_block_ids, res, + mask=sens_block_ids < sens_size) + +def sensesum(env, features, pair_first, pair_second, out_sense=None): + _, n_nu, _ = env.shape + n_atom, n_feat = features.shape + n_pairs = len(pair_first) + if out_sense == None: + out_sense = torch.zeros((n_pairs, n_nu), dtype=features.dtype, device=features.device) + dtype = tl.float32 + if features.dtype == torch.float64: + dtype = tl.float64 + p2_sens_size = triton.next_power_of_2(n_nu) + p2_feat_size = triton.next_power_of_2(n_feat) + sensesum_kernel[(n_pairs,)]( + out_sense, + env, + features, + pair_first, + pair_second, + n_pairs, + n_nu, + n_feat, + p2_sens_size, + p2_feat_size, + dtype=dtype) + return out_sense + +@triton.jit +def featsum_kernel(out_feat, + env_ptr, + sens_ptr, + pfirst_ptr, + psecond_ptr, + atom2_ids_ptr, + atom2_starts_ptr, + atom_size, + sens_size: tl.constexpr, + feat_size: tl.constexpr, + p2_sens_size: tl.constexpr, + p2_feat_size: tl.constexpr, + dtype: tl.constexpr = tl.float32): + atom_id = tl.program_id(axis=0) + start = tl.load(atom2_starts_ptr + atom_id, mask=atom_id < atom_size, other=0) + end = tl.load(atom2_starts_ptr + atom_id + 1, mask=atom_id < atom_size, other=0) + target_id = tl.load(atom2_ids_ptr + atom_id, mask=atom_id < atom_size, other=0) + sens_block_ids = tl.arange(0, p2_sens_size) + feat_block_ids = tl.arange(0, p2_feat_size) + tmp = tl.zeros((p2_feat_size,), dtype=dtype) + for ind in range(start, end): + # [p2_sens_size,], coming from the pair sensitivity + sense = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, + mask=sens_block_ids < sens_size, other=0.0) + pair_ind = tl.load(pfirst_ptr + ind) # TODO do we need mask here + mask = (sens_block_ids[:, None] < sens_size) & (feat_block_ids[None, :] < feat_size) + block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + # [p2_sens_size, p2_feat_size] + env = tl.load(env_ptr + (pair_ind * sens_size * feat_size) + block_ids, mask=mask) + # temp_mat and tmp is [p2_feat_size,] + temp_mat = tl.sum(env * sense[:, None], axis=0) + tmp = tmp + temp_mat + tl.store(out_feat + (target_id * feat_size) + feat_block_ids, tmp, mask=feat_block_ids < feat_size) + +def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, out_feat=None): + n_atom, n_nu, n_feat = env.shape + (n_pairs,) = pair_first.shape + (n_atoms_with_pairs,) = atom2_ids.shape + if out_feat == None: + out_feat = torch.zeros((n_atom, n_feat), dtype=env.dtype, device=env.device) + dtype = tl.float32 + if env.dtype == torch.float64: + dtype = tl.float64 + p2_sens_size = triton.next_power_of_2(n_nu) + p2_feat_size = triton.next_power_of_2(n_feat) + featsum_kernel[(n_atoms_with_pairs,)]( + out_feat, + env, + sense, + pair_first, + pair_second, + atom2_ids, + atom2_starts, + n_atoms_with_pairs, + n_nu, + n_feat, + p2_sens_size, + p2_feat_size, + dtype=dtype) + return out_feat + +def featsum(env, sense, pfirst, psecond): + pfirst_hold = pfirst + argsort, atom2_ids, atom2_starts, psecond, (sense, pfirst) = resort_pairs_cached(psecond, [sense, pfirst]) + resort_pairs_cached(pfirst_hold, []) + return featsum_triton(env, sense, pfirst, psecond, atom2_ids, atom2_starts, out_feat=None) diff --git a/hippynn/custom_kernels/test_env_triton.py b/hippynn/custom_kernels/test_env_triton.py new file mode 100644 index 00000000..4239d2e0 --- /dev/null +++ b/hippynn/custom_kernels/test_env_triton.py @@ -0,0 +1,12 @@ + +import torch +from .env_triton import envsum, sensesum, featsum +from .test_env_numba import Envops_tester, main, get_simulated_data +from .test_env_numba import TEST_MEGA_PARAMS, TEST_LARGE_PARAMS, TEST_MEDIUM_PARAMS, TEST_SMALL_PARAMS +from .utils import resort_pairs_cached +if __name__ == "__main__": + + main(envsum, + sensesum, + featsum, + ) \ No newline at end of file From b62831032ed070b0c7292b3c79f04e8c3c6efa93 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Mon, 22 Jul 2024 14:28:47 -0600 Subject: [PATCH 02/24] add CPU compatibility --- hippynn/custom_kernels/env_triton.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index ef9aefc6..d89f63e6 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -3,6 +3,8 @@ import triton.language as tl from .utils import resort_pairs_cached +from .env_pytorch import envsum as envsum_pt, sensesum as sensesum_pt, featsum as featsum_pt + @triton.jit def envsum_kernel(out_env_ptr, sens_ptr, @@ -39,6 +41,9 @@ def envsum_kernel(out_env_ptr, tl.store(out_env_ptr + (target_id * sens_size * feat_size) + block_ids, tmp, mask=mask) def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env_fetures=None): + if sensitivities.device == torch.device('cpu'): + return featsum_pt(sensitivities,features,pair_first,pair_second) + n_pairs, n_nu = sensitivities.shape n_atom, n_feat = features.shape (n_atom_with_pairs,) = atom_ids.shape @@ -107,6 +112,9 @@ def sensesum_kernel(out_sense_ptr, mask=sens_block_ids < sens_size) def sensesum(env, features, pair_first, pair_second, out_sense=None): + if env.device == torch.device('cpu'): + return featsum_pt(env,features,pair_first,pair_second) + _, n_nu, _ = env.shape n_atom, n_feat = features.shape n_pairs = len(pair_first) @@ -167,6 +175,8 @@ def featsum_kernel(out_feat, tl.store(out_feat + (target_id * feat_size) + feat_block_ids, tmp, mask=feat_block_ids < feat_size) def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, out_feat=None): + if env.device == torch.device('cpu'): + return featsum_pt(env,sense,pair_first,pair_second) n_atom, n_nu, n_feat = env.shape (n_pairs,) = pair_first.shape (n_atoms_with_pairs,) = atom2_ids.shape From 727c502ed2755ebccd9d5df333d0c2227ba46c62 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Mon, 22 Jul 2024 14:32:02 -0600 Subject: [PATCH 03/24] move call site for CPU version --- hippynn/custom_kernels/env_triton.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index d89f63e6..e7090fc9 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -41,9 +41,6 @@ def envsum_kernel(out_env_ptr, tl.store(out_env_ptr + (target_id * sens_size * feat_size) + block_ids, tmp, mask=mask) def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env_fetures=None): - if sensitivities.device == torch.device('cpu'): - return featsum_pt(sensitivities,features,pair_first,pair_second) - n_pairs, n_nu = sensitivities.shape n_atom, n_feat = features.shape (n_atom_with_pairs,) = atom_ids.shape @@ -70,6 +67,8 @@ def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, at return out_env_fetures def envsum(sense, features, pfirst, psecond): + if sense.device == torch.device('cpu'): + return featsum_pt(sense,features,pfirst,psecond) psecond_hold = psecond argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) resort_pairs_cached(psecond_hold, []) @@ -114,7 +113,6 @@ def sensesum_kernel(out_sense_ptr, def sensesum(env, features, pair_first, pair_second, out_sense=None): if env.device == torch.device('cpu'): return featsum_pt(env,features,pair_first,pair_second) - _, n_nu, _ = env.shape n_atom, n_feat = features.shape n_pairs = len(pair_first) @@ -175,8 +173,6 @@ def featsum_kernel(out_feat, tl.store(out_feat + (target_id * feat_size) + feat_block_ids, tmp, mask=feat_block_ids < feat_size) def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, out_feat=None): - if env.device == torch.device('cpu'): - return featsum_pt(env,sense,pair_first,pair_second) n_atom, n_nu, n_feat = env.shape (n_pairs,) = pair_first.shape (n_atoms_with_pairs,) = atom2_ids.shape @@ -204,6 +200,8 @@ def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, return out_feat def featsum(env, sense, pfirst, psecond): + if env.device == torch.device('cpu'): + return featsum_pt(env,sense,pfirst,psecond) pfirst_hold = pfirst argsort, atom2_ids, atom2_starts, psecond, (sense, pfirst) = resort_pairs_cached(psecond, [sense, pfirst]) resort_pairs_cached(pfirst_hold, []) From 934413f285c6359fe82a0182c9579f3b45380187 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Mon, 22 Jul 2024 14:33:57 -0600 Subject: [PATCH 04/24] call the right functions --- hippynn/custom_kernels/env_triton.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index e7090fc9..df33d79f 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -68,7 +68,7 @@ def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, at def envsum(sense, features, pfirst, psecond): if sense.device == torch.device('cpu'): - return featsum_pt(sense,features,pfirst,psecond) + return envsum_pt(sense,features,pfirst,psecond) psecond_hold = psecond argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) resort_pairs_cached(psecond_hold, []) @@ -112,7 +112,7 @@ def sensesum_kernel(out_sense_ptr, def sensesum(env, features, pair_first, pair_second, out_sense=None): if env.device == torch.device('cpu'): - return featsum_pt(env,features,pair_first,pair_second) + return sensesum_pt(env,features,pair_first,pair_second) _, n_nu, _ = env.shape n_atom, n_feat = features.shape n_pairs = len(pair_first) From 28619e0e6be162ce109105568c26ae927d95f879 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Mon, 22 Jul 2024 14:36:06 -0600 Subject: [PATCH 05/24] fix GPU memory limits for tests --- hippynn/custom_kernels/test_env_numba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index f798560d..364e0c71 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -370,7 +370,7 @@ def main(env_impl,sense_impl,feat_impl): print("Running GPU tests") meminfo = numba.cuda.current_context().get_memory_info() use_large_gpu = meminfo.free > 2 ** 31 - use_verylarge_gpu = meminfo.free > 2**3 + use_verylarge_gpu = meminfo.free > 2**35 n_large = 3 if use_large_gpu else 0 tester.check_correctness(device=torch.device("cuda"),n_large=n_large) From 6931aa226d17c9f64fa83bc37279adc530d01b37 Mon Sep 17 00:00:00 2001 From: Mehmet Cagri Kaymak Date: Mon, 22 Jul 2024 15:57:27 -0600 Subject: [PATCH 06/24] set the seed for repro. and change triton loading logic --- hippynn/custom_kernels/env_triton.py | 4 ++-- hippynn/custom_kernels/test_env_numba.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index df33d79f..e1173d86 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -94,7 +94,7 @@ def sensesum_kernel(out_sense_ptr, mask = (sens_block_ids[:, None] < sens_size) & (feat_block_ids[None, :] < feat_size) block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] # [p2_sens_size, p2_feat_size] - env = tl.load(env_ptr + (first * sens_size * feat_size) + block_ids, mask=mask) + env = tl.load(env_ptr + (first * sens_size * feat_size) + block_ids, mask=mask, other=0.0) # [p2_feat_size, ] feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, mask=feat_block_ids < feat_size, other=0.0) @@ -166,7 +166,7 @@ def featsum_kernel(out_feat, mask = (sens_block_ids[:, None] < sens_size) & (feat_block_ids[None, :] < feat_size) block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] # [p2_sens_size, p2_feat_size] - env = tl.load(env_ptr + (pair_ind * sens_size * feat_size) + block_ids, mask=mask) + env = tl.load(env_ptr + (pair_ind * sens_size * feat_size) + block_ids, mask=mask, other=0.0) # temp_mat and tmp is [p2_feat_size,] temp_mat = tl.sum(env * sense[:, None], axis=0) tmp = tmp + temp_mat diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index 364e0c71..471171e3 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -8,7 +8,8 @@ from . import env_pytorch from . import autograd_wrapper from . import env_numba - +# set the seed for reproducibility +np.random.seed(0) def get_simulated_data( n_molecules, n_atoms, atom_prob, n_features, n_nu, printinfo=False, dtype=None, device=torch.device("cpu") From e2b8901548e8f5e526d5df864a044ec783712a08 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Tue, 23 Jul 2024 13:24:02 -0600 Subject: [PATCH 07/24] update compare_against --- hippynn/custom_kernels/test_env_numba.py | 34 +++++++++++++++--------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index 471171e3..02640179 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -8,8 +8,8 @@ from . import env_pytorch from . import autograd_wrapper from . import env_numba -# set the seed for reproducibility -np.random.seed(0) +from . import env_cupy + def get_simulated_data( n_molecules, n_atoms, atom_prob, n_features, n_nu, printinfo=False, dtype=None, device=torch.device("cpu") @@ -243,7 +243,8 @@ def all_close_witherror(self, r1, r2): return np.max(np.abs(diff_arr / tol_arr)) # max violation of bounds def check_allclose(self, repeats=30, use_large=False, device=torch.device("cpu")): - for i in range(repeats): + from tqdm.auto import tqdm + for i in tqdm(range(repeats),leave=False): try: self.check_allclose_once(use_large, device=device) except Exception as ee: @@ -273,6 +274,10 @@ def check_speed( comp_envsum = env_numba.new_envsum comp_sensesum = env_numba.new_sensesum comp_featsum = env_numba.new_featsum + elif compare_against == "Cupy": + comp_envsum = env_cupy.cupy_envsum + comp_sensesum = env_cupy.cupy_sensesum + comp_featsum = env_cupy.cupy_featsum else: raise ValueError("Unknown implementation to comapre against:'{}'".format(compare_against)) @@ -360,6 +365,8 @@ def elapsed(self): def main(env_impl,sense_impl,feat_impl): + + np.random.seed(0) tester = Envops_tester( env_impl, sense_impl, @@ -371,31 +378,32 @@ def main(env_impl,sense_impl,feat_impl): print("Running GPU tests") meminfo = numba.cuda.current_context().get_memory_info() use_large_gpu = meminfo.free > 2 ** 31 - use_verylarge_gpu = meminfo.free > 2**35 - - n_large = 3 if use_large_gpu else 0 + use_verylarge_gpu = meminfo.free > 2**34 + + n_large = 10 if use_large_gpu else 0 tester.check_correctness(device=torch.device("cuda"),n_large=n_large) + compare_against = "Pytorch" if use_verylarge_gpu: print("-" * 80) print("Mega systems:", TEST_MEGA_PARAMS) - tester.check_speed(n_repetitions=20,data_size=TEST_MEGA_PARAMS, device=torch.device("cuda"), compare_against="Pytorch") + tester.check_speed(n_repetitions=20,data_size=TEST_MEGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against) else: print("Numba indicates less than 32GB free GPU memory -- skipping mega system test") if use_large_gpu: print("-" * 80) - tester.check_speed(n_repetitions=20,data_size=TEST_LARGE_PARAMS, device=torch.device("cuda"), compare_against="Pytorch") + tester.check_speed(n_repetitions=20,data_size=TEST_LARGE_PARAMS, device=torch.device("cuda"), compare_against=compare_against) else: print("Numba indicates less than 2GB free GPU memory -- skipping large system test") print("-" * 80) print("Medium systems:", TEST_MEDIUM_PARAMS) tester.check_speed( - n_repetitions=100, data_size=TEST_MEDIUM_PARAMS, device=torch.device("cuda"), compare_against="Pytorch" + n_repetitions=100, data_size=TEST_MEDIUM_PARAMS, device=torch.device("cuda"), compare_against=compare_against ) print("-" * 80) print("Small systems:", TEST_SMALL_PARAMS) tester.check_speed( - n_repetitions=100, data_size=TEST_SMALL_PARAMS, device=torch.device("cuda"), compare_against="Pytorch" + n_repetitions=100, data_size=TEST_SMALL_PARAMS, device=torch.device("cuda"), compare_against=compare_against ) else: @@ -405,13 +413,13 @@ def main(env_impl,sense_impl,feat_impl): tester.check_correctness() print("-" * 80) print("Large systems:", TEST_LARGE_PARAMS) - tester.check_speed(n_repetitions=10, compare_against="Pytorch") + tester.check_speed(n_repetitions=10, compare_against=compare_against) print("-" * 80) print("Medium systems:", TEST_MEDIUM_PARAMS) - tester.check_speed(n_repetitions=100, data_size=TEST_MEDIUM_PARAMS, compare_against="Pytorch") + tester.check_speed(n_repetitions=100, data_size=TEST_MEDIUM_PARAMS, compare_against=compare_against) print("-" * 80) print("Small systems:", TEST_SMALL_PARAMS) - tester.check_speed(n_repetitions=100, compare_against="Pytorch", data_size=TEST_SMALL_PARAMS) + tester.check_speed(n_repetitions=100, compare_against=compare_against, data_size=TEST_SMALL_PARAMS) if __name__ == "__main__": From c88086bcccca5f09e427f1fbe1856cc5d663b47a Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Tue, 23 Jul 2024 14:07:42 -0600 Subject: [PATCH 08/24] update tester --- hippynn/custom_kernels/test_env_numba.py | 188 +++++++++++++++-------- hippynn/custom_kernels/utils.py | 3 + 2 files changed, 131 insertions(+), 60 deletions(-) diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index 02640179..b2d4a30a 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -7,13 +7,28 @@ from . import env_pytorch from . import autograd_wrapper -from . import env_numba -from . import env_cupy - - -def get_simulated_data( - n_molecules, n_atoms, atom_prob, n_features, n_nu, printinfo=False, dtype=None, device=torch.device("cpu") -): +from .utils import clear_pair_cache + +import warnings + +try: + from . import env_numba +except ImportError: + warnings.warn("numba implementation not importable.") + env_numba = None +try: + from . import env_cupy +except ImportError: + warnings.warn("cupy implementation not importable.") + env_cupy = None +try: + from . import env_triton +except ImportError: + warnings.warn("triton implementation not importable.") + env_triton = None + + +def get_simulated_data(n_molecules, n_atoms, atom_prob, n_features, n_nu, printinfo=False, dtype=None, device=torch.device("cpu")): """ Get semi-realistic test data for hipnn. n_molecules : number of molecules in the batch @@ -80,7 +95,7 @@ def get_simulated_data( n_pairs = len(pair_first) # NOTE: These fake sensitivities are NONSYMMETRIC. - # Current HIPNN does not do that, but a future one could. + # Current HIP-NN does not do that, but a future one could. on_sensitivites = np.random.choice([True, False], p=[3 / n_nu, 1 - 3 / n_nu], size=(n_pairs, n_nu)) pair_sensitivites = np.random.random(size=(n_pairs, n_nu)) * on_sensitivites assert not (pair_first == pair_second).any() @@ -101,7 +116,7 @@ def get_simulated_data( return pair_sense, features, pair_first, pair_second -TEST_TINY_PARAMS = dict(n_molecules=2, n_atoms=3, atom_prob=1., n_features=5, n_nu=7) +TEST_TINY_PARAMS = dict(n_molecules=2, n_atoms=3, atom_prob=1.0, n_features=5, n_nu=7) TEST_SMALL_PARAMS = dict(n_molecules=10, n_atoms=30, atom_prob=0.7, n_features=10, n_nu=20) TEST_MEDIUM_PARAMS = dict(n_molecules=100, n_atoms=30, atom_prob=0.7, n_features=20, n_nu=20) TEST_LARGE_PARAMS = dict(n_molecules=1000, n_atoms=30, atom_prob=0.7, n_features=80, n_nu=20) @@ -195,16 +210,16 @@ def check_allclose_once(self, use_large=False, device=torch.device("cpu")): if max_deviation > self.suspicious_deviation: print("Closeness check for {} by suspicious amount".format(name), max_deviation) - def check_empty(self,device=torch.device('cpu')): + def check_empty(self, device=torch.device("cpu")): sense, feat, pfirst, psecond = get_simulated_data(**TEST_TINY_PARAMS, dtype=torch.float64, device=device) - pfirst = psecond = torch.zeros((0,),dtype=torch.long,device=pfirst.device) - sense = torch.zeros((0,sense.shape[1]),dtype=sense.dtype,device=sense.device) - + pfirst = psecond = torch.zeros((0,), dtype=torch.long, device=pfirst.device) + sense = torch.zeros((0, sense.shape[1]), dtype=sense.dtype, device=sense.device) + try: - env = self.envsum(sense,feat,pfirst,psecond) - sense_g = self.sensesum(env,feat,pfirst,psecond) - feat_g = self.featsum(env,sense,pfirst,psecond) + env = self.envsum(sense, feat, pfirst, psecond) + sense_g = self.sensesum(env, feat, pfirst, psecond) + feat_g = self.featsum(env, sense, pfirst, psecond) except Exception as ee: raise ValueError("Failed an operation on data with zero pairs") from ee print("Passed zero-pair check") @@ -244,7 +259,8 @@ def all_close_witherror(self, r1, r2): def check_allclose(self, repeats=30, use_large=False, device=torch.device("cpu")): from tqdm.auto import tqdm - for i in tqdm(range(repeats),leave=False): + + for i in tqdm(range(repeats), leave=False): try: self.check_allclose_once(use_large, device=device) except Exception as ee: @@ -262,31 +278,45 @@ def check_correctness(self, n_grad=1, n_small=100, n_large=3, device=torch.devic self.check_allclose(repeats=n_large, use_large=True, device=device) print("Passed large tensor forward checks!") - def check_speed( - self, n_repetitions=10, device=torch.device("cpu"), data_size=TEST_LARGE_PARAMS, compare_against="Pytorch" - ): + def check_speed(self, n_repetitions=10, device=torch.device("cpu"), data_size=TEST_LARGE_PARAMS, compare_against="pytorch"): - if compare_against == "Pytorch": + if compare_against.lower() == "pytorch": comp_envsum = env_pytorch.envsum comp_sensesum = env_pytorch.sensesum comp_featsum = env_pytorch.featsum - elif compare_against == "Numba": + elif compare_against.lower() == "numba": comp_envsum = env_numba.new_envsum comp_sensesum = env_numba.new_sensesum comp_featsum = env_numba.new_featsum - elif compare_against == "Cupy": + elif compare_against.lower() == "cupy": comp_envsum = env_cupy.cupy_envsum comp_sensesum = env_cupy.cupy_sensesum comp_featsum = env_cupy.cupy_featsum + elif compare_against.lower() == "triton": + comp_envsum = env_triton.envsum + comp_sensesum = env_triton.featsum + comp_featsum = env_triton.featsum + else: raise ValueError("Unknown implementation to comapre against:'{}'".format(compare_against)) te, ts, tf = (TimerHolder(name) for name in ("Envsum", "Sensesum", "Featsum")) - tne, tns, tnf = ( - TimerHolder("{}_{}".format(compare_against, name)) for name in ("Envsum", "Sensesum", "Featsum") - ) + tne, tns, tnf = (TimerHolder("{}_{}".format(compare_against, name)) for name in ("Envsum", "Sensesum", "Featsum")) + print("Repetitions: {}".format(n_repetitions)) with torch.autograd.no_grad(): + # Warming up by running on data of this specific size + sense, feat, pfirst, psecond = get_simulated_data(**data_size, dtype=torch.float32, device=device) + env = comp_envsum(sense, feat, pfirst, psecond) + comp_sensesum(env, feat, pfirst, psecond) + comp_featsum(env, sense, pfirst, psecond) + self.envsum(sense, feat, pfirst, psecond) + self.sensesum(env, feat, pfirst, psecond) + self.featsum(env, sense, pfirst, psecond) + + # Note: in this implementation we clear the pair cache for each run. + # In real conditions speedups could be greater due to caching of pairs. + # with torch.autograd.profiler.profile() as prof: for i in range(n_repetitions): print(".", end="", flush=True) @@ -297,6 +327,7 @@ def check_speed( comp_sensesum(env, feat, pfirst, psecond) with tnf.add(): comp_featsum(env, sense, pfirst, psecond) + clear_pair_cache() with te.add(): self.envsum(sense, feat, pfirst, psecond) with ts.add(): @@ -364,66 +395,103 @@ def elapsed(self): return self.end - self.start -def main(env_impl,sense_impl,feat_impl): +def main(env_impl, sense_impl, feat_impl, args=None): - np.random.seed(0) + if args is None: + # calling without arguments looks for them from command line + args = parse_args() + print("Got args:", args) + np.random.seed(args.seed) tester = Envops_tester( env_impl, sense_impl, feat_impl, ) - # % time - if torch.cuda.is_available(): + compare_against = args.compare_against + test_gpu = not args.no_test_gpu + test_cpu = not args.no_test_cpu + correctness = not args.no_correctness + + if torch.cuda.is_available() and not args.no_test_gpu: print("Running GPU tests") meminfo = numba.cuda.current_context().get_memory_info() - use_large_gpu = meminfo.free > 2 ** 31 - use_verylarge_gpu = meminfo.free > 2**34 - - n_large = 10 if use_large_gpu else 0 - tester.check_correctness(device=torch.device("cuda"),n_large=n_large) - compare_against = "Pytorch" + use_large_gpu = meminfo.free > 2**31 + use_verylarge_gpu = meminfo.free > 30 * (2**30) + + n_large = args.n_large if use_large_gpu else 0 + if correctness: + tester.check_correctness(device=torch.device("cuda"), n_large=n_large) + if use_verylarge_gpu: print("-" * 80) print("Mega systems:", TEST_MEGA_PARAMS) - tester.check_speed(n_repetitions=20,data_size=TEST_MEGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against) + tester.check_speed(n_repetitions=20, data_size=TEST_MEGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against) else: - print("Numba indicates less than 32GB free GPU memory -- skipping mega system test") + print("Numba indicates less than 30GB free GPU memory -- skipping mega system test") if use_large_gpu: print("-" * 80) - tester.check_speed(n_repetitions=20,data_size=TEST_LARGE_PARAMS, device=torch.device("cuda"), compare_against=compare_against) + tester.check_speed(n_repetitions=20, data_size=TEST_LARGE_PARAMS, device=torch.device("cuda"), compare_against=compare_against) else: print("Numba indicates less than 2GB free GPU memory -- skipping large system test") print("-" * 80) print("Medium systems:", TEST_MEDIUM_PARAMS) - tester.check_speed( - n_repetitions=100, data_size=TEST_MEDIUM_PARAMS, device=torch.device("cuda"), compare_against=compare_against - ) + tester.check_speed(n_repetitions=100, data_size=TEST_MEDIUM_PARAMS, device=torch.device("cuda"), compare_against=compare_against) print("-" * 80) print("Small systems:", TEST_SMALL_PARAMS) - tester.check_speed( - n_repetitions=100, data_size=TEST_SMALL_PARAMS, device=torch.device("cuda"), compare_against=compare_against - ) + tester.check_speed(n_repetitions=100, data_size=TEST_SMALL_PARAMS, device=torch.device("cuda"), compare_against=compare_against) + + else: + if not args.no_test_gpu: + print("Cuda not available, not running GPU tests.") + else: + print("Skipped GPU tests.") + if test_cpu: + print("Running CPU tests") + if test_gpu: + tester.check_correctness() + + print("-" * 80) + print("Large systems:", TEST_LARGE_PARAMS) + tester.check_speed(n_repetitions=10, compare_against=compare_against) + print("-" * 80) + print("Medium systems:", TEST_MEDIUM_PARAMS) + tester.check_speed(n_repetitions=100, data_size=TEST_MEDIUM_PARAMS, compare_against=compare_against) + print("-" * 80) + print("Small systems:", TEST_SMALL_PARAMS) + tester.check_speed(n_repetitions=100, compare_against=compare_against, data_size=TEST_SMALL_PARAMS) else: - print("Cuda not available, not running GPU tests.") + print("Skipped CPU tests.") + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=0, help="name for run") + + parser.add_argument( + "--compare_against", type=str, default="pytorch", help=""" + implementation to compare speed with. Options are: pytorch, numba, cupy, triton""",) + + parser.add_argument( + "--n_large", type=int, default=0, help=""" + Number of times to check correctness of forward pass. Set this to a large number (e.g. 200) to + stress-test a new implementation against corner-cases.""",) + + parser.add_argument("--no-test-cpu", action="store_true", default=False, help="Set to false to skip CPU tests.") + parser.add_argument("--no-test-gpu", action="store_true", default=False, help="Set to false to skip GPU tests.") + parser.add_argument("--no-correctness", action="store_true", default=False, help="Set to false to skip GPU tests.") - print("Running CPU tests") - tester.check_correctness() - print("-" * 80) - print("Large systems:", TEST_LARGE_PARAMS) - tester.check_speed(n_repetitions=10, compare_against=compare_against) - print("-" * 80) - print("Medium systems:", TEST_MEDIUM_PARAMS) - tester.check_speed(n_repetitions=100, data_size=TEST_MEDIUM_PARAMS, compare_against=compare_against) - print("-" * 80) - print("Small systems:", TEST_SMALL_PARAMS) - tester.check_speed(n_repetitions=100, compare_against=compare_against, data_size=TEST_SMALL_PARAMS) + args = parser.parse_args() + return args if __name__ == "__main__": - main(env_numba.new_envsum, - env_numba.new_sensesum, - env_numba.new_featsum, + main( + env_numba.new_envsum, + env_numba.new_sensesum, + env_numba.new_featsum, ) diff --git a/hippynn/custom_kernels/utils.py b/hippynn/custom_kernels/utils.py index e3668752..cf704af4 100644 --- a/hippynn/custom_kernels/utils.py +++ b/hippynn/custom_kernels/utils.py @@ -104,6 +104,9 @@ def _make_cache(): # Dict mapping device to key cache info _CACHE_STORE = collections.defaultdict(_make_cache) +def clear_pair_cache(): + _CACHE_STORE.clear() + CACHE_LOCK_MISSES = 0 def resort_pairs_cached(key, others): global CACHE_LOCK_MISSES From 298b258b6b4200cf50f22c52b7b54a59ef3ffcf4 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Wed, 24 Jul 2024 12:33:35 -0600 Subject: [PATCH 09/24] small tweaks to testing arguments --- hippynn/custom_kernels/test_env_numba.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index b2d4a30a..d1feb78c 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -321,6 +321,7 @@ def check_speed(self, n_repetitions=10, device=torch.device("cpu"), data_size=TE for i in range(n_repetitions): print(".", end="", flush=True) sense, feat, pfirst, psecond = get_simulated_data(**data_size, dtype=torch.float32, device=device) + torch.cuda.synchronize() with tne.add(): env = comp_envsum(sense, feat, pfirst, psecond) with tns.add(): @@ -400,6 +401,12 @@ def main(env_impl, sense_impl, feat_impl, args=None): if args is None: # calling without arguments looks for them from command line args = parse_args() + + if isinstance(args,dict): + from types import SimpleNamespace + args = SimpleNamespace(**args) + + print("Got args:", args) np.random.seed(args.seed) tester = Envops_tester( @@ -473,7 +480,7 @@ def parse_args(): parser.add_argument("--seed", type=int, default=0, help="name for run") parser.add_argument( - "--compare_against", type=str, default="pytorch", help=""" + "--compare-against", type=str, default="pytorch", help=""" implementation to compare speed with. Options are: pytorch, numba, cupy, triton""",) parser.add_argument( From 5df6edd15b06b220ebdd8de816957b8bb30c45a9 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Wed, 24 Jul 2024 12:41:25 -0600 Subject: [PATCH 10/24] integrate triton kernels to options, apply formatter --- hippynn/custom_kernels/__init__.py | 24 +++- hippynn/custom_kernels/autograd_wrapper.py | 12 +- hippynn/custom_kernels/env_cupy.py | 133 ++++++++++-------- hippynn/custom_kernels/env_numba.py | 14 +- hippynn/custom_kernels/env_triton.py | 152 +++++++++++---------- hippynn/custom_kernels/fast_convert.py | 19 ++- hippynn/custom_kernels/tensor_wrapper.py | 17 ++- hippynn/custom_kernels/test_env_cupy.py | 9 +- hippynn/custom_kernels/test_env_numba.py | 20 ++- hippynn/custom_kernels/test_env_triton.py | 11 +- hippynn/custom_kernels/utils.py | 65 +++++---- 11 files changed, 263 insertions(+), 213 deletions(-) diff --git a/hippynn/custom_kernels/__init__.py b/hippynn/custom_kernels/__init__.py index 614367a2..87f3b2f6 100644 --- a/hippynn/custom_kernels/__init__.py +++ b/hippynn/custom_kernels/__init__.py @@ -34,6 +34,13 @@ except ImportError: pass +try: + import triton + + CUSTOM_KERNELS_AVAILABLE.append("triton") +except ImportError: + pass + if not CUSTOM_KERNELS_AVAILABLE: warnings.warn("Numba or cupy not available: Custom Kernels will be disabled.") @@ -83,7 +90,7 @@ def set_custom_kernels(active: Union[bool, str] = True): if isinstance(active, str): active = active.lower() - if active not in [True, False, "numba", "cupy", "pytorch", "auto"]: + if active not in [True, False, "triton", "numba", "cupy", "pytorch", "auto"]: raise ValueError(f"Unrecognized custom kernel implementation: {active}") active_map = {"auto": True, "pytorch": False} @@ -91,10 +98,11 @@ def set_custom_kernels(active: Union[bool, str] = True): if active == "auto" or active == "pytorch": active = False elif active: - raise RuntimeError("Numba or cupy was not found. Custom kernels are not available.") + raise RuntimeError("Numba or cupy was not found. Custom kernels are not available, but they were required by library settings.") else: active = active_map.get(active, active) + # Handle fallback to pytorch kernels. if not active: envsum = env_pytorch.envsum sensesum = env_pytorch.sensesum @@ -102,11 +110,15 @@ def set_custom_kernels(active: Union[bool, str] = True): CUSTOM_KERNELS_ACTIVE = False return + # Select custom kernel implementation + if not CUSTOM_KERNELS_AVAILABLE: raise RuntimeError("Numba was not found. Custom kernels are not available.") if active is True: - if "cupy" in CUSTOM_KERNELS_AVAILABLE: + if "triton" in CUSTOM_KERNELS_AVAILABLE: + active = "triton" + elif "cupy" in CUSTOM_KERNELS_AVAILABLE: active = "cupy" else: active = "numba" @@ -114,7 +126,11 @@ def set_custom_kernels(active: Union[bool, str] = True): if active not in CUSTOM_KERNELS_AVAILABLE: raise RuntimeError(f"Unavailable custom kernel implementation: {active}") - if active == "cupy": + if active == "triton": + from .env_triton import envsum as triton_envsum, sensesum as triton_sensesum, featsum as triton_featsum + + envsum, sensesum, featsum = autograd_wrapper.wrap_envops(triton_envsum, triton_sensesum, triton_featsum) + elif active == "cupy": _check_numba() _check_cupy() from .env_cupy import cupy_envsum, cupy_featsum, cupy_sensesum diff --git a/hippynn/custom_kernels/autograd_wrapper.py b/hippynn/custom_kernels/autograd_wrapper.py index c283eb95..9f124d51 100644 --- a/hippynn/custom_kernels/autograd_wrapper.py +++ b/hippynn/custom_kernels/autograd_wrapper.py @@ -19,9 +19,9 @@ def forward(ctx, sense, feat, pfirst, psecond): if pfirst.shape[0] == 0: n_pair, n_nu = sense.shape n_atom, n_feat = feat.shape - if n_pair!=0 or psecond.shape[0]!=0: + if n_pair != 0 or psecond.shape[0] != 0: raise ValueError("Inconsistent shapes for envsum.") - return torch.zeros((n_atom,n_nu,n_feat),dtype=feat.dtype,device=feat.device) + return torch.zeros((n_atom, n_nu, n_feat), dtype=feat.dtype, device=feat.device) env = envsum_impl(sense, feat, pfirst, psecond) return env @@ -49,9 +49,9 @@ def forward(ctx, env, feat, pfirst, psecond): if pfirst.shape[0] == 0: n_atom0, n_nu, n_feat0 = env.shape n_atom1, n_feat1 = feat.shape - if psecond.shape[0] !=0 or n_atom0!=n_atom1 or n_feat0 != n_feat1: + if psecond.shape[0] != 0 or n_atom0 != n_atom1 or n_feat0 != n_feat1: raise ValueError("Inconsistent shapes for sensesum") - return torch.zeros((0,n_nu),dtype=feat.dtype,device=feat.device) + return torch.zeros((0, n_nu), dtype=feat.dtype, device=feat.device) sense = sensesum_impl(env, feat, pfirst, psecond) return sense @@ -72,9 +72,9 @@ def forward(ctx, env, sense, pfirst, psecond): if pfirst.shape[0] == 0: n_atom, n_nu0, n_feat = env.shape n_pair, n_nu1 = sense.shape - if psecond.shape[0] !=0 or n_nu0!=n_nu1: + if psecond.shape[0] != 0 or n_nu0 != n_nu1: raise ValueError("Inconsistent shapes for featsum") - return torch.zeros((n_atom,n_feat),dtype=env.dtype,device=env.device) + return torch.zeros((n_atom, n_feat), dtype=env.dtype, device=env.device) feat = featsum_impl(env, sense, pfirst, psecond) return feat diff --git a/hippynn/custom_kernels/env_cupy.py b/hippynn/custom_kernels/env_cupy.py index 830c7e51..bd0048c0 100644 --- a/hippynn/custom_kernels/env_cupy.py +++ b/hippynn/custom_kernels/env_cupy.py @@ -11,7 +11,7 @@ from hippynn.custom_kernels.env_numba import WrappedEnvsum, WrappedSensesum, WrappedFeatsum from hippynn.custom_kernels.utils import resort_pairs_cached -CUPY_KERNEL_CODE=r""" +CUPY_KERNEL_CODE = r""" extern "C" __global__ void cupy_envsum( const FLOATX* sense, @@ -110,62 +110,73 @@ } """ -_CUPY_MODULES = { - dtype: cupy.RawModule(code=CUPY_KERNEL_CODE.replace("FLOATX",dtype)) - for dtype in ('float','double') - } - -def _cupy_gpu_not_found(*args,**kwargs): - raise RuntimeError("Error: CuPy could not find the GPU." - "Verify that your numba installation is able to find cuda toolkit, as this \n" - "error condition likely indicates that torch can find the GPU, but cupy can't.\n" - "Alternatively, disable custom kernels.") +_CUPY_MODULES = {dtype: cupy.RawModule(code=CUPY_KERNEL_CODE.replace("FLOATX", dtype)) for dtype in ("float", "double")} + + +def _cupy_gpu_not_found(*args, **kwargs): + raise RuntimeError( + "Error: CuPy could not find the GPU." + "Verify that your numba installation is able to find cuda toolkit, as this \n" + "error condition likely indicates that torch can find the GPU, but cupy can't.\n" + "Alternatively, disable custom kernels." + ) + class CupyGPUKernel: _cupy_name = None + def __init__(self): - if not cupy.cuda.is_available(): - self.kernel32 = _cupy_gpu_not_found - self.kernel64 = _cupy_gpu_not_found - else: - self.kernel32 = _CUPY_MODULES['float'].get_function(self._cupy_name) - self.kernel64 = _CUPY_MODULES['double'].get_function(self._cupy_name) - - def __call__(self,dtype,BPG,TPB,array_args,shape_args): - + if not cupy.cuda.is_available(): + self.kernel32 = _cupy_gpu_not_found + self.kernel64 = _cupy_gpu_not_found + else: + self.kernel32 = _CUPY_MODULES["float"].get_function(self._cupy_name) + self.kernel64 = _CUPY_MODULES["double"].get_function(self._cupy_name) + + def __call__(self, dtype, BPG, TPB, array_args, shape_args): + out_array = array_args[-1] array_args = [cupy.asarray(a.detach().contiguous()).ravel() for a in array_args] args = (*array_args, *shape_args) - + if dtype == torch.float32: - self.kernel32(BPG,TPB,args) + self.kernel32(BPG, TPB, args) elif dtype == torch.float64: - self.kernel64(BPG,TPB,args) + self.kernel64(BPG, TPB, args) else: raise ValueError("Bad dtype: {}".format(dtype)) return out_array - -class CupyEnvsum(CupyGPUKernel,WrappedEnvsum): - _cupy_name = 'cupy_envsum' - def __call__(self,sense,feat,pfirst,psecond): + +class CupyEnvsum(CupyGPUKernel, WrappedEnvsum): + _cupy_name = "cupy_envsum" + + def __call__(self, sense, feat, pfirst, psecond): psecond_hold = psecond argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) - resort_pairs_cached(psecond_hold,[]) + resort_pairs_cached(psecond_hold, []) dev = sense.device if dev.type == "cpu": return self.cpu_kernel(sense, feat, pfirst, psecond, atom1_ids, atom1_starts) n_pairs, n_nu = sense.shape n_atoms, n_feat = feat.shape - n_interact, = atom1_ids.shape + (n_interact,) = atom1_ids.shape dtype = sense.dtype - env_out = torch.zeros((n_atoms,n_nu,n_feat,), device=dev, dtype=dtype) - array_args = sense,feat,psecond,atom1_ids,atom1_starts,env_out - shape_args = n_nu,n_feat,n_interact - + env_out = torch.zeros( + ( + n_atoms, + n_nu, + n_feat, + ), + device=dev, + dtype=dtype, + ) + array_args = sense, feat, psecond, atom1_ids, atom1_starts, env_out + shape_args = n_nu, n_feat, n_interact + TPB_MAX = 512 TPB_X = n_feat TPB_Y = TPB_MAX // n_feat @@ -173,54 +184,57 @@ def __call__(self,sense,feat,pfirst,psecond): BPG_X = (n_interact + TPB_Y - 1) // TPB_Y BPG_Y = 1 BPG = (BPG_X, BPG_Y) - - args = *array_args,*shape_args - return super().__call__(dtype,BPG,TPB,array_args,shape_args) -class CupySensesum(CupyGPUKernel,WrappedSensesum): - _cupy_name = 'cupy_sensesum' - def __call__(self,env,feat,pfirst,psecond): + args = *array_args, *shape_args + return super().__call__(dtype, BPG, TPB, array_args, shape_args) + + +class CupySensesum(CupyGPUKernel, WrappedSensesum): + _cupy_name = "cupy_sensesum" + + def __call__(self, env, feat, pfirst, psecond): dev = env.device if dev.type == "cpu": return self.cpu_kernel(env, feat, pfirst, psecond) - n_pairs, = pfirst.shape + (n_pairs,) = pfirst.shape n_atoms, n_nu, n_feat = env.shape dtype = env.dtype - sense_out = torch.zeros((n_pairs,n_nu), device=dev, dtype=dtype) + sense_out = torch.zeros((n_pairs, n_nu), device=dev, dtype=dtype) array_args = env, feat, pfirst, psecond, sense_out shape_args = n_pairs, n_nu, n_feat - + TPB_MAX = 512 TPB_Y = n_nu - TPB_X = TPB_MAX//TPB_Y - TPB = (TPB_X,TPB_Y) - BPG_X = (n_pairs + TPB_X - 1 )//TPB_X - BPG = (BPG_X,1) + TPB_X = TPB_MAX // TPB_Y + TPB = (TPB_X, TPB_Y) + BPG_X = (n_pairs + TPB_X - 1) // TPB_X + BPG = (BPG_X, 1) - return super().__call__(dtype,BPG,TPB,array_args,shape_args) + return super().__call__(dtype, BPG, TPB, array_args, shape_args) -class CupyFeatsum(CupyGPUKernel,WrappedFeatsum): - _cupy_name = 'cupy_featsum' - def __call__(self,env,sense,pfirst,psecond): +class CupyFeatsum(CupyGPUKernel, WrappedFeatsum): + _cupy_name = "cupy_featsum" + + def __call__(self, env, sense, pfirst, psecond): pfirst_hold = pfirst argsort, atom2_ids, atom2_starts, psecond, (sense, pfirst) = resort_pairs_cached(psecond, [sense, pfirst]) - resort_pairs_cached(pfirst_hold,[]) + resort_pairs_cached(pfirst_hold, []) dev = env.device if dev.type == "cpu": return self.cpu_kernel(env, sense, pfirst, psecond, atom2_ids, atom2_starts) - n_pairs, = pfirst.shape + (n_pairs,) = pfirst.shape n_atoms, n_nu, n_feat = env.shape - n_interact, = atom2_ids.shape + (n_interact,) = atom2_ids.shape dtype = env.dtype - feat_out = torch.zeros((n_atoms,n_feat), device=dev, dtype=dtype) + feat_out = torch.zeros((n_atoms, n_feat), device=dev, dtype=dtype) array_args = env, sense, pfirst, atom2_ids, atom2_starts, feat_out - shape_args = n_nu, n_feat ,n_interact - + shape_args = n_nu, n_feat, n_interact + TPB_max = 512 if n_feat > 32: TPB_x = ((n_feat + 31) // 32) * 32 @@ -234,9 +248,10 @@ def __call__(self,env,sense,pfirst,psecond): TPB = (TPB_x, TPB_y) BPG_x = (n_atoms + TPB_y - 1) // TPB_y BPG = (BPG_x, BPG_y) - - return super().__call__(dtype,BPG,TPB,array_args,shape_args) - + + return super().__call__(dtype, BPG, TPB, array_args, shape_args) + + cupy_envsum = CupyEnvsum() cupy_sensesum = CupySensesum() cupy_featsum = CupyFeatsum() diff --git a/hippynn/custom_kernels/env_numba.py b/hippynn/custom_kernels/env_numba.py index 7581732e..f2fe918d 100644 --- a/hippynn/custom_kernels/env_numba.py +++ b/hippynn/custom_kernels/env_numba.py @@ -50,9 +50,7 @@ def launch_bounds(self, sense_shape, fs, pfs, pss, atom1_ids_shape, *other_shape @staticmethod def make_kernel(KERNEL_DTYPE): - sig = "void({DTYPE}[:,:,],{DTYPE}[:,:],int64[:],int64[:],int64[:],int64[:],{DTYPE}[:,:,:])".format( - DTYPE=KERNEL_DTYPE - ) + sig = "void({DTYPE}[:,:,],{DTYPE}[:,:],int64[:],int64[:],int64[:],int64[:],{DTYPE}[:,:,:])".format(DTYPE=KERNEL_DTYPE) @numba.cuda.jit( sig, @@ -123,9 +121,9 @@ def launch_bounds(self, env_shape, feat_shape, pfirst_shape, psecond_shape): n_atoms, n_nu, n_feat = env_shape TPB_MAX = 512 TPB_Y = n_nu - TPB_X = TPB_MAX//TPB_Y - TPB = (TPB_X,TPB_Y) - BPG = (n_pairs + TPB_X -1 )//TPB_X + TPB_X = TPB_MAX // TPB_Y + TPB = (TPB_X, TPB_Y) + BPG = (n_pairs + TPB_X - 1) // TPB_X return BPG, TPB @staticmethod @@ -209,9 +207,7 @@ def launch_bounds(self, env_shape, sense_shape, pfirst_shape, psecond_shape, ato @staticmethod def make_kernel(KERNEL_DTYPE): - sig = "void({DTYPE}[:,:,:],{DTYPE}[:,:],int64[:],int64[:],int64[:],int64[:],{DTYPE}[:,:])".format( - DTYPE=KERNEL_DTYPE - ) + sig = "void({DTYPE}[:,:,:],{DTYPE}[:,:],int64[:],int64[:],int64[:],int64[:],{DTYPE}[:,:])".format(DTYPE=KERNEL_DTYPE) @numba.cuda.jit( sig, diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index e1173d86..fd7cfa6f 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -5,34 +5,35 @@ from .env_pytorch import envsum as envsum_pt, sensesum as sensesum_pt, featsum as featsum_pt + @triton.jit -def envsum_kernel(out_env_ptr, - sens_ptr, - feat_ptr, - psecond_ptr, - atom_ids_ptr, - atom_starts_ptr, - atom_size, - sens_size: tl.constexpr, - feat_size: tl.constexpr, - p2_sens_size: tl.constexpr, - p2_feat_size: tl.constexpr, - dtype: tl.constexpr = tl.float32): +def envsum_kernel( + out_env_ptr, + sens_ptr, + feat_ptr, + psecond_ptr, + atom_ids_ptr, + atom_starts_ptr, + atom_size, + sens_size: tl.constexpr, + feat_size: tl.constexpr, + p2_sens_size: tl.constexpr, + p2_feat_size: tl.constexpr, + dtype: tl.constexpr = tl.float32, +): atom_id = tl.program_id(axis=0) start = tl.load(atom_starts_ptr + atom_id, mask=atom_id < atom_size, other=0) end = tl.load(atom_starts_ptr + atom_id + 1, mask=atom_id < atom_size, other=0) - target_id = tl.load(atom_ids_ptr + atom_id, mask=atom_id < atom_size, other=0) + target_id = tl.load(atom_ids_ptr + atom_id, mask=atom_id < atom_size, other=0) sens_block_ids = tl.arange(0, p2_sens_size) feat_block_ids = tl.arange(0, p2_feat_size) tmp = tl.zeros((p2_sens_size, p2_feat_size), dtype=dtype) - for ind in range(start, end): + for ind in range(start, end): # [p2_sens_size,], coming from the pair sensitivity - s = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, - mask=sens_block_ids < sens_size, other=0.0) - pair_ind = tl.load(psecond_ptr + ind) # TODO do we need mask here + s = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=sens_block_ids < sens_size, other=0.0) + pair_ind = tl.load(psecond_ptr + ind) # TODO do we need mask here # [p2_feat_size,], coming from the neighbor feature - feat = tl.load(feat_ptr + (pair_ind * feat_size) + feat_block_ids, - mask=feat_block_ids < feat_size, other=0.0) + feat = tl.load(feat_ptr + (pair_ind * feat_size) + feat_block_ids, mask=feat_block_ids < feat_size, other=0.0) # temp_mat and tmp is [p2_sens_size, p2_feat_size] temp_mat = s[:, None] * feat[None, :] tmp = tmp + temp_mat @@ -40,6 +41,7 @@ def envsum_kernel(out_env_ptr, block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] tl.store(out_env_ptr + (target_id * sens_size * feat_size) + block_ids, tmp, mask=mask) + def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env_fetures=None): n_pairs, n_nu = sensitivities.shape n_atom, n_feat = features.shape @@ -63,29 +65,34 @@ def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, at n_feat, p2_sens_size, p2_feat_size, - dtype=dtype) + dtype=dtype, + ) return out_env_fetures + def envsum(sense, features, pfirst, psecond): - if sense.device == torch.device('cpu'): - return envsum_pt(sense,features,pfirst,psecond) + if sense.device == torch.device("cpu"): + return envsum_pt(sense, features, pfirst, psecond) psecond_hold = psecond argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) resort_pairs_cached(psecond_hold, []) return envsum_triton(sense, features, pfirst, psecond, atom1_ids, atom1_starts, out_env_fetures=None) -@triton.jit -def sensesum_kernel(out_sense_ptr, - env_ptr, - feat_ptr, - pfirst_ptr, - psecond_ptr, - pair_size, - sens_size: tl.constexpr, - feat_size: tl.constexpr, - p2_sens_size: tl.constexpr, - p2_feat_size: tl.constexpr, - dtype: tl.constexpr = tl.float32): + +@triton.jit +def sensesum_kernel( + out_sense_ptr, + env_ptr, + feat_ptr, + pfirst_ptr, + psecond_ptr, + pair_size, + sens_size: tl.constexpr, + feat_size: tl.constexpr, + p2_sens_size: tl.constexpr, + p2_feat_size: tl.constexpr, + dtype: tl.constexpr = tl.float32, +): pair_id = tl.program_id(axis=0) first = tl.load(pfirst_ptr + pair_id, mask=pair_id < pair_size, other=0) second = tl.load(psecond_ptr + pair_id, mask=pair_id < pair_size, other=0) @@ -96,23 +103,22 @@ def sensesum_kernel(out_sense_ptr, # [p2_sens_size, p2_feat_size] env = tl.load(env_ptr + (first * sens_size * feat_size) + block_ids, mask=mask, other=0.0) # [p2_feat_size, ] - feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, - mask=feat_block_ids < feat_size, other=0.0) - ''' + feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, mask=feat_block_ids < feat_size, other=0.0) + """ type_f32: tl.constexpr = tl.float32 type_check: tl.constexpr = (dtype == type_f32) if type_check: res = tl.dot(env, feat[:, None]) else: res = tl.sum(env * feat[None, :], axis=1) - ''' + """ res = tl.sum(env * feat[None, :], axis=1) - tl.store(out_sense_ptr + (pair_id * sens_size) + sens_block_ids, res, - mask=sens_block_ids < sens_size) + tl.store(out_sense_ptr + (pair_id * sens_size) + sens_block_ids, res, mask=sens_block_ids < sens_size) + def sensesum(env, features, pair_first, pair_second, out_sense=None): - if env.device == torch.device('cpu'): - return sensesum_pt(env,features,pair_first,pair_second) + if env.device == torch.device("cpu"): + return sensesum_pt(env, features, pair_first, pair_second) _, n_nu, _ = env.shape n_atom, n_feat = features.shape n_pairs = len(pair_first) @@ -124,45 +130,38 @@ def sensesum(env, features, pair_first, pair_second, out_sense=None): p2_sens_size = triton.next_power_of_2(n_nu) p2_feat_size = triton.next_power_of_2(n_feat) sensesum_kernel[(n_pairs,)]( - out_sense, - env, - features, - pair_first, - pair_second, - n_pairs, - n_nu, - n_feat, - p2_sens_size, - p2_feat_size, - dtype=dtype) + out_sense, env, features, pair_first, pair_second, n_pairs, n_nu, n_feat, p2_sens_size, p2_feat_size, dtype=dtype + ) return out_sense -@triton.jit -def featsum_kernel(out_feat, - env_ptr, - sens_ptr, - pfirst_ptr, - psecond_ptr, - atom2_ids_ptr, - atom2_starts_ptr, - atom_size, - sens_size: tl.constexpr, - feat_size: tl.constexpr, - p2_sens_size: tl.constexpr, - p2_feat_size: tl.constexpr, - dtype: tl.constexpr = tl.float32): + +@triton.jit +def featsum_kernel( + out_feat, + env_ptr, + sens_ptr, + pfirst_ptr, + psecond_ptr, + atom2_ids_ptr, + atom2_starts_ptr, + atom_size, + sens_size: tl.constexpr, + feat_size: tl.constexpr, + p2_sens_size: tl.constexpr, + p2_feat_size: tl.constexpr, + dtype: tl.constexpr = tl.float32, +): atom_id = tl.program_id(axis=0) start = tl.load(atom2_starts_ptr + atom_id, mask=atom_id < atom_size, other=0) end = tl.load(atom2_starts_ptr + atom_id + 1, mask=atom_id < atom_size, other=0) - target_id = tl.load(atom2_ids_ptr + atom_id, mask=atom_id < atom_size, other=0) + target_id = tl.load(atom2_ids_ptr + atom_id, mask=atom_id < atom_size, other=0) sens_block_ids = tl.arange(0, p2_sens_size) feat_block_ids = tl.arange(0, p2_feat_size) tmp = tl.zeros((p2_feat_size,), dtype=dtype) - for ind in range(start, end): + for ind in range(start, end): # [p2_sens_size,], coming from the pair sensitivity - sense = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, - mask=sens_block_ids < sens_size, other=0.0) - pair_ind = tl.load(pfirst_ptr + ind) # TODO do we need mask here + sense = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=sens_block_ids < sens_size, other=0.0) + pair_ind = tl.load(pfirst_ptr + ind) # TODO do we need mask here mask = (sens_block_ids[:, None] < sens_size) & (feat_block_ids[None, :] < feat_size) block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] # [p2_sens_size, p2_feat_size] @@ -172,6 +171,7 @@ def featsum_kernel(out_feat, tmp = tmp + temp_mat tl.store(out_feat + (target_id * feat_size) + feat_block_ids, tmp, mask=feat_block_ids < feat_size) + def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, out_feat=None): n_atom, n_nu, n_feat = env.shape (n_pairs,) = pair_first.shape @@ -196,12 +196,14 @@ def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, n_feat, p2_sens_size, p2_feat_size, - dtype=dtype) + dtype=dtype, + ) return out_feat + def featsum(env, sense, pfirst, psecond): - if env.device == torch.device('cpu'): - return featsum_pt(env,sense,pfirst,psecond) + if env.device == torch.device("cpu"): + return featsum_pt(env, sense, pfirst, psecond) pfirst_hold = pfirst argsort, atom2_ids, atom2_starts, psecond, (sense, pfirst) = resort_pairs_cached(psecond, [sense, pfirst]) resort_pairs_cached(pfirst_hold, []) diff --git a/hippynn/custom_kernels/fast_convert.py b/hippynn/custom_kernels/fast_convert.py index 649d4d41..119e5f04 100644 --- a/hippynn/custom_kernels/fast_convert.py +++ b/hippynn/custom_kernels/fast_convert.py @@ -7,6 +7,7 @@ import numpy as np from numba.cuda.cudadrv import devices, devicearray + try: from numba.cuda.api_util import prepare_shape_strides_dtype except ImportError: @@ -34,14 +35,15 @@ torch.int32: " 0 else 0 data = (data_ptr, False) # read-only is false - shape = shape strides = strides - shape, strides, dtype = prepare_shape_strides_dtype( - shape, strides, dtype, order='C') + shape, strides, dtype = prepare_shape_strides_dtype(shape, strides, dtype, order="C") size = driver.memory_size_from_info(shape, strides, dtype.itemsize) devptr = driver.get_devptr_for_active_ctx(data_ptr) - data = driver.MemoryPointer( - current_context(), devptr, size=size, owner=tensor) + data = driver.MemoryPointer(current_context(), devptr, size=size, owner=tensor) stream = 0 # No "Numba default stream", not the CUDA default stream - da = devicearray.DeviceNDArray(shape=shape, strides=strides, - dtype=dtype, gpu_data=data, - stream=stream) + da = devicearray.DeviceNDArray(shape=shape, strides=strides, dtype=dtype, gpu_data=data, stream=stream) out.append(da) return out diff --git a/hippynn/custom_kernels/tensor_wrapper.py b/hippynn/custom_kernels/tensor_wrapper.py index d32f5aeb..ade8ddcf 100644 --- a/hippynn/custom_kernels/tensor_wrapper.py +++ b/hippynn/custom_kernels/tensor_wrapper.py @@ -7,6 +7,7 @@ import torch from .fast_convert import batch_convert_torch_to_numba + def via_numpy(func): """Decorator for piping a function through numpy arrays, and then giving the result back to torch. @@ -23,12 +24,16 @@ def wrapped(*args): return wrapped -def _numba_gpu_not_found(*args,**kwargs): - raise RuntimeError("Error: Numba not configured to run on GPU.\n" - "numba.cuda.is_available() returned False; numba was not able to find a GPU.\n" - "Verify that your numba installation is able to find cuda toolkit, as this \n" - "error condition likely indicates that torch can find the GPU, but numba can't.\n" - "Alternatively, disable custom kernels.") + +def _numba_gpu_not_found(*args, **kwargs): + raise RuntimeError( + "Error: Numba not configured to run on GPU.\n" + "numba.cuda.is_available() returned False; numba was not able to find a GPU.\n" + "Verify that your numba installation is able to find cuda toolkit, as this \n" + "error condition likely indicates that torch can find the GPU, but numba can't.\n" + "Alternatively, disable custom kernels." + ) + class NumbaCompatibleTensorFunction: def __init__(self): diff --git a/hippynn/custom_kernels/test_env_cupy.py b/hippynn/custom_kernels/test_env_cupy.py index c3215f68..638dc642 100644 --- a/hippynn/custom_kernels/test_env_cupy.py +++ b/hippynn/custom_kernels/test_env_cupy.py @@ -1,4 +1,3 @@ - import torch import numba from . import env_cupy @@ -6,8 +5,8 @@ from .test_env_numba import TEST_MEGA_PARAMS, TEST_LARGE_PARAMS, TEST_MEDIUM_PARAMS, TEST_SMALL_PARAMS if __name__ == "__main__": - main(env_cupy.cupy_envsum, - env_cupy.cupy_sensesum, - env_cupy.cupy_featsum, + main( + env_cupy.cupy_envsum, + env_cupy.cupy_sensesum, + env_cupy.cupy_featsum, ) - diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index d1feb78c..06821164 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -402,10 +402,10 @@ def main(env_impl, sense_impl, feat_impl, args=None): # calling without arguments looks for them from command line args = parse_args() - if isinstance(args,dict): + if isinstance(args, dict): from types import SimpleNamespace - args = SimpleNamespace(**args) + args = SimpleNamespace(**args) print("Got args:", args) np.random.seed(args.seed) @@ -480,13 +480,21 @@ def parse_args(): parser.add_argument("--seed", type=int, default=0, help="name for run") parser.add_argument( - "--compare-against", type=str, default="pytorch", help=""" - implementation to compare speed with. Options are: pytorch, numba, cupy, triton""",) + "--compare-against", + type=str, + default="pytorch", + help=""" + implementation to compare speed with. Options are: pytorch, numba, cupy, triton""", + ) parser.add_argument( - "--n_large", type=int, default=0, help=""" + "--n_large", + type=int, + default=0, + help=""" Number of times to check correctness of forward pass. Set this to a large number (e.g. 200) to - stress-test a new implementation against corner-cases.""",) + stress-test a new implementation against corner-cases.""", + ) parser.add_argument("--no-test-cpu", action="store_true", default=False, help="Set to false to skip CPU tests.") parser.add_argument("--no-test-gpu", action="store_true", default=False, help="Set to false to skip GPU tests.") diff --git a/hippynn/custom_kernels/test_env_triton.py b/hippynn/custom_kernels/test_env_triton.py index 4239d2e0..95788068 100644 --- a/hippynn/custom_kernels/test_env_triton.py +++ b/hippynn/custom_kernels/test_env_triton.py @@ -1,12 +1,13 @@ - import torch from .env_triton import envsum, sensesum, featsum from .test_env_numba import Envops_tester, main, get_simulated_data from .test_env_numba import TEST_MEGA_PARAMS, TEST_LARGE_PARAMS, TEST_MEDIUM_PARAMS, TEST_SMALL_PARAMS from .utils import resort_pairs_cached + if __name__ == "__main__": - main(envsum, - sensesum, - featsum, - ) \ No newline at end of file + main( + envsum, + sensesum, + featsum, + ) diff --git a/hippynn/custom_kernels/utils.py b/hippynn/custom_kernels/utils.py index cf704af4..7623977f 100644 --- a/hippynn/custom_kernels/utils.py +++ b/hippynn/custom_kernels/utils.py @@ -4,27 +4,30 @@ import collections from typing import List + @torch.jit.script def get_id_and_starts(key): n_items = key.shape[0] key_diff = key[:-1] - key[1:] - key_start = torch.nonzero(key_diff)[:,0] + 1 - start = torch.zeros((1,),device=key.device,dtype=torch.long) - end = n_items*torch.ones((1,),device=key.device,dtype=torch.long) - key_start = torch.cat([start,key_start,end]) + key_start = torch.nonzero(key_diff)[:, 0] + 1 + start = torch.zeros((1,), device=key.device, dtype=torch.long) + end = n_items * torch.ones((1,), device=key.device, dtype=torch.long) + key_start = torch.cat([start, key_start, end]) key_ids = key[key_start[:-1]] return key_ids, key_start + @torch.jit.script -def resort_pairs(key,others: List[torch.Tensor]): +def resort_pairs(key, others: List[torch.Tensor]): keysort, argsort = torch.sort(key) others = [o[argsort] for o in others] key_ids, key_starts = get_id_and_starts(keysort) return argsort, key_ids, key_starts, keysort, others -class _CacheEntry(): - def __init__(self,key,cache): - # Cache key is stored this way because sometimes + +class _CacheEntry: + def __init__(self, key, cache): + # Cache key is stored this way because sometimes # pytorch 'resurrects' a Tensor from C++ to pytorch # during an autograd calculation as a completely # new pyobject, so id(key) is not a completely safe way @@ -43,35 +46,35 @@ def __init__(self,key,cache): self.key = key self.cache = cache self.computed = False - - def __eq__(self,other): - if not isinstance(other,type(self)): - return False + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False if self.key is other.key: return True if self.cache_key == other.cache_key: # For now raise an error since this feature is experimental. # Most of the benefits seem to work from the `is` check. - if not torch.equal(self.key,other.key): + if not torch.equal(self.key, other.key): raise RuntimeError("Caching by key does not work! Please raise an issue.") - return True + return True return False - + def find(self): if self in self.cache: - other = self.cache[self.cache.index(self)] + other = self.cache[self.cache.index(self)] return other for other in self.cache: # Expensive, but not as expensive as sorting + nonzero, # which is also a blocking operation. - # Note that pytorch is smart enough not to generate + # Note that pytorch is smart enough not to generate # a blocking call unless the actual memory needs to be # checked, i.e. it will skip if arrays have different sizes. - if torch.equal(self.key,other.key): + if torch.equal(self.key, other.key): return other else: - return self - + return self + def compute_and_store(self): if self.computed: # This shouldn't happen! The guard is here because @@ -85,39 +88,47 @@ def compute_and_store(self): self.key_starts = key_starts self.computed = True self.cache.append(self) - - def retrieve(self,others): + + def retrieve(self, others): keysort = self.key[self.argsort] others = [o[self.argsort] for o in others] return self.argsort, self.key_ids, self.key_starts, keysort, others @classmethod - def lookup_key(cls,key,cache): - entry = cls(key,cache) + def lookup_key(cls, key, cache): + entry = cls(key, cache) return entry.find() + N_CACHED_KEYS_PER_DEVICE = 2 + + def _make_cache(): deque = collections.deque(maxlen=N_CACHED_KEYS_PER_DEVICE) lock = threading.Lock() return deque, lock + + # Dict mapping device to key cache info _CACHE_STORE = collections.defaultdict(_make_cache) + def clear_pair_cache(): _CACHE_STORE.clear() + CACHE_LOCK_MISSES = 0 + + def resort_pairs_cached(key, others): global CACHE_LOCK_MISSES deque, lock = _CACHE_STORE[key.device] got_lock = lock.acquire(blocking=False) if got_lock: - entry = _CacheEntry.lookup_key(key,deque) + entry = _CacheEntry.lookup_key(key, deque) if not entry.computed: entry.compute_and_store() lock.release() return entry.retrieve(others) CACHE_LOCK_MISSES += 1 - return resort_pairs(key,others) - + return resort_pairs(key, others) From ecf706b30bc332275515381f6ce5154e4bcc95d2 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Wed, 24 Jul 2024 12:57:19 -0600 Subject: [PATCH 11/24] adjust configparser --- hippynn/_settings_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hippynn/_settings_setup.py b/hippynn/_settings_setup.py index 6869c93f..bd5af2a2 100644 --- a/hippynn/_settings_setup.py +++ b/hippynn/_settings_setup.py @@ -86,7 +86,7 @@ def kernel_handler(kernel_string): rc_name = os.path.expanduser("~/.hippynnrc") if os.path.exists(rc_name) and os.path.isfile(rc_name): - config = configparser.ConfigParser() + config = configparser.ConfigParser(inline_comment_prefixes="#") config.read(rc_name) config_sources["~/.hippynnrc"] = config["GLOBALS"] From 4f9b3895e560b632b05eb909eb68c42b5d7a1813 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Thu, 25 Jul 2024 13:50:54 -0600 Subject: [PATCH 12/24] try refactor triton code (cosmetic) --- hippynn/_settings_setup.py | 2 +- hippynn/custom_kernels/env_triton.py | 83 +++++++++++++++++++--------- 2 files changed, 58 insertions(+), 27 deletions(-) diff --git a/hippynn/_settings_setup.py b/hippynn/_settings_setup.py index bd5af2a2..3c1c5ef2 100644 --- a/hippynn/_settings_setup.py +++ b/hippynn/_settings_setup.py @@ -54,7 +54,7 @@ def kernel_handler(kernel_string): "true": True, }.get(kernel_string, kernel_string) - if kernel not in [True, False, "auto", "cupy", "numba"]: + if kernel not in [True, False, "auto", "triton", "cupy", "numba"]: warnings.warn(f"Unrecognized custom kernel option: {kernel_string}. Setting custom kernels to 'auto'") kernel = "auto" diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index fd7cfa6f..b035f09d 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -22,24 +22,37 @@ def envsum_kernel( dtype: tl.constexpr = tl.float32, ): atom_id = tl.program_id(axis=0) - start = tl.load(atom_starts_ptr + atom_id, mask=atom_id < atom_size, other=0) - end = tl.load(atom_starts_ptr + atom_id + 1, mask=atom_id < atom_size, other=0) - target_id = tl.load(atom_ids_ptr + atom_id, mask=atom_id < atom_size, other=0) + + valid_atom_id = atom_id < atom_size + + start = tl.load(atom_starts_ptr + atom_id, mask=valid_atom_id, other=0) + end = tl.load(atom_starts_ptr + atom_id + 1, mask=valid_atom_id, other=0) + target_id = tl.load(atom_ids_ptr + atom_id, mask=valid_atom_id, other=0) + sens_block_ids = tl.arange(0, p2_sens_size) feat_block_ids = tl.arange(0, p2_feat_size) + + valid_sens = sens_block_ids < sens_size + valid_feat = feat_block_ids < feat_size + valid_env = valid_sens[:, None] & valid_feat[None, :] + tmp = tl.zeros((p2_sens_size, p2_feat_size), dtype=dtype) + for ind in range(start, end): # [p2_sens_size,], coming from the pair sensitivity - s = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=sens_block_ids < sens_size, other=0.0) - pair_ind = tl.load(psecond_ptr + ind) # TODO do we need mask here + s = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=valid_sens, other=0.0) + atom2_id = tl.load(psecond_ptr + ind) # TODO C: do we need mask here # N: I don't think so # [p2_feat_size,], coming from the neighbor feature - feat = tl.load(feat_ptr + (pair_ind * feat_size) + feat_block_ids, mask=feat_block_ids < feat_size, other=0.0) + feat = tl.load(feat_ptr + (atom2_id * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) # temp_mat and tmp is [p2_sens_size, p2_feat_size] temp_mat = s[:, None] * feat[None, :] tmp = tmp + temp_mat - mask = (sens_block_ids[:, None] < sens_size) & (feat_block_ids[None, :] < feat_size) - block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] - tl.store(out_env_ptr + (target_id * sens_size * feat_size) + block_ids, tmp, mask=mask) + + atom_offset = (target_id * sens_size * feat_size) + env_block_id = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + + # TODO: use sparsity of sensitivities to reduce workload? (see numba envsum implementation) + tl.store(out_env_ptr + atom_offset + env_block_id, tmp, mask=valid_env) def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env_fetures=None): @@ -75,7 +88,7 @@ def envsum(sense, features, pfirst, psecond): return envsum_pt(sense, features, pfirst, psecond) psecond_hold = psecond argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) - resort_pairs_cached(psecond_hold, []) + resort_pairs_cached(psecond_hold, []) # Preemptively sort for backwards pass. return envsum_triton(sense, features, pfirst, psecond, atom1_ids, atom1_starts, out_env_fetures=None) @@ -94,16 +107,24 @@ def sensesum_kernel( dtype: tl.constexpr = tl.float32, ): pair_id = tl.program_id(axis=0) - first = tl.load(pfirst_ptr + pair_id, mask=pair_id < pair_size, other=0) - second = tl.load(psecond_ptr + pair_id, mask=pair_id < pair_size, other=0) + valid_pair = pair_id < pair_size + + first = tl.load(pfirst_ptr + pair_id, mask=valid_pair, other=0) + second = tl.load(psecond_ptr + pair_id, mask=valid_pair, other=0) + sens_block_ids = tl.arange(0, p2_sens_size) feat_block_ids = tl.arange(0, p2_feat_size) - mask = (sens_block_ids[:, None] < sens_size) & (feat_block_ids[None, :] < feat_size) + + valid_sens = sens_block_ids < sens_size + valid_feat = feat_block_ids < feat_size + valid_env = valid_sens[:, None] & valid_feat[None, :] + block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] # [p2_sens_size, p2_feat_size] - env = tl.load(env_ptr + (first * sens_size * feat_size) + block_ids, mask=mask, other=0.0) + env = tl.load(env_ptr + (first * sens_size * feat_size) + block_ids, mask=valid_env, other=0.0) # [p2_feat_size, ] - feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, mask=feat_block_ids < feat_size, other=0.0) + feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) + # N: What is going on in this string? """ type_f32: tl.constexpr = tl.float32 type_check: tl.constexpr = (dtype == type_f32) @@ -113,7 +134,8 @@ def sensesum_kernel( res = tl.sum(env * feat[None, :], axis=1) """ res = tl.sum(env * feat[None, :], axis=1) - tl.store(out_sense_ptr + (pair_id * sens_size) + sens_block_ids, res, mask=sens_block_ids < sens_size) + # TODO: use sparsity of sensitivities to reduce workload? (see numba envsum implementation) + tl.store(out_sense_ptr + (pair_id * sens_size) + sens_block_ids, res, mask=valid_sens) def sensesum(env, features, pair_first, pair_second, out_sense=None): @@ -152,24 +174,33 @@ def featsum_kernel( dtype: tl.constexpr = tl.float32, ): atom_id = tl.program_id(axis=0) - start = tl.load(atom2_starts_ptr + atom_id, mask=atom_id < atom_size, other=0) - end = tl.load(atom2_starts_ptr + atom_id + 1, mask=atom_id < atom_size, other=0) - target_id = tl.load(atom2_ids_ptr + atom_id, mask=atom_id < atom_size, other=0) + valid_atom = atom_id < atom_size + + start = tl.load(atom2_starts_ptr + atom_id, mask=valid_atom, other=0) + end = tl.load(atom2_starts_ptr + atom_id + 1, mask=valid_atom, other=0) + target_id = tl.load(atom2_ids_ptr + atom_id, mask=valid_atom, other=0) + sens_block_ids = tl.arange(0, p2_sens_size) feat_block_ids = tl.arange(0, p2_feat_size) + valid_feat = feat_block_ids < feat_size + valid_sens = sens_block_ids < sens_size + + valid_env = valid_sens[:, None] * valid_feat[None, :] + + env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + tmp = tl.zeros((p2_feat_size,), dtype=dtype) + for ind in range(start, end): # [p2_sens_size,], coming from the pair sensitivity - sense = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=sens_block_ids < sens_size, other=0.0) - pair_ind = tl.load(pfirst_ptr + ind) # TODO do we need mask here - mask = (sens_block_ids[:, None] < sens_size) & (feat_block_ids[None, :] < feat_size) - block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + sense = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=valid_sens, other=0.0) + atom1_ind = tl.load(pfirst_ptr + ind) # C: TODO do we need mask here #N: Don't think so # [p2_sens_size, p2_feat_size] - env = tl.load(env_ptr + (pair_ind * sens_size * feat_size) + block_ids, mask=mask, other=0.0) + env = tl.load(env_ptr + (atom1_ind * sens_size * feat_size) + env_block_ids, mask=valid_env, other=0.0) # temp_mat and tmp is [p2_feat_size,] temp_mat = tl.sum(env * sense[:, None], axis=0) tmp = tmp + temp_mat - tl.store(out_feat + (target_id * feat_size) + feat_block_ids, tmp, mask=feat_block_ids < feat_size) + tl.store(out_feat + (target_id * feat_size) + feat_block_ids, tmp, mask=valid_feat) def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, out_feat=None): @@ -206,5 +237,5 @@ def featsum(env, sense, pfirst, psecond): return featsum_pt(env, sense, pfirst, psecond) pfirst_hold = pfirst argsort, atom2_ids, atom2_starts, psecond, (sense, pfirst) = resort_pairs_cached(psecond, [sense, pfirst]) - resort_pairs_cached(pfirst_hold, []) + resort_pairs_cached(pfirst_hold, []) # preemptively sort (probably no-op) return featsum_triton(env, sense, pfirst, psecond, atom2_ids, atom2_starts, out_feat=None) From ab9d47ac4cc72a1628f61804cfbf1347b826661e Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Thu, 25 Jul 2024 13:51:35 -0600 Subject: [PATCH 13/24] formatter --- hippynn/custom_kernels/env_triton.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index b035f09d..095c4c32 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -48,7 +48,7 @@ def envsum_kernel( temp_mat = s[:, None] * feat[None, :] tmp = tmp + temp_mat - atom_offset = (target_id * sens_size * feat_size) + atom_offset = target_id * sens_size * feat_size env_block_id = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] # TODO: use sparsity of sensitivities to reduce workload? (see numba envsum implementation) @@ -88,7 +88,7 @@ def envsum(sense, features, pfirst, psecond): return envsum_pt(sense, features, pfirst, psecond) psecond_hold = psecond argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) - resort_pairs_cached(psecond_hold, []) # Preemptively sort for backwards pass. + resort_pairs_cached(psecond_hold, []) # Preemptively sort for backwards pass. return envsum_triton(sense, features, pfirst, psecond, atom1_ids, atom1_starts, out_env_fetures=None) @@ -237,5 +237,5 @@ def featsum(env, sense, pfirst, psecond): return featsum_pt(env, sense, pfirst, psecond) pfirst_hold = pfirst argsort, atom2_ids, atom2_starts, psecond, (sense, pfirst) = resort_pairs_cached(psecond, [sense, pfirst]) - resort_pairs_cached(pfirst_hold, []) # preemptively sort (probably no-op) + resort_pairs_cached(pfirst_hold, []) # preemptively sort (probably no-op) return featsum_triton(env, sense, pfirst, psecond, atom2_ids, atom2_starts, out_feat=None) From dec9fdd9a015003a4a1cc4d47a0701a5fd56a5a0 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Thu, 25 Jul 2024 14:06:15 -0600 Subject: [PATCH 14/24] add ultra-size test --- hippynn/custom_kernels/test_env_numba.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index 06821164..d55a16e5 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -121,6 +121,7 @@ def get_simulated_data(n_molecules, n_atoms, atom_prob, n_features, n_nu, printi TEST_MEDIUM_PARAMS = dict(n_molecules=100, n_atoms=30, atom_prob=0.7, n_features=20, n_nu=20) TEST_LARGE_PARAMS = dict(n_molecules=1000, n_atoms=30, atom_prob=0.7, n_features=80, n_nu=20) TEST_MEGA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=100) +TEST_ULTRA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=320) # reference implementation @@ -407,7 +408,6 @@ def main(env_impl, sense_impl, feat_impl, args=None): args = SimpleNamespace(**args) - print("Got args:", args) np.random.seed(args.seed) tester = Envops_tester( env_impl, @@ -426,11 +426,18 @@ def main(env_impl, sense_impl, feat_impl, args=None): use_large_gpu = meminfo.free > 2**31 use_verylarge_gpu = meminfo.free > 30 * (2**30) + use_ultra = (not correctness) and use_verylarge_gpu and (compare_against.lower() != "pytorch") + n_large = args.n_large if use_large_gpu else 0 if correctness: tester.check_correctness(device=torch.device("cuda"), n_large=n_large) if use_verylarge_gpu: + if use_ultra: + + print("-" * 80) + print("Ultra systems:", TEST_ULTRA_PARAMS) + tester.check_speed(n_repetitions=20, data_size=TEST_ULTRA_PARAMS, device=torch.device("cuda"), compare_against=compare_against) print("-" * 80) print("Mega systems:", TEST_MEGA_PARAMS) tester.check_speed(n_repetitions=20, data_size=TEST_MEGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against) From 1ecd7589d08cdd19fd5f6a384ec88c1d403d79da Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Thu, 25 Jul 2024 14:58:14 -0600 Subject: [PATCH 15/24] remove explicit numba dependency from custom kernel tests --- hippynn/custom_kernels/test_env_numba.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index d55a16e5..65d75640 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -1,7 +1,6 @@ """ Test module for verifying implementation correctness against pytorch. """ -import numba import numpy as np import torch @@ -11,6 +10,7 @@ import warnings + try: from . import env_numba except ImportError: @@ -383,13 +383,11 @@ def __init__(self): def __enter__(self): if torch.cuda.is_available(): torch.cuda.synchronize() - numba.cuda.synchronize() self.start = time.time() def __exit__(self, exc_type, exc_value, exc_tb): if torch.cuda.is_available(): torch.cuda.synchronize() - numba.cuda.synchronize() self.end = time.time() @property @@ -422,9 +420,10 @@ def main(env_impl, sense_impl, feat_impl, args=None): if torch.cuda.is_available() and not args.no_test_gpu: print("Running GPU tests") - meminfo = numba.cuda.current_context().get_memory_info() - use_large_gpu = meminfo.free > 2**31 - use_verylarge_gpu = meminfo.free > 30 * (2**30) + free_mem, total_mem = torch.cuda.memory.mem_get_info() + + use_large_gpu = free_mem > 2**31 + use_verylarge_gpu = free_mem > 30 * (2**30) use_ultra = (not correctness) and use_verylarge_gpu and (compare_against.lower() != "pytorch") @@ -434,7 +433,6 @@ def main(env_impl, sense_impl, feat_impl, args=None): if use_verylarge_gpu: if use_ultra: - print("-" * 80) print("Ultra systems:", TEST_ULTRA_PARAMS) tester.check_speed(n_repetitions=20, data_size=TEST_ULTRA_PARAMS, device=torch.device("cuda"), compare_against=compare_against) @@ -445,6 +443,7 @@ def main(env_impl, sense_impl, feat_impl, args=None): print("Numba indicates less than 30GB free GPU memory -- skipping mega system test") if use_large_gpu: print("-" * 80) + print("Large systems:", TEST_LARGE_PARAMS) tester.check_speed(n_repetitions=20, data_size=TEST_LARGE_PARAMS, device=torch.device("cuda"), compare_against=compare_against) else: print("Numba indicates less than 2GB free GPU memory -- skipping large system test") From 4aab176c9945f1dd0a6c3a0c1125778ffad9e924 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Thu, 25 Jul 2024 15:04:51 -0600 Subject: [PATCH 16/24] update triton to use numba on CPU --- hippynn/custom_kernels/env_triton.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index 095c4c32..2498e556 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -3,8 +3,14 @@ import triton.language as tl from .utils import resort_pairs_cached -from .env_pytorch import envsum as envsum_pt, sensesum as sensesum_pt, featsum as featsum_pt +# Load backup implementation for CPU tensors. +from .env_pytorch import envsum as envsum_alternative, sensesum as sensesum_alternative, featsum as featsum_alternative +# If numba is available, this implementation will default to numba on CPU. If not, use vanilla pytorch. +try: + from .env_numba import new_envsum as envsum_alternative, new_sensesum as sensesum_alternative, new_featsum as featsum_alternative +except ImportError: + pass @triton.jit def envsum_kernel( @@ -85,7 +91,7 @@ def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, at def envsum(sense, features, pfirst, psecond): if sense.device == torch.device("cpu"): - return envsum_pt(sense, features, pfirst, psecond) + return envsum_alternative(sense, features, pfirst, psecond) psecond_hold = psecond argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) resort_pairs_cached(psecond_hold, []) # Preemptively sort for backwards pass. @@ -124,7 +130,7 @@ def sensesum_kernel( env = tl.load(env_ptr + (first * sens_size * feat_size) + block_ids, mask=valid_env, other=0.0) # [p2_feat_size, ] feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) - # N: What is going on in this string? + # TODO N: What is going on in this string? """ type_f32: tl.constexpr = tl.float32 type_check: tl.constexpr = (dtype == type_f32) @@ -140,7 +146,7 @@ def sensesum_kernel( def sensesum(env, features, pair_first, pair_second, out_sense=None): if env.device == torch.device("cpu"): - return sensesum_pt(env, features, pair_first, pair_second) + return sensesum_alternative(env, features, pair_first, pair_second) _, n_nu, _ = env.shape n_atom, n_feat = features.shape n_pairs = len(pair_first) @@ -234,7 +240,7 @@ def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, def featsum(env, sense, pfirst, psecond): if env.device == torch.device("cpu"): - return featsum_pt(env, sense, pfirst, psecond) + return featsum_alternative(env, sense, pfirst, psecond) pfirst_hold = pfirst argsort, atom2_ids, atom2_starts, psecond, (sense, pfirst) = resort_pairs_cached(psecond, [sense, pfirst]) resort_pairs_cached(pfirst_hold, []) # preemptively sort (probably no-op) From fc7cb0aae8f513df66d66e00991e2256976466bd Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Thu, 25 Jul 2024 15:13:16 -0600 Subject: [PATCH 17/24] more formatting and name changes --- hippynn/custom_kernels/env_triton.py | 33 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index 2498e556..fa178bc3 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -37,6 +37,7 @@ def envsum_kernel( sens_block_ids = tl.arange(0, p2_sens_size) feat_block_ids = tl.arange(0, p2_feat_size) + env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] valid_sens = sens_block_ids < sens_size valid_feat = feat_block_ids < feat_size @@ -55,25 +56,24 @@ def envsum_kernel( tmp = tmp + temp_mat atom_offset = target_id * sens_size * feat_size - env_block_id = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] # TODO: use sparsity of sensitivities to reduce workload? (see numba envsum implementation) - tl.store(out_env_ptr + atom_offset + env_block_id, tmp, mask=valid_env) + tl.store(out_env_ptr + atom_offset + env_block_ids, tmp, mask=valid_env) -def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env_fetures=None): +def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env=None): n_pairs, n_nu = sensitivities.shape n_atom, n_feat = features.shape (n_atom_with_pairs,) = atom_ids.shape - if out_env_fetures == None: - out_env_fetures = torch.zeros((n_atom, n_nu, n_feat), dtype=features.dtype, device=features.device) + if out_env is None: + out_env = torch.zeros((n_atom, n_nu, n_feat), dtype=features.dtype, device=features.device) dtype = tl.float32 if features.dtype == torch.float64: dtype = tl.float64 p2_sens_size = triton.next_power_of_2(n_nu) p2_feat_size = triton.next_power_of_2(n_feat) envsum_kernel[(n_atom_with_pairs,)]( - out_env_fetures, + out_env, sensitivities, features, pair_second, @@ -86,7 +86,7 @@ def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, at p2_feat_size, dtype=dtype, ) - return out_env_fetures + return out_env def envsum(sense, features, pfirst, psecond): @@ -95,7 +95,7 @@ def envsum(sense, features, pfirst, psecond): psecond_hold = psecond argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) resort_pairs_cached(psecond_hold, []) # Preemptively sort for backwards pass. - return envsum_triton(sense, features, pfirst, psecond, atom1_ids, atom1_starts, out_env_fetures=None) + return envsum_triton(sense, features, pfirst, psecond, atom1_ids, atom1_starts, out_env=None) @triton.jit @@ -120,14 +120,14 @@ def sensesum_kernel( sens_block_ids = tl.arange(0, p2_sens_size) feat_block_ids = tl.arange(0, p2_feat_size) + env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] valid_sens = sens_block_ids < sens_size valid_feat = feat_block_ids < feat_size valid_env = valid_sens[:, None] & valid_feat[None, :] - block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] # [p2_sens_size, p2_feat_size] - env = tl.load(env_ptr + (first * sens_size * feat_size) + block_ids, mask=valid_env, other=0.0) + env = tl.load(env_ptr + (first * sens_size * feat_size) + env_block_ids, mask=valid_env, other=0.0) # [p2_feat_size, ] feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) # TODO N: What is going on in this string? @@ -150,7 +150,7 @@ def sensesum(env, features, pair_first, pair_second, out_sense=None): _, n_nu, _ = env.shape n_atom, n_feat = features.shape n_pairs = len(pair_first) - if out_sense == None: + if out_sense is None: out_sense = torch.zeros((n_pairs, n_nu), dtype=features.dtype, device=features.device) dtype = tl.float32 if features.dtype == torch.float64: @@ -188,13 +188,12 @@ def featsum_kernel( sens_block_ids = tl.arange(0, p2_sens_size) feat_block_ids = tl.arange(0, p2_feat_size) - valid_feat = feat_block_ids < feat_size - valid_sens = sens_block_ids < sens_size - - valid_env = valid_sens[:, None] * valid_feat[None, :] - env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + valid_sens = sens_block_ids < sens_size + valid_feat = feat_block_ids < feat_size + valid_env = valid_sens[:, None] & valid_feat[None, :] + tmp = tl.zeros((p2_feat_size,), dtype=dtype) for ind in range(start, end): @@ -213,7 +212,7 @@ def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, n_atom, n_nu, n_feat = env.shape (n_pairs,) = pair_first.shape (n_atoms_with_pairs,) = atom2_ids.shape - if out_feat == None: + if out_feat is None: out_feat = torch.zeros((n_atom, n_feat), dtype=env.dtype, device=env.device) dtype = tl.float32 if env.dtype == torch.float64: From d1c4e372a6279cb960a7fe10401f496a0f0e1fd0 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Thu, 25 Jul 2024 15:19:11 -0600 Subject: [PATCH 18/24] fix lack of forward correctness checks! --- hippynn/custom_kernels/env_triton.py | 1 + hippynn/custom_kernels/test_env_numba.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index fa178bc3..e7071e6e 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -12,6 +12,7 @@ except ImportError: pass + @triton.jit def envsum_kernel( out_env_ptr, diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index 65d75640..8a910163 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -426,8 +426,8 @@ def main(env_impl, sense_impl, feat_impl, args=None): use_verylarge_gpu = free_mem > 30 * (2**30) use_ultra = (not correctness) and use_verylarge_gpu and (compare_against.lower() != "pytorch") - - n_large = args.n_large if use_large_gpu else 0 + + n_large = args.n_large if use_large_gpu else 5 if correctness: tester.check_correctness(device=torch.device("cuda"), n_large=n_large) @@ -435,7 +435,9 @@ def main(env_impl, sense_impl, feat_impl, args=None): if use_ultra: print("-" * 80) print("Ultra systems:", TEST_ULTRA_PARAMS) - tester.check_speed(n_repetitions=20, data_size=TEST_ULTRA_PARAMS, device=torch.device("cuda"), compare_against=compare_against) + tester.check_speed( + n_repetitions=20, data_size=TEST_ULTRA_PARAMS, device=torch.device("cuda"), compare_against=compare_against + ) print("-" * 80) print("Mega systems:", TEST_MEGA_PARAMS) tester.check_speed(n_repetitions=20, data_size=TEST_MEGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against) From bd5c7df7a54124eb361728d7890ff6c882d2007e Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Thu, 25 Jul 2024 15:25:11 -0600 Subject: [PATCH 19/24] actually do what the last commit says --- hippynn/custom_kernels/test_env_numba.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index 8a910163..58a0cb3e 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -427,9 +427,10 @@ def main(env_impl, sense_impl, feat_impl, args=None): use_ultra = (not correctness) and use_verylarge_gpu and (compare_against.lower() != "pytorch") - n_large = args.n_large if use_large_gpu else 5 + n_large_gpu = args.n_large if use_large_gpu else 0 + if correctness: - tester.check_correctness(device=torch.device("cuda"), n_large=n_large) + tester.check_correctness(device=torch.device("cuda"), n_large=n_large_gpu) if use_verylarge_gpu: if use_ultra: @@ -465,8 +466,8 @@ def main(env_impl, sense_impl, feat_impl, args=None): if test_cpu: print("Running CPU tests") - if test_gpu: - tester.check_correctness() + if correctness: + tester.check_correctness(n_large=args.n_large) print("-" * 80) print("Large systems:", TEST_LARGE_PARAMS) @@ -498,7 +499,7 @@ def parse_args(): parser.add_argument( "--n_large", type=int, - default=0, + default=5, help=""" Number of times to check correctness of forward pass. Set this to a large number (e.g. 200) to stress-test a new implementation against corner-cases.""", From 286b3795a2114b01678b921b702a72916a0ad0b0 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Thu, 25 Jul 2024 17:32:23 -0600 Subject: [PATCH 20/24] update for todos --- hippynn/custom_kernels/env_pytorch.py | 1 - hippynn/custom_kernels/env_triton.py | 9 ++++++--- hippynn/custom_kernels/test_env_numba.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/hippynn/custom_kernels/env_pytorch.py b/hippynn/custom_kernels/env_pytorch.py index 5a37b33c..d802f386 100644 --- a/hippynn/custom_kernels/env_pytorch.py +++ b/hippynn/custom_kernels/env_pytorch.py @@ -22,7 +22,6 @@ def sensesum(env, features, pair_first, pair_second): sense = (pair_env * pair_feat.unsqueeze(1)).sum(dim=2) return sense - def featsum(env, sense, pair_first, pair_second): n_atoms, n_nu, n_feat = env.shape pair_env = env[pair_first] diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index e7071e6e..5c57bdaf 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -49,7 +49,7 @@ def envsum_kernel( for ind in range(start, end): # [p2_sens_size,], coming from the pair sensitivity s = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=valid_sens, other=0.0) - atom2_id = tl.load(psecond_ptr + ind) # TODO C: do we need mask here # N: I don't think so + atom2_id = tl.load(psecond_ptr + ind) # [p2_feat_size,], coming from the neighbor feature feat = tl.load(feat_ptr + (atom2_id * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) # temp_mat and tmp is [p2_sens_size, p2_feat_size] @@ -131,7 +131,9 @@ def sensesum_kernel( env = tl.load(env_ptr + (first * sens_size * feat_size) + env_block_ids, mask=valid_env, other=0.0) # [p2_feat_size, ] feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) - # TODO N: What is going on in this string? + # TODO: Here we use outer product followed by sum b/c built-in triton dot needs batches and FP<64. + # Can we make this better then? + # For future reference: """ type_f32: tl.constexpr = tl.float32 type_check: tl.constexpr = (dtype == type_f32) @@ -140,6 +142,7 @@ def sensesum_kernel( else: res = tl.sum(env * feat[None, :], axis=1) """ + res = tl.sum(env * feat[None, :], axis=1) # TODO: use sparsity of sensitivities to reduce workload? (see numba envsum implementation) tl.store(out_sense_ptr + (pair_id * sens_size) + sens_block_ids, res, mask=valid_sens) @@ -200,7 +203,7 @@ def featsum_kernel( for ind in range(start, end): # [p2_sens_size,], coming from the pair sensitivity sense = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=valid_sens, other=0.0) - atom1_ind = tl.load(pfirst_ptr + ind) # C: TODO do we need mask here #N: Don't think so + atom1_ind = tl.load(pfirst_ptr + ind) # [p2_sens_size, p2_feat_size] env = tl.load(env_ptr + (atom1_ind * sens_size * feat_size) + env_block_ids, mask=valid_env, other=0.0) # temp_mat and tmp is [p2_feat_size,] diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index 58a0cb3e..d9a117c1 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -428,7 +428,7 @@ def main(env_impl, sense_impl, feat_impl, args=None): use_ultra = (not correctness) and use_verylarge_gpu and (compare_against.lower() != "pytorch") n_large_gpu = args.n_large if use_large_gpu else 0 - + if correctness: tester.check_correctness(device=torch.device("cuda"), n_large=n_large_gpu) From 4265297a97c6e667ef340bd83c36c5f92b17f61c Mon Sep 17 00:00:00 2001 From: Mehmet Cagri Kaymak Date: Thu, 25 Jul 2024 22:36:48 -0600 Subject: [PATCH 21/24] split feat. and sense vectors to chunks to lower register pressure, add autotune --- hippynn/custom_kernels/env_triton.py | 228 +++++++++++++++++++-------- 1 file changed, 160 insertions(+), 68 deletions(-) diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index 5c57bdaf..c7112440 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -2,6 +2,7 @@ import triton import triton.language as tl from .utils import resort_pairs_cached +import math # Load backup implementation for CPU tensors. from .env_pytorch import envsum as envsum_alternative, sensesum as sensesum_alternative, featsum as featsum_alternative @@ -12,7 +13,77 @@ except ImportError: pass - +def config_pruner(configs, kwargs): + ''' + Trims the unnecessary config options based on the sens. and feat. sizes + ''' + p2_sens_size = triton.next_power_of_2(kwargs["sens_size"]) + p2_feat_size = triton.next_power_of_2(kwargs["feat_size"]) + + used = set() + for config in configs: + sense_block_size = min(p2_sens_size, config.kwargs["SENS_BLOCK_SIZE"]) + feat_block_size = min(p2_feat_size, config.kwargs["FEAT_BLOCK_SIZE"]) + + if (sense_block_size, + feat_block_size, + config.num_stages, + config.num_warps) in used: + continue + used.add((sense_block_size, + feat_block_size, + config.num_stages, + config.num_warps)) + yield triton.Config( + { + "SENS_BLOCK_SIZE": sense_block_size, + "FEAT_BLOCK_SIZE": feat_block_size, + }, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) + +def get_autotune_config(): + ''' + Create a list of config options for the kernels + TODO: Need to spend time actually figuring out more reasonable options + targeted for modern GPUs + ''' + return [ + triton.Config({'SENS_BLOCK_SIZE': 16, 'FEAT_BLOCK_SIZE': 16}), + triton.Config({'SENS_BLOCK_SIZE': 16, 'FEAT_BLOCK_SIZE': 32}), + triton.Config({'SENS_BLOCK_SIZE': 16, 'FEAT_BLOCK_SIZE': 64}), + triton.Config({'SENS_BLOCK_SIZE': 16, 'FEAT_BLOCK_SIZE': 128}), + triton.Config({'SENS_BLOCK_SIZE': 16, 'FEAT_BLOCK_SIZE': 256}), + + triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 32}), + triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 64}), + triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 128}), + triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 128}, num_warps=8), + + triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 256}), + triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 256}, num_warps=8), + + + triton.Config({'SENS_BLOCK_SIZE': 64, 'FEAT_BLOCK_SIZE': 32}), + triton.Config({'SENS_BLOCK_SIZE': 64, 'FEAT_BLOCK_SIZE': 64}), + triton.Config({'SENS_BLOCK_SIZE': 64, 'FEAT_BLOCK_SIZE': 128}), + triton.Config({'SENS_BLOCK_SIZE': 64, 'FEAT_BLOCK_SIZE': 256}), + + triton.Config({'SENS_BLOCK_SIZE': 128, 'FEAT_BLOCK_SIZE': 32}), + triton.Config({'SENS_BLOCK_SIZE': 128, 'FEAT_BLOCK_SIZE': 64}), + triton.Config({'SENS_BLOCK_SIZE': 128, 'FEAT_BLOCK_SIZE': 64}, num_warps=8), + + triton.Config({'SENS_BLOCK_SIZE': 256, 'FEAT_BLOCK_SIZE': 32}), + triton.Config({'SENS_BLOCK_SIZE': 256, 'FEAT_BLOCK_SIZE': 64}), + triton.Config({'SENS_BLOCK_SIZE': 256, 'FEAT_BLOCK_SIZE': 64}, num_warps=8), + ] + +@triton.autotune( + configs=get_autotune_config(), + key=['sens_size', 'feat_size'], + prune_configs_by={ "early_config_prune": config_pruner} +) @triton.jit def envsum_kernel( out_env_ptr, @@ -24,11 +95,14 @@ def envsum_kernel( atom_size, sens_size: tl.constexpr, feat_size: tl.constexpr, - p2_sens_size: tl.constexpr, - p2_feat_size: tl.constexpr, + SENS_BLOCK_SIZE: tl.constexpr, + FEAT_BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr = tl.float32, ): atom_id = tl.program_id(axis=0) + sens_id = tl.program_id(axis=1) + feat_id = tl.program_id(axis=2) valid_atom_id = atom_id < atom_size @@ -36,23 +110,23 @@ def envsum_kernel( end = tl.load(atom_starts_ptr + atom_id + 1, mask=valid_atom_id, other=0) target_id = tl.load(atom_ids_ptr + atom_id, mask=valid_atom_id, other=0) - sens_block_ids = tl.arange(0, p2_sens_size) - feat_block_ids = tl.arange(0, p2_feat_size) + sens_block_ids = tl.arange(0, SENS_BLOCK_SIZE) + (sens_id * SENS_BLOCK_SIZE) + feat_block_ids = tl.arange(0, FEAT_BLOCK_SIZE) + (feat_id * FEAT_BLOCK_SIZE) env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] valid_sens = sens_block_ids < sens_size valid_feat = feat_block_ids < feat_size valid_env = valid_sens[:, None] & valid_feat[None, :] - tmp = tl.zeros((p2_sens_size, p2_feat_size), dtype=dtype) + tmp = tl.zeros((SENS_BLOCK_SIZE, FEAT_BLOCK_SIZE), dtype=dtype) for ind in range(start, end): - # [p2_sens_size,], coming from the pair sensitivity + # [SENS_BLOCK_SIZE,], coming from the pair sensitivity s = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=valid_sens, other=0.0) atom2_id = tl.load(psecond_ptr + ind) - # [p2_feat_size,], coming from the neighbor feature + # [FEAT_BLOCK_SIZE,], coming from the neighbor feature feat = tl.load(feat_ptr + (atom2_id * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) - # temp_mat and tmp is [p2_sens_size, p2_feat_size] + # temp_mat and tmp is [SENS_BLOCK_SIZE, FEAT_BLOCK_SIZE] temp_mat = s[:, None] * feat[None, :] tmp = tmp + temp_mat @@ -61,7 +135,6 @@ def envsum_kernel( # TODO: use sparsity of sensitivities to reduce workload? (see numba envsum implementation) tl.store(out_env_ptr + atom_offset + env_block_ids, tmp, mask=valid_env) - def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env=None): n_pairs, n_nu = sensitivities.shape n_atom, n_feat = features.shape @@ -71,9 +144,10 @@ def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, at dtype = tl.float32 if features.dtype == torch.float64: dtype = tl.float64 - p2_sens_size = triton.next_power_of_2(n_nu) - p2_feat_size = triton.next_power_of_2(n_feat) - envsum_kernel[(n_atom_with_pairs,)]( + + grid = lambda META: (n_atom_with_pairs, triton.cdiv(n_nu, META['SENS_BLOCK_SIZE']), triton.cdiv(n_feat, META['FEAT_BLOCK_SIZE'])) + + envsum_kernel[grid]( out_env, sensitivities, features, @@ -83,10 +157,10 @@ def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, at n_atom_with_pairs, n_nu, n_feat, - p2_sens_size, - p2_feat_size, dtype=dtype, ) + #print('best config') + #print(envsum_kernel.best_config) return out_env @@ -98,7 +172,11 @@ def envsum(sense, features, pfirst, psecond): resort_pairs_cached(psecond_hold, []) # Preemptively sort for backwards pass. return envsum_triton(sense, features, pfirst, psecond, atom1_ids, atom1_starts, out_env=None) - +@triton.autotune( + configs=get_autotune_config(), + key=['sens_size', 'feat_size'], + prune_configs_by={ "early_config_prune": config_pruner} +) @triton.jit def sensesum_kernel( out_sense_ptr, @@ -109,43 +187,49 @@ def sensesum_kernel( pair_size, sens_size: tl.constexpr, feat_size: tl.constexpr, - p2_sens_size: tl.constexpr, - p2_feat_size: tl.constexpr, + SENS_BLOCK_SIZE: tl.constexpr, + FEAT_BLOCK_SIZE: tl.constexpr, dtype: tl.constexpr = tl.float32, ): pair_id = tl.program_id(axis=0) + sense_id = tl.program_id(axis=1) + num_feat_blocks: tl.constexpr = tl.cdiv(feat_size, FEAT_BLOCK_SIZE) valid_pair = pair_id < pair_size first = tl.load(pfirst_ptr + pair_id, mask=valid_pair, other=0) second = tl.load(psecond_ptr + pair_id, mask=valid_pair, other=0) - sens_block_ids = tl.arange(0, p2_sens_size) - feat_block_ids = tl.arange(0, p2_feat_size) - env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + sens_block_ids = tl.arange(0, SENS_BLOCK_SIZE) + (sense_id * SENS_BLOCK_SIZE) + feat_block_ids = tl.arange(0, FEAT_BLOCK_SIZE) valid_sens = sens_block_ids < sens_size - valid_feat = feat_block_ids < feat_size - valid_env = valid_sens[:, None] & valid_feat[None, :] - - # [p2_sens_size, p2_feat_size] - env = tl.load(env_ptr + (first * sens_size * feat_size) + env_block_ids, mask=valid_env, other=0.0) - # [p2_feat_size, ] - feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) - # TODO: Here we use outer product followed by sum b/c built-in triton dot needs batches and FP<64. - # Can we make this better then? - # For future reference: - """ - type_f32: tl.constexpr = tl.float32 - type_check: tl.constexpr = (dtype == type_f32) - if type_check: - res = tl.dot(env, feat[:, None]) - else: - res = tl.sum(env * feat[None, :], axis=1) - """ - - res = tl.sum(env * feat[None, :], axis=1) + + + tmp = tl.zeros((SENS_BLOCK_SIZE,), dtype=dtype) + for feat_id in range(num_feat_blocks): + valid_feat = feat_block_ids < feat_size + env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + valid_env = valid_sens[:, None] & valid_feat[None, :] + # [SENS_BLOCK_SIZE, FEAT_BLOCK_SIZE] + env = tl.load(env_ptr + (first * sens_size * feat_size) + env_block_ids, mask=valid_env, other=0.0) + # [FEAT_BLOCK_SIZE, ] + feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) + # TODO: Here we use outer product followed by sum b/c built-in triton dot needs batches and FP<64. + # Can we make this better then? + # For future reference: + """ + type_f32: tl.constexpr = tl.float32 + type_check: tl.constexpr = (dtype == type_f32) + if type_check: + res = tl.dot(env, feat[:, None]) + else: + res = tl.sum(env * feat[None, :], axis=1) + """ + tmp += tl.sum(env * feat[None, :], axis=1) + # increment the feat block id + feat_block_ids += FEAT_BLOCK_SIZE # TODO: use sparsity of sensitivities to reduce workload? (see numba envsum implementation) - tl.store(out_sense_ptr + (pair_id * sens_size) + sens_block_ids, res, mask=valid_sens) + tl.store(out_sense_ptr + (pair_id * sens_size) + sens_block_ids, tmp, mask=valid_sens) def sensesum(env, features, pair_first, pair_second, out_sense=None): @@ -159,14 +243,18 @@ def sensesum(env, features, pair_first, pair_second, out_sense=None): dtype = tl.float32 if features.dtype == torch.float64: dtype = tl.float64 - p2_sens_size = triton.next_power_of_2(n_nu) - p2_feat_size = triton.next_power_of_2(n_feat) - sensesum_kernel[(n_pairs,)]( - out_sense, env, features, pair_first, pair_second, n_pairs, n_nu, n_feat, p2_sens_size, p2_feat_size, dtype=dtype + + grid = lambda META: (n_pairs, triton.cdiv(n_nu, META['SENS_BLOCK_SIZE'])) + sensesum_kernel[grid]( + out_sense, env, features, pair_first, pair_second, n_pairs, n_nu, n_feat, dtype=dtype ) return out_sense - +@triton.autotune( + configs=get_autotune_config(), + key=['sens_size', 'feat_size'], + prune_configs_by={ "early_config_prune": config_pruner} +) @triton.jit def featsum_kernel( out_feat, @@ -179,36 +267,42 @@ def featsum_kernel( atom_size, sens_size: tl.constexpr, feat_size: tl.constexpr, - p2_sens_size: tl.constexpr, - p2_feat_size: tl.constexpr, + SENS_BLOCK_SIZE: tl.constexpr, + FEAT_BLOCK_SIZE: tl.constexpr, dtype: tl.constexpr = tl.float32, ): atom_id = tl.program_id(axis=0) + feat_id = tl.program_id(axis=1) + num_sense_blocks: tl.constexpr = tl.cdiv(sens_size, SENS_BLOCK_SIZE) valid_atom = atom_id < atom_size start = tl.load(atom2_starts_ptr + atom_id, mask=valid_atom, other=0) end = tl.load(atom2_starts_ptr + atom_id + 1, mask=valid_atom, other=0) target_id = tl.load(atom2_ids_ptr + atom_id, mask=valid_atom, other=0) - sens_block_ids = tl.arange(0, p2_sens_size) - feat_block_ids = tl.arange(0, p2_feat_size) - env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] - valid_sens = sens_block_ids < sens_size + feat_block_ids = tl.arange(0, FEAT_BLOCK_SIZE) + (feat_id * FEAT_BLOCK_SIZE) + valid_feat = feat_block_ids < feat_size - valid_env = valid_sens[:, None] & valid_feat[None, :] - tmp = tl.zeros((p2_feat_size,), dtype=dtype) + tmp = tl.zeros((FEAT_BLOCK_SIZE,), dtype=dtype) for ind in range(start, end): - # [p2_sens_size,], coming from the pair sensitivity - sense = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=valid_sens, other=0.0) - atom1_ind = tl.load(pfirst_ptr + ind) - # [p2_sens_size, p2_feat_size] - env = tl.load(env_ptr + (atom1_ind * sens_size * feat_size) + env_block_ids, mask=valid_env, other=0.0) - # temp_mat and tmp is [p2_feat_size,] - temp_mat = tl.sum(env * sense[:, None], axis=0) - tmp = tmp + temp_mat + sens_block_ids = tl.arange(0, SENS_BLOCK_SIZE) + for sense_id in range(num_sense_blocks): + valid_sens = sens_block_ids < sens_size + # [SENS_BLOCK_SIZE,], coming from the pair sensitivity + sense = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=valid_sens, other=0.0) + atom1_ind = tl.load(pfirst_ptr + ind) + # [SENS_BLOCK_SIZE, FEAT_BLOCK_SIZE] + env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + valid_env = valid_sens[:, None] & valid_feat[None, :] + env = tl.load(env_ptr + (atom1_ind * sens_size * feat_size) + env_block_ids, mask=valid_env, other=0.0) + # temp_mat and tmp is [FEAT_BLOCK_SIZE,] + temp_mat = tl.sum(env * sense[:, None], axis=0) + tmp = tmp + temp_mat + # increment the sense block id + sens_block_ids += SENS_BLOCK_SIZE tl.store(out_feat + (target_id * feat_size) + feat_block_ids, tmp, mask=valid_feat) @@ -221,9 +315,9 @@ def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, dtype = tl.float32 if env.dtype == torch.float64: dtype = tl.float64 - p2_sens_size = triton.next_power_of_2(n_nu) - p2_feat_size = triton.next_power_of_2(n_feat) - featsum_kernel[(n_atoms_with_pairs,)]( + grid = lambda META: (n_atoms_with_pairs, triton.cdiv(n_feat, META['FEAT_BLOCK_SIZE'])) + + featsum_kernel[grid]( out_feat, env, sense, @@ -234,8 +328,6 @@ def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, n_atoms_with_pairs, n_nu, n_feat, - p2_sens_size, - p2_feat_size, dtype=dtype, ) return out_feat From 2a7fa6dffcc5809a4f90b9df4f275be098928064 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Fri, 26 Jul 2024 18:28:12 -0600 Subject: [PATCH 22/24] raise better errors in old custom kernels. partially update documentation --- AUTHORS.txt | 1 + CHANGELOG.rst | 17 ++++ hippynn/custom_kernels/env_cupy.py | 10 ++ hippynn/custom_kernels/env_numba.py | 7 ++ hippynn/custom_kernels/env_triton.py | 136 ++++++++++++--------------- 5 files changed, 96 insertions(+), 75 deletions(-) diff --git a/AUTHORS.txt b/AUTHORS.txt index a5d6d24b..0379e2ef 100644 --- a/AUTHORS.txt +++ b/AUTHORS.txt @@ -18,6 +18,7 @@ Sakib Matin (LANL) Emily Shinkle (LANL) Michael G. Taylor (LANL) Jan Janssen (LANL) +Cagri Kaymak (LANL) Also thanks to testing and feedback from: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a1d0065f..9f033278 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,20 @@ + + +Breaking changes: +----------------- + +New Features: +------------- + +- Added a new custom cuda kernel implementation using triton. These are highly performant and now the default implementation. + +Improvements: +------------- + + +Bug Fixes: +---------- + 0.0.3 ======= diff --git a/hippynn/custom_kernels/env_cupy.py b/hippynn/custom_kernels/env_cupy.py index bd0048c0..38c0ad49 100644 --- a/hippynn/custom_kernels/env_cupy.py +++ b/hippynn/custom_kernels/env_cupy.py @@ -177,6 +177,10 @@ def __call__(self, sense, feat, pfirst, psecond): array_args = sense, feat, psecond, atom1_ids, atom1_starts, env_out shape_args = n_nu, n_feat, n_interact + if n_feat > 512: + raise ValueError(f"Numba GPU custom kernels are not compatible with feature sizes greater than 512 (got {n_feat})") + + TPB_MAX = 512 TPB_X = n_feat TPB_Y = TPB_MAX // n_feat @@ -205,6 +209,9 @@ def __call__(self, env, feat, pfirst, psecond): array_args = env, feat, pfirst, psecond, sense_out shape_args = n_pairs, n_nu, n_feat + if n_nu > 512: + raise ValueError(f"Numba GPU custom kernels are not compatible with sensitivity sizes greater than 512 (got {n_nu})") + TPB_MAX = 512 TPB_Y = n_nu TPB_X = TPB_MAX // TPB_Y @@ -235,6 +242,9 @@ def __call__(self, env, sense, pfirst, psecond): array_args = env, sense, pfirst, atom2_ids, atom2_starts, feat_out shape_args = n_nu, n_feat, n_interact + if n_feat > 512: + raise ValueError(f"Cupy GPU custom kernels are not compatible with feature sizes greater than 512 (got {n_feat})") + TPB_max = 512 if n_feat > 32: TPB_x = ((n_feat + 31) // 32) * 32 diff --git a/hippynn/custom_kernels/env_numba.py b/hippynn/custom_kernels/env_numba.py index f2fe918d..1b840767 100644 --- a/hippynn/custom_kernels/env_numba.py +++ b/hippynn/custom_kernels/env_numba.py @@ -38,6 +38,8 @@ def out_shape(self, sense_shape, feat_shape, *other_shapes): def launch_bounds(self, sense_shape, fs, pfs, pss, atom1_ids_shape, *other_shapes): n_pairs, n_nu = sense_shape n_atom, n_feat = fs + if n_feat > 512: + raise ValueError(f"Numba GPU custom kernels are not compatible with feature sizes greater than 512 (got {n_feat})") (n_atoms_interacting,) = atom1_ids_shape TPB_MAX = 512 TPB_X = n_feat @@ -119,6 +121,9 @@ def out_shape(self, env_shape, feat_shape, pfirst_shape, psecond_shape): def launch_bounds(self, env_shape, feat_shape, pfirst_shape, psecond_shape): (n_pairs,) = pfirst_shape n_atoms, n_nu, n_feat = env_shape + if n_nu > 512: + raise ValueError(f"Numba GPU custom kernels are not compatible with sensitivity sizes greater than 512 (got {n_nu})") + TPB_MAX = 512 TPB_Y = n_nu TPB_X = TPB_MAX // TPB_Y @@ -191,6 +196,8 @@ def launch_bounds(self, env_shape, sense_shape, pfirst_shape, psecond_shape, ato n_pairs, n_nu = sense_shape n_atom, n_nu, n_feat = env_shape TPB_max = 512 + if n_feat > 512: + raise ValueError(f"Numba GPU custom kernels are not compatible with feature sizes greater than 512 (got {n_feat})") if n_feat > 32: TPB_x = ((n_feat + 31) // 32) * 32 TPB_y = TPB_max // TPB_x diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index c7112440..db8665de 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -2,88 +2,77 @@ import triton import triton.language as tl from .utils import resort_pairs_cached -import math - -# Load backup implementation for CPU tensors. -from .env_pytorch import envsum as envsum_alternative, sensesum as sensesum_alternative, featsum as featsum_alternative # If numba is available, this implementation will default to numba on CPU. If not, use vanilla pytorch. try: from .env_numba import new_envsum as envsum_alternative, new_sensesum as sensesum_alternative, new_featsum as featsum_alternative except ImportError: - pass + # Load backup implementation for CPU tensors. + from .env_pytorch import envsum as envsum_alternative, sensesum as sensesum_alternative, featsum as featsum_alternative + def config_pruner(configs, kwargs): - ''' + """ Trims the unnecessary config options based on the sens. and feat. sizes - ''' + """ p2_sens_size = triton.next_power_of_2(kwargs["sens_size"]) p2_feat_size = triton.next_power_of_2(kwargs["feat_size"]) used = set() for config in configs: + + # Don't use block sizes bigger than p2_sens_size or p2_feat_size; they will give the same result + # because there will only be one block. sense_block_size = min(p2_sens_size, config.kwargs["SENS_BLOCK_SIZE"]) feat_block_size = min(p2_feat_size, config.kwargs["FEAT_BLOCK_SIZE"]) - if (sense_block_size, - feat_block_size, - config.num_stages, - config.num_warps) in used: + if (sense_block_size, feat_block_size, config.num_stages, config.num_warps) in used: continue - used.add((sense_block_size, - feat_block_size, - config.num_stages, - config.num_warps)) + + used.add((sense_block_size, feat_block_size, config.num_stages, config.num_warps)) + yield triton.Config( - { + { "SENS_BLOCK_SIZE": sense_block_size, "FEAT_BLOCK_SIZE": feat_block_size, - }, - num_stages=config.num_stages, - num_warps=config.num_warps, + }, + num_stages=config.num_stages, + num_warps=config.num_warps, ) + def get_autotune_config(): - ''' + """ Create a list of config options for the kernels TODO: Need to spend time actually figuring out more reasonable options targeted for modern GPUs - ''' + """ return [ - triton.Config({'SENS_BLOCK_SIZE': 16, 'FEAT_BLOCK_SIZE': 16}), - triton.Config({'SENS_BLOCK_SIZE': 16, 'FEAT_BLOCK_SIZE': 32}), - triton.Config({'SENS_BLOCK_SIZE': 16, 'FEAT_BLOCK_SIZE': 64}), - triton.Config({'SENS_BLOCK_SIZE': 16, 'FEAT_BLOCK_SIZE': 128}), - triton.Config({'SENS_BLOCK_SIZE': 16, 'FEAT_BLOCK_SIZE': 256}), - - triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 32}), - triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 64}), - triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 128}), - triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 128}, num_warps=8), - - triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 256}), - triton.Config({'SENS_BLOCK_SIZE': 32, 'FEAT_BLOCK_SIZE': 256}, num_warps=8), - - - triton.Config({'SENS_BLOCK_SIZE': 64, 'FEAT_BLOCK_SIZE': 32}), - triton.Config({'SENS_BLOCK_SIZE': 64, 'FEAT_BLOCK_SIZE': 64}), - triton.Config({'SENS_BLOCK_SIZE': 64, 'FEAT_BLOCK_SIZE': 128}), - triton.Config({'SENS_BLOCK_SIZE': 64, 'FEAT_BLOCK_SIZE': 256}), - - triton.Config({'SENS_BLOCK_SIZE': 128, 'FEAT_BLOCK_SIZE': 32}), - triton.Config({'SENS_BLOCK_SIZE': 128, 'FEAT_BLOCK_SIZE': 64}), - triton.Config({'SENS_BLOCK_SIZE': 128, 'FEAT_BLOCK_SIZE': 64}, num_warps=8), - - triton.Config({'SENS_BLOCK_SIZE': 256, 'FEAT_BLOCK_SIZE': 32}), - triton.Config({'SENS_BLOCK_SIZE': 256, 'FEAT_BLOCK_SIZE': 64}), - triton.Config({'SENS_BLOCK_SIZE': 256, 'FEAT_BLOCK_SIZE': 64}, num_warps=8), + triton.Config({"SENS_BLOCK_SIZE": 16, "FEAT_BLOCK_SIZE": 16}), + triton.Config({"SENS_BLOCK_SIZE": 16, "FEAT_BLOCK_SIZE": 32}), + triton.Config({"SENS_BLOCK_SIZE": 16, "FEAT_BLOCK_SIZE": 64}), + triton.Config({"SENS_BLOCK_SIZE": 16, "FEAT_BLOCK_SIZE": 128}), + triton.Config({"SENS_BLOCK_SIZE": 16, "FEAT_BLOCK_SIZE": 256}), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 32}), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 64}), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 128}), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 128}, num_warps=8), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 256}), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 256}, num_warps=8), + triton.Config({"SENS_BLOCK_SIZE": 64, "FEAT_BLOCK_SIZE": 32}), + triton.Config({"SENS_BLOCK_SIZE": 64, "FEAT_BLOCK_SIZE": 64}), + triton.Config({"SENS_BLOCK_SIZE": 64, "FEAT_BLOCK_SIZE": 128}), + triton.Config({"SENS_BLOCK_SIZE": 64, "FEAT_BLOCK_SIZE": 256}), + triton.Config({"SENS_BLOCK_SIZE": 128, "FEAT_BLOCK_SIZE": 32}), + triton.Config({"SENS_BLOCK_SIZE": 128, "FEAT_BLOCK_SIZE": 64}), + triton.Config({"SENS_BLOCK_SIZE": 128, "FEAT_BLOCK_SIZE": 64}, num_warps=8), + triton.Config({"SENS_BLOCK_SIZE": 256, "FEAT_BLOCK_SIZE": 32}), + triton.Config({"SENS_BLOCK_SIZE": 256, "FEAT_BLOCK_SIZE": 64}), + triton.Config({"SENS_BLOCK_SIZE": 256, "FEAT_BLOCK_SIZE": 64}, num_warps=8), ] -@triton.autotune( - configs=get_autotune_config(), - key=['sens_size', 'feat_size'], - prune_configs_by={ "early_config_prune": config_pruner} -) + +@triton.autotune(configs=get_autotune_config(), key=["sens_size", "feat_size"], prune_configs_by={"early_config_prune": config_pruner}) @triton.jit def envsum_kernel( out_env_ptr, @@ -97,7 +86,6 @@ def envsum_kernel( feat_size: tl.constexpr, SENS_BLOCK_SIZE: tl.constexpr, FEAT_BLOCK_SIZE: tl.constexpr, - dtype: tl.constexpr = tl.float32, ): atom_id = tl.program_id(axis=0) @@ -135,17 +123,20 @@ def envsum_kernel( # TODO: use sparsity of sensitivities to reduce workload? (see numba envsum implementation) tl.store(out_env_ptr + atom_offset + env_block_ids, tmp, mask=valid_env) + def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env=None): n_pairs, n_nu = sensitivities.shape n_atom, n_feat = features.shape (n_atom_with_pairs,) = atom_ids.shape + if out_env is None: out_env = torch.zeros((n_atom, n_nu, n_feat), dtype=features.dtype, device=features.device) + dtype = tl.float32 if features.dtype == torch.float64: dtype = tl.float64 - grid = lambda META: (n_atom_with_pairs, triton.cdiv(n_nu, META['SENS_BLOCK_SIZE']), triton.cdiv(n_feat, META['FEAT_BLOCK_SIZE'])) + grid = lambda META: (n_atom_with_pairs, triton.cdiv(n_nu, META["SENS_BLOCK_SIZE"]), triton.cdiv(n_feat, META["FEAT_BLOCK_SIZE"])) envsum_kernel[grid]( out_env, @@ -159,8 +150,6 @@ def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, at n_feat, dtype=dtype, ) - #print('best config') - #print(envsum_kernel.best_config) return out_env @@ -172,11 +161,8 @@ def envsum(sense, features, pfirst, psecond): resort_pairs_cached(psecond_hold, []) # Preemptively sort for backwards pass. return envsum_triton(sense, features, pfirst, psecond, atom1_ids, atom1_starts, out_env=None) -@triton.autotune( - configs=get_autotune_config(), - key=['sens_size', 'feat_size'], - prune_configs_by={ "early_config_prune": config_pruner} -) + +@triton.autotune(configs=get_autotune_config(), key=["sens_size", "feat_size"], prune_configs_by={"early_config_prune": config_pruner}) @triton.jit def sensesum_kernel( out_sense_ptr, @@ -203,8 +189,7 @@ def sensesum_kernel( feat_block_ids = tl.arange(0, FEAT_BLOCK_SIZE) valid_sens = sens_block_ids < sens_size - - + tmp = tl.zeros((SENS_BLOCK_SIZE,), dtype=dtype) for feat_id in range(num_feat_blocks): valid_feat = feat_block_ids < feat_size @@ -235,26 +220,24 @@ def sensesum_kernel( def sensesum(env, features, pair_first, pair_second, out_sense=None): if env.device == torch.device("cpu"): return sensesum_alternative(env, features, pair_first, pair_second) + _, n_nu, _ = env.shape n_atom, n_feat = features.shape n_pairs = len(pair_first) + if out_sense is None: out_sense = torch.zeros((n_pairs, n_nu), dtype=features.dtype, device=features.device) + dtype = tl.float32 if features.dtype == torch.float64: dtype = tl.float64 - grid = lambda META: (n_pairs, triton.cdiv(n_nu, META['SENS_BLOCK_SIZE'])) - sensesum_kernel[grid]( - out_sense, env, features, pair_first, pair_second, n_pairs, n_nu, n_feat, dtype=dtype - ) + grid = lambda META: (n_pairs, triton.cdiv(n_nu, META["SENS_BLOCK_SIZE"])) + sensesum_kernel[grid](out_sense, env, features, pair_first, pair_second, n_pairs, n_nu, n_feat, dtype=dtype) return out_sense -@triton.autotune( - configs=get_autotune_config(), - key=['sens_size', 'feat_size'], - prune_configs_by={ "early_config_prune": config_pruner} -) + +@triton.autotune(configs=get_autotune_config(), key=["sens_size", "feat_size"], prune_configs_by={"early_config_prune": config_pruner}) @triton.jit def featsum_kernel( out_feat, @@ -280,7 +263,6 @@ def featsum_kernel( end = tl.load(atom2_starts_ptr + atom_id + 1, mask=valid_atom, other=0) target_id = tl.load(atom2_ids_ptr + atom_id, mask=valid_atom, other=0) - feat_block_ids = tl.arange(0, FEAT_BLOCK_SIZE) + (feat_id * FEAT_BLOCK_SIZE) valid_feat = feat_block_ids < feat_size @@ -307,15 +289,19 @@ def featsum_kernel( def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, out_feat=None): + n_atom, n_nu, n_feat = env.shape (n_pairs,) = pair_first.shape (n_atoms_with_pairs,) = atom2_ids.shape + if out_feat is None: out_feat = torch.zeros((n_atom, n_feat), dtype=env.dtype, device=env.device) + dtype = tl.float32 if env.dtype == torch.float64: dtype = tl.float64 - grid = lambda META: (n_atoms_with_pairs, triton.cdiv(n_feat, META['FEAT_BLOCK_SIZE'])) + + grid = lambda META: (n_atoms_with_pairs, triton.cdiv(n_feat, META["FEAT_BLOCK_SIZE"])) featsum_kernel[grid]( out_feat, From fa5845b37eac1a9ad71561917684159d69c97eab Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Mon, 29 Jul 2024 13:11:57 -0600 Subject: [PATCH 23/24] update documentation for custom kernels --- docs/source/installation.rst | 10 ++++--- docs/source/user_guide/ckernels.rst | 45 ++++++++++++++++++++++++----- docs/source/user_guide/settings.rst | 4 +-- hippynn/custom_kernels/__init__.py | 8 +++-- 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 3f3a3a27..54384e44 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -11,20 +11,22 @@ Requirements: * pytorch_ >= 1.9 * numpy_ Optional Dependencies: - * numba_ (recommended, for accelerating performance) - * cupy_ (also for accelerating performance) + * triton_ (recommended, for improved GPU performance) + * numba_ (recommended for improved CPU performance) + * cupy_ (Alternative for accelerating GPU performance) * ASE_ (for usage with ase) * matplotlib_ (for plotting) * tqdm_ (for progress bars) * graphviz_ (for viewing model graphs as figures) - * h5py_ (for ani-h5 datasets) - * pyanitools_ (for ani-h5 datasets) + * h5py_ (for loading ani-h5 datasets) + * pyanitools_ (for loading ani-h5 datasets) Interfacing codes: * ASE_ * PYSEQM_ * LAMMPS_ +.. _triton: https://triton-lang.org/ .. _numpy: https://numpy.org/ .. _Python: http://www.python.org .. _pytorch: http://www.pytorch.org diff --git a/docs/source/user_guide/ckernels.rst b/docs/source/user_guide/ckernels.rst index a285de10..61fa1537 100644 --- a/docs/source/user_guide/ckernels.rst +++ b/docs/source/user_guide/ckernels.rst @@ -1,9 +1,20 @@ Custom Kernels ============== - +Bottom line up front +-------------------- + +We use custom kernels in `hippynn` to accelerate the HIP-NN neural network message passing. +On the GPU, the best implementation to select is ``triton``, followed by ``cupy``, +followed by ``numba``. On the CPU, only ``numba`` is available. In general, these +custom kernels are very useful, and the only reasons for them to be off is if are +if the packages are not available for installation in your environment or if diagnosing +whether or not a bug could be related to potential misconfiguration of these additional packages. + +Detailed Explanation +-------------------- Analogs of convolutional layers that apply to continously variable points in space, such as the -`HIP-NN` interaction layer, can be awkward to write in pure-pytorch. +HIP-NN interaction layer, can be awkward to write in pure-pytorch. The :mod:`~hippynn.custom_kernels` subpackage implements some more efficient kernels for both the forward and backward pass of the sum over neighbors. This is implemented, more or less, as a CSR-type @@ -12,8 +23,8 @@ mixture of inner products and outer products on the remaining "feature" and "sen This behavior can be switched off (and is off by default if the dependencies are not installed) to revert to a pure pytorch implementation. -The custom kernels provide `much` better memory footprint than the pure pytorch implementation, -and a decent amount of speedup on those core operations. The memory footprint of the pytorch +The custom kernels provide *much* better memory footprint than the pure pytorch implementation, +and a very good amount of speedup on those core operations. The memory footprint of the pytorch implementation is approximately: .. math:: @@ -27,19 +38,33 @@ whereas the memory footprint of the custom kernels is approximately O(N_\mathrm{pairs}N_\mathrm{sensitivities} + N_\mathrm{atoms}N_\mathrm{features}N_\mathrm{sensitivities}). -The custom kernels are implemented using ``numba`` and/or ``cupy``, depending +The custom kernels are implemented using ``triton``, ``cupy`` and/or ``numba``, depending on what is installed in your python environment. However, there are certain overheads in using them. In particular, if you are using a GPU and your batch size is small, the pytorch implementations may actually be faster, because they launch more quickly. -This is especially true if you use a shallower model (one interaction layer) with +This is especially true if you use a shallow HIP-NN type model (one interaction layer) with with a small number of elements, because the memory waste in a pure pytorch implementation is proportional to the number of input features. -If you are using a CPU, the custom kernels are recommended at all times. +Nonetheless for most practical purposes, keeping custom kernels +on at all times is computationally recommended. +If you are using a CPU, the custom kernels are provided only using ``numba``, but they +do not come with any large overheads, and so provide computatonal benefits at all times. +The only reason to turn custom kernels off, in general, is to diagnose whether there are +issues with how they are being deployed; if ``numba`` or ``cupy`` is not correctly installed, +then we have found that sometimes the kernels may silently fail. The three custom kernels correspond to the interaction sum in hip-nn: -For envsum, sensum, featsum: +.. math:: + + a'_{i,a} = = \sum_{\nu,b} V^\nu_{a,b} e^{\nu}_{i,b} + + e^{\nu}_{i,a} = \sum_p s^\nu_{p} z_{p_j,a} + +Where :math:`a` is the pre-activation for an interaction layer using input features :math:`z`. + +For envsum, sensesum, featsum: .. math:: @@ -49,6 +74,10 @@ For envsum, sensum, featsum: f_{j,a} = \sum_{\nu,i} e_{p_i,\nu,a} s_{p_i,a} +These three functions form a closed system under automatic differentiation, and are linked to each +other in pytorch's autograd, thereby supporting custom kernels in backwards passes and in +double-backwards passes associated with Force training or similar features. + Custom kernels can be set ahead of time using :doc:`/user_guide/settings` and dynamically using :func:`~hippynn.custom_kernels.set_custom_kernels`. diff --git a/docs/source/user_guide/settings.rst b/docs/source/user_guide/settings.rst index 949f2377..a9692981 100644 --- a/docs/source/user_guide/settings.rst +++ b/docs/source/user_guide/settings.rst @@ -45,8 +45,8 @@ The following settings are available: - false - Yes * - USE_CUSTOM_KERNELS - - Use custom kernels with numba or cupy. Auto tries to detect the installation of numba or cupy. For more info see :doc:`/user_guide/ckernels`. - - auto, true, false, pytorch, numba, cupy + - Use custom kernels with triton, numba or cupy. Auto tries to detect the installation. For more info see :doc:`/user_guide/ckernels`. + - auto, true, false, pytorch, numba, cupy, triton - auto - Not directly, use :func:`~hippynn.custom_kernels.set_custom_kernels` * - WARN_LOW_DISTANCES diff --git a/hippynn/custom_kernels/__init__.py b/hippynn/custom_kernels/__init__.py index 87f3b2f6..cd7ca9d3 100644 --- a/hippynn/custom_kernels/__init__.py +++ b/hippynn/custom_kernels/__init__.py @@ -42,7 +42,8 @@ pass if not CUSTOM_KERNELS_AVAILABLE: - warnings.warn("Numba or cupy not available: Custom Kernels will be disabled.") + warnings.warn( + "Triton, cupy and numba are not available: Custom kernels will be disabled and performance maybe be degraded.") CUSTOM_KERNELS_ACTIVE = False @@ -82,7 +83,7 @@ def set_custom_kernels(active: Union[bool, str] = True): Activate or deactivate custom kernels for interaction. :param active: If true, set custom kernels to the best available. If False, turn them off and default to pytorch. - If "numba" or "cupy", use those implementations explicitly. If "auto", use best available. + If "triton", "numba" or "cupy", use those implementations explicitly. If "auto", use best available. :return: None """ global envsum, sensesum, featsum, CUSTOM_KERNELS_ACTIVE @@ -98,7 +99,8 @@ def set_custom_kernels(active: Union[bool, str] = True): if active == "auto" or active == "pytorch": active = False elif active: - raise RuntimeError("Numba or cupy was not found. Custom kernels are not available, but they were required by library settings.") + raise RuntimeError( + "Triton, numba and cupy were not found. Custom kernels are not available, but they were required by library settings.") else: active = active_map.get(active, active) From aa2c83371af9311070fd4aa4faf969b95f1409de Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Mon, 29 Jul 2024 13:18:48 -0600 Subject: [PATCH 24/24] update requirements/docs --- conda_requirements.txt | 3 ++- docs/source/user_guide/ckernels.rst | 2 ++ optional_dependencies.txt | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/conda_requirements.txt b/conda_requirements.txt index b64587dc..590e5ca3 100644 --- a/conda_requirements.txt +++ b/conda_requirements.txt @@ -1,5 +1,6 @@ numpy -pytorch >= 1.6 +pytorch >= 1.9 +torchtriton matplotlib numba cupy diff --git a/docs/source/user_guide/ckernels.rst b/docs/source/user_guide/ckernels.rst index 61fa1537..c810bbcd 100644 --- a/docs/source/user_guide/ckernels.rst +++ b/docs/source/user_guide/ckernels.rst @@ -10,6 +10,8 @@ followed by ``numba``. On the CPU, only ``numba`` is available. In general, thes custom kernels are very useful, and the only reasons for them to be off is if are if the packages are not available for installation in your environment or if diagnosing whether or not a bug could be related to potential misconfiguration of these additional packages. +``triton`` comes with recent versions of ``pytorch``, so optimistically you may already be +configured to use the custom kernels. Detailed Explanation -------------------- diff --git a/optional_dependencies.txt b/optional_dependencies.txt index 2b99196c..7235c67c 100644 --- a/optional_dependencies.txt +++ b/optional_dependencies.txt @@ -5,3 +5,4 @@ matplotlib tqdm graphviz h5py +triton \ No newline at end of file