Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 11, 2024
1 parent da5ca3d commit 91664b3
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 22 deletions.
2 changes: 1 addition & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def train(FLAGS):
config["model"], shared_links = preprocess_shared_params(config["model"])

multi_fitting_net = "fitting_net_dict" in config["model"]

# argcheck
if not (multi_task or multi_fitting_net):
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
Expand Down
6 changes: 3 additions & 3 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from .dp_atomic_model import (
DPAtomicModel,
)
from .dp_multi_fitting_atomic_model import (
DPMultiFittingAtomicModel,
)
from .energy_atomic_model import (
DPEnergyAtomicModel,
)
Expand All @@ -39,9 +42,6 @@
from .polar_atomic_model import (
DPPolarAtomicModel,
)
from .dp_multi_fitting_atomic_model import (
DPMultiFittingAtomicModel,
)

__all__ = [
"BaseAtomicModel",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def forward_atomic(

def get_out_bias(self) -> torch.Tensor:
return self.out_bias

def compute_or_load_stat(
self,
sampled_func,
Expand Down
7 changes: 4 additions & 3 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@
from .frozen import (
FrozenModel,
)
from .make_multi_fitting_model import (
make_multi_fitting_model,
)
from .make_hessian_model import (
make_hessian_model,
)
from .make_model import (
make_model,
)
from .make_multi_fitting_model import (
make_multi_fitting_model,
)
from .model import (
BaseModel,
)
Expand Down Expand Up @@ -193,6 +193,7 @@ def get_multi_fitting_model(model_params):
model.model_def_script = json.dumps(model_params_old)
return model


def get_standard_model(model_params):
model_params_old = model_params
model_params = copy.deepcopy(model_params)
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/model/model/dp_multi_fitting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

@BaseModel.register("multi_fitting")
class DPMultiFittingModel(DPModelCommon, DPMultiFittingModel_):

def __init__(
self,
*args,
Expand Down
25 changes: 13 additions & 12 deletions deepmd/pt/model/model/make_multi_fitting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@


def make_multi_fitting_model(T_AtomicModel: Type[BaseAtomicModel]):

class CM(BaseModel):
def __init__(
self,
Expand All @@ -48,7 +47,7 @@ def __init__(
atomic_model_: Optional[T_AtomicModel] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
if atomic_model_ is not None:
self.atomic_model: T_AtomicModel = atomic_model_
else:
Expand All @@ -75,7 +74,7 @@ def model_output_type(self) -> List[str]:
if vv.category == OutputVariableCategory.OUT.value:
vars.append(kk)
return vars

def get_out_bias(self) -> torch.Tensor:
return self.atomic_model.get_out_bias()

Expand Down Expand Up @@ -267,7 +266,7 @@ def _format_nlist(
pass # great!
assert nlist.shape[-1] == nnei
return nlist

def do_grad_r(
self,
var_name: Optional[str] = None,
Expand Down Expand Up @@ -373,38 +372,40 @@ def mixed_types(self) -> bool:
def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

@staticmethod
def make_pairs(nlist, mapping):
"""
return the pairs from nlist and mapping
pairs:
[[i1, j1, 0], [i2, j2, 0], ...],
[[i1, j1, 0], [i2, j2, 0], ...],
in which i and j are the local indices of the atoms
"""
nframes, nloc, nsel = nlist.shape
assert nframes == 1
nlist_reshape = torch.reshape(nlist, [nframes, nloc * nsel, 1])
mask = nlist_reshape.ge(0)

ii = torch.arange(nloc, dtype=torch.int64, device=nlist.device)
ii = torch.tile(ii.reshape(-1, 1), [1, nsel])
ii = torch.reshape(ii, [nframes, nloc * nsel, 1])
sel_ii = torch.masked_select(ii, mask)
sel_ii = torch.reshape(sel_ii, [nframes, -1, 1])

# nf x (nloc x nsel)
sel_nlist = torch.masked_select(nlist_reshape, mask)
sel_jj = torch.gather(mapping, 1, sel_nlist.reshape(nframes, -1))
sel_jj = torch.reshape(sel_jj, [nframes, -1, 1])

# nframes x (nloc x nsel) x 3
pairs = torch.zeros(nframes, nloc * nsel, 1, dtype=torch.int64, device=nlist.device)
pairs = torch.zeros(
nframes, nloc * nsel, 1, dtype=torch.int64, device=nlist.device
)
pairs = torch.masked_select(pairs, mask)
pairs = torch.reshape(pairs, [nframes, -1, 1])

pairs = torch.concat([sel_ii, sel_jj, pairs], -1)

# select the pair with jj > ii
mask = pairs[..., 1] > pairs[..., 0]
pairs = torch.masked_select(pairs, mask.reshape(nframes, -1, 1))
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from deepmd.pt.model.model import (
EnergyModel,
get_model,
get_zbl_model,
get_multi_fitting_model,
get_zbl_model,
)
from deepmd.pt.optimizer import (
KFOptimizerWrapper,
Expand Down

0 comments on commit 91664b3

Please sign in to comment.