diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index ca068d4821..876062cce6 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -389,9 +389,14 @@ def set_stat_mean_and_stddev( mean: np.ndarray, stddev: np.ndarray, ) -> None: + """Update mean and stddev for descriptor.""" self.se_atten.mean = mean self.se_atten.stddev = stddev + def get_stat_mean_and_stddev(self) -> Tuple[np.ndarray, np.ndarray]: + """Get mean and stddev for descriptor.""" + return self.se_atten.mean, self.se_atten.stddev + def change_type_map( self, type_map: List[str], model_with_new_type_stat=None ) -> None: diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 353dcd6128..766fe19302 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -606,6 +606,23 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) """Update mean and stddev for descriptor elements.""" raise NotImplementedError + def set_stat_mean_and_stddev( + self, + mean: List[np.ndarray], + stddev: List[np.ndarray], + ) -> None: + """Update mean and stddev for descriptor.""" + for ii, descrpt in enumerate([self.repinit, self.repformers]): + descrpt.mean = mean[ii] + descrpt.stddev = stddev[ii] + + def get_stat_mean_and_stddev(self) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """Get mean and stddev for descriptor.""" + return [self.repinit.mean, self.repformers.mean], [ + self.repinit.stddev, + self.repformers.stddev, + ] + def call( self, coord_ext: np.ndarray, diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index b0db8c4354..3b08426b13 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -183,6 +183,30 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) for descrpt in self.descrpt_list: descrpt.compute_input_stats(merged, path) + def set_stat_mean_and_stddev( + self, + mean: List[Union[np.ndarray, List[np.ndarray]]], + stddev: List[Union[np.ndarray, List[np.ndarray]]], + ) -> None: + """Update mean and stddev for descriptor.""" + for ii, descrpt in enumerate(self.descrpt_list): + descrpt.set_stat_mean_and_stddev(mean[ii], stddev[ii]) + + def get_stat_mean_and_stddev( + self, + ) -> Tuple[ + List[Union[np.ndarray, List[np.ndarray]]], + List[Union[np.ndarray, List[np.ndarray]]], + ]: + """Get mean and stddev for descriptor.""" + mean_list = [] + stddev_list = [] + for ii, descrpt in enumerate(self.descrpt_list): + mean_item, stddev_item = descrpt.get_stat_mean_and_stddev() + mean_list.append(mean_item) + stddev_list.append(stddev_item) + return mean_list, stddev_list + def call( self, coord_ext, diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index 215b13577b..49bf000248 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -127,6 +127,16 @@ def change_type_map( """ pass + @abstractmethod + def set_stat_mean_and_stddev(self, mean, stddev) -> None: + """Update mean and stddev for descriptor.""" + pass + + @abstractmethod + def get_stat_mean_and_stddev(self): + """Get mean and stddev for descriptor.""" + pass + def compute_input_stats( self, merged: Union[Callable[[], List[dict]], List[dict]], diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index a7d237ac41..504e357aeb 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -296,6 +296,19 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) """Update mean and stddev for descriptor elements.""" raise NotImplementedError + def set_stat_mean_and_stddev( + self, + mean: np.ndarray, + stddev: np.ndarray, + ) -> None: + """Update mean and stddev for descriptor.""" + self.davg = mean + self.dstd = stddev + + def get_stat_mean_and_stddev(self) -> Tuple[np.ndarray, np.ndarray]: + """Get mean and stddev for descriptor.""" + return self.davg, self.dstd + def cal_g( self, ss, diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index a8916df220..938826d16c 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -254,6 +254,19 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) """Update mean and stddev for descriptor elements.""" raise NotImplementedError + def set_stat_mean_and_stddev( + self, + mean: np.ndarray, + stddev: np.ndarray, + ) -> None: + """Update mean and stddev for descriptor.""" + self.davg = mean + self.dstd = stddev + + def get_stat_mean_and_stddev(self) -> Tuple[np.ndarray, np.ndarray]: + """Get mean and stddev for descriptor.""" + return self.davg, self.dstd + def cal_g( self, ss, diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 5d9dcb5ecc..b91f9a6c6e 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -234,6 +234,19 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) """Update mean and stddev for descriptor elements.""" raise NotImplementedError + def set_stat_mean_and_stddev( + self, + mean: np.ndarray, + stddev: np.ndarray, + ) -> None: + """Update mean and stddev for descriptor.""" + self.davg = mean + self.dstd = stddev + + def get_stat_mean_and_stddev(self) -> Tuple[np.ndarray, np.ndarray]: + """Get mean and stddev for descriptor.""" + return self.davg, self.dstd + def reinit_exclude( self, exclude_types: List[Tuple[int, int]] = [], diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 824c0615f8..ff29d14e1d 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -419,9 +419,14 @@ def set_stat_mean_and_stddev( mean: torch.Tensor, stddev: torch.Tensor, ) -> None: + """Update mean and stddev for descriptor.""" self.se_atten.mean = mean self.se_atten.stddev = stddev + def get_stat_mean_and_stddev(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Get mean and stddev for descriptor.""" + return self.se_atten.mean, self.se_atten.stddev + def change_type_map( self, type_map: List[str], model_with_new_type_stat=None ) -> None: diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index b26919c581..ae8c924e9a 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -429,6 +429,23 @@ def compute_input_stats( for ii, descrpt in enumerate([self.repinit, self.repformers]): descrpt.compute_input_stats(merged, path) + def set_stat_mean_and_stddev( + self, + mean: List[torch.Tensor], + stddev: List[torch.Tensor], + ) -> None: + """Update mean and stddev for descriptor.""" + for ii, descrpt in enumerate([self.repinit, self.repformers]): + descrpt.mean = mean[ii] + descrpt.stddev = stddev[ii] + + def get_stat_mean_and_stddev(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Get mean and stddev for descriptor.""" + return [self.repinit.mean, self.repformers.mean], [ + self.repinit.stddev, + self.repformers.stddev, + ] + def serialize(self) -> dict: repinit = self.repinit repformers = self.repformers diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index 26bbe7b199..d486cda399 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -197,6 +197,30 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) for descrpt in self.descrpt_list: descrpt.compute_input_stats(merged, path) + def set_stat_mean_and_stddev( + self, + mean: List[Union[torch.Tensor, List[torch.Tensor]]], + stddev: List[Union[torch.Tensor, List[torch.Tensor]]], + ) -> None: + """Update mean and stddev for descriptor.""" + for ii, descrpt in enumerate(self.descrpt_list): + descrpt.set_stat_mean_and_stddev(mean[ii], stddev[ii]) + + def get_stat_mean_and_stddev( + self, + ) -> Tuple[ + List[Union[torch.Tensor, List[torch.Tensor]]], + List[Union[torch.Tensor, List[torch.Tensor]]], + ]: + """Get mean and stddev for descriptor.""" + mean_list = [] + stddev_list = [] + for ii, descrpt in enumerate(self.descrpt_list): + mean_item, stddev_item = descrpt.get_stat_mean_and_stddev() + mean_list.append(mean_item) + stddev_list.append(stddev_item) + return mean_list, stddev_list + def forward( self, coord_ext: torch.Tensor, diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index f5b83aa81f..e771c03e52 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -273,9 +273,14 @@ def set_stat_mean_and_stddev( mean: torch.Tensor, stddev: torch.Tensor, ) -> None: + """Update mean and stddev for descriptor.""" self.sea.mean = mean self.sea.stddev = stddev + def get_stat_mean_and_stddev(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Get mean and stddev for descriptor.""" + return self.sea.mean, self.sea.stddev + def serialize(self) -> dict: obj = self.sea return { diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index 78112e6e87..e6ebe53c26 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -389,9 +389,14 @@ def set_stat_mean_and_stddev( mean: torch.Tensor, stddev: torch.Tensor, ) -> None: + """Update mean and stddev for descriptor.""" self.mean = mean self.stddev = stddev + def get_stat_mean_and_stddev(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Get mean and stddev for descriptor.""" + return self.mean, self.stddev + def serialize(self) -> dict: return { "@class": "Descriptor", diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index c2f31eb9f3..caa4c9ce45 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -303,9 +303,14 @@ def set_stat_mean_and_stddev( mean: torch.Tensor, stddev: torch.Tensor, ) -> None: + """Update mean and stddev for descriptor.""" self.seat.mean = mean self.seat.stddev = stddev + def get_stat_mean_and_stddev(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Get mean and stddev for descriptor.""" + return self.seat.mean, self.seat.stddev + def serialize(self) -> dict: obj = self.seat return { diff --git a/source/tests/universal/common/backend.py b/source/tests/universal/common/backend.py index d5747b77b7..44532a4d68 100644 --- a/source/tests/universal/common/backend.py +++ b/source/tests/universal/common/backend.py @@ -21,3 +21,13 @@ def modules_to_test(self) -> list: @abstractmethod def forward_wrapper(self, x): pass + + @classmethod + @abstractmethod + def convert_to_numpy(cls, xx): + pass + + @classmethod + @abstractmethod + def convert_from_numpy(cls, xx): + pass diff --git a/source/tests/universal/common/cases/descriptor/utils.py b/source/tests/universal/common/cases/descriptor/utils.py index 450b0c73cb..e1c2b80c15 100644 --- a/source/tests/universal/common/cases/descriptor/utils.py +++ b/source/tests/universal/common/cases/descriptor/utils.py @@ -9,6 +9,9 @@ from deepmd.dpmodel.utils import ( PairExcludeMask, ) +from deepmd.utils.finetune import ( + get_index_between_two_maps, +) from .....seed import ( GLOBAL_SEED, @@ -184,6 +187,164 @@ def test_change_type_map(self): ) np.testing.assert_allclose(rd_old_tm, rd_new_tm) + def test_change_type_map_extend_stat(self): + if ( + not self.module.mixed_types() + or getattr(self.module, "sel_no_mixed_types", None) is not None + ): + # skip if not mixed_types + return + full_type_map_test = [ + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + ] # 18 elements + rng = np.random.default_rng(GLOBAL_SEED) + for small_tm, large_tm in itertools.product( + [ + full_type_map_test[:8], # 8 elements, tebd default first dim + ["H", "O"], # slimmed types + ], # small_tm + [ + full_type_map_test[:], # 18 elements + full_type_map_test[ + :16 + ], # 16 elements, double of tebd default first dim + full_type_map_test[:8], # 8 elements, tebd default first dim + ], # large_tm + ): + # use shuffled type_map + rng.shuffle(small_tm) + rng.shuffle(large_tm) + small_tm_input = update_input_type_map(self.input_dict, small_tm) + small_tm_module = self.module_class(**small_tm_input) + + large_tm_input = update_input_type_map(self.input_dict, large_tm) + large_tm_module = self.module_class(**large_tm_input) + + # set random stat + mean_small_tm, std_small_tm = small_tm_module.get_stat_mean_and_stddev() + mean_large_tm, std_large_tm = large_tm_module.get_stat_mean_and_stddev() + if "list" not in self.input_dict: + mean_rand_small_tm, std_rand_small_tm = self.get_rand_stat( + rng, mean_small_tm, std_small_tm + ) + mean_rand_large_tm, std_rand_large_tm = self.get_rand_stat( + rng, mean_large_tm, std_large_tm + ) + else: + # for hybrid + mean_rand_small_tm, std_rand_small_tm = [], [] + mean_rand_large_tm, std_rand_large_tm = [], [] + for ii in range(len(mean_small_tm)): + mean_rand_item_small_tm, std_rand_item_small_tm = ( + self.get_rand_stat(rng, mean_small_tm[ii], std_small_tm[ii]) + ) + mean_rand_small_tm.append(mean_rand_item_small_tm) + std_rand_small_tm.append(std_rand_item_small_tm) + mean_rand_item_large_tm, std_rand_item_large_tm = ( + self.get_rand_stat(rng, mean_large_tm[ii], std_large_tm[ii]) + ) + mean_rand_large_tm.append(mean_rand_item_large_tm) + std_rand_large_tm.append(std_rand_item_large_tm) + + small_tm_module.set_stat_mean_and_stddev( + mean_rand_small_tm, std_rand_small_tm + ) + large_tm_module.set_stat_mean_and_stddev( + mean_rand_large_tm, std_rand_large_tm + ) + + # extend the type map + small_tm_module.change_type_map( + large_tm, model_with_new_type_stat=large_tm_module + ) + + # check the stat + mean_result, std_result = small_tm_module.get_stat_mean_and_stddev() + type_index_map = get_index_between_two_maps(small_tm, large_tm)[0] + + if "list" not in self.input_dict: + self.check_expect_stat( + type_index_map, mean_rand_small_tm, mean_rand_large_tm, mean_result + ) + self.check_expect_stat( + type_index_map, std_rand_small_tm, std_rand_large_tm, std_result + ) + else: + # for hybrid + for ii in range(len(mean_small_tm)): + self.check_expect_stat( + type_index_map, + mean_rand_small_tm[ii], + mean_rand_large_tm[ii], + mean_result[ii], + ) + self.check_expect_stat( + type_index_map, + std_rand_small_tm[ii], + std_rand_large_tm[ii], + std_result[ii], + ) + + def get_rand_stat(self, rng, mean, std): + if not isinstance(mean, list): + mean_rand, std_rand = self.get_rand_stat_item(rng, mean, std) + else: + mean_rand, std_rand = [], [] + for ii in range(len(mean)): + mean_rand_item, std_rand_item = self.get_rand_stat_item( + rng, mean[ii], std[ii] + ) + mean_rand.append(mean_rand_item) + std_rand.append(std_rand_item) + return mean_rand, std_rand + + def get_rand_stat_item(self, rng, mean, std): + mean = self.convert_to_numpy(mean) + std = self.convert_to_numpy(std) + mean_rand = rng.random(size=mean.shape) + std_rand = rng.random(size=std.shape) + mean_rand = self.convert_from_numpy(mean_rand) + std_rand = self.convert_from_numpy(std_rand) + return mean_rand, std_rand + + def check_expect_stat(self, type_index_map, stat_small, stat_large, stat_result): + if not isinstance(stat_small, list): + self.check_expect_stat_item( + type_index_map, stat_small, stat_large, stat_result + ) + else: + for ii in range(len(stat_small)): + self.check_expect_stat_item( + type_index_map, stat_small[ii], stat_large[ii], stat_result[ii] + ) + + def check_expect_stat_item( + self, type_index_map, stat_small, stat_large, stat_result + ): + stat_small = self.convert_to_numpy(stat_small) + stat_large = self.convert_to_numpy(stat_large) + stat_result = self.convert_to_numpy(stat_result) + full_stat = np.concatenate([stat_small, stat_large], axis=0) + expected_stat = full_stat[type_index_map] + np.testing.assert_allclose(expected_stat, stat_result) + def update_input_type_map(input_dict, type_map): updated_input_dict = deepcopy(input_dict) diff --git a/source/tests/universal/dpmodel/backend.py b/source/tests/universal/dpmodel/backend.py index 61982fea98..aff009b71b 100644 --- a/source/tests/universal/dpmodel/backend.py +++ b/source/tests/universal/dpmodel/backend.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + from deepmd.dpmodel.common import ( NativeOP, ) @@ -17,6 +19,14 @@ class DPTestCase(BackendTestCase): def forward_wrapper(self, x): return x + @classmethod + def convert_to_numpy(cls, xx: np.ndarray) -> np.ndarray: + return xx + + @classmethod + def convert_from_numpy(cls, xx: np.ndarray) -> np.ndarray: + return xx + @property def deserialized_module(self): return self.module.deserialize(self.module.serialize()) diff --git a/source/tests/universal/pt/backend.py b/source/tests/universal/pt/backend.py index 61110a0cc6..5ee4791ec8 100644 --- a/source/tests/universal/pt/backend.py +++ b/source/tests/universal/pt/backend.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np import torch from deepmd.pt.utils.utils import ( @@ -32,6 +33,14 @@ def modules_to_test(self): def test_jit(self): self.script_module + @classmethod + def convert_to_numpy(cls, xx: torch.Tensor) -> np.ndarray: + return to_numpy_array(xx) + + @classmethod + def convert_from_numpy(cls, xx: np.ndarray) -> torch.Tensor: + return to_torch_tensor(xx) + def forward_wrapper(self, module): def create_wrapper_method(method): def wrapper_method(self, *args, **kwargs):