Skip to content

Commit

Permalink
Merge branch 'devel' into rf_finetune
Browse files Browse the repository at this point in the history
Signed-off-by: Duo <[email protected]>
  • Loading branch information
iProzd authored Jun 6, 2024
2 parents 5664240 + 1a02e56 commit 9f1d473
Show file tree
Hide file tree
Showing 46 changed files with 359 additions and 116 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
exclude: ^source/3rdparty
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.5
rev: v0.4.7
hooks:
- id: ruff
args: ["--fix"]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def fwd(
self,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def mixed_types(self) -> bool:
# to match DPA1 and DPA2.
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ def call(
):
"""Calculate DescriptorBlock."""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ def mixed_types(self) -> bool:
"""
return self.se_atten.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down Expand Up @@ -906,6 +910,10 @@ def call(
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False


class NeighborGatedAttention(NativeOP):
def __init__(
Expand Down
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,12 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def mixed_types(self):
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def call(
rot_mat = np.transpose(h2g2, (0, 1, 3, 2))
return g1, g2, h2, rot_mat.reshape(-1, nloc, self.dim_emb, 3), sw

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True


# translated by GPT and modified
def get_residual(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ def mixed_types(self) -> bool:
"""
return self.atomic_model.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()
Expand Down
3 changes: 3 additions & 0 deletions deepmd/entrypoints/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from deepmd.utils.argcheck import (
gen_doc,
gen_json,
gen_json_schema,
)

__all__ = ["doc_train_input"]
Expand All @@ -15,6 +16,8 @@ def doc_train_input(*, out_type: str = "rst", **kwargs):
doc_str = gen_doc(make_anchor=True)
elif out_type == "json":
doc_str = gen_json()
elif out_type == "json_schema":
doc_str = gen_json_schema()
else:
raise RuntimeError(f"Unsupported out type {out_type}")
print(doc_str) # noqa: T201
2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def main_parser() -> argparse.ArgumentParser:
parsers_doc.add_argument(
"--out-type",
default="rst",
choices=["rst", "json"],
choices=["rst", "json", "json_schema"],
type=str,
help="The output type",
)
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,7 @@ def train(FLAGS):

def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
if '"type": "dpa2"' in model.get_model_def_script():
extra_files = {"type": "dpa2"}
else:
extra_files = {"type": "else"}
extra_files = {}
torch.jit.save(
model,
FLAGS.output,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def slim_type_map(self, type_map: List[str]) -> None:
self.descriptor.slim_type_map(type_map=type_map)
self.fitting_net.slim_type_map(type_map=type_map)

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def get_out_bias(self) -> torch.Tensor:
return self.out_bias

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def mixed_types(self) -> bool:
# to match DPA1 and DPA2.
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def forward(
"""Calculate DescriptorBlock."""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""


def make_default_type_embedding(
ntypes,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ def mixed_types(self) -> bool:
"""
return self.se_atten.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,12 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
# the env_protection of repinit is the same as that of the repformer
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def mixed_types(self):
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,7 @@ def get_stats(self) -> Dict[str, StatItem]:
"The statistics of the descriptor has not been computed."
)
return self.stats

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def mixed_types(self):
"""
return self.sea.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.sea.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.sea.get_env_protection()
Expand Down Expand Up @@ -687,3 +691,7 @@ def forward(
None,
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,10 @@ def forward(
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False


class NeighborGatedAttention(nn.Module):
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def mixed_types(self) -> bool:
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def mixed_types(self):
"""
return self.seat.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.seat.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.seat.get_env_protection()
Expand Down Expand Up @@ -702,3 +706,7 @@ def forward(
None,
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def mixed_types(self) -> bool:
"""
return self.model.mixed_types()

@torch.jit.export
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.model.has_message_passing()

@torch.jit.export
def forward(
self,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,11 @@ def mixed_types(self) -> bool:
"""
return self.atomic_model.mixed_types()

@torch.jit.export
def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def forward(
self,
coord,
Expand Down
Loading

0 comments on commit 9f1d473

Please sign in to comment.