Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: Add support for dipole and polar training #3380

Merged
merged 11 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ def forward_common_atomic(
if self.atom_excl is not None:
atom_mask = self.atom_excl.build_type_exclude_mask(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask

return ret_dict

Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def call(
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
nf, nloc = nlist.shape[:2]
if "mask" in model_predict_lower:
model_predict_lower["mask"] = model_predict_lower["mask"][:, :nloc]
model_predict = communicate_extended_output(
model_predict_lower,
self.model_output_def(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def fit_output_to_model_output(
"""
model_ret = dict(fit_ret.items())
for kk, vv in fit_ret.items():
if kk in ["mask"]:
continue
iProzd marked this conversation as resolved.
Show resolved Hide resolved
vdef = fit_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
Expand Down Expand Up @@ -59,6 +61,8 @@ def communicate_extended_output(

"""
new_ret = {}
if "mask" in model_ret:
new_ret["mask"] = model_ret["mask"]
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
vdef = model_output_def[kk]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from .loss import (
TaskLoss,
)
from .tensor import (
TensorLoss,
)

__all__ = [
"DenoiseLoss",
"EnergyStdLoss",
"TensorLoss",
"TaskLoss",
]
162 changes: 162 additions & 0 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

import torch

from deepmd.pt.loss.loss import (
TaskLoss,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class TensorLoss(TaskLoss):
def __init__(
self,
tensor_name: str,
tensor_size: int,
label_name: str,
pref_atomic: float = 0.0,
pref: float = 0.0,
inference=False,
**kwargs,
):
r"""Construct a loss for local and global tensors.

Parameters
----------
tensor_name : str
The name of the tensor in the model predictions to compute the loss.
tensor_size : int
The size (dimension) of the tensor.
label_name : str
The name of the tensor in the labels to compute the loss.
pref_atomic : float
The prefactor of the weight of atomic loss. It should be larger than or equal to 0.
pref : float
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.tensor_name = tensor_name
self.tensor_size = tensor_size
self.label_name = label_name
self.local_weight = pref_atomic
self.global_weight = pref
self.inference = inference

assert (
self.local_weight >= 0.0 and self.global_weight >= 0.0
), "Can not assign negative weight to `pref` and `pref_atomic`"
self.has_local_weight = self.local_weight > 0.0 or inference
self.has_global_weight = self.global_weight > 0.0 or inference
assert self.has_local_weight or self.has_global_weight, AssertionError(
"Can not assian zero weight both to `pref` and `pref_atomic`"
)

def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
"""Return loss on local and global tensors.

Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
del learning_rate, mae
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
more_loss = {}
if (
self.has_local_weight
and self.tensor_name in model_pred
and "atomic_" + self.label_name in label
):
local_tensor_pred = model_pred[self.tensor_name].reshape(
[-1, natoms, self.tensor_size]
)
local_tensor_label = label["atomic_" + self.label_name].reshape(
[-1, natoms, self.tensor_size]
)
diff = (local_tensor_pred - local_tensor_label).reshape(
[-1, self.tensor_size]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss = torch.mean(torch.square(diff))
if not self.inference:
more_loss[f"l2_local_{self.tensor_name}_loss"] = l2_local_loss.detach()
loss += self.local_weight * l2_local_loss
rmse_local = l2_local_loss.sqrt()
more_loss[f"rmse_local_{self.tensor_name}"] = rmse_local.detach()
if (
self.has_global_weight
and "global_" + self.tensor_name in model_pred
and self.label_name in label
):
global_tensor_pred = model_pred["global_" + self.tensor_name].reshape(
[-1, self.tensor_size]
)
global_tensor_label = label[self.label_name].reshape([-1, self.tensor_size])
diff = global_tensor_pred - global_tensor_label
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss = torch.mean(
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())
else:
atom_num = natoms
l2_global_loss = torch.mean(torch.square(diff))

Check warning on line 128 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L127-L128

Added lines #L127 - L128 were not covered by tests
if not self.inference:
more_loss[
f"l2_global_{self.tensor_name}_loss"
] = l2_global_loss.detach()
loss += self.global_weight * l2_global_loss
rmse_global = l2_global_loss.sqrt() / atom_num
more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach()
return loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_local_weight:
label_requirement.append(
DataRequirementItem(
"atomic_" + self.label_name,
ndof=self.tensor_size,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_global_weight:
label_requirement.append(
DataRequirementItem(
self.label_name,
ndof=self.tensor_size,
atomic=False,
must=False,
high_prec=False,
)
)
return label_requirement
7 changes: 6 additions & 1 deletion deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def forward_common_atomic(
if self.atom_excl is not None:
atom_mask = self.atom_excl(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask

return ret_dict

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_model(model_params):
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["mixed_types"] = descriptor.mixed_types()
fitting_net["embedding_width"] = descriptor.get_dim_out()
fitting_net["embedding_width"] = descriptor.get_dim_emb()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def forward(
model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze(
-3
)
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def forward(
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
else:
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

@torch.jit.export
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def forward(
)
else:
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def forward_common(
fparam=fp,
aparam=ap,
)
nf, nloc = nlist.shape[:2]
if "mask" in model_predict_lower:
model_predict_lower["mask"] = model_predict_lower["mask"][:, :nloc]
model_predict = communicate_extended_output(
model_predict_lower,
self.model_output_def(),
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def forward(
model_predict = {}
model_predict["polar"] = model_ret["polar"]
model_predict["global_polar"] = model_ret["polar_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def fit_output_to_model_output(
redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION
model_ret = dict(fit_ret.items())
for kk, vv in fit_ret.items():
if kk in ["mask"]:
continue
vdef = fit_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
Expand Down Expand Up @@ -192,6 +194,8 @@ def communicate_extended_output(
"""
redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION
new_ret = {}
if "mask" in model_ret:
new_ret["mask"] = model_ret["mask"]
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
vdef = model_output_def[kk]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def compute_output_stats(
The path to the stat file.

"""
raise NotImplementedError
pass

def forward(
self,
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class EnergyFittingNetDirect(Fitting):
def __init__(
self,
ntypes,
embedding_width,
dim_descrpt,
neuron,
bias_atom_e=None,
out_dim=1,
Expand All @@ -315,7 +315,7 @@ def __init__(
"""
super().__init__()
self.ntypes = ntypes
self.dim_descrpt = embedding_width
self.dim_descrpt = dim_descrpt
self.use_tebd = use_tebd
self.out_dim = out_dim
if bias_atom_e is None:
Expand All @@ -329,7 +329,7 @@ def __init__(
for type_i in range(self.ntypes):
one = ResidualDeep(
type_i,
embedding_width,
dim_descrpt,
neuron,
0.0,
out_dim=out_dim,
Expand All @@ -344,7 +344,7 @@ def __init__(
for type_i in range(self.ntypes):
bias_type = 0.0 if self.use_tebd else bias_atom_e[type_i]
one = ResidualDeep(
type_i, embedding_width, neuron, bias_type, resnet_dt=resnet_dt
type_i, dim_descrpt, neuron, bias_type, resnet_dt=resnet_dt
)
filter_layers.append(one)
self.filter_layers = torch.nn.ModuleList(filter_layers)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def compute_output_stats(
The path to the stat file.

"""
raise NotImplementedError
pass

def forward(
self,
Expand Down
Loading