Skip to content

Commit

Permalink
Add property bias UT
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang committed Dec 15, 2024
1 parent 3249891 commit 8e9bbc5
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 30 deletions.
3 changes: 2 additions & 1 deletion deepmd/dpmodel/fitting/property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class PropertyFittingNet(InvarFitting):
property_name:
The names of fitting properties, which should be consistent with the property names in the dataset.
If the data file is named `humo.npy`, this parameter should be "humo" or ["humo"].
If you want to fit two properties at the same time, supposing that the data files are named `humo.npy` and `lumo.npy`, this parameter should be `["humo", "lumo"]`.
If you want to fit two properties at the same time, supposing that the data files are named `humo.npy` and `lumo.npy`,
this parameter should be `["humo", "lumo"]`.
property_dim:
The dimensions of fitting properties, which should be consistent with the property dimensions in the dataset.
Note that the order here must be the same as the order of `property_name`.
Expand Down
1 change: 1 addition & 0 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ def test_property(
assert isinstance(property_name, list)
assert isinstance(property_dim, list)
assert sum(property_dim) == dp.task_dim
assert len(property_name) == len(property_dim), f"The shape of the `property_name` you provide must be consistent with the `property_dim`, but your `property_name` is {property_name} and your `property_dim` is {property_dim}!"
for name, dim in zip(property_name, property_dim):
data.add(name, dim, atomic=False, must=True, high_prec=True)
if has_atom_property:
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def get_intensive(self) -> bool:
return self.dp.model["Default"].get_intensive()

def get_property_name(self) -> Union[list[str], str]:
"""Get the name of the property."""
"""Get the names of the properties."""
return self.dp.model["Default"].get_property_name()

def get_property_dim(self) -> Union[list[int], int]:
"""Get the dimension of the property."""
"""Get the dimensions of the properties."""
return self.dp.model["Default"].get_property_dim()

@property
Expand Down
29 changes: 17 additions & 12 deletions deepmd/pt/loss/property.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,27 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
label["property"] = torch.cat(concat_property, dim=1)
assert label["property"].shape == (nbz, self.task_dim)

out_std = (
model.atomic_model.out_std[0][0]
if self.out_std is None
else torch.tensor(
if self.out_std is None:
out_std = model.atomic_model.out_std[0][0]
else:
out_std = torch.tensor(
self.out_std, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
)
out_bias = (
model.atomic_model.out_bias[0][0]
if self.out_bias is None
else torch.tensor(
if out_std.shape != (self.task_dim,):
raise ValueError(
f"Expected out_std to have shape ({self.task_dim},), but got {out_std.shape}"
)

if self.out_bias is None:
out_bias = model.atomic_model.out_bias[0][0]
else:
out_bias = torch.tensor(
self.out_bias, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
)
assert len(out_std.shape) == 1
assert out_std.shape[0] == self.task_dim
if out_bias.shape != (self.task_dim,):
raise ValueError(
f"Expected out_bias to have shape ({self.task_dim},), but got {out_bias.shape}"
)

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
Expand Down
16 changes: 14 additions & 2 deletions deepmd/pt/model/task/property.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,15 @@ class PropertyFittingNet(InvarFitting):
dim_descrpt : int
Embedding width per atom.
task_dim : int
The dimension of outputs of fitting net.
The dimension of outputs of fitting net.
property_name:
The names of fitting properties, which should be consistent with the property names in the dataset.
If the data file is named `humo.npy`, this parameter should be "humo" or ["humo"].
If you want to fit two properties at the same time, supposing that the data files are named `humo.npy` and `lumo.npy`,
this parameter should be `["humo", "lumo"]`.
property_dim:
The dimensions of fitting properties, which should be consistent with the property dimensions in the dataset.
Note that the order here must be the same as the order of `property_name`.
neuron : list[int]
Number of neurons in each hidden layers of the fitting net.
bias_atom_p : torch.Tensor, optional
Expand Down Expand Up @@ -94,9 +102,13 @@ def __init__(
self.intensive = intensive
if isinstance(property_name, str):
property_name = [property_name]
self.property_name = property_name
if isinstance(property_dim, int):
property_dim = [property_dim]
assert len(property_name) == len(property_dim), (
f"The number of property names ({len(property_name)}) must match "
f"the number of property dimensions ({len(property_dim)})."
)
self.property_name = property_name
self.property_dim = property_dim
super().__init__(
var_name="property",
Expand Down
3 changes: 1 addition & 2 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,8 +1613,7 @@ def fitting_property():
Argument(
"property_name",
[str, list],
optional=True,
default="property",
optional=False,
doc=doc_property_name,
),
Argument(
Expand Down
15 changes: 8 additions & 7 deletions deepmd/utils/out_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,23 +129,22 @@ def compute_stats_from_atomic(
)
return output_bias, output_std


def compute_stats_property(
output_redu: np.ndarray,
natoms: np.ndarray,
assigned_bias: Optional[np.ndarray] = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute the output statistics.
"""Compute the mean value and standard deviation of reduced output.
Given the reduced output value and the number of atoms for each atom,
compute the least-squares solution as the atomic output bias and std.
Given the reduced output value, compute the mean value and standard deviation of output.
Parameters
----------
output_redu
The reduced output value, shape is [nframes, *(odim0, odim1, ...)].
natoms
The number of atoms for each atom, shape is [nframes, ntypes].
It is used to generate a fake bias in property fitting.
assigned_bias
The assigned output bias, shape is [ntypes, *(odim0, odim1, ...)].
Set to a tensor of shape (odim0, odim1, ...) filled with nan if the bias
Expand All @@ -154,9 +153,11 @@ def compute_stats_property(
Returns
-------
np.ndarray
The computed output bias, shape is [ntypes, *(odim0, odim1, ...)].
The computed output mean(fake bias), shape is [ntypes, *(odim0, odim1, ...)].
In property fitting, we assume that the atom output is not element-dependent,
i.e., the `bias` is the same for each atom (they are all mean value of reduced output).
np.ndarray
The computed output std, shape is [*(odim0, odim1, ...)].
The computed output standard deviation, shape is [*(odim0, odim1, ...)].
"""
natoms = np.array(natoms) # [nf, ntypes]
nf, ntypes = natoms.shape
Expand All @@ -173,7 +174,7 @@ def compute_stats_property(
)
output_std = np.std(output_redu, axis=0)

computed_output_bias = computed_output_bias.reshape([natoms.shape[1]] + var_shape) # noqa: RUF005
computed_output_bias = computed_output_bias.reshape([natoms.shape[1]] + var_shape)
output_std = output_std.reshape(var_shape)

return computed_output_bias, output_std
3 changes: 2 additions & 1 deletion examples/property/train/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
"fitting_net": {
"type": "property",
"intensive": true,
"task_dim": 3,
"property_name": "band_prop",
"property_dim": 3,
"neuron": [
240,
240,
Expand Down
31 changes: 30 additions & 1 deletion source/tests/common/dpmodel/test_output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,20 @@ def test_model_output_def(self) -> None:
atomic=True,
intensive=True,
),
OutputVariableDef(
"property",
[6],
reducible=True,
r_differentiable=False,
c_differentiable=False,
atomic=True,
intensive=True,
sub_var_name=["foo","bar"],
),
]
# fitting definition
fd = FittingOutputDef(defs)
expected_keys = ["energy", "energy2", "energy3", "dos", "foo", "gap"]
expected_keys = ["energy", "energy2", "energy3", "dos", "foo", "gap", "property"]
self.assertEqual(
set(expected_keys),
set(fd.keys()),
Expand All @@ -101,20 +111,23 @@ def test_model_output_def(self) -> None:
self.assertEqual(fd["dos"].shape, [10])
self.assertEqual(fd["foo"].shape, [3])
self.assertEqual(fd["gap"].shape, [13])
self.assertEqual(fd["property"].shape, [6])
# atomic
self.assertEqual(fd["energy"].atomic, True)
self.assertEqual(fd["energy2"].atomic, True)
self.assertEqual(fd["energy3"].atomic, True)
self.assertEqual(fd["dos"].atomic, True)
self.assertEqual(fd["foo"].atomic, True)
self.assertEqual(fd["gap"].atomic, True)
self.assertEqual(fd["property"].atomic, True)
# reduce
self.assertEqual(fd["energy"].reducible, True)
self.assertEqual(fd["energy2"].reducible, True)
self.assertEqual(fd["energy3"].reducible, True)
self.assertEqual(fd["dos"].reducible, True)
self.assertEqual(fd["foo"].reducible, False)
self.assertEqual(fd["gap"].reducible, True)
self.assertEqual(fd["property"].reducible, True)
# derivative
self.assertEqual(fd["energy"].r_differentiable, True)
self.assertEqual(fd["energy"].c_differentiable, True)
Expand All @@ -128,16 +141,19 @@ def test_model_output_def(self) -> None:
self.assertEqual(fd["dos"].r_differentiable, False)
self.assertEqual(fd["foo"].r_differentiable, False)
self.assertEqual(fd["gap"].r_differentiable, False)
self.assertEqual(fd["property"].r_differentiable, False)
self.assertEqual(fd["dos"].c_differentiable, False)
self.assertEqual(fd["foo"].c_differentiable, False)
self.assertEqual(fd["gap"].c_differentiable, False)
self.assertEqual(fd["property"].c_differentiable, False)
# magnetic
self.assertEqual(fd["energy"].magnetic, False)
self.assertEqual(fd["energy2"].magnetic, False)
self.assertEqual(fd["energy3"].magnetic, True)
self.assertEqual(fd["dos"].magnetic, False)
self.assertEqual(fd["foo"].magnetic, False)
self.assertEqual(fd["gap"].magnetic, False)
self.assertEqual(fd["property"].magnetic, False)
# model definition
md = ModelOutputDef(fd)
expected_keys = [
Expand Down Expand Up @@ -166,6 +182,8 @@ def test_model_output_def(self) -> None:
"mask_mag",
"gap",
"gap_redu",
"property",
"property_redu",
]
self.assertEqual(
set(expected_keys),
Expand All @@ -180,6 +198,7 @@ def test_model_output_def(self) -> None:
self.assertEqual(md["dos"].reducible, True)
self.assertEqual(md["foo"].reducible, False)
self.assertEqual(md["gap"].reducible, True)
self.assertEqual(md["property"].reducible, True)
# derivative
self.assertEqual(md["energy"].r_differentiable, True)
self.assertEqual(md["energy"].c_differentiable, True)
Expand All @@ -193,9 +212,11 @@ def test_model_output_def(self) -> None:
self.assertEqual(md["dos"].r_differentiable, False)
self.assertEqual(md["foo"].r_differentiable, False)
self.assertEqual(md["gap"].r_differentiable, False)
self.assertEqual(md["property"].c_differentiable, False)
self.assertEqual(md["dos"].c_differentiable, False)
self.assertEqual(md["foo"].c_differentiable, False)
self.assertEqual(md["gap"].c_differentiable, False)
self.assertEqual(md["property"].magnetic, False)
# shape
self.assertEqual(md["mask"].shape, [1])
self.assertEqual(md["mask_mag"].shape, [1])
Expand All @@ -220,6 +241,7 @@ def test_model_output_def(self) -> None:
self.assertEqual(md["energy3_derv_c_mag"].shape, [1, 9])
self.assertEqual(md["gap"].shape, [13])
self.assertEqual(md["gap_redu"].shape, [13])
self.assertEqual(md["property"].shape, [6])
# atomic
self.assertEqual(md["energy"].atomic, True)
self.assertEqual(md["energy2"].atomic, True)
Expand All @@ -242,6 +264,8 @@ def test_model_output_def(self) -> None:
self.assertEqual(md["energy3_derv_c_redu"].atomic, False)
self.assertEqual(md["gap"].atomic, True)
self.assertEqual(md["gap_redu"].atomic, False)
self.assertEqual(md["property"].atomic, True)
self.assertEqual(md["property_redu"].atomic, False)
# category
self.assertEqual(md["mask"].category, OutputVariableCategory.OUT)
self.assertEqual(md["mask_mag"].category, OutputVariableCategory.OUT)
Expand Down Expand Up @@ -279,6 +303,8 @@ def test_model_output_def(self) -> None:
)
self.assertEqual(md["gap"].category, OutputVariableCategory.OUT)
self.assertEqual(md["gap_redu"].category, OutputVariableCategory.REDU)
self.assertEqual(md["property"].category, OutputVariableCategory.OUT)
self.assertEqual(md["property_redu"].category, OutputVariableCategory.REDU)
# flag
OVO = OutputVariableOperation
self.assertEqual(md["energy"].category & OVO.REDU, 0)
Expand All @@ -299,6 +325,9 @@ def test_model_output_def(self) -> None:
self.assertEqual(md["gap"].category & OVO.REDU, 0)
self.assertEqual(md["gap"].category & OVO.DERV_R, 0)
self.assertEqual(md["gap"].category & OVO.DERV_C, 0)
self.assertEqual(md["property"].category & OVO.REDU, 0)
self.assertEqual(md["property"].category & OVO.DERV_R, 0)
self.assertEqual(md["property"].category & OVO.DERV_C, 0)
# flag: energy
self.assertEqual(
md["energy_redu"].category & OVO.REDU,
Expand Down
7 changes: 7 additions & 0 deletions source/tests/common/test_out_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd.utils.out_stat import (
compute_stats_from_atomic,
compute_stats_from_redu,
compute_stats_property,
)


Expand Down Expand Up @@ -89,6 +90,12 @@ def test_compute_stats_from_redu_with_assigned_bias(self) -> None:
rtol=1e-7,
)

def test_compute_stats_property(self) -> None:
bias, std = compute_stats_property(self.output_redu, self.natoms)
for fake_atom_bias in bias:
np.testing.assert_allclose(fake_atom_bias, np.mean(self.output_redu,axis=0), rtol=1e-7)
np.testing.assert_allclose(std, np.std(self.output_redu,axis=0), rtol=1e-7)

def test_compute_stats_from_atomic(self) -> None:
bias, std = compute_stats_from_atomic(self.output, self.atype)
np.testing.assert_allclose(bias, self.mean)
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@
"fitting_net": {
"type": "property",
"task_dim": 3,
"property_name": ["foo", "bar"],
"property_dim": [1, 2],
"property_name": "band_property",
"property_dim": 3,
"neuron": [24, 24, 24],
"resnet_dt": True,
"intensive": True,
Expand Down

0 comments on commit 8e9bbc5

Please sign in to comment.