Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct exclude_types in descriptors #3841

Merged
merged 11 commits into from
Jun 4, 2024
6 changes: 4 additions & 2 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,8 @@ def call(
nf, nloc, nnei, _ = dmatrix.shape
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# nfnl x nnei
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
# nfnl x nnei
nlist = nlist.reshape(nf * nloc, nnei)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
# nfnl x nnei x 4
dmatrix = dmatrix.reshape(nf * nloc, nnei, 4)
Expand All @@ -821,6 +823,8 @@ def call(
atype_embd_nnei = np.tile(atype_embd[:, np.newaxis, :], (1, nnei, 1))
# nfnl x nnei
nlist_mask = nlist != -1
# nfnl x nnei
nlist_mask = nlist_mask * exclude_mask.astype(bool)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
# nfnl x nnei x 1
sw = np.where(nlist_mask[:, :, None], sw, 0.0)
nlist_masked = np.where(nlist_mask, nlist, 0)
Expand All @@ -830,8 +834,6 @@ def call(
nf * nloc, nnei, self.tebd_dim
)
ng = self.neuron[-1]
# nfnl x nnei
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
# nfnl x nnei x 4
rr = dmatrix.reshape(nf * nloc, nnei, 4)
rr = rr * exclude_mask[:, :, None]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def call(
mapping: Optional[np.ndarray] = None,
):
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
nlist = nlist * exclude_mask
nlist = np.where(exclude_mask, nlist, -1)
# nf x nloc x nnei x 4
dmatrix, diff, sw = self.env_mat.call(
coord_ext, atype_ext, nlist, self.mean, self.stddev
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class DescrptSeA(NativeOP, BaseDescriptor):
The precision of the embedding net parameters. Supported options are |PRECISION|
spin
The deepspin object.
ntypes : int
Number of element types.
Not used in this descriptor, only to be compat with input.

Limitations
-----------
Expand Down Expand Up @@ -157,9 +160,11 @@ def __init__(
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
spin: Optional[Any] = None,
ntypes: Optional[int] = None, # to be compat with input
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
del ntypes
## seed, uniform_seed, not included.
if spin is not None:
raise NotImplementedError("spin is not implemented")
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class DescrptSeR(NativeOP, BaseDescriptor):
The precision of the embedding net parameters. Supported options are |PRECISION|
spin
The deepspin object.
ntypes : int
Number of element types.
Not used in this descriptor, only to be compat with input.

Limitations
-----------
Expand Down Expand Up @@ -113,9 +116,11 @@ def __init__(
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
spin: Optional[Any] = None,
ntypes: Optional[int] = None, # to be compat with input
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
del ntypes
## seed, uniform_seed, not included.
if not type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class DescrptSeT(NativeOP, BaseDescriptor):
If the weights of embedding net are trainable.
seed : int, Optional
Random seed for initializing the network parameters.
ntypes : int
Number of element types.
Not used in this descriptor, only to be compat with input.
"""

def __init__(
Expand All @@ -101,7 +104,9 @@ def __init__(
precision: str = DEFAULT_PRECISION,
trainable: bool = True,
seed: Optional[int] = None,
ntypes: Optional[int] = None, # to be compat with input
) -> None:
del ntypes
self.rcut = rcut
self.rcut_smth = rcut_smth
self.sel = sel
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def forward(
atype = extended_atype[:, :nloc]
# nb x nloc x nnei
exclude_mask = self.emask(nlist, extended_atype)
nlist = nlist * exclude_mask
nlist = torch.where(exclude_mask != 0, nlist, -1)
# nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1
dmatrix, diff, sw = prod_env_mat(
extended_coord,
Expand Down
9 changes: 6 additions & 3 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,6 @@ def forward(
nlist_mask = nlist != -1
iProzd marked this conversation as resolved.
Show resolved Hide resolved
nlist = torch.where(nlist == -1, 0, nlist)
sw = torch.squeeze(sw, -1)
# beyond the cutoff sw should be 0.0
sw = sw.masked_fill(~nlist_mask, 0.0)
# nf x nloc x nt -> nf x nloc x nnei x nt
atype_tebd = extended_atype_embd[:, :nloc, :]
atype_tebd_nnei = atype_tebd.unsqueeze(2).expand(-1, -1, self.nnei, -1)
Expand All @@ -495,8 +493,13 @@ def forward(
atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index)
# nb x nloc x nnei x nt
atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt)
# nb x nloc x nnei
exclude_mask = self.emask(nlist, extended_atype)
nlist_mask = nlist_mask & (exclude_mask != 0)
# beyond the cutoff sw should be 0.0
sw = sw.masked_fill(~nlist_mask, 0.0)
# (nb x nloc) x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nb * nloc, nnei)
exclude_mask = exclude_mask.view(nb * nloc, nnei)
if self.old_impl:
assert self.filter_layers_old is not None
dmatrix = dmatrix.view(
Expand Down
19 changes: 14 additions & 5 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,18 @@ def _pass_filter(
tf.shape(inputs_i)[0],
self.nei_type_vec, # extra input for atten
)
# (nsamples * natoms * nnei, 1)
nei_exclude_mask = tf.slice(
tf.reshape(tf.cast(mask, self.filter_precision), [-1, 4]),
[0, 0],
[-1, 1],
)
# (nsamples * natoms, 1, nnei)
self.nmask *= tf.reshape(
nei_exclude_mask,
[-1, 1, self.sel_all_a[0]],
)
self.negative_mask = -(2 << 32) * (1.0 - self.nmask)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
iProzd marked this conversation as resolved.
Show resolved Hide resolved
if self.smooth:
inputs_i = tf.where(
tf.cast(mask, tf.bool),
Expand All @@ -727,12 +739,9 @@ def _pass_filter(
tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt]
),
)
# (nsamples, natoms, nnei)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
self.recovered_switch *= tf.reshape(
tf.slice(
tf.reshape(tf.cast(mask, self.filter_precision), [-1, 4]),
[0, 0],
[-1, 1],
),
nei_exclude_mask,
[-1, natoms[0], self.sel_all_a[0]],
)
else:
Expand Down
133 changes: 133 additions & 0 deletions source/tests/universal/common/cases/cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np


# originally copied from source/tests/pt/model/test_env_mat.py
class TestCaseSingleFrameWithNlist:
def setUp(self):
# nloc == 3, nall == 4
self.nloc = 3
self.nall = 4
self.nf, self.nt = 2, 2
self.coord_ext = np.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, -2, 0],
],
dtype=np.float64,
).reshape([1, self.nall, 3])
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall])
self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall])
# sel = [5, 2]
self.sel = [5, 2]
self.sel_mix = [7]
self.natoms = [3, 3, 2, 1]
self.nlist = np.array(
[
[1, 3, -1, -1, -1, 2, -1],
[0, -1, -1, -1, -1, 2, -1],
[0, 1, -1, -1, -1, -1, -1],
],
dtype=int,
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 2.2
self.rcut_smth = 0.4
# permutations
self.perm = np.array([2, 0, 1, 3], dtype=np.int32)
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32)
# permute the coord and atype
self.coord_ext = np.concatenate(
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0
).reshape(self.nf, self.nall * 3)
self.atype_ext = np.concatenate(
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0
)
self.mapping = np.concatenate(
[self.mapping, self.mapping[:, self.perm]], axis=0
)

# permute the nlist
nlist1 = self.nlist[:, self.perm[: self.nloc], :]
mask = nlist1 == -1
nlist1 = inv_perm[nlist1]
nlist1 = np.where(mask, -1, nlist1)
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)
self.atol = 1e-12


class TestCaseSingleFrameWithNlistWithVirtual:
def setUp(self):
# nloc == 3, nall == 4
self.nloc = 4
self.nall = 5
self.nf, self.nt = 2, 2
self.coord_ext = np.array(
[
[0, 0, 0],
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, -2, 0],
],
dtype=np.float64,
).reshape([1, self.nall, 3])
self.atype_ext = np.array([0, -1, 0, 1, 0], dtype=int).reshape([1, self.nall])
# sel = [5, 2]
self.sel = [5, 2]
self.sel_mix = [7]
self.natoms = [3, 3, 2, 1]
self.nlist = np.array(
[
[2, 4, -1, -1, -1, 3, -1],
[-1, -1, -1, -1, -1, -1, -1],
[0, -1, -1, -1, -1, 3, -1],
[0, 2, -1, -1, -1, -1, -1],
],
dtype=int,
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 2.2
self.rcut_smth = 0.4
# permutations
self.perm = np.array([3, 0, 1, 2, 4], dtype=np.int32)
inv_perm = np.argsort(self.perm)
# permute the coord and atype
self.coord_ext = np.concatenate(
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0
).reshape(self.nf, self.nall * 3)
self.atype_ext = np.concatenate(
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0
)
# permute the nlist
nlist1 = self.nlist[:, self.perm[: self.nloc], :]
mask = nlist1 == -1
nlist1 = inv_perm[nlist1]
nlist1 = np.where(mask, -1, nlist1)
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)
self.get_real_mapping = np.array([[0, 2, 3], [0, 1, 3]], dtype=np.int32)
self.atol = 1e-12


class TestCaseSingleFrameWithoutNlist:
def setUp(self):
# nloc == 3, nall == 4
self.nloc = 3
self.nf, self.nt = 1, 2
self.coord = np.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
],
dtype=np.float64,
).reshape([1, self.nloc * 3])
self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc])
self.cell = 2.0 * np.eye(3).reshape([1, 9])
# sel = [5, 2]
self.sel = [16, 8]
self.sel_mix = [24]
self.natoms = [3, 3, 2, 1]
self.rcut = 2.2
self.rcut_smth = 0.4
self.atol = 1e-12
1 change: 1 addition & 0 deletions source/tests/universal/common/cases/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
11 changes: 11 additions & 0 deletions source/tests/universal/common/cases/descriptor/descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from .utils import (
DescriptorTestCase,
)


class DescriptorTest(DescriptorTestCase):
def setUp(self) -> None:
DescriptorTestCase.setUp(self)
Loading