Skip to content

Commit

Permalink
Feat: add se_atten_v2to PyTorch and DP (#3840)
Browse files Browse the repository at this point in the history
Solve #3831 and #3139
- add `se_atten_v2` to PyTorch and DP
- add document equation for `se_attn_v2`

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a new descriptor class with enhanced configuration options
and methods for serialization and deserialization.
- Added new configurable parameters to the descriptor setup for improved
flexibility.

- **Documentation**
- Updated function documentation to reflect new arguments and usage
instructions.

- **Bug Fixes**
- Refined serialization logic to handle new parameters and class types
more accurately.
- Improved error messages for better clarity during serialization
processes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Chenqqian Zhang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Duo <[email protected]>
  • Loading branch information
3 people authored Jun 6, 2024
1 parent c1a294b commit c4ac5b5
Show file tree
Hide file tree
Showing 10 changed files with 1,021 additions and 10 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .make_base_descriptor import (
make_base_descriptor,
)
from .se_atten_v2 import (
DescrptSeAttenV2,
)
from .se_e2_a import (
DescrptSeA,
)
Expand All @@ -26,6 +29,7 @@
"DescrptSeR",
"DescrptSeT",
"DescrptDPA1",
"DescrptSeAttenV2",
"DescrptDPA2",
"DescrptHybrid",
"make_base_descriptor",
Expand Down
180 changes: 180 additions & 0 deletions deepmd/dpmodel/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
List,
Optional,
Tuple,
Union,
)

import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
PRECISION_DICT,
)
from deepmd.dpmodel.utils import (
NetworkCollection,
)
from deepmd.dpmodel.utils.type_embed import (
TypeEmbedNet,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .base_descriptor import (
BaseDescriptor,
)
from .dpa1 import (
DescrptDPA1,
NeighborGatedAttention,
)


@BaseDescriptor.register("se_atten_v2")
class DescrptSeAttenV2(DescrptDPA1):
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: Union[List[int], int],
ntypes: int,
neuron: List[int] = [25, 50, 100],
axis_neuron: int = 8,
tebd_dim: int = 8,
resnet_dt: bool = False,
trainable: bool = True,
type_one_side: bool = False,
attn: int = 128,
attn_layer: int = 2,
attn_dotr: bool = True,
attn_mask: bool = False,
exclude_types: List[Tuple[int, int]] = [],
env_protection: float = 0.0,
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
scaling_factor=1.0,
normalize: bool = True,
temperature: Optional[float] = None,
trainable_ln: bool = True,
ln_eps: Optional[float] = 1e-5,
concat_output_tebd: bool = True,
spin: Optional[Any] = None,
stripped_type_embedding: Optional[bool] = None,
use_econf_tebd: bool = False,
type_map: Optional[List[str]] = None,
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
DescrptDPA1.__init__(
self,
rcut,
rcut_smth,
sel,
ntypes,
neuron=neuron,
axis_neuron=axis_neuron,
tebd_dim=tebd_dim,
tebd_input_mode="strip",
resnet_dt=resnet_dt,
trainable=trainable,
type_one_side=type_one_side,
attn=attn,
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=attn_mask,
exclude_types=exclude_types,
env_protection=env_protection,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
scaling_factor=scaling_factor,
normalize=normalize,
temperature=temperature,
trainable_ln=trainable_ln,
ln_eps=ln_eps,
smooth_type_embedding=True,
concat_output_tebd=concat_output_tebd,
spin=spin,
stripped_type_embedding=stripped_type_embedding,
use_econf_tebd=use_econf_tebd,
type_map=type_map,
# consistent with argcheck, not used though
seed=seed,
)

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
obj = self.se_atten
data = {
"@class": "Descriptor",
"type": "se_atten_v2",
"@version": 1,
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
"ntypes": obj.ntypes,
"neuron": obj.neuron,
"axis_neuron": obj.axis_neuron,
"tebd_dim": obj.tebd_dim,
"set_davg_zero": obj.set_davg_zero,
"attn": obj.attn,
"attn_layer": obj.attn_layer,
"attn_dotr": obj.attn_dotr,
"attn_mask": False,
"activation_function": obj.activation_function,
"resnet_dt": obj.resnet_dt,
"scaling_factor": obj.scaling_factor,
"normalize": obj.normalize,
"temperature": obj.temperature,
"trainable_ln": obj.trainable_ln,
"ln_eps": obj.ln_eps,
"type_one_side": obj.type_one_side,
"concat_output_tebd": self.concat_output_tebd,
"use_econf_tebd": self.use_econf_tebd,
"type_map": self.type_map,
# make deterministic
"precision": np.dtype(PRECISION_DICT[obj.precision]).name,
"embeddings": obj.embeddings.serialize(),
"embeddings_strip": obj.embeddings_strip.serialize(),
"attention_layers": obj.dpa1_attention.serialize(),
"env_mat": obj.env_mat.serialize(),
"type_embedding": self.type_embedding.serialize(),
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": obj["davg"],
"dstd": obj["dstd"],
},
## to be updated when the options are supported.
"trainable": self.trainable,
"spin": None,
}
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptSeAttenV2":
"""Deserialize from dict."""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
data.pop("env_mat")
embeddings_strip = data.pop("embeddings_strip")
obj = cls(**data)

obj.se_atten["davg"] = variables["davg"]
obj.se_atten["dstd"] = variables["dstd"]
obj.se_atten.embeddings = NetworkCollection.deserialize(embeddings)
obj.se_atten.embeddings_strip = NetworkCollection.deserialize(embeddings_strip)
obj.type_embedding = TypeEmbedNet.deserialize(type_embedding)
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
attention_layers
)
return obj
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
DescrptBlockSeA,
DescrptSeA,
)
from .se_atten_v2 import (
DescrptSeAttenV2,
)
from .se_r import (
DescrptSeR,
)
Expand All @@ -42,6 +45,7 @@
"make_default_type_embedding",
"DescrptBlockSeA",
"DescrptBlockSeAtten",
"DescrptSeAttenV2",
"DescrptSeA",
"DescrptSeR",
"DescrptSeT",
Expand Down
Loading

0 comments on commit c4ac5b5

Please sign in to comment.