diff --git a/.git_archival.txt b/.git_archival.txt new file mode 100644 index 0000000000..8fb235d704 --- /dev/null +++ b/.git_archival.txt @@ -0,0 +1,4 @@ +node: $Format:%H$ +node-date: $Format:%cI$ +describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$ +ref-names: $Format:%D$ diff --git a/.gitattributes b/.gitattributes index 82d852900b..776405a339 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,4 @@ # do not show up detailed difference on GitHub source/3rdparty/* linguist-generated=true source/3rdparty/README.md linguist-generated=false +.git_archival.txt export-subst diff --git a/.github/workflows/package_c.yml b/.github/workflows/package_c.yml index e11f773b3a..c5a3a3a7b0 100644 --- a/.github/workflows/package_c.yml +++ b/.github/workflows/package_c.yml @@ -26,6 +26,8 @@ jobs: filename: libdeepmd_c_cu11.tar.gz steps: - uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Package C library run: ./source/install/docker_package_c.sh env: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e0b0a4a76..e552895589 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: trailing-whitespace exclude: "^.+\\.pbtxt$" @@ -52,7 +52,7 @@ repos: - id: blacken-docs # C++ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.2 + rev: v18.1.3 hooks: - id: clang-format exclude: ^source/3rdparty|source/lib/src/gpu/cudart/.+\.inc diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index dbb344d5ca..9e43851157 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Dict, List, @@ -30,11 +31,44 @@ def __init__( type_map: List[str], atom_exclude_types: List[int] = [], pair_exclude_types: List[Tuple[int, int]] = [], + rcond: Optional[float] = None, + preset_out_bias: Optional[Dict[str, np.ndarray]] = None, ): super().__init__() self.type_map = type_map self.reinit_atom_exclude(atom_exclude_types) self.reinit_pair_exclude(pair_exclude_types) + self.rcond = rcond + self.preset_out_bias = preset_out_bias + + def init_out_stat(self): + """Initialize the output bias.""" + ntypes = self.get_ntypes() + self.bias_keys: List[str] = list(self.fitting_output_def().keys()) + self.max_out_size = max( + [self.atomic_output_def()[kk].size for kk in self.bias_keys] + ) + self.n_out = len(self.bias_keys) + out_bias_data = np.zeros([self.n_out, ntypes, self.max_out_size]) + out_std_data = np.ones([self.n_out, ntypes, self.max_out_size]) + self.out_bias = out_bias_data + self.out_std = out_std_data + + def __setitem__(self, key, value): + if key in ["out_bias"]: + self.out_bias = value + elif key in ["out_std"]: + self.out_std = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ["out_bias"]: + return self.out_bias + elif key in ["out_std"]: + return self.out_std + else: + raise KeyError(key) def get_type_map(self) -> List[str]: """Get the type map.""" @@ -132,6 +166,7 @@ def forward_common_atomic( fparam=fparam, aparam=aparam, ) + ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc atom_mask = ext_atom_mask[:, :nloc].astype(np.int32) @@ -150,6 +185,84 @@ def forward_common_atomic( def serialize(self) -> dict: return { + "type_map": self.type_map, "atom_exclude_types": self.atom_exclude_types, "pair_exclude_types": self.pair_exclude_types, + "rcond": self.rcond, + "preset_out_bias": self.preset_out_bias, + "@variables": { + "out_bias": self.out_bias, + "out_std": self.out_std, + }, } + + @classmethod + def deserialize(cls, data: dict) -> "BaseAtomicModel": + data = copy.deepcopy(data) + variables = data.pop("@variables") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = variables[kk] + return obj + + def apply_out_stat( + self, + ret: Dict[str, np.ndarray], + atype: np.ndarray, + ): + """Apply the stat to each atomic output. + The developer may override the method to define how the bias is applied + to the atomic output of the model. + + Parameters + ---------- + ret + The returned dict by the forward_atomic method + atype + The atom types. nf x nloc + + """ + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + # nf x nloc x odims, out_bias: ntypes x odims + ret[kk] = ret[kk] + out_bias[kk][atype] + return ret + + def _varsize( + self, + shape: List[int], + ) -> int: + output_size = 1 + len_shape = len(shape) + for i in range(len_shape): + output_size *= shape[i] + return output_size + + def _get_bias_index( + self, + kk: str, + ) -> int: + res: List[int] = [] + for i, e in enumerate(self.bias_keys): + if e == kk: + res.append(i) + assert len(res) == 1 + return res[0] + + def _fetch_out_stat( + self, + keys: List[str], + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + ret_bias = {} + ret_std = {} + ntypes = self.get_ntypes() + for kk in keys: + idx = self._get_bias_index(kk) + isize = self._varsize(self.atomic_output_def()[kk].shape) + ret_bias[kk] = self.out_bias[idx, :, :isize].reshape( + [ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005 + ) + ret_std[kk] = self.out_std[idx, :, :isize].reshape( + [ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005 + ) + return ret_bias, ret_std diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index cca46d3710..b13bfc17ba 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -49,11 +49,12 @@ def __init__( type_map: List[str], **kwargs, ): + super().__init__(type_map, **kwargs) self.type_map = type_map self.descriptor = descriptor self.fitting = fitting self.type_map = type_map - super().__init__(type_map, **kwargs) + super().init_out_stat() def fitting_output_def(self) -> FittingOutputDef: """Get the output def of the fitting net.""" @@ -79,27 +80,6 @@ def mixed_types(self) -> bool: """ return self.descriptor.mixed_types() - def set_out_bias(self, out_bias: np.ndarray, add=False) -> None: - """ - Modify the output bias for the atomic model. - - Parameters - ---------- - out_bias : np.ndarray - The new bias to be applied. - add : bool, optional - Whether to add the new bias to the existing one. - If False, the output bias will be directly replaced by the new bias. - If True, the new bias will be added to the existing one. - """ - self.fitting["bias_atom_e"] = ( - out_bias + self.fitting["bias_atom_e"] if add else out_bias - ) - - def get_out_bias(self) -> np.ndarray: - """Return the output bias of the atomic model.""" - return self.fitting["bias_atom_e"] - def forward_atomic( self, extended_coord: np.ndarray, @@ -157,7 +137,7 @@ def serialize(self) -> dict: { "@class": "Model", "type": "standard", - "@version": 1, + "@version": 2, "type_map": self.type_map, "descriptor": self.descriptor.serialize(), "fitting": self.fitting.serialize(), @@ -168,13 +148,14 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility(data.pop("@version", 1), 2, 2) data.pop("@class") data.pop("type") descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) fitting_obj = BaseFitting.deserialize(data.pop("fitting")) - type_map = data.pop("type_map") - obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data) + data["descriptor"] = descriptor_obj + data["fitting"] = fitting_obj + obj = super().deserialize(data) return obj def get_dim_fparam(self) -> int: diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 71e4aa542a..b38d309fd7 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -52,6 +52,8 @@ def __init__( type_map: List[str], **kwargs, ): + super().__init__(type_map, **kwargs) + super().init_out_stat() self.models = models sub_model_type_maps = [md.get_type_map() for md in models] err_msg = [] @@ -66,7 +68,6 @@ def __init__( self.mapping_list.append(self.remap_atype(tpmp, self.type_map)) assert len(err_msg) == 0, "\n".join(err_msg) self.mixed_types_list = [model.mixed_types() for model in self.models] - super().__init__(type_map, **kwargs) def mixed_types(self) -> bool: """If true, the model @@ -86,7 +87,7 @@ def get_rcut(self) -> float: def get_type_map(self) -> List[str]: """Get the type map.""" - raise self.type_map + return self.type_map def get_model_rcuts(self) -> List[float]: """Get the cut-off radius for each individual models.""" @@ -162,7 +163,6 @@ def forward_atomic( ) ] ener_list = [] - for i, model in enumerate(self.models): mapping = self.mapping_list[i] ener_list.append( @@ -176,13 +176,10 @@ def forward_atomic( )["energy"] ) self.weights = self._compute_weight(extended_coord, extended_atype, nlists_) - self.atomic_bias = None - if self.atomic_bias is not None: - raise NotImplementedError("Need to add bias in a future PR.") - else: - fit_ret = { - "energy": np.sum(np.stack(ener_list) * np.stack(self.weights), axis=0), - } # (nframes, nloc, 1) + + fit_ret = { + "energy": np.sum(np.stack(ener_list) * np.stack(self.weights), axis=0), + } # (nframes, nloc, 1) return fit_ret @staticmethod @@ -222,27 +219,30 @@ def fitting_output_def(self) -> FittingOutputDef: ) def serialize(self) -> dict: - return { - "@class": "Model", - "type": "linear", - "@version": 1, - "models": [model.serialize() for model in self.models], - "type_map": self.type_map, - } + dd = super().serialize() + dd.update( + { + "@class": "Model", + "@version": 2, + "type": "linear", + "models": [model.serialize() for model in self.models], + "type_map": self.type_map, + } + ) + return dd @classmethod def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) - data.pop("@class") - data.pop("type") - type_map = data.pop("type_map") + check_version_compatibility(data.pop("@version", 2), 2, 2) + data.pop("@class", None) + data.pop("type", None) models = [ BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model) for model in data["models"] ] - data.pop("models") - return cls(models, type_map, **data) + data["models"] = models + return super().deserialize(data) def _compute_weight( self, @@ -252,7 +252,8 @@ def _compute_weight( ) -> List[np.ndarray]: """This should be a list of user defined weights that matches the number of models to be combined.""" nmodels = len(self.models) - return [np.ones(1) / nmodels for _ in range(nmodels)] + nframes, nloc, _ = nlists_[0].shape + return [np.ones((nframes, nloc, 1)) / nmodels for _ in range(nmodels)] def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" @@ -275,27 +276,6 @@ def get_sel_type(self) -> List[int]: # join all the selected types return list(set().union(*[model.get_sel_type() for model in self.models])) - def set_out_bias(self, out_bias: np.ndarray, add=False) -> None: - """ - Modify the output bias for all the models in the linear atomic model. - - Parameters - ---------- - out_bias : torch.Tensor - The new bias to be applied. - add : bool, optional - Whether to add the new bias to the existing one. - If False, the output bias will be directly replaced by the new bias. - If True, the new bias will be added to the existing one. - """ - for model in self.models: - model.set_out_bias(out_bias, add=add) - - def get_out_bias(self) -> np.ndarray: - """Return the weighted output bias of the linear atomic model.""" - # TODO add get_out_bias for linear atomic model - raise NotImplementedError - def is_aparam_nall(self) -> bool: """Check whether the shape of atomic parameters is (nframes, nall, ndim). @@ -336,24 +316,21 @@ def __init__( **kwargs, ): models = [dp_model, zbl_model] - super().__init__(models, type_map, **kwargs) - self.dp_model = dp_model - self.zbl_model = zbl_model + kwargs["models"] = models + kwargs["type_map"] = type_map + super().__init__(**kwargs) self.sw_rmin = sw_rmin self.sw_rmax = sw_rmax self.smin_alpha = smin_alpha def serialize(self) -> dict: - dd = BaseAtomicModel.serialize(self) + dd = super().serialize() dd.update( { "@class": "Model", - "type": "zbl", "@version": 2, - "models": LinearEnergyAtomicModel( - models=[self.models[0], self.models[1]], type_map=self.type_map - ).serialize(), + "type": "zbl", "sw_rmin": self.sw_rmin, "sw_rmax": self.sw_rmax, "smin_alpha": self.smin_alpha, @@ -364,25 +341,15 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 2, 1) - data.pop("@class") - data.pop("type") - sw_rmin = data.pop("sw_rmin") - sw_rmax = data.pop("sw_rmax") - smin_alpha = data.pop("smin_alpha") - linear_model = LinearEnergyAtomicModel.deserialize(data.pop("models")) - dp_model, zbl_model = linear_model.models - type_map = linear_model.type_map - - return cls( - dp_model=dp_model, - zbl_model=zbl_model, - sw_rmin=sw_rmin, - sw_rmax=sw_rmax, - type_map=type_map, - smin_alpha=smin_alpha, - **data, - ) + check_version_compatibility(data.pop("@version", 1), 2, 2) + models = [ + BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model) + for model in data["models"] + ] + data["dp_model"], data["zbl_model"] = models[0], models[1] + data.pop("@class", None) + data.pop("type", None) + return super().deserialize(data) def _compute_weight( self, diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index 3e02a5d076..936c2b0943 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -95,25 +95,6 @@ def get_sel_type(self) -> List[int]: If returning an empty list, all atom types are selected. """ - @abstractmethod - def set_out_bias(self, out_bias: t_tensor, add=False) -> None: - """ - Modify the output bias for the atomic model. - - Parameters - ---------- - out_bias : t_tensor - The new bias to be applied. - add : bool, optional - Whether to add the new bias to the existing one. - If False, the output bias will be directly replaced by the new bias. - If True, the new bias will be added to the existing one. - """ - - @abstractmethod - def get_out_bias(self) -> t_tensor: - """Return the output bias of the atomic model.""" - @abstractmethod def is_aparam_nall(self) -> bool: """Check whether the shape of atomic parameters is (nframes, nall, ndim). diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index c970278bcf..d3d179e6e2 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -64,6 +64,7 @@ def __init__( **kwargs, ): super().__init__(type_map, **kwargs) + super().init_out_stat() self.tab_file = tab_file self.rcut = rcut self.type_map = type_map @@ -130,32 +131,13 @@ def mixed_types(self) -> bool: # to match DPA1 and DPA2. return True - def set_out_bias(self, out_bias: np.ndarray, add=False) -> None: - """ - Modify the output bias for the atomic model. - - Parameters - ---------- - out_bias : torch.Tensor - The new bias to be applied. - add : bool, optional - Whether to add the new bias to the existing one. - If False, the output bias will be directly replaced by the new bias. - If True, the new bias will be added to the existing one. - """ - self.bias_atom_e = out_bias + self.bias_atom_e if add else out_bias - - def get_out_bias(self) -> np.ndarray: - """Return the output bias of the atomic model.""" - return self.bias_atom_e - def serialize(self) -> dict: dd = BaseAtomicModel.serialize(self) dd.update( { "@class": "Model", "type": "pairtab", - "@version": 1, + "@version": 2, "tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel, @@ -167,14 +149,13 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "PairTabAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility(data.pop("@version", 1), 2, 2) data.pop("@class") data.pop("type") - rcut = data.pop("rcut") - sel = data.pop("sel") - type_map = data.pop("type_map") tab = PairTab.deserialize(data.pop("tab")) - tab_model = cls(None, rcut, sel, type_map, **data) + data["tab_file"] = None + tab_model = super().deserialize(data) + tab_model.tab = tab tab_model.tab_info = tab_model.tab.tab_info nspline, ntypes = tab_model.tab_info[-2:].astype(int) diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 6d6324770c..98325f41ee 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -36,8 +36,6 @@ class DipoleFitting(GeneralFitting): Parameters ---------- - var_name - The name of the output variable. ntypes The number of atom types. dim_descrpt @@ -86,7 +84,6 @@ class DipoleFitting(GeneralFitting): def __init__( self, - var_name: str, ntypes: int, dim_descrpt: int, embedding_width: int, @@ -124,7 +121,7 @@ def __init__( self.r_differentiable = r_differentiable self.c_differentiable = c_differentiable super().__init__( - var_name=var_name, + var_name="dipole", ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -161,6 +158,8 @@ def serialize(self) -> dict: def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 1, 1) + var_name = data.pop("var_name", None) + assert var_name == "dipole" return super().deserialize(data) def output_def(self): diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 5d75037137..2a691e963d 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -39,8 +39,6 @@ class PolarFitting(GeneralFitting): Parameters ---------- - var_name - The name of the output variable. ntypes The number of atom types. dim_descrpt @@ -88,7 +86,6 @@ class PolarFitting(GeneralFitting): def __init__( self, - var_name: str, ntypes: int, dim_descrpt: int, embedding_width: int, @@ -145,7 +142,7 @@ def __init__( self.shift_diag = shift_diag self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION) super().__init__( - var_name=var_name, + var_name="polar", ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -201,6 +198,8 @@ def serialize(self) -> dict: def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 1) + var_name = data.pop("var_name", None) + assert var_name == "polar" return super().deserialize(data) def output_def(self): diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 488f0f7a22..eafce67e84 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -159,7 +159,7 @@ def prepare_trainer_input_single( stat_file_path_single, ) - rank = dist.get_rank() if dist.is_initialized() else 0 + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 if not multi_task: ( train_data, diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index cee9ab87c1..127a16fd42 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later - +import copy import logging from typing import ( Callable, @@ -31,6 +31,10 @@ from deepmd.pt.utils.stat import ( compute_output_stats, ) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) from deepmd.utils.path import ( DPPath, ) @@ -89,12 +93,8 @@ def init_out_stat(self): [self.atomic_output_def()[kk].size for kk in self.bias_keys] ) self.n_out = len(self.bias_keys) - out_bias_data = torch.zeros( - [self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device - ) - out_std_data = torch.ones( - [self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device - ) + out_bias_data = self._default_bias() + out_std_data = self._default_std() self.register_buffer("out_bias", out_bias_data) self.register_buffer("out_std", out_std_data) @@ -256,10 +256,37 @@ def forward_common_atomic( def serialize(self) -> dict: return { + "type_map": self.type_map, "atom_exclude_types": self.atom_exclude_types, "pair_exclude_types": self.pair_exclude_types, + "rcond": self.rcond, + "preset_out_bias": self.preset_out_bias, + "@variables": { + "out_bias": to_numpy_array(self.out_bias), + "out_std": to_numpy_array(self.out_std), + }, } + @classmethod + def deserialize(cls, data: dict) -> "BaseAtomicModel": + data = copy.deepcopy(data) + variables = data.pop("@variables", None) + variables = ( + {"out_bias": None, "out_std": None} if variables is None else variables + ) + obj = cls(**data) + obj["out_bias"] = ( + to_torch_tensor(variables["out_bias"]) + if variables["out_bias"] is not None + else obj._default_bias() + ) + obj["out_std"] = ( + to_torch_tensor(variables["out_std"]) + if variables["out_std"] is not None + else obj._default_std() + ) + return obj + def compute_or_load_stat( self, merged: Union[Callable[[], List[dict]], List[dict]], @@ -368,7 +395,6 @@ def change_out_bias( rcond=self.rcond, preset_bias=self.preset_out_bias, ) - # self.set_out_bias(delta_bias, add=True) self._store_out_stat(delta_bias, out_std, add=True) elif bias_adjust_mode == "set-by-statistic": bias_out, std_out = compute_output_stats( @@ -379,7 +405,6 @@ def change_out_bias( rcond=self.rcond, preset_bias=self.preset_out_bias, ) - # self.set_out_bias(bias_out) self._store_out_stat(bias_out, std_out) else: raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) @@ -414,6 +439,18 @@ def model_forward(coord, atype, box, fparam=None, aparam=None): return model_forward + def _default_bias(self): + ntypes = self.get_ntypes() + return torch.zeros( + [self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device + ) + + def _default_std(self): + ntypes = self.get_ntypes() + return torch.ones( + [self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device + ) + def _varsize( self, shape: List[int], diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index c2000decc7..3d9a57bf70 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -100,7 +100,7 @@ def serialize(self) -> dict: dd.update( { "@class": "Model", - "@version": 1, + "@version": 2, "type": "standard", "type_map": self.type_map, "descriptor": self.descriptor.serialize(), @@ -112,13 +112,14 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility(data.pop("@version", 1), 2, 1) data.pop("@class", None) data.pop("type", None) descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) fitting_obj = BaseFitting.deserialize(data.pop("fitting")) - type_map = data.pop("type_map", None) - obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data) + data["descriptor"] = descriptor_obj + data["fitting"] = fitting_obj + obj = super().deserialize(data) return obj def forward_atomic( @@ -178,6 +179,9 @@ def forward_atomic( ) return fit_ret + def get_out_bias(self) -> torch.Tensor: + return self.out_bias + def compute_or_load_stat( self, sampled_func, @@ -219,27 +223,6 @@ def wrapped_sampler(): self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) - def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None: - """ - Modify the output bias for the atomic model. - - Parameters - ---------- - out_bias : torch.Tensor - The new bias to be applied. - add : bool, optional - Whether to add the new bias to the existing one. - If False, the output bias will be directly replaced by the new bias. - If True, the new bias will be added to the existing one. - """ - self.fitting_net["bias_atom_e"] = ( - out_bias + self.fitting_net["bias_atom_e"] if add else out_bias - ) - - def get_out_bias(self) -> torch.Tensor: - """Return the output bias of the atomic model.""" - return self.fitting_net["bias_atom_e"] - def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.fitting_net.get_dim_fparam() diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index f9fc97dea4..b58594d3ce 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy from typing import ( + Callable, Dict, List, Optional, Tuple, + Union, ) import torch @@ -73,7 +75,6 @@ def __init__( self.mapping_list.append(self.remap_atype(tpmp, self.type_map)) assert len(err_msg) == 0, "\n".join(err_msg) - self.atomic_bias = None self.mixed_types_list = [model.mixed_types() for model in self.models] self.rcuts = torch.tensor( self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE @@ -92,6 +93,9 @@ def mixed_types(self) -> bool: """ return True + def get_out_bias(self) -> torch.Tensor: + return self.out_bias + def get_rcut(self) -> float: """Get the cut-off radius.""" return max(self.get_model_rcuts()) @@ -188,8 +192,9 @@ def forward_atomic( for i, model in enumerate(self.models): mapping = self.mapping_list[i] + # apply bias to each individual model ener_list.append( - model.forward_atomic( + model.forward_common_atomic( extended_coord, mapping[extended_atype], nlists_[i], @@ -198,26 +203,32 @@ def forward_atomic( aparam, )["energy"] ) - weights = self._compute_weight(extended_coord, extended_atype, nlists_) - atype = extended_atype[:, :nloc] - for idx, model in enumerate(self.models): - # TODO: provide interfaces for atomic models to access bias_atom_e - if isinstance(model, DPAtomicModel): - bias_atom_e = model.fitting_net.bias_atom_e - elif isinstance(model, PairTabAtomicModel): - bias_atom_e = model.bias_atom_e - else: - bias_atom_e = None - if bias_atom_e is not None: - ener_list[idx] += bias_atom_e[atype] - fit_ret = { "energy": torch.sum(torch.stack(ener_list) * torch.stack(weights), dim=0), } # (nframes, nloc, 1) return fit_ret + def apply_out_stat( + self, + ret: Dict[str, torch.Tensor], + atype: torch.Tensor, + ): + """Apply the stat to each atomic output. + The developer may override the method to define how the bias is applied + to the atomic output of the model. + + Parameters + ---------- + ret + The returned dict by the forward_atomic method + atype + The atom types. nf x nloc + + """ + return ret + @staticmethod def remap_atype(ori_map: List[str], new_map: List[str]) -> torch.Tensor: """ @@ -257,59 +268,43 @@ def fitting_output_def(self) -> FittingOutputDef: ) def serialize(self) -> dict: - return { - "@class": "Model", - "@version": 1, - "type": "linear", - "models": [model.serialize() for model in self.models], - "type_map": self.type_map, - } + dd = super().serialize() + dd.update( + { + "@class": "Model", + "@version": 2, + "type": "linear", + "models": [model.serialize() for model in self.models], + "type_map": self.type_map, + } + ) + return dd @classmethod def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) - data.pop("@class") - data.pop("type") - type_map = data.pop("type_map") + check_version_compatibility(data.get("@version", 2), 2, 1) + data.pop("@class", None) + data.pop("type", None) models = [ BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model) for model in data["models"] ] - data.pop("models") - return cls(models, type_map, **data) + data["models"] = models + return super().deserialize(data) def _compute_weight( self, extended_coord, extended_atype, nlists_ ) -> List[torch.Tensor]: """This should be a list of user defined weights that matches the number of models to be combined.""" nmodels = len(self.models) + nframes, nloc, _ = nlists_[0].shape return [ - torch.ones(1, dtype=torch.float64, device=env.DEVICE) / nmodels + torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) + / nmodels for _ in range(nmodels) ] - def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None: - """ - Modify the output bias for all the models in the linear atomic model. - - Parameters - ---------- - out_bias : torch.Tensor - The new bias to be applied. - add : bool, optional - Whether to add the new bias to the existing one. - If False, the output bias will be directly replaced by the new bias. - If True, the new bias will be added to the existing one. - """ - for model in self.models: - model.set_out_bias(out_bias, add=add) - - def get_out_bias(self) -> torch.Tensor: - """Return the weighted output bias of the linear atomic model.""" - # TODO add get_out_bias for linear atomic model - raise NotImplementedError - def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" # tricky... @@ -346,6 +341,53 @@ def is_aparam_nall(self) -> bool: """ return False + def compute_or_load_out_stat( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + stat_file_path: Optional[DPPath] = None, + ): + """ + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + stat_file_path : Optional[DPPath] + The path to the stat file. + + """ + for md in self.models: + md.compute_or_load_out_stat(merged, stat_file_path) + + def compute_or_load_stat( + self, + sampled_func, + stat_file_path: Optional[DPPath] = None, + ): + """ + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. + + Parameters + ---------- + sampled_func + The lazy sampled function to get data frames from different data systems. + stat_file_path + The dictionary of paths to the statistics files. + """ + for md in self.models: + md.compute_or_load_stat(sampled_func, stat_file_path) + class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel): """Model linearly combine a list of AtomicModels. @@ -379,7 +421,9 @@ def __init__( **kwargs, ): models = [dp_model, zbl_model] - super().__init__(models, type_map, **kwargs) + kwargs["models"] = models + kwargs["type_map"] = type_map + super().__init__(**kwargs) self.sw_rmin = sw_rmin self.sw_rmax = sw_rmax @@ -388,39 +432,13 @@ def __init__( # this is a placeholder being updated in _compute_weight, to handle Jit attribute init error. self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE) - def compute_or_load_stat( - self, - sampled_func, - stat_file_path: Optional[DPPath] = None, - ): - """ - Compute or load the statistics parameters of the model, - such as mean and standard deviation of descriptors or the energy bias of the fitting net. - When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), - and saved in the `stat_file_path`(s). - When `sampled` is not provided, it will check the existence of `stat_file_path`(s) - and load the calculated statistics parameters. - - Parameters - ---------- - sampled_func - The lazy sampled function to get data frames from different data systems. - stat_file_path - The dictionary of paths to the statistics files. - """ - self.models[0].compute_or_load_stat(sampled_func, stat_file_path) - self.models[1].compute_or_load_stat(sampled_func, stat_file_path) - def serialize(self) -> dict: - dd = BaseAtomicModel.serialize(self) + dd = super().serialize() dd.update( { "@class": "Model", "@version": 2, "type": "zbl", - "models": LinearEnergyAtomicModel( - models=[self.models[0], self.models[1]], type_map=self.type_map - ).serialize(), "sw_rmin": self.sw_rmin, "sw_rmax": self.sw_rmax, "smin_alpha": self.smin_alpha, @@ -432,24 +450,14 @@ def serialize(self) -> dict: def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 1) - sw_rmin = data.pop("sw_rmin") - sw_rmax = data.pop("sw_rmax") - smin_alpha = data.pop("smin_alpha") - linear_model = LinearEnergyAtomicModel.deserialize(data.pop("models")) - dp_model, zbl_model = linear_model.models - type_map = linear_model.type_map - + models = [ + BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model) + for model in data["models"] + ] + data["dp_model"], data["zbl_model"] = models[0], models[1] data.pop("@class", None) data.pop("type", None) - return cls( - dp_model=dp_model, - zbl_model=zbl_model, - sw_rmin=sw_rmin, - sw_rmax=sw_rmax, - type_map=type_map, - smin_alpha=smin_alpha, - **data, - ) + return super().deserialize(data) def _compute_weight( self, diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 627dffd620..4f8bce78e1 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -71,8 +71,6 @@ def __init__( rcut: float, sel: Union[int, List[int]], type_map: List[str], - rcond: Optional[float] = None, - atom_ener: Optional[List[float]] = None, **kwargs, ): super().__init__(type_map, **kwargs) @@ -81,8 +79,6 @@ def __init__( self.rcut = rcut self.tab = self._set_pairtab(tab_file, rcut) - self.rcond = rcond - self.atom_ener = atom_ener self.type_map = type_map self.ntypes = len(type_map) @@ -136,6 +132,9 @@ def fitting_output_def(self) -> FittingOutputDef: ] ) + def get_out_bias(self) -> torch.Tensor: + return self.out_bias + def get_rcut(self) -> float: return self.rcut @@ -166,14 +165,12 @@ def serialize(self) -> dict: dd.update( { "@class": "Model", - "@version": 1, + "@version": 2, "type": "pairtab", "tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel, "type_map": self.type_map, - "rcond": self.rcond, - "atom_ener": self.atom_ener, } ) return dd @@ -181,16 +178,12 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "PairTabAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) - rcut = data.pop("rcut") - sel = data.pop("sel") - type_map = data.pop("type_map") - rcond = data.pop("rcond") - atom_ener = data.pop("atom_ener") + check_version_compatibility(data.pop("@version", 1), 2, 1) tab = PairTab.deserialize(data.pop("tab")) data.pop("@class", None) data.pop("type", None) - tab_model = cls(None, rcut, sel, type_map, rcond, atom_ener, **data) + data["tab_file"] = None + tab_model = super().deserialize(data) tab_model.tab = tab tab_model.register_buffer("tab_info", torch.from_numpy(tab_model.tab.tab_info)) @@ -226,25 +219,6 @@ def compute_or_load_stat( """ self.compute_or_load_out_stat(merged, stat_file_path) - def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None: - """ - Modify the output bias for the atomic model. - - Parameters - ---------- - out_bias : torch.Tensor - The new bias to be applied. - add : bool, optional - Whether to add the new bias to the existing one. - If False, the output bias will be directly replaced by the new bias. - If True, the new bias will be added to the existing one. - """ - self.bias_atom_e = out_bias + self.bias_atom_e if add else out_bias - - def get_out_bias(self) -> torch.Tensor: - """Return the output bias of the atomic model.""" - return self.bias_atom_e - def forward_atomic( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index ca445c8588..cddbbf5291 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -39,8 +39,6 @@ class DipoleFittingNet(GeneralFitting): Parameters ---------- - var_name : str - The atomic property to fit, 'dipole'. ntypes : int Element count. dim_descrpt : int @@ -97,7 +95,7 @@ def __init__( self.r_differentiable = r_differentiable self.c_differentiable = c_differentiable super().__init__( - var_name=kwargs.pop("var_name", "dipole"), + var_name="dipole", ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -131,6 +129,7 @@ def serialize(self) -> dict: def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("var_name", None) return super().deserialize(data) def output_def(self) -> FittingOutputDef: diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 544d23555c..cd944996be 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -47,8 +47,6 @@ class PolarFittingNet(GeneralFitting): Parameters ---------- - var_name : str - The atomic property to fit, 'polar'. ntypes : int Element count. dim_descrpt : int @@ -127,7 +125,7 @@ def __init__( ntypes, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE ) super().__init__( - var_name=kwargs.pop("var_name", "polar"), + var_name="polar", ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -180,6 +178,7 @@ def serialize(self) -> dict: def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 1) + data.pop("var_name", None) return super().deserialize(data) def output_def(self) -> FittingOutputDef: diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index 06b341d987..6196414243 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -47,7 +47,7 @@ def __init__( # the first param, because this helps with casting in load_state_dict self._state = self.state[self._params[0]] self._state.setdefault("kalman_lambda", kalman_lambda) - self.dist_init = dist.is_initialized() + self.dist_init = dist.is_available() and dist.is_initialized() self.rank = dist.get_rank() if self.dist_init else 0 self.dindex = [] self.remainder = 0 diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 73404b0c83..fe9b432fb7 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -122,8 +122,14 @@ def __init__( self.model_keys = ( list(model_params["model_dict"]) if self.multi_task else ["Default"] ) - self.rank = dist.get_rank() if dist.is_initialized() else 0 - self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.rank = ( + dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + ) + self.world_size = ( + dist.get_world_size() + if dist.is_available() and dist.is_initialized() + else 1 + ) self.num_model = len(self.model_keys) # Iteration config @@ -169,7 +175,9 @@ def get_dataloader_and_buffer(_data, _params): _data, sampler=_sampler, batch_size=None, - num_workers=NUM_WORKERS, # setting to 0 diverges the behavior of its iterator; should be >=1 + num_workers=NUM_WORKERS + if dist.is_available() + else 0, # setting to 0 diverges the behavior of its iterator; should be >=1 drop_last=False, pin_memory=True, ) @@ -607,7 +615,7 @@ def single_model_finetune( if shared_links is not None: self.wrapper.share_params(shared_links, resume=resuming or self.rank != 0) - if dist.is_initialized(): + if dist.is_available() and dist.is_initialized(): torch.cuda.set_device(LOCAL_RANK) # DDP will guarantee the model parameters are identical across all processes self.wrapper = DDP( @@ -673,7 +681,7 @@ def run(self): record_file = f"Sample_rank_{self.rank}.txt" fout1 = open(record_file, mode="w", buffering=1) log.info("Start to train %d steps.", self.num_steps) - if dist.is_initialized(): + if dist.is_available() and dist.is_initialized(): log.info(f"Rank: {dist.get_rank()}/{dist.get_world_size()}") if self.enable_tensorboard: from torch.utils.tensorboard import ( @@ -734,7 +742,11 @@ def step(_step_id, task_key="Default"): elif self.opt_type == "LKF": if isinstance(self.loss, EnergyStdLoss): KFOptWrapper = KFOptimizerWrapper( - self.wrapper, self.optimizer, 24, 6, dist.is_initialized() + self.wrapper, + self.optimizer, + 24, + 6, + dist.is_available() and dist.is_initialized(), ) pref_e = self.opt_param["kf_start_pref_e"] * ( self.opt_param["kf_limit_pref_e"] @@ -753,7 +765,9 @@ def step(_step_id, task_key="Default"): # [coord, atype, natoms, mapping, shift, nlist, box] model_pred = {"energy": p_energy, "force": p_force} module = ( - self.wrapper.module if dist.is_initialized() else self.wrapper + self.wrapper.module + if dist.is_available() and dist.is_initialized() + else self.wrapper ) def fake_model(): @@ -768,10 +782,16 @@ def fake_model(): ) elif isinstance(self.loss, DenoiseLoss): KFOptWrapper = KFOptimizerWrapper( - self.wrapper, self.optimizer, 24, 6, dist.is_initialized() + self.wrapper, + self.optimizer, + 24, + 6, + dist.is_available() and dist.is_initialized(), ) module = ( - self.wrapper.module if dist.is_initialized() else self.wrapper + self.wrapper.module + if dist.is_available() and dist.is_initialized() + else self.wrapper ) model_pred = KFOptWrapper.update_denoise_coord( input_dict, @@ -924,7 +944,11 @@ def log_loss_valid(_task_key="Default"): # Handle the case if rank 0 aborted and re-assigned self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pt") - module = self.wrapper.module if dist.is_initialized() else self.wrapper + module = ( + self.wrapper.module + if dist.is_available() and dist.is_initialized() + else self.wrapper + ) self.save_model(self.latest_model, lr=cur_lr, step=_step_id) log.info(f"Saved model to {self.latest_model}") symlink_prefix_files(self.latest_model.stem, self.save_ckpt) @@ -990,7 +1014,11 @@ def log_loss_valid(_task_key="Default"): prof.stop() def save_model(self, save_path, lr=0.0, step=0): - module = self.wrapper.module if dist.is_initialized() else self.wrapper + module = ( + self.wrapper.module + if dist.is_available() and dist.is_initialized() + else self.wrapper + ) module.train_infos["lr"] = lr module.train_infos["step"] = step torch.save( @@ -1168,7 +1196,7 @@ def _model_change_out_bias( idx_type_map = sorter[np.searchsorted(model_type_map, new_type_map, sorter=sorter)] log.info( f"Change output bias of {new_type_map!s} " - f"from {to_numpy_array(old_bias[idx_type_map]).reshape(-1)!s} " - f"to {to_numpy_array(new_bias[idx_type_map]).reshape(-1)!s}." + f"from {to_numpy_array(old_bias[:,idx_type_map]).reshape(-1)!s} " + f"to {to_numpy_array(new_bias[:,idx_type_map]).reshape(-1)!s}." ) return _model diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 361bc4b0b6..8ebe75868e 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -97,7 +97,11 @@ def construct_dataset(system): with Pool( os.cpu_count() - // (int(os.environ["LOCAL_WORLD_SIZE"]) if dist.is_initialized() else 1) + // ( + int(os.environ["LOCAL_WORLD_SIZE"]) + if dist.is_available() and dist.is_initialized() + else 1 + ) ) as pool: self.systems = pool.map(construct_dataset, systems) @@ -127,7 +131,7 @@ def construct_dataset(system): self.batch_sizes = batch_size * np.ones(len(systems), dtype=int) assert len(self.systems) == len(self.batch_sizes) for system, batch_size in zip(self.systems, self.batch_sizes): - if dist.is_initialized(): + if dist.is_available() and dist.is_initialized(): system_sampler = DistributedSampler(system) self.sampler_list.append(system_sampler) else: @@ -138,7 +142,8 @@ def construct_dataset(system): num_workers=0, # Should be 0 to avoid too many threads forked sampler=system_sampler, collate_fn=collate_batch, - shuffle=(not dist.is_initialized()) and shuffle, + shuffle=(not (dist.is_available() and dist.is_initialized())) + and shuffle, ) self.dataloaders.append(system_dataloader) self.index.append(len(system_dataloader)) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index d85741b231..77da1e01f1 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from collections import ( + defaultdict, +) from typing import ( Callable, Dict, @@ -23,6 +26,7 @@ to_torch_tensor, ) from deepmd.utils.out_stat import ( + compute_stats_from_atomic, compute_stats_from_redu, ) from deepmd.utils.path import ( @@ -171,10 +175,9 @@ def model_forward_auto_batch_size(*args, **kwargs): for kk in keys: model_predict[kk].append( to_numpy_array( - torch.sum(sample_predict[kk], dim=1) # nf x nloc x odims + sample_predict[kk] # nf x nloc x odims ) ) - model_predict = {kk: np.concatenate(model_predict[kk]) for kk in keys} return model_predict @@ -203,6 +206,31 @@ def _make_preset_out_bias( return np.array(nbias) +def _fill_stat_with_global( + atomic_stat: Union[np.ndarray, None], + global_stat: np.ndarray, +): + """This function is used to fill atomic stat with global stat. + + Parameters + ---------- + atomic_stat : Union[np.ndarray, None] + The atomic stat. + global_stat : np.ndarray + The global stat. + if the atomic stat is None, use global stat. + if the atomic stat is not None, but has nan values (missing atypes), fill with global stat. + """ + if atomic_stat is None: + return global_stat + else: + return np.nan_to_num( + np.where( + np.isnan(atomic_stat) & ~np.isnan(global_stat), global_stat, atomic_stat + ) + ) + + def compute_output_stats( merged: Union[Callable[[], List[dict]], List[dict]], ntypes: int, @@ -246,87 +274,294 @@ def compute_output_stats( # failed to restore the bias from stat file. compute if bias_atom_e is None: - # only get data for once + # only get data once, sampled is a list of dict[str, torch.Tensor] sampled = merged() if callable(merged) else merged + if model_forward is not None: + model_pred = _compute_model_predict(sampled, keys, model_forward) + else: + model_pred = None + # remove the keys that are not in the sample keys = [keys] if isinstance(keys, str) else keys assert isinstance(keys, list) - new_keys = [ii for ii in keys if ii in sampled[0].keys()] + new_keys = [ + ii + for ii in keys + if (ii in sampled[0].keys()) or ("atom_" + ii in sampled[0].keys()) + ] del keys keys = new_keys - # get label dict from sample - outputs = {kk: [item[kk] for item in sampled] for kk in keys} - data_mixed_type = "real_natoms_vec" in sampled[0] - natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec" - for system in sampled: - if "atom_exclude_types" in system: - type_mask = AtomExcludeMask( - ntypes, system["atom_exclude_types"] - ).get_type_mask() - system[natoms_key][:, 2:] *= type_mask.unsqueeze(0) - input_natoms = [item[natoms_key] for item in sampled] - # shape: (nframes, ndim) - merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys} - # shape: (nframes, ntypes) - merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:]) - nf = merged_natoms.shape[0] - if preset_bias is not None: - assigned_atom_ener = { - kk: _make_preset_out_bias(ntypes, preset_bias[kk]) - if kk in preset_bias.keys() - else None - for kk in keys - } - else: - assigned_atom_ener = {kk: None for kk in keys} - - if model_forward is None: - stats_input = merged_output - else: - # subtract the model bias and output the delta bias - model_predict = _compute_model_predict(sampled, keys, model_forward) - stats_input = {kk: merged_output[kk] - model_predict[kk] for kk in keys} + # split system based on label + atomic_sampled_idx = defaultdict(list) + global_sampled_idx = defaultdict(list) - bias_atom_e = {} - std_atom_e = {} for kk in keys: - bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_redu( - stats_input[kk], - merged_natoms, - assigned_bias=assigned_atom_ener[kk], - rcond=rcond, - ) - bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e) + for idx, system in enumerate(sampled): + if (("find_atom_" + kk) in system) and ( + system["find_atom_" + kk] > 0.0 + ): + atomic_sampled_idx[kk].append(idx) + elif (("find_" + kk) in system) and (system["find_" + kk] > 0.0): + global_sampled_idx[kk].append(idx) + + else: + continue + + # use index to gather model predictions for the corresponding systems. + + model_pred_g = ( + { + kk: [vv[idx] for idx in global_sampled_idx[kk]] + for kk, vv in model_pred.items() + } + if model_pred + else None + ) + model_pred_a = ( + { + kk: [vv[idx] for idx in atomic_sampled_idx[kk]] + for kk, vv in model_pred.items() + } + if model_pred + else None + ) - # unbias_e is only used for print rmse - if model_forward is None: - unbias_e = { - kk: merged_natoms @ bias_atom_e[kk].reshape(ntypes, -1) for kk in keys + # concat all frames within those systmes + model_pred_g = ( + { + kk: np.concatenate(model_pred_g[kk]) + for kk in model_pred_g.keys() + if len(model_pred_g[kk]) > 0 } - else: - unbias_e = { - kk: model_predict[kk].reshape(nf, -1) - + merged_natoms @ bias_atom_e[kk].reshape(ntypes, -1) - for kk in keys + if model_pred + else None + ) + model_pred_a = ( + { + kk: np.concatenate(model_pred_a[kk]) + for kk in model_pred_a.keys() + if len(model_pred_a[kk]) > 0 } - atom_numbs = merged_natoms.sum(-1) + if model_pred + else None + ) - def rmse(x): - return np.sqrt(np.mean(np.square(x))) + # compute stat + bias_atom_g, std_atom_g = compute_output_stats_global( + sampled, + ntypes, + keys, + rcond, + preset_bias, + model_pred_g, + ) + bias_atom_a, std_atom_a = compute_output_stats_atomic( + sampled, + ntypes, + keys, + model_pred_a, + ) + # merge global/atomic bias + bias_atom_e, std_atom_e = {}, {} for kk in keys: - rmse_ae = rmse( - (unbias_e[kk].reshape(nf, -1) - merged_output[kk].reshape(nf, -1)) - / atom_numbs[:, None] - ) - log.info( - f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}." - ) + # use atomic bias whenever available + if kk in bias_atom_a: + bias_atom_e[kk] = bias_atom_a[kk] + std_atom_e[kk] = std_atom_a[kk] + else: + bias_atom_e[kk] = None + std_atom_e[kk] = None + # use global bias to fill missing atomic bias + if kk in bias_atom_g: + bias_atom_e[kk] = _fill_stat_with_global( + bias_atom_e[kk], bias_atom_g[kk] + ) + std_atom_e[kk] = _fill_stat_with_global(std_atom_e[kk], std_atom_g[kk]) + if (bias_atom_e[kk] is None) or (std_atom_e[kk] is None): + raise RuntimeError("Fail to compute stat.") if stat_file_path is not None: _save_to_file(stat_file_path, bias_atom_e, std_atom_e) - ret_bias = {kk: to_torch_tensor(vv) for kk, vv in bias_atom_e.items()} - ret_std = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()} + bias_atom_e = {kk: to_torch_tensor(vv) for kk, vv in bias_atom_e.items()} + std_atom_e = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()} + return bias_atom_e, std_atom_e - return ret_bias, ret_std + +def compute_output_stats_global( + sampled: List[dict], + ntypes: int, + keys: List[str], + rcond: Optional[float] = None, + preset_bias: Optional[Dict[str, List[Optional[torch.Tensor]]]] = None, + model_pred: Optional[Dict[str, np.ndarray]] = None, +): + """This function only handle stat computation from reduced global labels.""" + # get label dict from sample; for each key, only picking the system with global labels. + outputs = { + kk: [ + system[kk] + for system in sampled + if kk in system and system.get(f"find_{kk}", 0) > 0 + ] + for kk in keys + } + + data_mixed_type = "real_natoms_vec" in sampled[0] + natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec" + for system in sampled: + if "atom_exclude_types" in system: + type_mask = AtomExcludeMask( + ntypes, system["atom_exclude_types"] + ).get_type_mask() + system[natoms_key][:, 2:] *= type_mask.unsqueeze(0) + + input_natoms = { + kk: [ + item[natoms_key] + for item in sampled + if kk in item and item.get(f"find_{kk}", 0) > 0 + ] + for kk in keys + } + # shape: (nframes, ndim) + merged_output = { + kk: to_numpy_array(torch.cat(outputs[kk])) + for kk in keys + if len(outputs[kk]) > 0 + } + # shape: (nframes, ntypes) + + merged_natoms = { + kk: to_numpy_array(torch.cat(input_natoms[kk])[:, 2:]) + for kk in keys + if len(input_natoms[kk]) > 0 + } + nf = {kk: merged_natoms[kk].shape[0] for kk in keys if kk in merged_natoms} + if preset_bias is not None: + assigned_atom_ener = { + kk: _make_preset_out_bias(ntypes, preset_bias[kk]) + if kk in preset_bias.keys() + else None + for kk in keys + } + else: + assigned_atom_ener = {kk: None for kk in keys} + + if model_pred is None: + stats_input = merged_output + else: + # subtract the model bias and output the delta bias + + model_pred = {kk: np.sum(model_pred[kk], axis=1) for kk in keys} + stats_input = { + kk: merged_output[kk] - model_pred[kk] for kk in keys if kk in merged_output + } + + bias_atom_e = {} + std_atom_e = {} + for kk in keys: + if kk in stats_input: + bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_redu( + stats_input[kk], + merged_natoms[kk], + assigned_bias=assigned_atom_ener[kk], + rcond=rcond, + ) + else: + # this key does not have global labels, skip it. + continue + bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e) + + # unbias_e is only used for print rmse + + if model_pred is None: + unbias_e = { + kk: merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1) + for kk in bias_atom_e.keys() + } + else: + unbias_e = { + kk: model_pred[kk].reshape(nf[kk], -1) + + merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1) + for kk in bias_atom_e.keys() + } + atom_numbs = {kk: merged_natoms[kk].sum(-1) for kk in bias_atom_e.keys()} + + def rmse(x): + return np.sqrt(np.mean(np.square(x))) + + for kk in bias_atom_e.keys(): + rmse_ae = rmse( + (unbias_e[kk].reshape(nf[kk], -1) - merged_output[kk].reshape(nf[kk], -1)) + / atom_numbs[kk][:, None] + ) + log.info( + f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}." + ) + return bias_atom_e, std_atom_e + + +def compute_output_stats_atomic( + sampled: List[dict], + ntypes: int, + keys: List[str], + model_pred: Optional[Dict[str, np.ndarray]] = None, +): + # get label dict from sample; for each key, only picking the system with atomic labels. + outputs = { + kk: [ + system["atom_" + kk] + for system in sampled + if ("atom_" + kk) in system and system.get(f"find_atom_{kk}", 0) > 0 + ] + for kk in keys + } + natoms = { + kk: [ + system["atype"] + for system in sampled + if ("atom_" + kk) in system and system.get(f"find_atom_{kk}", 0) > 0 + ] + for kk in keys + } + # shape: (nframes, nloc, ndim) + merged_output = { + kk: to_numpy_array(torch.cat(outputs[kk])) + for kk in keys + if len(outputs[kk]) > 0 + } + merged_natoms = { + kk: to_numpy_array(torch.cat(natoms[kk])) for kk in keys if len(natoms[kk]) > 0 + } + + if model_pred is None: + stats_input = merged_output + else: + # subtract the model bias and output the delta bias + stats_input = { + kk: merged_output[kk] - model_pred[kk] for kk in keys if kk in merged_output + } + + bias_atom_e = {} + std_atom_e = {} + + for kk in keys: + if kk in stats_input: + bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_atomic( + stats_input[kk], + merged_natoms[kk], + ) + # correction for missing types + missing_types = ntypes - merged_natoms[kk].max() - 1 + if missing_types > 0: + nan_padding = np.empty((missing_types, bias_atom_e[kk].shape[1])) + nan_padding.fill(np.nan) + bias_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0) + std_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0) + else: + # this key does not have atomic labels, skip it. + continue + bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e) + return bias_atom_e, std_atom_e diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 0b26d83732..3ca763870b 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -704,6 +704,10 @@ def _pass_filter( tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt] ), ) + self.recovered_switch *= tf.reshape( + tf.slice(tf.reshape(mask, [-1, 4]), [0, 0], [-1, 1]), + [-1, natoms[0], self.sel_all_a[0]], + ) else: inputs_i *= mask if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: diff --git a/deepmd/tf/fit/dipole.py b/deepmd/tf/fit/dipole.py index f98d52c7bd..d99c793415 100644 --- a/deepmd/tf/fit/dipole.py +++ b/deepmd/tf/fit/dipole.py @@ -362,7 +362,6 @@ def serialize(self, suffix: str) -> dict: "@class": "Fitting", "type": "dipole", "@version": 1, - "var_name": "dipole", "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, "embedding_width": self.dim_rot_mat_1, diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 473b57ff54..c124bd3ef4 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -555,7 +555,6 @@ def serialize(self, suffix: str) -> dict: "@class": "Fitting", "type": "polar", "@version": 1, - "var_name": "polar", "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, "embedding_width": self.dim_rot_mat_1, diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 76bcc6072b..fc8f862e3b 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -14,6 +14,8 @@ Union, ) +import numpy as np + from deepmd.common import ( j_get_type, ) @@ -785,11 +787,16 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": The deserialized descriptor """ data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility(data.pop("@version", 2), 2, 1) descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix) fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix) + # BEGINE not supported keys data.pop("atom_exclude_types") data.pop("pair_exclude_types") + data.pop("rcond", None) + data.pop("preset_out_bias", None) + data.pop("@variables", None) + # END not supported keys return cls( descriptor=descriptor, fitting_net=fitting, @@ -813,14 +820,23 @@ def serialize(self, suffix: str = "") -> dict: raise NotImplementedError("type embedding is not supported") if self.spin is not None: raise NotImplementedError("spin is not supported") + + ntypes = len(self.get_type_map()) + dict_fit = self.fitting.serialize(suffix=suffix) return { "@class": "Model", "type": "standard", - "@version": 1, + "@version": 2, "type_map": self.type_map, "descriptor": self.descrpt.serialize(suffix=suffix), - "fitting": self.fitting.serialize(suffix=suffix), + "fitting": dict_fit, # not supported yet "atom_exclude_types": [], "pair_exclude_types": [], + "rcond": None, + "preset_out_bias": None, + "@variables": { + "out_bias": np.zeros([1, ntypes, dict_fit["dim_out"]]), + "out_std": np.ones([1, ntypes, dict_fit["dim_out"]]), + }, } diff --git a/deepmd/utils/out_stat.py b/deepmd/utils/out_stat.py index 1dcbcb1280..9678f8ed72 100644 --- a/deepmd/utils/out_stat.py +++ b/deepmd/utils/out_stat.py @@ -112,6 +112,7 @@ def compute_stats_from_atomic( assert output.ndim == 3 assert atype.ndim == 2 assert output.shape[:2] == atype.shape + # compute output bias nframes, nloc, ndim = output.shape ntypes = atype.max() + 1 diff --git a/source/api_c/tests/CMakeLists.txt b/source/api_c/tests/CMakeLists.txt index 1b035b1a6c..c42055ba6f 100644 --- a/source/api_c/tests/CMakeLists.txt +++ b/source/api_c/tests/CMakeLists.txt @@ -8,8 +8,16 @@ set_target_properties( add_executable(runUnitTests_c ${TEST_SRC}) target_link_libraries(runUnitTests_c PRIVATE GTest::gtest_main ${LIB_DEEPMD_C} - rt coverage_config) + coverage_config) target_link_libraries(runUnitTests_c PRIVATE ${LIB_DEEPMD} ${LIB_DEEPMD_CC}) + +if(UNIX AND NOT APPLE) + find_library(RT_LIBRARY rt) + if(RT_LIBRARY) + target_link_libraries(runUnitTests_c PRIVATE ${RT_LIBRARY}) + endif() +endif() + target_precompile_headers(runUnitTests_c PRIVATE test_utils.h [["deepmd.hpp"]]) add_test( NAME runUnitTests_c diff --git a/source/api_cc/tests/CMakeLists.txt b/source/api_cc/tests/CMakeLists.txt index 1511dbe3bc..5599b63243 100644 --- a/source/api_cc/tests/CMakeLists.txt +++ b/source/api_cc/tests/CMakeLists.txt @@ -4,8 +4,16 @@ project(deepmd_api_test) file(GLOB TEST_SRC test_*.cc) add_executable(runUnitTests_cc ${TEST_SRC}) -target_link_libraries(runUnitTests_cc GTest::gtest_main ${LIB_DEEPMD_CC} rt +target_link_libraries(runUnitTests_cc GTest::gtest_main ${LIB_DEEPMD_CC} coverage_config) + +if(UNIX AND NOT APPLE) + find_library(RT_LIBRARY rt) + if(RT_LIBRARY) + target_link_libraries(runUnitTests_cc ${RT_LIBRARY}) + endif() +endif() + target_precompile_headers(runUnitTests_cc PRIVATE test_utils.h) add_test( NAME runUnitTest_cc diff --git a/source/install/docker_package_c.sh b/source/install/docker_package_c.sh index 544c175a0a..3846daf93b 100755 --- a/source/install/docker_package_c.sh +++ b/source/install/docker_package_c.sh @@ -5,6 +5,7 @@ SCRIPT_PATH=$(dirname $(realpath -s $0)) docker run --rm -v ${SCRIPT_PATH}/../..:/root/deepmd-kit -w /root/deepmd-kit \ tensorflow/build:${TENSORFLOW_BUILD_VERSION:-2.15}-python3.11 \ /bin/sh -c "pip install \"tensorflow${TENSORFLOW_VERSION}\" cmake \ + && git config --global --add safe.directory /root/deepmd-kit \ && cd /root/deepmd-kit/source/install \ && CC=/dt9/usr/bin/gcc \ CXX=/dt9/usr/bin/g++ \ diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 18a29934ca..4f33d58c10 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -94,7 +94,6 @@ def addtional_data(self) -> dict: "ntypes": self.ntypes, "dim_descrpt": self.inputs.shape[-1], "mixed_types": mixed_types, - "var_name": "dipole", "embedding_width": 30, } diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 5b55c6d333..a6e0e07784 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -94,7 +94,6 @@ def addtional_data(self) -> dict: "ntypes": self.ntypes, "dim_descrpt": self.inputs.shape[-1], "mixed_types": mixed_types, - "var_name": "polar", "embedding_width": 30, } diff --git a/source/tests/pt/model/test_atomic_model_atomic_stat.py b/source/tests/pt/model/test_atomic_model_atomic_stat.py new file mode 100644 index 0000000000..8f365a09fe --- /dev/null +++ b/source/tests/pt/model/test_atomic_model_atomic_stat.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + Optional, +) + +import h5py +import numpy as np +import torch + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.atomic_model import ( + BaseAtomicModel, + DPAtomicModel, +) +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.model.task.base_fitting import ( + BaseFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.path import ( + DPPath, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class FooFitting(torch.nn.Module, BaseFitting): + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "foo", + [1], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "bar", + [1, 2], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + raise NotImplementedError + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["foo"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + .view([nf, nloc, *self.output_def()["foo"].shape]) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + .to(env.DEVICE) + ) + ret["bar"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ) + .view([nf, nloc, *self.output_def()["bar"].shape]) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + .to(env.DEVICE) + ) + return ret + + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self): + self.tempdir.cleanup() + + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + self.merged_output_stat = [ + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5, 6 + "atom_foo": to_torch_tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1) + ), + # bias of bar: [1, 5], [3, 2] + "bar": to_torch_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_atom_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5, 6 from atomic label. + "foo": to_torch_tensor(np.array([5.0, 7.0]).reshape(2, 1)), + # bias of bar: [1, 5], [3, 2] + "bar": to_torch_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft = FooFitting().to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([5.0, 6.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error(): + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + + # 4. test change bias + BaseAtomicModel.change_out_bias( + md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + args = [ + to_torch_tensor(ii) + for ii in [ + self.coord_ext, + to_numpy_array(self.merged_output_stat[0]["atype_ext"]), + self.nlist, + ] + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + + expected_ret3 = {} + # new bias [2.666, 1.333] + expected_ret3["foo"] = np.array( + [[3.6667, 4.6667, 4.3333], [6.6667, 6.3333, 7.3333]] + ).reshape(2, 3, 1) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) + + +class TestAtomicModelStatMergeGlobalAtomic( + unittest.TestCase, TestCaseSingleFrameWithNlist +): + def tearDown(self): + self.tempdir.cleanup() + + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + self.merged_output_stat = [ + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5.5, nan + "atom_foo": to_torch_tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1) + ), + # bias of bar: [1, 5], [3, 2] + "bar": to_torch_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_atom_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5.5, 3 from atomic label. + "foo": to_torch_tensor(np.array([5.0, 7.0]).reshape(2, 1)), + # bias of bar: [1, 5], [3, 2] + "bar": to_torch_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft = FooFitting().to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([5.5, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error(): + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + + # 4. test change bias + BaseAtomicModel.change_out_bias( + md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + args = [ + to_torch_tensor(ii) + for ii in [ + self.coord_ext, + to_numpy_array(self.merged_output_stat[0]["atype_ext"]), + self.nlist, + ] + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + expected_ret3 = {} + # new bias [2, -5] + expected_ret3["foo"] = np.array([[3, 4, -2], [6, 0, 1]]).reshape(2, 3, 1) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) diff --git a/source/tests/pt/model/test_atomic_model_stat.py b/source/tests/pt/model/test_atomic_model_global_stat.py similarity index 89% rename from source/tests/pt/model/test_atomic_model_stat.py rename to source/tests/pt/model/test_atomic_model_global_stat.py index e266cf215a..ca71b604ce 100644 --- a/source/tests/pt/model/test_atomic_model_stat.py +++ b/source/tests/pt/model/test_atomic_model_global_stat.py @@ -12,6 +12,7 @@ import numpy as np import torch +from deepmd.dpmodel.atomic_model import DPAtomicModel as DPDPAtomicModel from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, @@ -20,12 +21,16 @@ BaseAtomicModel, DPAtomicModel, ) -from deepmd.pt.model.descriptor.dpa1 import ( +from deepmd.pt.model.descriptor import ( DescrptDPA1, + DescrptSeA, ) from deepmd.pt.model.task.base_fitting import ( BaseFitting, ) +from deepmd.pt.model.task.ener import ( + InvarFitting, +) from deepmd.pt.utils import ( env, ) @@ -150,6 +155,8 @@ def setUp(self): "bar": to_torch_tensor( np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), } ] self.tempdir = tempfile.TemporaryDirectory() @@ -441,3 +448,50 @@ def cvt_ret(x): expected_ret1["bar"] = ret0["bar"] + bar_bias[at] for kk in ["foo", "pix", "bar"]: np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + def test_serialize(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "foo", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + type_map = ["A", "B"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + md1 = DPAtomicModel.deserialize(md0.serialize()) + ret1 = md1.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + + for kk in ["foo"]: + np.testing.assert_almost_equal(ret0[kk], ret1[kk]) + + md2 = DPDPAtomicModel.deserialize(md0.serialize()) + args = [self.coord_ext, self.atype_ext, self.nlist] + ret2 = md2.forward_common_atomic(*args) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret0[kk], ret2[kk]) diff --git a/source/tests/pt/model/test_linear_atomic_model_stat.py b/source/tests/pt/model/test_linear_atomic_model_stat.py new file mode 100644 index 0000000000..f7feeda550 --- /dev/null +++ b/source/tests/pt/model/test_linear_atomic_model_stat.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + Optional, +) + +import h5py +import numpy as np +import torch + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.atomic_model import ( + DPAtomicModel, + LinearEnergyAtomicModel, +) +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.model.task.base_fitting import ( + BaseFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.path import ( + DPPath, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class FooFittingA(torch.nn.Module, BaseFitting): + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "energy", + [1], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + raise NotImplementedError + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["energy"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + .view([nf, nloc, *self.output_def()["energy"].shape]) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + .to(env.DEVICE) + ) + + return ret + + +class FooFittingB(torch.nn.Module, BaseFitting): + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "energy", + [1], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + raise NotImplementedError + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["energy"] = ( + torch.Tensor( + [ + [7.0, 8.0, 9.0], + [10.0, 11.0, 12.0], + ] + ) + .view([nf, nloc, *self.output_def()["energy"].shape]) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + .to(env.DEVICE) + ) + + return ret + + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self): + self.tempdir.cleanup() + + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + nf, nloc, nnei = self.nlist.shape + self.merged_output_stat = [ + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 1, 3 + "energy": to_torch_tensor(np.array([5.0, 7.0]).reshape(2, 1)), + "find_energy": np.float32(1.0), + } + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_linear_atomic_model_stat_with_bias(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft_a = FooFittingA().to(env.DEVICE) + ft_b = FooFittingB().to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft_a, + type_map=type_map, + ).to(env.DEVICE) + md1 = DPAtomicModel( + ds, + ft_b, + type_map=type_map, + ).to(env.DEVICE) + linear_model = LinearEnergyAtomicModel([md0, md1], type_map=type_map).to( + env.DEVICE + ) + + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + # 1. test run without bias + # nf x na x odim + ret0 = linear_model.forward_common_atomic(*args) + + ret0 = to_numpy_array(ret0["energy"]) + ret_no_bias = [] + for md in linear_model.models: + ret_no_bias.append( + to_numpy_array(md.forward_common_atomic(*args)["energy"]) + ) + expected_ret0 = np.array( + [ + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ] + ).reshape(nf, nloc, *linear_model.fitting_output_def()["energy"].shape) + + np.testing.assert_almost_equal(ret0, expected_ret0) + + # 2. test bias is applied + linear_model.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + # bias applied to sub atomic models. + ener_bias = np.array([1.0, 3.0]).reshape(2, 1) + linear_ret = [] + for idx, md in enumerate(linear_model.models): + ret = md.forward_common_atomic(*args) + ret = to_numpy_array(ret["energy"]) + linear_ret.append(ret_no_bias[idx] + ener_bias[at]) + np.testing.assert_almost_equal((ret_no_bias[idx] + ener_bias[at]), ret) + + # linear model not adding bias again + ret1 = linear_model.forward_common_atomic(*args) + ret1 = to_numpy_array(ret1["energy"]) + np.testing.assert_almost_equal(np.mean(np.stack(linear_ret), axis=0), ret1) diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index 4f5be912cf..1a75caebdc 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -39,7 +39,9 @@ def test( natoms = 10 cell = 8.6 * torch.eye(3, dtype=dtype, device=env.DEVICE) - atype = torch.randint(0, 3, [natoms], device=env.DEVICE) + atype0 = torch.arange(3, dtype=dtype, device=env.DEVICE) + atype1 = torch.randint(0, 3, [natoms - 3], device=env.DEVICE) + atype = torch.cat([atype0, atype1]).view([natoms]) coord0 = torch.tensor( [ 0.0, @@ -148,7 +150,6 @@ def setUp(self): self.epsilon, self.aprec = None, None -# @unittest.skip("dpa-1 not smooth at the moment") class TestEnergyModelDPA1(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) @@ -160,6 +161,30 @@ def setUp(self): self.aprec = 1e-5 +class TestEnergyModelDPA1Excl1(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + model_params["pair_exclude_types"] = [[0, 1]] + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + # less degree of smoothness, + # error can be systematically removed by reducing epsilon + self.epsilon = 1e-5 + self.aprec = 1e-5 + + +class TestEnergyModelDPA1Excl12(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + model_params["pair_exclude_types"] = [[0, 1], [0, 2]] + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + # less degree of smoothness, + # error can be systematically removed by reducing epsilon + self.epsilon = 1e-5 + self.aprec = 1e-5 + + class TestEnergyModelDPA2(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_dpa2) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 1635ad56ea..f0a988607e 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -106,6 +106,7 @@ def setUp(self): self.config["training"]["training_data"]["systems"] = data_file self.config["training"]["validation_data"]["systems"] = data_file self.config["model"] = deepcopy(model_dos) + self.config["model"]["type_map"] = ["H"] self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 self.not_all_grad = True diff --git a/source/tests/tf/test_model_se_atten.py b/source/tests/tf/test_model_se_atten.py index 1a5094c743..36cf4887c0 100644 --- a/source/tests/tf/test_model_se_atten.py +++ b/source/tests/tf/test_model_se_atten.py @@ -890,3 +890,133 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self): np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae) np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad) np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad) + + def test_smoothness_of_stripped_type_embedding_smooth_model_excluded_types(self): + """test: auto-diff, continuity of e,f,v.""" + jfile = "water_se_atten.json" + jdata = j_loader(jfile) + + systems = j_must_have(jdata, "systems") + set_pfx = j_must_have(jdata, "set_prefix") + batch_size = 1 + test_size = 1 + rcut = j_must_have(jdata["model"]["descriptor"], "rcut") + + data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None) + + test_data = data.get_test() + numb_test = 1 + + jdata["model"]["descriptor"].pop("type", None) + jdata["model"]["descriptor"]["ntypes"] = 2 + jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["smooth_type_embdding"] = True + jdata["model"]["descriptor"]["attn_layer"] = 1 + jdata["model"]["descriptor"]["rcut"] = 6.0 + jdata["model"]["descriptor"]["rcut_smth"] = 4.0 + jdata["model"]["descriptor"]["exclude_types"] = [[0, 0], [0, 1]] + jdata["model"]["descriptor"]["set_davg_zero"] = False + descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True) + jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() + jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out() + jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1() + fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True) + typeebd_param = jdata["model"]["type_embedding"] + typeebd = TypeEmbedNet( + ntypes=descrpt.get_ntypes(), + neuron=typeebd_param["neuron"], + activation_function=None, + resnet_dt=typeebd_param["resnet_dt"], + seed=typeebd_param["seed"], + uniform_seed=True, + padding=True, + ) + model = EnerModel(descrpt, fitting, typeebd) + + input_data = { + "coord": [test_data["coord"]], + "box": [test_data["box"]], + "type": [test_data["type"]], + "natoms_vec": [test_data["natoms_vec"]], + "default_mesh": [test_data["default_mesh"]], + } + model._compute_input_stat(input_data) + model.descrpt.bias_atom_e = data.compute_energy_shift() + # make the original implementation failed + model.descrpt.davg[:] += 1e-1 + + t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c") + t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy") + t_force = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_force") + t_virial = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_virial") + t_atom_ener = tf.placeholder( + GLOBAL_TF_FLOAT_PRECISION, [None], name="t_atom_ener" + ) + t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord") + t_type = tf.placeholder(tf.int32, [None], name="i_type") + t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms") + t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box") + t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh") + is_training = tf.placeholder(tf.bool) + inputs_dict = {} + + model_pred = model.build( + t_coord, + t_type, + t_natoms, + t_box, + t_mesh, + inputs_dict, + suffix=self.filename + + "-" + + inspect.stack()[0][3] + + "test_model_se_atten_model_compressible_excluded_types", + reuse=False, + ) + energy = model_pred["energy"] + force = model_pred["force"] + virial = model_pred["virial"] + + feed_dict_test = { + t_prop_c: test_data["prop_c"], + t_energy: test_data["energy"][:numb_test], + t_force: np.reshape(test_data["force"][:numb_test, :], [-1]), + t_virial: np.reshape(test_data["virial"][:numb_test, :], [-1]), + t_atom_ener: np.reshape(test_data["atom_ener"][:numb_test, :], [-1]), + t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]), + t_box: test_data["box"][:numb_test, :], + t_type: np.reshape(test_data["type"][:numb_test, :], [-1]), + t_natoms: test_data["natoms_vec"], + t_mesh: test_data["default_mesh"], + is_training: False, + } + sess = self.cached_session().__enter__() + sess.run(tf.global_variables_initializer()) + [pe, pf, pv] = sess.run([energy, force, virial], feed_dict=feed_dict_test) + pf, pv = pf.reshape(-1), pv.reshape(-1) + + eps = 1e-4 + delta = 1e-6 + fdf, fdv = finite_difference_fv( + sess, energy, feed_dict_test, t_coord, t_box, delta=eps + ) + np.testing.assert_allclose(pf, fdf, delta) + np.testing.assert_allclose(pv, fdv, delta) + + tested_eps = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7] + for eps in tested_eps: + deltae = 1e-15 + deltad = 1e-15 + de, df, dv = check_smooth_efv( + sess, + energy, + force, + virial, + feed_dict_test, + t_coord, + jdata["model"]["descriptor"]["rcut"], + delta=eps, + ) + np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae) + np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad) + np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)