Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more methods to DeepPot #175

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion deepmd_pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable, Optional, Tuple, Union, List
from deepmd_pt.utils import env
from deepmd_pt.utils.auto_batch_size import AutoBatchSize
from deepmd_utils.infer.deep_pot import DeepPot as DeepPotBase


class DeepEval:
Expand Down Expand Up @@ -53,7 +54,7 @@ def eval(
raise NotImplementedError


class DeepPot(DeepEval):
class DeepPot(DeepEval, DeepPotBase):
def __init__(
self,
model_file: "Path",
Expand Down Expand Up @@ -177,6 +178,22 @@ def _eval_model(
else:
return energy_out, force_out, virial_out, atomic_energy_out, atomic_virial_out

def get_ntypes(self) -> int:
"""Get the number of atom types of this model."""
return len(self.type_map)

def get_type_map(self) -> List[str]:
"""Get the type map (element name of the atom types) of this model."""
return self.type_map

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this DP."""
return 0

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this DP."""
return 0


# For tests only
def eval_model(
Expand Down