Skip to content

Commit

Permalink
feat: apply descriptor exclude_types to env mat stat (#3625)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Apr 7, 2024
1 parent 39d027e commit 87d293a
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 4 deletions.
3 changes: 3 additions & 0 deletions deepmd/pt/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def iter(
radial_only,
protection=self.descriptor.env_protection,
)
# apply excluded_types
exclude_mask = self.descriptor.emask(nlist, extended_atype)
env_mat *= exclude_mask.unsqueeze(-1)
# reshape to nframes * nloc at the atom level,
# so nframes/mixed_type do not matter
env_mat = env_mat.view(
Expand Down
3 changes: 2 additions & 1 deletion deepmd/tf/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Dict,
List,
Optional,
Set,
Tuple,
)

Expand Down Expand Up @@ -357,7 +358,7 @@ def pass_tensors_from_frz_model(

def build_type_exclude_mask(
self,
exclude_types: List[Tuple[int, int]],
exclude_types: Set[Tuple[int, int]],
ntypes: int,
sel: List[int],
ndescrpt: int,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,18 @@ def __init__(
sel_a=self.sel_a,
sel_r=self.sel_r,
)
if len(self.exclude_types):
# exclude types applied to data stat
mask = self.build_type_exclude_mask(
self.exclude_types,
self.ntypes,
self.sel_a,
self.ndescrpt,
# for data stat, nloc == nall
self.place_holders["type"],
tf.size(self.place_holders["type"]),
)
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
self.original_sel = None
self.multi_task = multi_task
Expand Down
20 changes: 17 additions & 3 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
List,
Optional,
Set,
Tuple,
)

Expand Down Expand Up @@ -282,6 +283,19 @@ def __init__(
sel_a=self.sel_all_a,
sel_r=self.sel_all_r,
)
if len(self.exclude_types):
# exclude types applied to data stat
mask = self.build_type_exclude_mask_mixed(
self.exclude_types,
self.ntypes,
self.sel_a,
self.ndescrpt,
# for data stat, nloc == nall
self.place_holders["type"],
tf.size(self.place_holders["type"]),
self.nei_type_vec_t, # extra input for atten
)
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)

def compute_input_stats(
Expand Down Expand Up @@ -672,7 +686,7 @@ def _pass_filter(
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
if len(self.exclude_types):
mask = self.build_type_exclude_mask(
mask = self.build_type_exclude_mask_mixed(
self.exclude_types,
self.ntypes,
self.sel_a,
Expand Down Expand Up @@ -1367,9 +1381,9 @@ def init_variables(
)
)

def build_type_exclude_mask(
def build_type_exclude_mask_mixed(
self,
exclude_types: List[Tuple[int, int]],
exclude_types: Set[Tuple[int, int]],
ntypes: int,
sel: List[int],
ndescrpt: int,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/tf/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,18 @@ def __init__(
rcut_smth=self.rcut_smth,
sel=self.sel_r,
)
if len(self.exclude_types):
# exclude types applied to data stat
mask = self.build_type_exclude_mask(
self.exclude_types,
self.ntypes,
self.sel_r,
self.ndescrpt,
# for data stat, nloc == nall
self.place_holders["type"],
tf.size(self.place_holders["type"]),
)
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
self.sub_sess = tf.Session(
graph=sub_graph, config=default_tf_session_config
)
Expand Down
38 changes: 38 additions & 0 deletions source/tests/pt/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,44 @@ def tf_compute_input_stats(self):
)


class TestExcludeTypes(DatasetTest, unittest.TestCase):
def setup_data(self):
original_data = str(Path(__file__).parent / "water/data/data_0")
picked_data = str(Path(__file__).parent / "picked_data_for_test_stat")
dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy(
picked_data
)
self.mixed_type = False
return picked_data

def setup_tf(self):
return DescrptSeA_tf(
rcut=self.rcut,
rcut_smth=self.rcut_smth,
sel=self.sel,
neuron=self.filter_neuron,
axis_neuron=self.axis_neuron,
exclude_types=[[0, 0], [1, 1]],
)

def setup_pt(self):
return DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
self.filter_neuron,
self.axis_neuron,
exclude_types=[[0, 0], [1, 1]],
).sea # get the block who has stat as private vars

def tf_compute_input_stats(self):
coord = self.dp_merged["coord"]
atype = self.dp_merged["type"]
natoms = self.dp_merged["natoms_vec"]
box = self.dp_merged["box"]
self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {})


class TestOutputStat(unittest.TestCase):
def setUp(self):
self.data_file = [str(Path(__file__).parent / "water/data/data_0")]
Expand Down

0 comments on commit 87d293a

Please sign in to comment.