diff --git a/.flexci/build_and_push.sh b/.flexci/build_and_push.sh index 2e8a018..75a00a5 100644 --- a/.flexci/build_and_push.sh +++ b/.flexci/build_and_push.sh @@ -32,10 +32,10 @@ docker_build_and_push() { WAIT_PIDS="" # PyTorch 1.5 + Python 3.6 -docker_build_and_push torch15 \ +docker_build_and_push torch19 \ --build-arg base_image="nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04" \ - --build-arg python_version="3.6.12" \ - --build-arg pip_packages="torch==1.5.* torchvision==0.6.* ${TEST_PIP_PACKAGES}" & + --build-arg python_version="3.7.9" \ + --build-arg pip_packages="torch==1.9.1 ${TEST_PIP_PACKAGES}" & WAIT_PIDS="$! ${WAIT_PIDS}" # Wait until the build complete. diff --git a/.flexci/pytest_script.sh b/.flexci/pytest_script.sh index 0a0960b..8ffba7f 100644 --- a/.flexci/pytest_script.sh +++ b/.flexci/pytest_script.sh @@ -2,7 +2,7 @@ set -eu #IMAGE=pytorch/pytorch:1.5.1-cuda10.1-cudnn7-devel -IMAGE=asia.gcr.io/pfn-public-ci/torch-dftd-ci:torch15 +IMAGE=asia.gcr.io/pfn-public-ci/torch-dftd-ci:torch19 main() { @@ -20,7 +20,7 @@ main() { docker run --runtime=nvidia --rm --volume="$(pwd)":/workspace -w /workspace \ ${IMAGE} \ bash -x -c "pip install flake8 pytest pytest-cov pytest-xdist pytest-benchmark && \ - pip install cupy-cuda102 pytorch-pfn-extras!=0.5.0 && \ + pip install cupy-cuda102 pytorch-pfn-extras==0.4.2 && \ pip install -e .[develop] && \ pysen run lint && \ pytest --cov=torch_dftd -n $(nproc) -m 'not slow' tests && diff --git a/README.md b/README.md index 50ecbea..8a5ba44 100644 --- a/README.md +++ b/README.md @@ -34,14 +34,14 @@ print(f"forces {forces}") ## Dependency The library is tested under following environment. - - python: 3.6 + - python: 3.7 - CUDA: 10.2 ```bash -torch==1.5.1 +torch==1.9.1 ase==3.21.1 # Below is only for 3-body term -cupy-cuda102==8.6.0 -pytorch-pfn-extras==0.3.2 +cupy-cuda102==9.5.0 +pytorch-pfn-extras==0.4.2 ``` ## Development tips diff --git a/torch_dftd/dftd3_xc_params.py b/torch_dftd/dftd3_xc_params.py index 00c806c..6789deb 100644 --- a/torch_dftd/dftd3_xc_params.py +++ b/torch_dftd/dftd3_xc_params.py @@ -535,7 +535,7 @@ def get_dftd3_default_params( rs6 = 1.1 s18 = 0.0 alp = 20.0 - rs18 = None # Not used. + rs18 = 0.0 # It is DUMMY value. Not used. if xc == "b-lyp": s6 = 1.2 elif xc == "b-p": diff --git a/torch_dftd/functions/dftd2.py b/torch_dftd/functions/dftd2.py index 13e9347..513bc3a 100644 --- a/torch_dftd/functions/dftd2.py +++ b/torch_dftd/functions/dftd2.py @@ -6,6 +6,7 @@ from torch_dftd.functions.smoothing import poly_smoothing +@torch.jit.script def edisp_d2( Z: Tensor, r: Tensor, @@ -44,7 +45,7 @@ def edisp_d2( r2 = r ** 2 r6 = r2 ** 3 - idx_i, idx_j = edge_index + idx_i, idx_j = edge_index[0], edge_index[1] # compute all necessary quantities Zi = Z[idx_i] # (n_edges,) Zj = Z[idx_j] @@ -71,6 +72,7 @@ def edisp_d2( # (1,) g = e6.sum()[None] else: + assert batch is not None # (n_graphs,) if batch.size()[0] == 0: n_graphs = 1 diff --git a/torch_dftd/functions/dftd3.py b/torch_dftd/functions/dftd3.py index dc5d61d..f69667f 100644 --- a/torch_dftd/functions/dftd3.py +++ b/torch_dftd/functions/dftd3.py @@ -9,8 +9,8 @@ # conversion factors used in grimme d3 code -d3_autoang = 0.52917726 # for converting distance from bohr to angstrom -d3_autoev = 27.21138505 # for converting a.u. to eV +d3_autoang: float = 0.52917726 # for converting distance from bohr to angstrom +d3_autoev: float = 27.21138505 # for converting a.u. to eV d3_k1 = 16.000 d3_k2 = 4 / 3 @@ -18,6 +18,7 @@ d3_maxc = 5 # maximum number of coordination complexes +@torch.jit.script def _ncoord( Z: Tensor, r: Tensor, @@ -53,7 +54,7 @@ def _ncoord( Zi = Z[idx_i] Zj = Z[idx_j] rco = rcov[Zi] + rcov[Zj] # (n_edges,) - rr = rco.type(r.dtype) / r + rr = rco.to(r.dtype) / r damp = 1.0 / (1.0 + torch.exp(-k1 * (rr - 1.0))) if cutoff is not None and cutoff_smoothing == "poly": damp *= poly_smoothing(r, cutoff) @@ -66,6 +67,7 @@ def _ncoord( return g # (n_atoms,) +@torch.jit.script def _getc6( Zi: Tensor, Zj: Tensor, nci: Tensor, ncj: Tensor, c6ab: Tensor, k3: float = d3_k3 ) -> Tensor: @@ -84,7 +86,7 @@ def _getc6( """ # gather the relevant entries from the table # c6ab (95, 95, 5, 5, 3) --> c6ab_ (n_edges, 5, 5, 3) - c6ab_ = c6ab[Zi, Zj].type(nci.dtype) + c6ab_ = c6ab[Zi, Zj].to(nci.dtype) # calculate c6 coefficients # cn0, cn1, cn2 (n_edges, 5, 5) @@ -104,6 +106,7 @@ def _getc6( return c6 +@torch.jit.script def edisp( Z: Tensor, r: Tensor, @@ -120,12 +123,12 @@ def edisp( shift_pos: Optional[Tensor] = None, pos: Optional[Tensor] = None, cell: Optional[Tensor] = None, - r2=None, - r6=None, - r8=None, - k1=d3_k1, - k2=d3_k2, - k3=d3_k3, + r2: Optional[Tensor] = None, + r6: Optional[Tensor] = None, + r8: Optional[Tensor] = None, + k1: float = d3_k1, + k2: float = d3_k2, + k3: float = d3_k3, cutoff_smoothing: str = "none", damping: str = "zero", bidirectional: bool = False, @@ -146,7 +149,7 @@ def edisp( cnthr (float or None): cutoff distance for coordination number calculation in **bohr** batch (Tensor or None): (n_atoms,) batch_edge (Tensor or None): (n_edges,) - shift_pos (Tensor or None): (n_atoms,) used to calculate 3-body term when abc=True + shift_pos (Tensor or None): (n_edges, 3) used to calculate 3-body term when abc=True pos (Tensor): (n_atoms, 3) position in **bohr** cell (Tensor): (3, 3) cell size in **bohr** r2 (Tensor or None): @@ -171,7 +174,7 @@ def edisp( if r8 is None: r8 = r6 * r2 - idx_i, idx_j = edge_index + idx_i, idx_j = edge_index[0], edge_index[1] # compute all necessary quantities Zi = Z[idx_i] # (n_edges,) Zj = Z[idx_j] @@ -192,7 +195,7 @@ def edisp( ncj = nc[idx_j] c6 = _getc6(Zi, Zj, nci, ncj, c6ab=c6ab, k3=k3) # c6 coefficients - c8 = 3 * c6 * r2r4[Zi].type(c6.dtype) * r2r4[Zj].type(c6.dtype) # c8 coefficient + c8 = 3 * c6 * r2r4[Zi].to(c6.dtype) * r2r4[Zj].to(c6.dtype) # c8 coefficient s6 = params["s6"] s8 = params["s18"] @@ -250,6 +253,7 @@ def edisp( g = e68.to(torch.float64).sum()[None] else: # (n_graphs,) + assert batch is not None if batch.size()[0] == 0: n_graphs = 1 else: @@ -261,6 +265,7 @@ def edisp( g *= 2.0 if abc: + assert cnthr is not None within_cutoff = r <= cnthr # r_abc = r[within_cutoff] # r2_abc = r2[within_cutoff] @@ -282,12 +287,7 @@ def edisp( # (n_edges, ) -> (n_edges * 2, ) shift_abc = None if shift_abc is None else torch.cat([shift_abc, -shift_abc], dim=0) with torch.no_grad(): - # triplet_node_index, triplet_edge_index = calc_triplets_cycle(edge_index_abc, n_atoms, shift=shift_abc) - # Type hinting - triplet_node_index: Tensor - multiplicity: Tensor - edge_jk: Tensor - batch_triplets: Optional[Tensor] + assert pos is not None triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets( edge_index_abc, shift_pos=shift_abc, @@ -303,7 +303,6 @@ def edisp( ) r_jk = calc_distances(pos, torch.stack([idx_j, idx_k], dim=0), cell, shift_jk) kj_within_cutoff = r_jk <= cnthr - del shift_jk triplet_node_index = triplet_node_index[kj_within_cutoff] multiplicity, edge_jk, batch_triplets = ( @@ -355,5 +354,6 @@ def edisp( e6abc = e3.to(torch.float64).sum() g += e6abc else: + assert batch_triplets is not None g.scatter_add_(0, batch_triplets, e3.to(torch.float64)) return g # (n_graphs,) diff --git a/torch_dftd/functions/distance.py b/torch_dftd/functions/distance.py index e177bdb..90a4f0b 100644 --- a/torch_dftd/functions/distance.py +++ b/torch_dftd/functions/distance.py @@ -4,12 +4,13 @@ from torch import Tensor +@torch.jit.script def calc_distances( pos: Tensor, edge_index: Tensor, cell: Optional[Tensor] = None, shift_pos: Optional[Tensor] = None, - eps=1e-20, + eps: float = 1e-20, ) -> Tensor: """Distance calculation function. @@ -17,6 +18,7 @@ def calc_distances( pos (Tensor): (n_atoms, 3) atom positions. edge_index (Tensor): (2, n_edges) edge_index for graph. cell (Tensor): cell size, None for non periodic system. + This it NOT USED now, it is left for backward compatibility. shift_pos (Tensor): (n_edges, 3) position shift vectors of edges owing to the periodic boundary. It should be length unit. eps (float): Small float value to avoid NaN in backward when the distance is 0. @@ -25,11 +27,11 @@ def calc_distances( """ - idx_i, idx_j = edge_index + idx_i, idx_j = edge_index[0], edge_index[1] # calculate interatomic distances Ri = pos[idx_i] Rj = pos[idx_j] - if cell is not None: + if shift_pos is not None: Rj += shift_pos # eps is to avoid Nan in backward when Dij = 0 with sqrt. Dij = torch.sqrt(torch.sum((Ri - Rj) ** 2, dim=-1) + eps) diff --git a/torch_dftd/functions/smoothing.py b/torch_dftd/functions/smoothing.py index 1f8595f..e993e39 100644 --- a/torch_dftd/functions/smoothing.py +++ b/torch_dftd/functions/smoothing.py @@ -2,12 +2,13 @@ from torch import Tensor +@torch.jit.script def poly_smoothing(r: Tensor, cutoff: float) -> Tensor: """Computes a smooth step from 1 to 0 starting at 1 bohr before the cutoff Args: r (Tensor): (n_edges,) - cutoff (float): () + cutoff (float): cutoff length Returns: r (Tensor): Smoothed `r` diff --git a/torch_dftd/functions/triplets.py b/torch_dftd/functions/triplets.py index a21be85..50c3297 100644 --- a/torch_dftd/functions/triplets.py +++ b/torch_dftd/functions/triplets.py @@ -1,72 +1,22 @@ -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import Tensor from torch_dftd.functions.triplets_kernel import _calc_triplets_core_gpu -def calc_triplets( - edge_index: Tensor, - shift_pos: Optional[Tensor] = None, - dtype=torch.float32, - batch_edge: Optional[Tensor] = None, +@torch.jit.script +def _calc_triplets_core( + counts: Tensor, + unique: Tensor, + dst: Tensor, + edge_indices: Tensor, + batch_edge: Tensor, + counts_cumsum: Tensor, + dtype: torch.dtype = torch.float32, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """Calculate triplet edge index. - - Args: - edge_index (Tensor): (2, n_edges) edge_index for graph. It must be bidirectional edge. - shift_pos (Tensor or None): (n_edges, 3) used to calculate unique atoms when pbc=True. - dtype: dtype for `multiplicity` - batch_edge (Tensor or None): Specify batch indices for `edge_index`. - - Returns: - triplet_node_index (Tensor): (3, n_triplets) index for node `i`, `j`, `k` respectively. - i.e.: idx_i, idx_j, idx_k = triplet_node_index - multiplicity (Tensor): (n_triplets,) multiplicity indicates duplication of same triplet pair. - It only takes 1 in non-pbc, but it takes 2 or 3 in pbc case. dtype is specified in the argument. - edge_jk (Tensor): (n_triplet_edges, 2=(j, k)) edge indices for j and k. - i.e.: idx_j, idx_k = edge_jk[:, 0], edge_jk[:, 1] - batch_triplets (Tensor): (n_triplets,) batch indices for each triplets. - """ - dst, src = edge_index - is_larger = dst >= src - dst = dst[is_larger] - src = src[is_larger] - # sort `src` - sort_inds = torch.argsort(src) - src = src[sort_inds] - dst = dst[sort_inds] - - if shift_pos is None: - edge_indices = torch.arange(src.shape[0], dtype=torch.long, device=edge_index.device) - else: - edge_indices = torch.arange(shift_pos.shape[0], dtype=torch.long, device=edge_index.device) - edge_indices = edge_indices[is_larger][sort_inds] - - if batch_edge is None: - batch_edge = torch.zeros(src.shape[0], dtype=torch.long, device=edge_index.device) - else: - batch_edge = batch_edge[is_larger][sort_inds] - - unique, counts = torch.unique_consecutive(src, return_counts=True) - counts_cumsum = torch.cumsum(counts, dim=0) - counts_cumsum = torch.cat( - [torch.zeros((1,), device=counts.device, dtype=torch.long), counts_cumsum], dim=0 - ) - - if str(unique.device) == "cpu": - return _calc_triplets_core( - counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype=dtype - ) - else: - return _calc_triplets_core_gpu( - counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype=dtype - ) - - -def _calc_triplets_core(counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype): device = unique.device - n_triplets = torch.sum(counts * (counts - 1) / 2) + n_triplets = int(torch.sum(counts * (counts - 1) / 2)) if n_triplets == 0: # (n_triplet_edges, 3) triplet_node_index = torch.zeros((0, 3), dtype=torch.long, device=device) @@ -78,20 +28,26 @@ def _calc_triplets_core(counts, unique, dst, edge_indices, batch_edge, counts_cu batch_triplets = torch.zeros((0,), dtype=torch.long, device=device) return triplet_node_index, multiplicity, edge_jk, batch_triplets - triplet_node_index_list = [] # (n_triplet_edges, 3) - edge_jk_list = [] # (n_triplet_edges, 2) represents j and k indices - multiplicity_list = [] # (n_triplet_edges) represents multiplicity - batch_triplets_list = [] # (n_triplet_edges) represents batch index for triplets + triplet_node_index_list: List[List[int]] = [] # (n_triplet_edges, 3) + edge_jk_list: List[List[int]] = [] # (n_triplet_edges, 2) represents j and k indices + multiplicity_list: List[float] = [] # (n_triplet_edges) represents multiplicity + batch_triplets_list: List[int] = [] # (n_triplet_edges) represents batch index for triplets + + unique_list: List[int] = unique.tolist() + dst_list: List[int] = dst.tolist() + counts_list: List[int] = counts.tolist() + counts_cumsum_list: List[int] = counts_cumsum.tolist() + batch_edge_list: List[int] = batch_edge.tolist() for i in range(len(unique)): - _src = unique[i].item() - _n_edges = counts[i].item() - _dst = dst[counts_cumsum[i] : counts_cumsum[i + 1]] - _offset = counts_cumsum[i].item() - _batch_index = batch_edge[counts_cumsum[i]].item() + _src: int = unique_list[i] + _n_edges: int = counts_list[i] + _dst: List[int] = dst_list[counts_cumsum_list[i] : counts_cumsum_list[i + 1]] + _offset = counts_cumsum_list[i] + _batch_index = batch_edge_list[counts_cumsum_list[i]] for j in range(_n_edges - 1): for k in range(j + 1, _n_edges): - _dst0 = _dst[j].item() # _dst0 maybe swapped with _dst1, need to reset here. - _dst1 = _dst[k].item() + _dst0: int = _dst[j] # _dst0 maybe swapped with _dst1, need to reset here. + _dst1: int = _dst[k] batch_triplets_list.append(_batch_index) # --- triplet_node_index_list & shift_list in sorted way... --- # sort order to be _src <= _dst0 <= _dst1, and i <= _j <= _k @@ -135,3 +91,63 @@ def _calc_triplets_core(counts, unique, dst, edge_indices, batch_edge, counts_cu # (n_triplet_edges, 3=(ij, ik, jk), 3=(xyz) ) batch_triplets = torch.as_tensor(batch_triplets_list, dtype=torch.long, device=device) return triplet_node_index, multiplicity, edge_jk, batch_triplets + + +@torch.jit.script +def calc_triplets( + edge_index: Tensor, + shift_pos: Optional[Tensor] = None, + dtype: torch.dtype = torch.float32, + batch_edge: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Calculate triplet edge index. + + Args: + edge_index (Tensor): (2, n_edges) edge_index for graph. It must be bidirectional edge. + shift_pos (Tensor or None): (n_edges, 3) used to calculate unique atoms when pbc=True. + dtype: dtype for `multiplicity` + batch_edge (Tensor or None): Specify batch indices for `edge_index`. + + Returns: + triplet_node_index (Tensor): (3, n_triplets) index for node `i`, `j`, `k` respectively. + i.e.: idx_i, idx_j, idx_k = triplet_node_index + multiplicity (Tensor): (n_triplets,) multiplicity indicates duplication of same triplet pair. + It only takes 1 in non-pbc, but it takes 2 or 3 in pbc case. dtype is specified in the argument. + edge_jk (Tensor): (n_triplet_edges, 2=(j, k)) edge indices for j and k. + i.e.: idx_j, idx_k = edge_jk[:, 0], edge_jk[:, 1] + batch_triplets (Tensor): (n_triplets,) batch indices for each triplets. + """ + dst, src = edge_index[0], edge_index[1] + is_larger = dst >= src + dst = dst[is_larger] + src = src[is_larger] + # sort `src` + sort_inds = torch.argsort(src) + src = src[sort_inds] + dst = dst[sort_inds] + + if shift_pos is None: + edge_indices = torch.arange(src.shape[0], dtype=torch.long, device=edge_index.device) + else: + edge_indices = torch.arange(shift_pos.shape[0], dtype=torch.long, device=edge_index.device) + edge_indices = edge_indices[is_larger][sort_inds] + + if batch_edge is None: + batch_edge = torch.zeros(src.shape[0], dtype=torch.long, device=edge_index.device) + else: + batch_edge = batch_edge[is_larger][sort_inds] + + unique, counts = torch.unique_consecutive(src, return_counts=True) + counts_cumsum = torch.cumsum(counts, dim=0) + counts_cumsum = torch.cat( + [torch.zeros((1,), device=counts.device, dtype=torch.long), counts_cumsum], dim=0 + ) + + if str(unique.device) == "cpu": + return _calc_triplets_core( + counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype=dtype + ) + else: + return _calc_triplets_core_gpu( + counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype=dtype + ) diff --git a/torch_dftd/functions/triplets_kernel.py b/torch_dftd/functions/triplets_kernel.py index cf4d96a..52c5edc 100644 --- a/torch_dftd/functions/triplets_kernel.py +++ b/torch_dftd/functions/triplets_kernel.py @@ -113,15 +113,16 @@ def _cupy2torch(array: cp.ndarray) -> Tensor: _calc_triplets_core_gpu_kernel = None -def _calc_triplets_core_gpu( +@torch.jit.ignore +def _calc_triplets_core_gpu_run_kernel( counts: Tensor, unique: Tensor, dst: Tensor, edge_indices: Tensor, batch_edge: Tensor, counts_cumsum: Tensor, - dtype: torch.dtype = torch.float32, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + dtype: torch.dtype = torch.float32 if not _ppe_available: raise ImportError("Please install pytorch_pfn_extras to use `_calc_triplets_core_gpu`!") if not _cupy_available: @@ -156,3 +157,21 @@ def _calc_triplets_core_gpu( ) # torch tensor buffer is already modified in above cupy functions. return triplet_node_index, multiplicity, edge_jk, batch_triplets + + +@torch.jit.script +def _calc_triplets_core_gpu( + counts: Tensor, + unique: Tensor, + dst: Tensor, + edge_indices: Tensor, + batch_edge: Tensor, + counts_cumsum: Tensor, + dtype: torch.dtype = torch.float32, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + # dtype cannot be used inside @torch.jit.ignore function... + # https://github.com/pytorch/pytorch/issues/51941 + triplet_node_index, multiplicity, edge_jk, batch_triplets = _calc_triplets_core_gpu_run_kernel( + counts, unique, dst, edge_indices, batch_edge, counts_cumsum + ) + return triplet_node_index, multiplicity.to(dtype), edge_jk, batch_triplets diff --git a/torch_dftd/nn/base_dftd_module.py b/torch_dftd/nn/base_dftd_module.py index 48248ff..dee3679 100644 --- a/torch_dftd/nn/base_dftd_module.py +++ b/torch_dftd/nn/base_dftd_module.py @@ -1,15 +1,14 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import torch -from ase.neighborlist import primitive_neighbor_list -from ase.units import Bohr from torch import Tensor, nn +from torch_dftd.functions.dftd3 import d3_autoang, d3_autoev class BaseDFTDModule(nn.Module): """BaseDFTDModule""" + @torch.jit.ignore def calc_energy_batch( self, Z: Tensor, @@ -21,6 +20,8 @@ def calc_energy_batch( batch: Optional[Tensor] = None, batch_edge: Optional[Tensor] = None, damping: str = "zero", + autoang: float = d3_autoang, + autoev: float = d3_autoev, ) -> Tensor: """Forward computation to calculate atomic wise dispersion energy. @@ -36,12 +37,15 @@ def calc_energy_batch( batch (Tensor): (n_atoms,) Specify which graph this atom belongs to batch_edge (Tensor): (n_edges, 3) Specify which graph this edge belongs to damping (str): + autoang (float): + autoev (float): Returns: energy (Tensor): (n_atoms,) """ raise NotImplementedError() + @torch.jit.export def calc_energy( self, Z: Tensor, @@ -53,7 +57,9 @@ def calc_energy( batch: Optional[Tensor] = None, batch_edge: Optional[Tensor] = None, damping: str = "zero", - ) -> List[Dict[str, Any]]: + autoang: float = d3_autoang, + autoev: float = d3_autoev, + ) -> List[Dict[str, float]]: """Forward computation of dispersion energy Backward computation is skipped for fast computation of only energy. @@ -68,24 +74,38 @@ def calc_energy( batch (Tensor): batch_edge (Tensor): damping (str): damping method. "zero", "bj", "zerom", "bjm" + autoang (float): + autoev (float): Returns: results_list (list): calculated result. It contains calculate energy in "energy" key. """ with torch.no_grad(): E_disp = self.calc_energy_batch( - Z, pos, edge_index, cell, pbc, shift_pos, batch, batch_edge, damping=damping + Z, + pos, + edge_index, + cell, + pbc, + shift_pos, + batch, + batch_edge, + damping=damping, + autoang=autoang, + autoev=autoev, ) + E_disp_list: List[float] = E_disp.tolist() if batch is None: - return [{"energy": E_disp.item()}] + return [{"energy": E_disp_list[0]}] else: if batch.size()[0] == 0: n_graphs = 1 else: n_graphs = int(batch[-1]) + 1 - return [{"energy": E_disp[i].item()} for i in range(n_graphs)] + return [{"energy": E_disp_list[i]} for i in range(n_graphs)] - def calc_energy_and_forces( + @torch.jit.export + def _calc_energy_and_forces_core( self, Z: Tensor, pos: Tensor, @@ -96,23 +116,9 @@ def calc_energy_and_forces( batch: Optional[Tensor] = None, batch_edge: Optional[Tensor] = None, damping: str = "zero", - ) -> List[Dict[str, Any]]: - """Forward computation of dispersion energy, force and stress - - Args: - Z (Tensor): (n_atoms,) atomic numbers. - pos (Tensor): atom positions in angstrom - cell (Tensor): cell size in angstrom, None for non periodic system. - pbc (Tensor): pbc condition, None for non periodic system. - shift_pos (Tensor): (n_atoms, 3) shift vector (length unit). - damping (str): damping method. "zero", "bj", "zerom", "bjm" - - Returns: - results (list): calculated results. Contains following: - "energy": () - "forces": (n_atoms, 3) - "stress": (6,) - """ + autoang: float = d3_autoang, + autoev: float = d3_autoev, + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: pos.requires_grad_(True) if cell is not None: # pos is depending on `cell` size @@ -123,21 +129,21 @@ def calc_energy_and_forces( shift_pos.requires_grad_(True) E_disp = self.calc_energy_batch( - Z, pos, edge_index, cell, pbc, shift_pos, batch, batch_edge, damping=damping + Z, + pos, + edge_index, + cell, + pbc, + shift_pos, + batch, + batch_edge, + damping=damping, + autoang=autoang, + autoev=autoev, ) E_disp.sum().backward() - forces = -pos.grad # [eV/angstrom] - if batch is None: - results_list = [{"energy": E_disp.item(), "forces": forces.cpu().numpy()}] - else: - if batch.size()[0] == 0: - n_graphs = 1 - else: - n_graphs = int(batch[-1]) + 1 - results_list = [{"energy": E_disp[i].item()} for i in range(n_graphs)] - for i in range(n_graphs): - results_list[i]["forces"] = forces[batch == i].cpu().numpy() + pos_grad = pos.grad if cell is not None: # stress = torch.mm(cell_grad, cell.T) / cell_volume @@ -155,9 +161,14 @@ def calc_energy_and_forces( dim=0, ) stress = cell_grad.to(cell.dtype) / cell_volume - results_list[0]["stress"] = stress.detach().cpu().numpy() else: + assert isinstance(batch, Tensor) assert isinstance(batch_edge, Tensor) + if batch.size()[0] == 0: + n_graphs = 1 + else: + n_graphs = int(batch[-1]) + 1 + # cell (bs, 3, 3) cell_volume = torch.det(cell).abs() cell_grad = pos.new_zeros((n_graphs, 6), dtype=torch.float64) @@ -172,6 +183,69 @@ def calc_energy_and_forces( (shift_pos[:, voigt_left] * shift_pos.grad[:, voigt_right]).to(torch.float64), ) stress = cell_grad.to(cell.dtype) / cell_volume[:, None] + stress = stress + else: + stress = None + return E_disp, pos_grad, stress + + @torch.jit.ignore + def calc_energy_and_forces( + self, + Z: Tensor, + pos: Tensor, + edge_index: Tensor, + cell: Optional[Tensor] = None, + pbc: Optional[Tensor] = None, + shift_pos: Optional[Tensor] = None, + batch: Optional[Tensor] = None, + batch_edge: Optional[Tensor] = None, + damping: str = "zero", + autoang: float = d3_autoang, + autoev: float = d3_autoev, + ) -> List[Dict[str, Any]]: + """Forward computation of dispersion energy, force and stress + + Args: + Z (Tensor): (n_atoms,) atomic numbers. + pos (Tensor): atom positions in angstrom + cell (Tensor): cell size in angstrom, None for non periodic system. + pbc (Tensor): pbc condition, None for non periodic system. + shift_pos (Tensor): (n_atoms, 3) shift vector (length unit). + damping (str): damping method. "zero", "bj", "zerom", "bjm" + autoang (float): + autoev (float): + + Returns: + results (list): calculated results. Contains following: + "energy": () + "forces": (n_atoms, 3) + "stress": (6,) + """ + E_disp, pos_grad, stress = self._calc_energy_and_forces_core( + Z, pos, edge_index, cell, pbc, shift_pos, batch, batch_edge, damping, autoang, autoev + ) + + forces = (-pos_grad).cpu().numpy() + n_graphs = 0 # Just to declare for torch.jit.script. + if batch is None: + results_list = [{"energy": E_disp.item(), "forces": forces}] + else: + if batch.size()[0] == 0: + n_graphs = 1 + else: + n_graphs = int(batch[-1]) + 1 + E_disp_list = E_disp.tolist() + results_list = [{"energy": E_disp_list[i]} for i in range(n_graphs)] + batch_array = batch.cpu().numpy() + for i in range(n_graphs): + results_list[i]["forces"] = forces[batch_array == i] + + if stress is not None: + # stress = torch.mm(cell_grad, cell.T) / cell_volume + # Get stress in Voigt notation (xx, yy, zz, yz, xz, xy) + if batch is None: + results_list[0]["stress"] = stress.detach().cpu().numpy() + else: stress = stress.detach().cpu().numpy() for i in range(n_graphs): results_list[i]["stress"] = stress[i] diff --git a/torch_dftd/nn/dftd2_module.py b/torch_dftd/nn/dftd2_module.py index 57fd76d..af9862b 100644 --- a/torch_dftd/nn/dftd2_module.py +++ b/torch_dftd/nn/dftd2_module.py @@ -20,6 +20,9 @@ class DFTD2Module(BaseDFTDModule): bidirectional (bool): calculated `edge_index` is bidirectional or not. """ + c6ab: Tensor + r0ab: Tensor + def __init__( self, params: Dict[str, float], @@ -35,12 +38,14 @@ def __init__( self.dtype = dtype self.bidirectional = bidirectional self.cutoff_smoothing = cutoff_smoothing + self._bohr = Bohr # For torch.jit, `Bohr` must be local variable. r0ab, c6ab = get_dftd2_params() # atom pair coefficient (87, 87) self.register_buffer("c6ab", c6ab) # atom pair distance (95, 95) self.register_buffer("r0ab", r0ab) + @torch.jit.export def calc_energy_batch( self, Z: Tensor, @@ -52,28 +57,30 @@ def calc_energy_batch( batch: Optional[Tensor] = None, batch_edge: Optional[Tensor] = None, damping: str = "zero", + autoang: float = d3_autoang, + autoev: float = d3_autoev, ) -> Tensor: """Forward computation to calculate atomic wise dispersion energy""" - shift_pos = pos.new_zeros((edge_index.size()[1], 3, 3)) if shift_pos is None else shift_pos - pos_bohr = pos / d3_autoang # angstrom -> bohr + shift_pos = pos.new_zeros((edge_index.size()[1], 3)) if shift_pos is None else shift_pos + pos_bohr = pos / autoang # angstrom -> bohr if cell is None: cell_bohr: Optional[Tensor] = None else: - cell_bohr = cell / d3_autoang # angstrom -> bohr - shift_bohr = shift_pos / d3_autoang # angstrom -> bohr + cell_bohr = cell / autoang # angstrom -> bohr + shift_bohr = shift_pos / autoang # angstrom -> bohr r = calc_distances(pos_bohr, edge_index, cell_bohr, shift_bohr) # E_disp (n_graphs,): Energy in eV unit - E_disp = d3_autoev * edisp_d2( + E_disp = autoev * edisp_d2( Z, r, edge_index, - c6ab=self.c6ab, # type:ignore - r0ab=self.r0ab, # type:ignore + c6ab=self.c6ab, + r0ab=self.r0ab, params=self.params, damping=damping, bidirectional=self.bidirectional, - cutoff=self.cutoff / Bohr, + cutoff=self.cutoff / self._bohr, batch=batch, batch_edge=batch_edge, cutoff_smoothing=self.cutoff_smoothing, diff --git a/torch_dftd/nn/dftd3_module.py b/torch_dftd/nn/dftd3_module.py index 6fe0a84..e5fbaac 100644 --- a/torch_dftd/nn/dftd3_module.py +++ b/torch_dftd/nn/dftd3_module.py @@ -24,6 +24,11 @@ class DFTD3Module(BaseDFTDModule): bidirectional (bool): calculated `edge_index` is bidirectional or not. """ + c6ab: Tensor + r0ab: Tensor + rcov: Tensor + r2r4: Tensor + def __init__( self, params: Dict[str, float], @@ -62,7 +67,9 @@ def __init__( self.dtype = dtype self.bidirectional = bidirectional self.cutoff_smoothing = cutoff_smoothing + self._bohr = Bohr # For torch.jit, `Bohr` must be local variable. + @torch.jit.export def calc_energy_batch( self, Z: Tensor, @@ -74,28 +81,30 @@ def calc_energy_batch( batch: Optional[Tensor] = None, batch_edge: Optional[Tensor] = None, damping: str = "zero", + autoang: float = d3_autoang, + autoev: float = d3_autoev, ) -> Tensor: """Forward computation to calculate atomic wise dispersion energy""" - shift_pos = pos.new_zeros((edge_index.size()[1], 3, 3)) if shift_pos is None else shift_pos - pos_bohr = pos / d3_autoang # angstrom -> bohr + shift_pos = pos.new_zeros((edge_index.size()[1], 3)) if shift_pos is None else shift_pos + pos_bohr = pos / autoang # angstrom -> bohr if cell is None: cell_bohr: Optional[Tensor] = None else: - cell_bohr = cell / d3_autoang # angstrom -> bohr - shift_bohr = shift_pos / d3_autoang # angstrom -> bohr + cell_bohr = cell / autoang # angstrom -> bohr + shift_bohr = shift_pos / autoang # angstrom -> bohr r = calc_distances(pos_bohr, edge_index, cell_bohr, shift_bohr) # E_disp (n_graphs,): Energy in eV unit - E_disp = d3_autoev * edisp( + E_disp = autoev * edisp( Z, r, edge_index, - c6ab=self.c6ab, # type:ignore - r0ab=self.r0ab, # type:ignore - rcov=self.rcov, # type:ignore - r2r4=self.r2r4, # type:ignore + c6ab=self.c6ab, + r0ab=self.r0ab, + rcov=self.rcov, + r2r4=self.r2r4, params=self.params, - cutoff=self.cutoff / Bohr, - cnthr=self.cnthr / Bohr, + cutoff=self.cutoff / self._bohr, + cnthr=self.cnthr / self._bohr, batch=batch, batch_edge=batch_edge, shift_pos=shift_bohr, diff --git a/torch_dftd/torch_dftd3_calculator.py b/torch_dftd/torch_dftd3_calculator.py index 5c050ad..fa68456 100644 --- a/torch_dftd/torch_dftd3_calculator.py +++ b/torch_dftd/torch_dftd3_calculator.py @@ -58,7 +58,7 @@ def __init__( self.old = old self.device = torch.device(device) if old: - self.dftd_module: torch.nn.Module = DFTD2Module( + dftd_module: torch.nn.Module = DFTD2Module( self.params, cutoff=cutoff, dtype=dtype, @@ -66,7 +66,7 @@ def __init__( cutoff_smoothing=cutoff_smoothing, ) else: - self.dftd_module = DFTD3Module( + dftd_module = DFTD3Module( self.params, cutoff=cutoff, cnthr=cnthr, @@ -75,7 +75,7 @@ def __init__( bidirectional=bidirectional, cutoff_smoothing=cutoff_smoothing, ) - self.dftd_module.to(device) + self.dftd_module = torch.jit.script(dftd_module.to(device)) self.dtype = dtype self.cutoff = cutoff self.bidirectional = bidirectional