From f32614fd30100e44544a7f2ae73962a4ec564634 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Sun, 7 Apr 2024 20:29:42 +0800 Subject: [PATCH] feat: add UTs --- .../atomic_model/linear_atomic_model.py | 3 +- .../model/atomic_model/linear_atomic_model.py | 75 ++++-- .../pt/model/test_linear_atomic_model_stat.py | 229 ++++++++++++++++++ 3 files changed, 282 insertions(+), 25 deletions(-) create mode 100644 source/tests/pt/model/test_linear_atomic_model_stat.py diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index e6296316a5..e4a85d7bc2 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -248,7 +248,8 @@ def _compute_weight( ) -> List[np.ndarray]: """This should be a list of user defined weights that matches the number of models to be combined.""" nmodels = len(self.models) - return [np.ones(1) / nmodels for _ in range(nmodels)] + nframes, nloc, _ = nlists_[0].shape + return [np.ones((nframes, nloc, 1)) / nmodels for _ in range(nmodels)] def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 87b58388f2..9bf033953e 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -5,6 +5,8 @@ List, Optional, Tuple, + Union, + Callable, ) import torch @@ -293,8 +295,9 @@ def _compute_weight( ) -> List[torch.Tensor]: """This should be a list of user defined weights that matches the number of models to be combined.""" nmodels = len(self.models) + nframes, nloc, _ = nlists_[0].shape return [ - torch.ones(1, dtype=torch.float64, device=env.DEVICE) / nmodels + torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) / nmodels for _ in range(nmodels) ] @@ -333,6 +336,53 @@ def is_aparam_nall(self) -> bool: If False, the shape is (nframes, nloc, ndim). """ return False + + def compute_or_load_out_stat( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + stat_file_path: Optional[DPPath] = None, + ): + """ + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + stat_file_path : Optional[DPPath] + The path to the stat file. + + """ + for md in self.models: + md.compute_or_load_out_stat(merged, stat_file_path) + + def compute_or_load_stat( + self, + sampled_func, + stat_file_path: Optional[DPPath] = None, + ): + """ + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. + + Parameters + ---------- + sampled_func + The lazy sampled function to get data frames from different data systems. + stat_file_path + The dictionary of paths to the statistics files. + """ + for md in self.models: + md.compute_or_load_stat(sampled_func, stat_file_path) class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel): @@ -376,29 +426,6 @@ def __init__( # this is a placeholder being updated in _compute_weight, to handle Jit attribute init error. self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE) - def compute_or_load_stat( - self, - sampled_func, - stat_file_path: Optional[DPPath] = None, - ): - """ - Compute or load the statistics parameters of the model, - such as mean and standard deviation of descriptors or the energy bias of the fitting net. - When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), - and saved in the `stat_file_path`(s). - When `sampled` is not provided, it will check the existence of `stat_file_path`(s) - and load the calculated statistics parameters. - - Parameters - ---------- - sampled_func - The lazy sampled function to get data frames from different data systems. - stat_file_path - The dictionary of paths to the statistics files. - """ - self.models[0].compute_or_load_stat(sampled_func, stat_file_path) - self.models[1].compute_or_load_stat(sampled_func, stat_file_path) - def serialize(self) -> dict: dd = BaseAtomicModel.serialize(self) dd.update( diff --git a/source/tests/pt/model/test_linear_atomic_model_stat.py b/source/tests/pt/model/test_linear_atomic_model_stat.py new file mode 100644 index 0000000000..ae1ca84419 --- /dev/null +++ b/source/tests/pt/model/test_linear_atomic_model_stat.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + Optional, +) + +import h5py +import numpy as np +import torch + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.atomic_model import ( + BaseAtomicModel, + DPAtomicModel, + LinearEnergyAtomicModel, +) +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.model.task.base_fitting import ( + BaseFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.path import ( + DPPath, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class FooFittingA(torch.nn.Module, BaseFitting): + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "energy", + [1], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + raise NotImplementedError + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["energy"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + .view([nf, nloc] + self.output_def()["energy"].shape) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + .to(env.DEVICE) + ) + + return ret + +class FooFittingB(torch.nn.Module, BaseFitting): + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "energy", + [1], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + raise NotImplementedError + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["energy"] = ( + torch.Tensor( + [ + [7.0, 8.0, 9.0], + [10.0, 11.0, 12.0], + ] + ) + .view([nf, nloc] + self.output_def()["energy"].shape) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + .to(env.DEVICE) + ) + + return ret + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self): + self.tempdir.cleanup() + + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + nf, nloc, nnei = self.nlist.shape + self.merged_output_stat = [ + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 1, 3 + "energy": to_torch_tensor(np.array([5.0, 7.0]).reshape(2, 1)), + + } + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_linear_atomic_model_stat_with_bias(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft_a = FooFittingA().to(env.DEVICE) + ft_b = FooFittingB().to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft_a, + type_map=type_map, + ).to(env.DEVICE) + md1 = DPAtomicModel( + ds, + ft_b, + type_map=type_map, + ).to(env.DEVICE) + linear_model = LinearEnergyAtomicModel( + [md0,md1],type_map=type_map + ).to(env.DEVICE) + + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + # 1. test run without bias + # nf x na x odim + ret0 = linear_model.forward_common_atomic(*args) + + ret0 = to_numpy_array(ret0["energy"]) + ret_no_bias = [] + for md in linear_model.models: + ret_no_bias.append(md.forward_common_atomic(*args)["energy"]) + expected_ret0 = np.array( + [ + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ] + ).reshape([nf, nloc] + linear_model.fitting_output_def()["energy"].shape) + + np.testing.assert_almost_equal(ret0, expected_ret0) + + # 2. test bias is applied + linear_model.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + # bias applied to sub atomic models. + ener_bias = np.array([1.0, 3.0]).reshape(2, 1) + linear_ret = [] + for idx, md in enumerate(linear_model.models): + ret = md.forward_common_atomic(*args) + ret = to_numpy_array(ret["energy"]) + linear_ret.append(ret_no_bias[idx] + ener_bias[at]) + np.testing.assert_almost_equal((ret_no_bias[idx] + ener_bias[at]), ret) + + # linear model not adding bias again + ret1 = linear_model.forward_common_atomic(*args) + ret1 = to_numpy_array(ret1["energy"]) + np.testing.assert_almost_equal(torch.mean(torch.stack(linear_ret),dim=0), ret1)