Skip to content

Commit

Permalink
Add interface to multi-fitting architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
ChiahsinChu committed Jun 11, 2024
1 parent a7ab1af commit 332be00
Show file tree
Hide file tree
Showing 9 changed files with 808 additions and 2 deletions.
3 changes: 2 additions & 1 deletion deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Dict,
List,
Optional,
Union,
)

from deepmd.dpmodel.output_def import (
Expand Down Expand Up @@ -87,7 +88,7 @@ def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""

@abstractmethod
def get_sel_type(self) -> List[int]:
def get_sel_type(self) -> Union[List[int], List[List[int]]]:
"""Get the selected atom types of this model.
Only atoms with selected atom types have atomic contribution
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def train(FLAGS):
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

multi_fitting_net = "fitting_net_dict" in config["model"]

# argcheck
if not multi_task:
if not (multi_task or multi_fitting_net):
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from .dp_atomic_model import (
DPAtomicModel,
)
from .dp_multi_fitting_atomic_model import (
DPMultiFittingAtomicModel,
)
from .energy_atomic_model import (
DPEnergyAtomicModel,
)
Expand All @@ -50,4 +53,5 @@
"DPPolarAtomicModel",
"DPDipoleAtomicModel",
"DPZBLLinearEnergyAtomicModel",
"DPMultiFittingAtomicModel",
]
282 changes: 282 additions & 0 deletions deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import functools
import logging
from typing import (
Dict,
List,
Optional,
)

import torch

from deepmd.dpmodel import (
FittingOutputDef,
)
from deepmd.pt.model.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .base_atomic_model import (
BaseAtomicModel,
)

log = logging.getLogger(__name__)


@BaseAtomicModel.register("multi_fitting")
class DPMultiFittingAtomicModel(BaseAtomicModel):
"""Model give atomic prediction of some physical property.
Parameters
----------
descriptor
Descriptor
fitting_dict
Dict of Fitting net
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
"""

def __init__(
self,
descriptor,
fitting_dict,
type_map: Optional[List[str]],
**kwargs,
):
super().__init__(type_map, **kwargs)
ntypes = len(type_map)
self.type_map = type_map

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute type_map, which was previously defined in superclass
BaseAtomicModel
.
self.ntypes = ntypes
self.descriptor = descriptor
self.rcut = self.descriptor.get_rcut()
self.sel = self.descriptor.get_sel()
fitting_dict = copy.deepcopy(fitting_dict)
self.model_type = fitting_dict.pop("type")
self.fitting_net_dict = fitting_dict
self.fitting_net = fitting_dict
super().init_out_stat()

Check warning on line 69 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L58-L69

Added lines #L58 - L69 were not covered by tests

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
l = []
for name, fitting_net in self.fitting_net_dict.items():
for vdef in fitting_net.output_def().var_defs.values():
vdef.name = name
l.append(vdef)
return FittingOutputDef(l)

Check warning on line 78 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L73-L78

Added lines #L73 - L78 were not covered by tests

@torch.jit.export
def get_rcut(self) -> float:
"""Get the cut-off radius."""
return self.rcut

Check warning on line 83 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L83

Added line #L83 was not covered by tests

@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

Check warning on line 88 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L88

Added line #L88 was not covered by tests

def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.sel

Check warning on line 92 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L92

Added line #L92 was not covered by tests

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
2. uses a neighbor list that does not distinguish different atomic types.
If false, the model
1. assumes total number of atoms of each atom type aligned across frames;
2. uses a neighbor list that distinguishes different atomic types.
"""
return self.descriptor.mixed_types()

Check warning on line 104 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L104

Added line #L104 was not covered by tests

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

Check warning on line 108 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L108

Added line #L108 was not covered by tests

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(

Check warning on line 112 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L111-L112

Added lines #L111 - L112 were not covered by tests
{
"@class": "Model",
"@version": 2,
"type": "multi_fitting",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": [
fitting_net.serialize()
for fitting_net in self.fitting_net_dict.values()
],
"fitting_name": self.fitting_net_dict.keys(),
}
)
return dd

Check warning on line 126 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L126

Added line #L126 was not covered by tests

@classmethod
def deserialize(cls, data) -> "DPMultiFittingAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))

Check warning on line 134 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L130-L134

Added lines #L130 - L134 were not covered by tests

fitting_dict = {}
fitting_names = data["fitting_name"]
for name, fitting in zip(fitting_names, data.pop("fitting")):
fitting_obj = BaseFitting.deserialize(fitting)
fitting_dict[name] = fitting_obj

Check warning on line 140 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L136-L140

Added lines #L136 - L140 were not covered by tests
# type_map = data.pop("type_map", None)
# obj = cls(descriptor_obj, fitting_dict, type_map=type_map, **data)
data["descriptor"] = descriptor_obj
data["fitting"] = list(fitting_dict.values())
data["fitting_name"] = list(fitting_dict.keys())
obj = super().deserialize(data)
return obj

Check warning on line 147 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L143-L147

Added lines #L143 - L147 were not covered by tests

def forward_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Return atomic prediction.
Parameters
----------
extended_coord
coodinates in extended region
extended_atype
atomic type in extended region
nlist
neighbor list. nf x nloc x nsel
mapping
mapps the extended indices to local indices
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
Returns
-------
result_dict
the result dict, defined by the `FittingOutputDef`.
"""
nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc]
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
descriptor, rot_mat, g2, h2, sw = self.descriptor(

Check warning on line 186 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L182-L186

Added lines #L182 - L186 were not covered by tests
extended_coord,
extended_atype,
nlist,
mapping=mapping,
comm_dict=comm_dict,
)
assert descriptor is not None
fit_ret_dict = {}
for name, fitting_net in self.fitting_net_dict.items():
fitting = fitting_net(

Check warning on line 196 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L193-L196

Added lines #L193 - L196 were not covered by tests
descriptor,
atype,
gr=rot_mat,
g2=g2,
h2=h2,
fparam=fparam,
aparam=aparam,
)
for v in fitting.values():
fit_ret_dict[name] = v
return fit_ret_dict

Check warning on line 207 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L205-L207

Added lines #L205 - L207 were not covered by tests

def get_out_bias(self) -> torch.Tensor:
return self.out_bias

Check warning on line 210 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L210

Added line #L210 was not covered by tests

def compute_or_load_stat(
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.
Parameters
----------
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
"""
if stat_file_path is not None and self.type_map is not None:

Check warning on line 232 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L232

Added line #L232 was not covered by tests
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)

Check warning on line 235 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L235

Added line #L235 was not covered by tests

@functools.lru_cache
def wrapped_sampler():
sampled = sampled_func()
if self.pair_excl is not None:
pair_exclude_types = self.pair_excl.get_exclude_types()
for sample in sampled:
sample["pair_exclude_types"] = list(pair_exclude_types)
if self.atom_excl is not None:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
return sampled

Check warning on line 248 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L237-L248

Added lines #L237 - L248 were not covered by tests

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

Check warning on line 251 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L250-L251

Added lines #L250 - L251 were not covered by tests

@torch.jit.export
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return list(self.fitting_net_dict.values())[0].get_dim_fparam()

Check warning on line 256 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L256

Added line #L256 was not covered by tests

@torch.jit.export
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return list(self.fitting_net_dict.values())[0].get_dim_aparam()

Check warning on line 261 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L261

Added line #L261 was not covered by tests

@torch.jit.export
def get_sel_type(self) -> List[List[int]]:
"""Get the selected atom types of this model.
Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""
sel_type = []
for fitting_net in self.fitting_net_dict.values():
sel_type.append(fitting_net.get_sel_type())
return sel_type

Check warning on line 274 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L271-L274

Added lines #L271 - L274 were not covered by tests

@torch.jit.export
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
If False, the shape is (nframes, nloc, ndim).
"""
return False

Check warning on line 282 in deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_multi_fitting_atomic_model.py#L282

Added line #L282 was not covered by tests
Loading

0 comments on commit 332be00

Please sign in to comment.