Skip to content

Commit

Permalink
Merge pull request pfnet-research#12 from mwata/pad
Browse files Browse the repository at this point in the history
fix cases where --pad_num_foobar are not specified
  • Loading branch information
masakiwatanabe authored and GitHub Enterprise committed Apr 20, 2023
2 parents c5ea81c + fb2bf43 commit 64473bc
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions export_static_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ def prepare_data(args):
Z = torch.nn.functional.pad(Z, (0, n_pad), mode="constant")
pos = torch.nn.functional.pad(pos, (0, 0, 0, n_pad), mode="constant")
else:
atom_mask = None
atom_mask = torch.ones(len(Z), dtype=bool)

if args.pad_num_cells is not None:
shift_mask = torch.tensor(np.arange(args.pad_num_cells) < len(shift_vecs))
n_pad = args.pad_num_cells - len(shift_vecs)
assert n_pad >= 0
shift_vecs = torch.nn.functional.pad(shift_vecs, (0, 0, 0, n_pad), mode="constant")
else:
shift_mask = None
shift_mask = torch.ones(len(shift_vecs), dtype=bool)

print("n_atoms = ", len(Z), "n_cell = ", len(shift_vecs), file=sys.stderr)
print("atoms = ", atoms, file=sys.stderr)
Expand Down
9 changes: 5 additions & 4 deletions torch_dftd_static/functions/dftd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ def edisp( # calculate edisp by all-pair computation
atom_mask: Optional[Tensor] = None,
shift_mask: Optional[Tensor] = None,
):
assert atom_mask is not None
assert shift_mask is not None

n_atoms = len(Z)
#assert torch.all(shift_vecs[0] == 0.0)
#triu_mask = (torch.arange(n_atoms)[:, None] < torch.arange(n_atoms)[None, :])[:, :, None] | ((torch.arange(len(shift_vecs)) > 0)[None, None, :])
triu_mask = (torch.arange(n_atoms)[:, None] < torch.arange(n_atoms)[None, :])[:, :, None] | ((torch.any(shift_vecs != 0.0, axis=-1))[None, None, :])

if atom_mask is not None:
triu_mask = triu_mask & atom_mask[:, None, None] & atom_mask[None, :, None]
if shift_mask is not None:
triu_mask = triu_mask & shift_mask[None, None, :]
triu_mask = triu_mask & atom_mask[:, None, None] & atom_mask[None, :, None]
triu_mask = triu_mask & shift_mask[None, None, :]

# calculate pairwise distances
shifted_pos = pos[:, None, :] + shift_vecs[None, :, :]
Expand Down
2 changes: 1 addition & 1 deletion torch_dftd_static/nn/dftd3_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def calc_energy(
pos: Tensor,
shift_vecs: Tensor,
cell_volume: float,
damping: str = "zero",
damping: str,
atom_mask: Optional[Tensor] = None,
shift_mask: Optional[Tensor] = None,
) -> Tensor:
Expand Down

0 comments on commit 64473bc

Please sign in to comment.