Skip to content

Commit

Permalink
Merge pull request pfnet-research#11 from mwata/pad
Browse files Browse the repository at this point in the history
add atom_mask and shift_mask
  • Loading branch information
masakiwatanabe authored and GitHub Enterprise committed Apr 20, 2023
2 parents 141103a + 118c93b commit c5ea81c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
27 changes: 25 additions & 2 deletions export_static_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ def __init__(
)
self.damping = damping

def forward(self, Z, pos, shift_vecs, cell_volume):
def forward(self, Z, pos, shift_vecs, cell_volume, atom_mask, shift_mask):
r = self.dftd_module.calc_energy(
Z,
pos,
shift_vecs,
cell_volume,
damping=self.damping,
atom_mask=atom_mask,
shift_mask=shift_mask,
)
return r[0]["energy"]

Expand All @@ -88,6 +90,8 @@ def parse_args():
parser.add_argument("--repeat", type=str, help="number of repeats inside cell")
parser.add_argument("--clip_num_atoms", type=int, help="max number of atoms (exceeded atoms are trashed)")
parser.add_argument("--out_dir", type=str, help="onnx output dir", required=True)
parser.add_argument("--pad_num_atoms", type=int, help="num_atoms after padding", required=False)
parser.add_argument("--pad_num_cells", type=int, help="num_cells after padding", required=False)
return parser.parse_args()

def prepare_data(args):
Expand Down Expand Up @@ -120,6 +124,23 @@ def prepare_data(args):
shift_vecs = calc_shift_vecs(cell, pbc, cutoff=cutoff)
shift_vecs = torch.tensor(shift_vecs)

if args.pad_num_atoms is not None:
atom_mask = torch.tensor(np.arange(args.pad_num_atoms) < len(Z))
n_pad = args.pad_num_atoms - len(Z)
assert n_pad >= 0
Z = torch.nn.functional.pad(Z, (0, n_pad), mode="constant")
pos = torch.nn.functional.pad(pos, (0, 0, 0, n_pad), mode="constant")
else:
atom_mask = None

if args.pad_num_cells is not None:
shift_mask = torch.tensor(np.arange(args.pad_num_cells) < len(shift_vecs))
n_pad = args.pad_num_cells - len(shift_vecs)
assert n_pad >= 0
shift_vecs = torch.nn.functional.pad(shift_vecs, (0, 0, 0, n_pad), mode="constant")
else:
shift_mask = None

print("n_atoms = ", len(Z), "n_cell = ", len(shift_vecs), file=sys.stderr)
print("atoms = ", atoms, file=sys.stderr)

Expand All @@ -128,6 +149,8 @@ def prepare_data(args):
"pos": pos.type(torch.float32),
"shift_vecs": shift_vecs.type(torch.float32),
"cell_volume": cell_volume,
"atom_mask": atom_mask,
"shift_mask": shift_mask,
}

exporter = ExportONNX(cutoff=cutoff, damping="bj")
Expand All @@ -142,4 +165,4 @@ def prepare_data(args):

print("out_dir = ", out_dir, file=sys.stderr)
ppe_onnx.export_testcase(exporter, tuple(args.values()), out_dir, verbose=True,
input_names=["Z","pos","shift_vecs","cell_volume"])
input_names=["Z","pos","shift_vecs","cell_volume","atom_mask","shift_mask"])
10 changes: 9 additions & 1 deletion torch_dftd_static/functions/dftd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,19 @@ def edisp( # calculate edisp by all-pair computation
k3=d3_k3,
cutoff_smoothing: str = "none",
damping: str = "zero",
atom_mask: Optional[Tensor] = None,
shift_mask: Optional[Tensor] = None,
):
n_atoms = len(Z)
#assert torch.all(shift_vecs[0] == 0.0)
#triu_mask = (torch.arange(n_atoms)[:, None] < torch.arange(n_atoms)[None, :])[:, :, None] | ((torch.arange(len(shift_vecs)) > 0)[None, None, :])
triu_mask = (torch.arange(n_atoms)[:, None] < torch.arange(n_atoms)[None, :])[:, :, None] | ((torch.any(shift_vecs != 0.0, axis=-1))[None, None, :])

if atom_mask is not None:
triu_mask = triu_mask & atom_mask[:, None, None] & atom_mask[None, :, None]
if shift_mask is not None:
triu_mask = triu_mask & shift_mask[None, None, :]

# calculate pairwise distances
shifted_pos = pos[:, None, :] + shift_vecs[None, :, :]
r2 = torch.sum((pos[:, None, None, :] - shifted_pos[None, :, :, :]) ** 2, axis=-1)
Expand Down Expand Up @@ -86,8 +93,9 @@ def edisp( # calculate edisp by all-pair computation
e68 *= poly_smoothing(r, cutoff)

e68 = torch.where(r <= cutoff, e68, torch.tensor(0.0))

e68 = torch.where(triu_mask, e68, torch.tensor(0.0))

return torch.sum(e68.to(torch.float64).sum()) * 2.0

#e68_same_cell = e68[:, :, 0]
Expand Down
6 changes: 5 additions & 1 deletion torch_dftd_static/nn/dftd3_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Dict
from typing import Dict, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -67,6 +67,8 @@ def calc_energy(
shift_vecs: Tensor,
cell_volume: float,
damping: str = "zero",
atom_mask: Optional[Tensor] = None,
shift_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward computation to calculate atomic wise dispersion energy"""

Expand All @@ -85,5 +87,7 @@ def calc_energy(
cnthr=self.cnthr / Bohr,
cutoff_smoothing=self.cutoff_smoothing,
damping=damping,
atom_mask=atom_mask,
shift_mask=shift_mask,
)
return [{"energy": E_disp}]

0 comments on commit c5ea81c

Please sign in to comment.