diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index 2b47cd81e6..cec2c7c839 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -7,6 +7,7 @@ Dict, List, Optional, + Union, ) from deepmd.dpmodel.output_def import ( @@ -87,7 +88,7 @@ def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" @abstractmethod - def get_sel_type(self) -> List[int]: + def get_sel_type(self) -> Union[List[int], List[List[int]]]: """Get the selected atom types of this model. Only atoms with selected atom types have atomic contribution diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 8e37dbf09b..8e0dfb96b8 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -243,8 +243,10 @@ def train(FLAGS): if multi_task: config["model"], shared_links = preprocess_shared_params(config["model"]) + multi_fitting_net = "fitting_net_dict" in config["model"] + # argcheck - if not multi_task: + if not (multi_task or multi_fitting_net): config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") config = normalize(config) diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index 3e94449057..4e267b672f 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -26,6 +26,9 @@ from .dp_atomic_model import ( DPAtomicModel, ) +from .dp_multi_fitting_atomic_model import ( + DPMultiFittingAtomicModel, +) from .energy_atomic_model import ( DPEnergyAtomicModel, ) @@ -50,4 +53,5 @@ "DPPolarAtomicModel", "DPDipoleAtomicModel", "DPZBLLinearEnergyAtomicModel", + "DPMultiFittingAtomicModel", ] diff --git a/deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py b/deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py new file mode 100644 index 0000000000..aa133f0162 --- /dev/null +++ b/deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import functools +import logging +from typing import ( + Dict, + List, + Optional, +) + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, +) +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.task.base_fitting import ( + BaseFitting, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) + +log = logging.getLogger(__name__) + + +@BaseAtomicModel.register("multi_fitting") +class DPMultiFittingAtomicModel(BaseAtomicModel): + """Model give atomic prediction of some physical property. + + Parameters + ---------- + descriptor + Descriptor + fitting_dict + Dict of Fitting net + type_map + Mapping atom type to the name (str) of the type. + For example `type_map[1]` gives the name of the type 1. + """ + + def __init__( + self, + descriptor, + fitting_dict, + type_map: Optional[List[str]], + **kwargs, + ): + super().__init__(type_map, **kwargs) + ntypes = len(type_map) + self.type_map = type_map + self.ntypes = ntypes + self.descriptor = descriptor + self.rcut = self.descriptor.get_rcut() + self.sel = self.descriptor.get_sel() + fitting_dict = copy.deepcopy(fitting_dict) + self.model_type = fitting_dict.pop("type") + self.fitting_net_dict = fitting_dict + self.fitting_net = fitting_dict + super().init_out_stat() + + def fitting_output_def(self) -> FittingOutputDef: + """Get the output def of the fitting net.""" + l = [] + for name, fitting_net in self.fitting_net_dict.items(): + for vdef in fitting_net.output_def().var_defs.values(): + vdef.name = name + l.append(vdef) + return FittingOutputDef(l) + + @torch.jit.export + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return self.rcut + + @torch.jit.export + def get_type_map(self) -> List[str]: + """Get the type map.""" + return self.type_map + + def get_sel(self) -> List[int]: + """Get the neighbor selection.""" + return self.sel + + def mixed_types(self) -> bool: + """If true, the model + 1. assumes total number of atoms aligned across frames; + 2. uses a neighbor list that does not distinguish different atomic types. + + If false, the model + 1. assumes total number of atoms of each atom type aligned across frames; + 2. uses a neighbor list that distinguishes different atomic types. + + """ + return self.descriptor.mixed_types() + + def has_message_passing(self) -> bool: + """Returns whether the atomic model has message passing.""" + return self.descriptor.has_message_passing() + + def serialize(self) -> dict: + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "@version": 2, + "type": "multi_fitting", + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "fitting": [ + fitting_net.serialize() + for fitting_net in self.fitting_net_dict.values() + ], + "fitting_name": self.fitting_net_dict.keys(), + } + ) + return dd + + @classmethod + def deserialize(cls, data) -> "DPMultiFittingAtomicModel": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 2, 1) + data.pop("@class", None) + data.pop("type", None) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + + fitting_dict = {} + fitting_names = data["fitting_name"] + for name, fitting in zip(fitting_names, data.pop("fitting")): + fitting_obj = BaseFitting.deserialize(fitting) + fitting_dict[name] = fitting_obj + # type_map = data.pop("type_map", None) + # obj = cls(descriptor_obj, fitting_dict, type_map=type_map, **data) + data["descriptor"] = descriptor_obj + data["fitting"] = list(fitting_dict.values()) + data["fitting_name"] = list(fitting_dict.keys()) + obj = super().deserialize(data) + return obj + + def forward_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + comm_dict: Optional[Dict[str, torch.Tensor]] = None, + ) -> Dict[str, torch.Tensor]: + """Return atomic prediction. + + Parameters + ---------- + extended_coord + coodinates in extended region + extended_atype + atomic type in extended region + nlist + neighbor list. nf x nloc x nsel + mapping + mapps the extended indices to local indices + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + + Returns + ------- + result_dict + the result dict, defined by the `FittingOutputDef`. + + """ + nframes, nloc, nnei = nlist.shape + atype = extended_atype[:, :nloc] + if self.do_grad_r() or self.do_grad_c(): + extended_coord.requires_grad_(True) + descriptor, rot_mat, g2, h2, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + comm_dict=comm_dict, + ) + assert descriptor is not None + fit_ret_dict = {} + for name, fitting_net in self.fitting_net_dict.items(): + fitting = fitting_net( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + for v in fitting.values(): + fit_ret_dict[name] = v + return fit_ret_dict + + def get_out_bias(self) -> torch.Tensor: + return self.out_bias + + def compute_or_load_stat( + self, + sampled_func, + stat_file_path: Optional[DPPath] = None, + ): + """ + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. + + Parameters + ---------- + sampled_func + The lazy sampled function to get data frames from different data systems. + stat_file_path + The dictionary of paths to the statistics files. + """ + if stat_file_path is not None and self.type_map is not None: + # descriptors and fitting net with different type_map + # should not share the same parameters + stat_file_path /= " ".join(self.type_map) + + @functools.lru_cache + def wrapped_sampler(): + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + return sampled + + self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + + @torch.jit.export + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return list(self.fitting_net_dict.values())[0].get_dim_fparam() + + @torch.jit.export + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return list(self.fitting_net_dict.values())[0].get_dim_aparam() + + @torch.jit.export + def get_sel_type(self) -> List[List[int]]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + sel_type = [] + for fitting_net in self.fitting_net_dict.values(): + sel_type.append(fitting_net.get_sel_type()) + return sel_type + + @torch.jit.export + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return False diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 1d46720af2..ba884e73a9 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -54,6 +54,9 @@ from .make_model import ( make_model, ) +from .make_multi_fitting_model import ( + make_multi_fitting_model, +) from .model import ( BaseModel, ) @@ -148,6 +151,49 @@ def get_zbl_model(model_params): ) +def get_multi_fitting_model(model_params): + model_params_old = model_params + model_params = copy.deepcopy(model_params) + ntypes = len(model_params["type_map"]) + # descriptor + model_params["descriptor"]["ntypes"] = ntypes + descriptor = BaseDescriptor(**model_params["descriptor"]) + # fitting_net_dict + fitting_dict = {} + fitting_net_dict = model_params.get("fitting_net_dict", {}) + fitting_dict["type"] = fitting_net_dict.pop("type", "pme_ener") + for k, fitting_net in fitting_net_dict.items(): + fitting_net["type"] = fitting_net.get("type", "ener") + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["mixed_types"] = descriptor.mixed_types() + fitting_net["embedding_width"] = descriptor.get_dim_emb() + fitting_net["dim_descrpt"] = descriptor.get_dim_out() + grad_force = "direct" not in fitting_net["type"] + if not grad_force: + fitting_net["out_dim"] = descriptor.get_dim_emb() + if "ener" in fitting_net["type"]: + fitting_net["return_energy"] = True + fitting = BaseFitting(**fitting_net) + fitting_dict[k] = fitting + atom_exclude_types = model_params.get("atom_exclude_types", []) + pair_exclude_types = model_params.get("pair_exclude_types", []) + + if fitting_dict["type"] == "pme_ener": + modelcls = None + else: + raise RuntimeError(f"Unknown fitting type: {fitting_net['type']}") + + model = modelcls( + descriptor=descriptor, + fitting_dict=fitting_dict, + type_map=model_params["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) + model.model_def_script = json.dumps(model_params_old) + return model + + def get_standard_model(model_params): model_params_old = model_params model_params = copy.deepcopy(model_params) @@ -201,6 +247,8 @@ def get_model(model_params): return get_spin_model(model_params) elif "use_srtab" in model_params: return get_zbl_model(model_params) + elif "fitting_net_dict" in model_params: + return get_multi_fitting_model(model_params) else: return get_standard_model(model_params) @@ -216,4 +264,5 @@ def get_model(model_params): "DPZBLModel", "make_model", "make_hessian_model", + "make_multi_fitting_model", ] diff --git a/deepmd/pt/model/model/dp_multi_fitting_model.py b/deepmd/pt/model/model/dp_multi_fitting_model.py new file mode 100644 index 0000000000..03cd51accf --- /dev/null +++ b/deepmd/pt/model/model/dp_multi_fitting_model.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.pt.model.atomic_model import ( + DPMultiFittingAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_multi_fitting_model import ( + make_multi_fitting_model, +) + +DPMultiFittingModel_ = make_multi_fitting_model(DPMultiFittingAtomicModel) + + +@BaseModel.register("multi_fitting") +class DPMultiFittingModel(DPModelCommon, DPMultiFittingModel_): + def __init__( + self, + *args, + **kwargs, + ): + DPModelCommon.__init__(self) + DPMultiFittingModel_.__init__(self, *args, **kwargs) diff --git a/deepmd/pt/model/model/make_multi_fitting_model.py b/deepmd/pt/model/model/make_multi_fitting_model.py new file mode 100644 index 0000000000..0f1fdc9399 --- /dev/null +++ b/deepmd/pt/model/model/make_multi_fitting_model.py @@ -0,0 +1,415 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Tuple, + Type, +) + +import torch + +from deepmd.dpmodel import ( + ModelOutputDef, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableCategory, + OutputVariableOperation, + check_operation_applied, +) +from deepmd.pt.model.atomic_model.base_atomic_model import ( + BaseAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) +from deepmd.pt.utils.env import ( + GLOBAL_PT_ENER_FLOAT_PRECISION, + GLOBAL_PT_FLOAT_PRECISION, + PRECISION_DICT, + RESERVED_PRECISON_DICT, +) +from deepmd.pt.utils.nlist import ( + nlist_distinguish_types, +) +from deepmd.utils.path import ( + DPPath, +) + + +def make_multi_fitting_model(T_AtomicModel: Type[BaseAtomicModel]): + class CM(BaseModel): + def __init__( + self, + *args, + # underscore to prevent conflict with normal inputs + atomic_model_: Optional[T_AtomicModel] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + if atomic_model_ is not None: + self.atomic_model: T_AtomicModel = atomic_model_ + else: + self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs) + self.precision_dict = PRECISION_DICT + self.reverse_precision_dict = RESERVED_PRECISON_DICT + self.global_pt_float_precision = GLOBAL_PT_FLOAT_PRECISION + self.global_pt_ener_float_precision = GLOBAL_PT_ENER_FLOAT_PRECISION + + def model_output_def(self): + """Get the output def for the model.""" + return ModelOutputDef(self.atomic_output_def()) + + @torch.jit.export + def model_output_type(self) -> List[str]: + """Get the output type for the model.""" + output_def = self.model_output_def() + var_defs = output_def.var_defs + # jit: Comprehension ifs are not supported yet + # type hint is critical for JIT + vars: List[str] = [] + for kk, vv in var_defs.items(): + # .value is critical for JIT + if vv.category == OutputVariableCategory.OUT.value: + vars.append(kk) + return vars + + def get_out_bias(self) -> torch.Tensor: + return self.atomic_model.get_out_bias() + + def change_out_bias( + self, + merged, + bias_adjust_mode="change-by-statistic", + ) -> None: + """Change the output bias of atomic model according to the input data and the pretrained model. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + bias_adjust_mode : str + The mode for changing output bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on labels of target dataset, + and do least square on the errors to obtain the target shift as bias. + 'set-by-statistic' : directly use the statistic output bias in the target dataset. + """ + self.atomic_model.change_out_bias( + merged, + bias_adjust_mode=bias_adjust_mode, + ) + + def input_type_cast( + self, + coord: torch.Tensor, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + str, + ]: + """Cast the input data to global float type.""" + input_prec = self.reverse_precision_dict[coord.dtype] + ### + ### type checking would not pass jit, convert to coord prec anyway + ### + # for vv, kk in zip([fparam, aparam], ["frame", "atomic"]): + # if vv is not None and self.reverse_precision_dict[vv.dtype] != input_prec: + # log.warning( + # f"type of {kk} parameter {self.reverse_precision_dict[vv.dtype]}" + # " does not match" + # f" that of the coordinate {input_prec}" + # ) + _lst: List[Optional[torch.Tensor]] = [ + vv.to(coord.dtype) if vv is not None else None + for vv in [box, fparam, aparam] + ] + box, fparam, aparam = _lst + if ( + input_prec + == self.reverse_precision_dict[self.global_pt_float_precision] + ): + return coord, box, fparam, aparam, input_prec + else: + pp = self.global_pt_float_precision + return ( + coord.to(pp), + box.to(pp) if box is not None else None, + fparam.to(pp) if fparam is not None else None, + aparam.to(pp) if aparam is not None else None, + input_prec, + ) + + def output_type_cast( + self, + model_ret: Dict[str, torch.Tensor], + input_prec: str, + ) -> Dict[str, torch.Tensor]: + """Convert the model output to the input prec.""" + do_cast = ( + input_prec + != self.reverse_precision_dict[self.global_pt_float_precision] + ) + pp = self.precision_dict[input_prec] + odef = self.model_output_def() + for kk in odef.keys(): + if kk not in model_ret.keys(): + # do not return energy_derv_c if not do_atomic_virial + continue + if check_operation_applied(odef[kk], OutputVariableOperation.REDU): + model_ret[kk] = ( + model_ret[kk].to(self.global_pt_ener_float_precision) + if model_ret[kk] is not None + else None + ) + elif do_cast: + model_ret[kk] = ( + model_ret[kk].to(pp) if model_ret[kk] is not None else None + ) + return model_ret + + def format_nlist( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + ): + """Format the neighbor list. + + 1. If the number of neighbors in the `nlist` is equal to sum(self.sel), + it does nothong + + 2. If the number of neighbors in the `nlist` is smaller than sum(self.sel), + the `nlist` is pad with -1. + + 3. If the number of neighbors in the `nlist` is larger than sum(self.sel), + the nearest sum(sel) neighbors will be preseved. + + Known limitations: + + In the case of not self.mixed_types, the nlist is always formatted. + May have side effact on the efficiency. + + Parameters + ---------- + extended_coord + coodinates in extended region. nf x nall x 3 + extended_atype + atomic type in extended region. nf x nall + nlist + neighbor list. nf x nloc x nsel + + Returns + ------- + formated_nlist + the formated nlist. + + """ + mixed_types = self.mixed_types() + nlist = self._format_nlist(extended_coord, nlist, sum(self.get_sel())) + if not mixed_types: + nlist = nlist_distinguish_types(nlist, extended_atype, self.get_sel()) + return nlist + + def _format_nlist( + self, + extended_coord: torch.Tensor, + nlist: torch.Tensor, + nnei: int, + ): + n_nf, n_nloc, n_nnei = nlist.shape + # nf x nall x 3 + extended_coord = extended_coord.view([n_nf, -1, 3]) + rcut = self.get_rcut() + + if n_nnei < nnei: + nlist = torch.cat( + [ + nlist, + -1 + * torch.ones( + [n_nf, n_nloc, nnei - n_nnei], + dtype=nlist.dtype, + device=nlist.device, + ), + ], + dim=-1, + ) + elif n_nnei > nnei: + m_real_nei = nlist >= 0 + nlist = torch.where(m_real_nei, nlist, 0) + # nf x nloc x 3 + coord0 = extended_coord[:, :n_nloc, :] + # nf x (nloc x nnei) x 3 + index = nlist.view(n_nf, n_nloc * n_nnei, 1).expand(-1, -1, 3) + coord1 = torch.gather(extended_coord, 1, index) + # nf x nloc x nnei x 3 + coord1 = coord1.view(n_nf, n_nloc, n_nnei, 3) + # nf x nloc x nnei + rr = torch.linalg.norm(coord0[:, :, None, :] - coord1, dim=-1) + rr = torch.where(m_real_nei, rr, float("inf")) + rr, nlist_mapping = torch.sort(rr, dim=-1) + nlist = torch.gather(nlist, 2, nlist_mapping) + nlist = torch.where(rr > rcut, -1, nlist) + nlist = nlist[..., :nnei] + else: # n_nnei == nnei: + pass # great! + assert nlist.shape[-1] == nnei + return nlist + + def do_grad_r( + self, + var_name: Optional[str] = None, + ) -> bool: + """Tell if the output variable `var_name` is r_differentiable. + if var_name is None, returns if any of the variable is r_differentiable. + """ + return self.atomic_model.do_grad_r(var_name) + + def do_grad_c( + self, + var_name: Optional[str] = None, + ) -> bool: + """Tell if the output variable `var_name` is c_differentiable. + if var_name is None, returns if any of the variable is c_differentiable. + """ + return self.atomic_model.do_grad_c(var_name) + + def serialize(self) -> dict: + return self.atomic_model.serialize() + + @classmethod + def deserialize(cls, data) -> "CM": + return cls(atomic_model_=T_AtomicModel.deserialize(data)) + + @torch.jit.export + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.atomic_model.get_dim_fparam() + + @torch.jit.export + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.atomic_model.get_dim_aparam() + + @torch.jit.export + def get_sel_type(self) -> List[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.atomic_model.get_sel_type() + + @torch.jit.export + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return self.atomic_model.is_aparam_nall() + + @torch.jit.export + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return self.atomic_model.get_rcut() + + @torch.jit.export + def get_type_map(self) -> List[str]: + """Get the type map.""" + return self.atomic_model.get_type_map() + + @torch.jit.export + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.atomic_model.get_nsel() + + @torch.jit.export + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.atomic_model.get_nnei() + + def atomic_output_def(self) -> FittingOutputDef: + """Get the output def of the atomic model.""" + return self.atomic_model.atomic_output_def() + + def compute_or_load_stat( + self, + sampled_func, + stat_file_path: Optional[DPPath] = None, + ): + """Compute or load the statistics.""" + return self.atomic_model.compute_or_load_stat(sampled_func, stat_file_path) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.atomic_model.get_sel() + + def mixed_types(self) -> bool: + """If true, the model + 1. assumes total number of atoms aligned across frames; + 2. uses a neighbor list that does not distinguish different atomic types. + + If false, the model + 1. assumes total number of atoms of each atom type aligned across frames; + 2. uses a neighbor list that distinguishes different atomic types. + + """ + return self.atomic_model.mixed_types() + + @torch.jit.export + def has_message_passing(self) -> bool: + """Returns whether the model has message passing.""" + return self.atomic_model.has_message_passing() + + @staticmethod + def make_pairs(nlist, mapping): + """ + return the pairs from nlist and mapping + pairs: + [[i1, j1, 0], [i2, j2, 0], ...], + in which i and j are the local indices of the atoms + """ + nframes, nloc, nsel = nlist.shape + assert nframes == 1 + nlist_reshape = torch.reshape(nlist, [nframes, nloc * nsel, 1]) + mask = nlist_reshape.ge(0) + + ii = torch.arange(nloc, dtype=torch.int64, device=nlist.device) + ii = torch.tile(ii.reshape(-1, 1), [1, nsel]) + ii = torch.reshape(ii, [nframes, nloc * nsel, 1]) + sel_ii = torch.masked_select(ii, mask) + sel_ii = torch.reshape(sel_ii, [nframes, -1, 1]) + + # nf x (nloc x nsel) + sel_nlist = torch.masked_select(nlist_reshape, mask) + sel_jj = torch.gather(mapping, 1, sel_nlist.reshape(nframes, -1)) + sel_jj = torch.reshape(sel_jj, [nframes, -1, 1]) + + # nframes x (nloc x nsel) x 3 + pairs = torch.zeros( + nframes, nloc * nsel, 1, dtype=torch.int64, device=nlist.device + ) + pairs = torch.masked_select(pairs, mask) + pairs = torch.reshape(pairs, [nframes, -1, 1]) + + pairs = torch.concat([sel_ii, sel_jj, pairs], -1) + + # select the pair with jj > ii + mask = pairs[..., 1] > pairs[..., 0] + pairs = torch.masked_select(pairs, mask.reshape(nframes, -1, 1)) + pairs = torch.reshape(pairs, [nframes, -1, 3]) + return pairs + + return CM diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index cceadb38d2..e1e0c26e5d 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -33,6 +33,7 @@ from deepmd.pt.model.model import ( EnergyModel, get_model, + get_multi_fitting_model, get_zbl_model, ) from deepmd.pt.optimizer import ( @@ -268,6 +269,8 @@ def get_single_model( ): if "use_srtab" in _model_params: model = get_zbl_model(deepcopy(_model_params)).to(DEVICE) + elif "fitting_net_dict" in _model_params: + model = get_multi_fitting_model(deepcopy(_model_params)).to(DEVICE) else: model = get_model(deepcopy(_model_params)).to(DEVICE) return model diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index bbb203eea9..1b007af79e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1613,6 +1613,7 @@ def model_args(exclude_hybrid=False): standard_model_args(), frozen_model_args(), pairtab_model_args(), + multi_fitting_model_args(), *hybrid_models, ], optional=True, @@ -1697,6 +1698,28 @@ def pairtab_model_args() -> Argument: return ca +def multi_fitting_model_args() -> Argument: + doc_descrpt = "The descriptor of atomic environment." + doc_fitting_dict = "The dict of fitting net of physical properties." + + ca = Argument( + "multi_fitting", + dict, + [ + Argument( + "descriptor", dict, [], [descrpt_variant_type_args()], doc=doc_descrpt + ), + Argument( + "fitting_net_dict", + dict, + [], + doc=doc_fitting_dict, + ), + ], + ) + return ca + + def linear_ener_model_args() -> Argument: doc_weights = ( "If the type is list of float, a list of weights for each model. "