From fb2bf4348d95faa6ae3682994d56cf17a044ea82 Mon Sep 17 00:00:00 2001 From: Masaki Watanabe Date: Thu, 20 Apr 2023 13:44:38 +0900 Subject: [PATCH] fix cases where --pad_num_foobar are not specified --- export_static_onnx.py | 4 ++-- torch_dftd_static/functions/dftd3.py | 9 +++++---- torch_dftd_static/nn/dftd3_module.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/export_static_onnx.py b/export_static_onnx.py index 186c681..c34574c 100644 --- a/export_static_onnx.py +++ b/export_static_onnx.py @@ -131,7 +131,7 @@ def prepare_data(args): Z = torch.nn.functional.pad(Z, (0, n_pad), mode="constant") pos = torch.nn.functional.pad(pos, (0, 0, 0, n_pad), mode="constant") else: - atom_mask = None + atom_mask = torch.ones(len(Z), dtype=bool) if args.pad_num_cells is not None: shift_mask = torch.tensor(np.arange(args.pad_num_cells) < len(shift_vecs)) @@ -139,7 +139,7 @@ def prepare_data(args): assert n_pad >= 0 shift_vecs = torch.nn.functional.pad(shift_vecs, (0, 0, 0, n_pad), mode="constant") else: - shift_mask = None + shift_mask = torch.ones(len(shift_vecs), dtype=bool) print("n_atoms = ", len(Z), "n_cell = ", len(shift_vecs), file=sys.stderr) print("atoms = ", atoms, file=sys.stderr) diff --git a/torch_dftd_static/functions/dftd3.py b/torch_dftd_static/functions/dftd3.py index 3359050..3782e52 100644 --- a/torch_dftd_static/functions/dftd3.py +++ b/torch_dftd_static/functions/dftd3.py @@ -27,15 +27,16 @@ def edisp( # calculate edisp by all-pair computation atom_mask: Optional[Tensor] = None, shift_mask: Optional[Tensor] = None, ): + assert atom_mask is not None + assert shift_mask is not None + n_atoms = len(Z) #assert torch.all(shift_vecs[0] == 0.0) #triu_mask = (torch.arange(n_atoms)[:, None] < torch.arange(n_atoms)[None, :])[:, :, None] | ((torch.arange(len(shift_vecs)) > 0)[None, None, :]) triu_mask = (torch.arange(n_atoms)[:, None] < torch.arange(n_atoms)[None, :])[:, :, None] | ((torch.any(shift_vecs != 0.0, axis=-1))[None, None, :]) - if atom_mask is not None: - triu_mask = triu_mask & atom_mask[:, None, None] & atom_mask[None, :, None] - if shift_mask is not None: - triu_mask = triu_mask & shift_mask[None, None, :] + triu_mask = triu_mask & atom_mask[:, None, None] & atom_mask[None, :, None] + triu_mask = triu_mask & shift_mask[None, None, :] # calculate pairwise distances shifted_pos = pos[:, None, :] + shift_vecs[None, :, :] diff --git a/torch_dftd_static/nn/dftd3_module.py b/torch_dftd_static/nn/dftd3_module.py index 6244d37..f458f09 100644 --- a/torch_dftd_static/nn/dftd3_module.py +++ b/torch_dftd_static/nn/dftd3_module.py @@ -66,7 +66,7 @@ def calc_energy( pos: Tensor, shift_vecs: Tensor, cell_volume: float, - damping: str = "zero", + damping: str, atom_mask: Optional[Tensor] = None, shift_mask: Optional[Tensor] = None, ) -> Tensor: