From 39b7bb2baeb4a2fd63d9ca706fcf966c5a3546f9 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers <56895592+lubbersnick@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:09:53 -0600 Subject: [PATCH] Add units to lammps interface and fix kokkos+cpu bug. (#90) * Add tentative fix for kokkos CPU, and add units to lammps interface * Fixes to mliap to enable usage on other hardware Fixes to triton kernel * Fixed spelling --------- Co-authored-by: Ben Nebgen --- hippynn/custom_kernels/env_triton.py | 9 +-- .../lammps_interface/mliap_interface.py | 78 +++++++++++++++---- 2 files changed, 65 insertions(+), 22 deletions(-) diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index db8665de..266318d7 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -10,13 +10,13 @@ # Load backup implementation for CPU tensors. from .env_pytorch import envsum as envsum_alternative, sensesum as sensesum_alternative, featsum as featsum_alternative - -def config_pruner(configs, kwargs): +def config_pruner(configs, nargs, **kwargs): """ Trims the unnecessary config options based on the sens. and feat. sizes """ - p2_sens_size = triton.next_power_of_2(kwargs["sens_size"]) - p2_feat_size = triton.next_power_of_2(kwargs["feat_size"]) + #print("For some reason the config pruner also gets arguments:",kwargs) + p2_sens_size = triton.next_power_of_2(nargs["sens_size"]) + p2_feat_size = triton.next_power_of_2(nargs["feat_size"]) used = set() for config in configs: @@ -40,7 +40,6 @@ def config_pruner(configs, kwargs): num_warps=config.num_warps, ) - def get_autotune_config(): """ Create a list of config options for the kernels diff --git a/hippynn/interfaces/lammps_interface/mliap_interface.py b/hippynn/interfaces/lammps_interface/mliap_interface.py index 53833622..2ce813fc 100644 --- a/hippynn/interfaces/lammps_interface/mliap_interface.py +++ b/hippynn/interfaces/lammps_interface/mliap_interface.py @@ -27,12 +27,27 @@ class MLIAPInterface(MLIAPUnified): Class for creating ML-IAP Unified model based on hippynn graphs. """ - def __init__(self, energy_node, element_types, ndescriptors=1, model_device=torch.device("cpu"), compute_dtype=torch.float32): + def __init__( + self, + energy_node, + element_types, + ndescriptors=1, + model_device=torch.device("cpu"), + compute_dtype=torch.float32, + energy_unit: float = None, + distance_unit: float = None, + ): """ :param energy_node: Node for energy :param element_types: list of atomic symbols corresponding to element types :param ndescriptors: the number of descriptors to report to LAMMPS :param model_device: the device to send torch data to (cpu or cuda) + :param energy_unit: If present, multiply the result by the given energy units. + If your model was trained in Hartree and your lammps script will operate in eV, + use en_unit = ase.units.Ha = 27.211386024367243 + :param distance_unit: If present, multi input distances by this much as well as dividing into output forces. + If your model was trained to accept nm as input and lammps uses Angstroms, + use dist_unit = ase.units.nm = 10. """ super().__init__() if hippynn.settings.PYTORCH_GPU_MEM_FRAC < 1.0: @@ -40,6 +55,8 @@ def __init__(self, energy_node, element_types, ndescriptors=1, model_device=torc self.element_types = element_types self.ndescriptors = ndescriptors self.model_device = model_device + self.energy_unit = energy_unit + self.distance_unit = distance_unit # Build the calculator self.rcutfac, self.species_set, self.graph = setup_LAMMPS_graph(energy_node) @@ -56,8 +73,8 @@ def compute_descriptors(self, data): def as_tensor(self, array): return torch.as_tensor(array, device=self.model_device) - def empty_tensor(self,dimentions): - return torch.empty(dimentions,device=self.model_device) + def empty_tensor(self, dimentions): + return torch.empty(dimentions, device=self.model_device) def compute_forces(self, data): """ @@ -67,10 +84,11 @@ def compute_forces(self, data): """ nlocal = self.as_tensor(data.nlistatoms) if nlocal.item() > 0: - #If there are no local atoms, do nothing + # If there are no local atoms, do nothing elems = self.as_tensor(data.elems).type(torch.int64).reshape(1, data.ntotal) z_vals = self.species_set[elems + 1] npairs = data.npairs + if npairs > 0: pair_i = self.as_tensor(data.pair_i).type(torch.int64) pair_j = self.as_tensor(data.pair_j).type(torch.int64) @@ -78,33 +96,54 @@ def compute_forces(self, data): else: pair_i = self.empty_tensor(0).type(torch.int64) pair_j = self.empty_tensor(0).type(torch.int64) - rij = self.empty_tensor([0,3]).type(self.compute_dtype) - + rij = self.empty_tensor([0, 3]).type(self.compute_dtype) + + if self.distance_unit is not None: + rij = self.dist_unit * rij + # note your sign for rij might need to be +1 or -1, depending on how your implementation works inputs = [z_vals, pair_i, pair_j, -rij, nlocal] atom_energy, total_energy, fij = self.graph(*inputs) - # Test if we are using lammps-kokkos or not. Is there a more clear way to do that? - if isinstance(data.elems, np.ndarray): - return_device = "cpu" - else: - # Hope that kokkos device and pytorch device are the same (default cuda) + using_kokkos = "kokkos" in data.__class__.__module__.lower() + if using_kokkos: return_device = elems.device - + else: + return_device = "cpu" + + # convert units + if self.energy_unit is not None: + atom_energy = self.en_unit * atom_energy + total_energy = self.en_unit * total_energy + fij = self.en_unit * fij + + if self.distance_unit is not None: + fij = fij / self.dist_unit + atom_energy = atom_energy.squeeze(1).detach().to(return_device) total_energy = total_energy.detach().to(return_device) f = self.as_tensor(data.f) fij = fij.type(f.dtype).detach().to(return_device) - - if return_device == "cpu": + + # hacky way to detect if we are in kokkos or not. + + if not using_kokkos: + # write back to data.eatoms directly. fij = fij.numpy() data.eatoms = atom_energy.numpy().astype(np.double) + if npairs > 0: + data.update_pair_forces(fij) else: + # view to data.eatoms using pytorch, and write into the view. eatoms = torch.as_tensor(data.eatoms, device=return_device) eatoms.copy_(atom_energy) - if npairs > 0: - data.update_pair_forces(fij) + if npairs > 0: + if return_device == "cpu": + data.update_pair_forces_cpu(fij) + else: + data.update_pair_forces_gpu(fij) + data.energy = total_energy.item() def __getstate__(self): @@ -116,11 +155,16 @@ def __setstate__(self, state): self.__dict__.update(state) try: torch.ones(0).to(self.model_device) - except (RuntimeError, AssertionError): + except RuntimeError: fallback = device_fallback() warnings.warn(f"Model device ({self.model_device}) not found, falling back to f{fallback}") self.model_device = fallback + if not hasattr(self, "en_unit"): + self.en_unit = None + if not hasattr(self, "dist_unit"): + self.dist_unit = None + self.species_set = self.species_set.to(self.model_device) self.graph.to(self.model_device)