Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: support multitask finetune #3480

Merged
merged 17 commits into from
Mar 22, 2024
Merged
122 changes: 122 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


import logging
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
)

import numpy as np
import torch

from deepmd.dpmodel.atomic_model import (
Expand All @@ -21,10 +24,21 @@
AtomExcludeMask,
PairExcludeMask,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.path import (
DPPath,
)

log = logging.getLogger(__name__)

BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)


Expand Down Expand Up @@ -176,6 +190,59 @@
"pair_exclude_types": self.pair_exclude_types,
}

def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None:
"""
Modify the output bias for the atomic model.

Parameters
----------
out_bias : torch.Tensor
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""
raise NotImplementedError

Check warning on line 206 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L206

Added line #L206 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def get_out_bias(self) -> torch.Tensor:
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"""Return the output bias of the atomic model."""
raise NotImplementedError

Check warning on line 210 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L210

Added line #L210 was not covered by tests

def get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""
model_output_type = list(self.atomic_output_def().keys())
if "mask" in model_output_type:
model_output_type.pop(model_output_type.index("mask"))
out_name = model_output_type[0]

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
)
atomic_ret = self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return atomic_ret[out_name].detach()

return model_forward

def compute_or_load_stat(
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
self,
sampled_func,
Expand All @@ -197,3 +264,58 @@
The path to the statistics files.
"""
raise NotImplementedError

def change_out_bias(
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
"""Change the energy bias according to the input data and the pretrained model.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the energy bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_shift : str
The mode for changing energy bias : ['delta', 'statistic']
'delta' : perform predictions on energies of target dataset,
and do least sqaure on the errors to obtain the target shift as bias.
'statistic' : directly use the statistic energy bias in the target dataset.
"""
sorter = np.argsort(full_type_map)
missing_types = [t for t in origin_type_map if t not in full_type_map]
assert (
not missing_types
), f"Some types are not in the pre-trained model: {list(missing_types)} !"
idx_type_map = sorter[
np.searchsorted(full_type_map, origin_type_map, sorter=sorter)
]
original_bias = self.get_out_bias()
if bias_shift == "delta":
iProzd marked this conversation as resolved.
Show resolved Hide resolved
delta_bias = compute_output_stats(
merged,
self.get_ntypes(),
model_forward=self.get_forward_wrapper_func(),
)
self.set_out_bias(delta_bias, add=True)
elif bias_shift == "statistic":
iProzd marked this conversation as resolved.
Show resolved Hide resolved
bias_atom = compute_output_stats(
merged,
self.get_ntypes(),
)
self.set_out_bias(bias_atom)
else:
raise RuntimeError("Unknown bias_shift mode: " + bias_shift)

Check warning on line 315 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L315

Added line #L315 was not covered by tests
bias_atom = self.get_out_bias()
log.info(
f"Change output bias of {origin_type_map!s} "
f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(bias_atom[idx_type_map]).reshape(-1)!s}."
)
99 changes: 15 additions & 84 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Optional,
)

import numpy as np
import torch

from deepmd.dpmodel import (
Expand All @@ -20,15 +19,6 @@
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -233,85 +223,26 @@ def wrapped_sampler():
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(wrapped_sampler, stat_file_path)

def change_out_bias(
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
"""Change the energy bias according to the input data and the pretrained model.
def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None:
"""
Modify the output bias for the atomic model.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the energy bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_shift : str
The mode for changing energy bias : ['delta', 'statistic']
'delta' : perform predictions on energies of target dataset,
and do least sqaure on the errors to obtain the target shift as bias.
'statistic' : directly use the statistic energy bias in the target dataset.
out_bias : torch.Tensor
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""
sorter = np.argsort(full_type_map)
missing_types = [t for t in origin_type_map if t not in full_type_map]
assert (
not missing_types
), f"Some types are not in the pre-trained model: {list(missing_types)} !"
idx_type_map = sorter[
np.searchsorted(full_type_map, origin_type_map, sorter=sorter)
]
original_bias = self.fitting_net["bias_atom_e"]
if bias_shift == "delta":

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
)
atomic_ret = self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return atomic_ret["energy"].detach()

delta_bias_e = compute_output_stats(
merged,
self.get_ntypes(),
model_forward=model_forward,
)
bias_atom_e = delta_bias_e + original_bias
elif bias_shift == "statistic":
bias_atom_e = compute_output_stats(
merged,
self.get_ntypes(),
)
else:
raise RuntimeError("Unknown bias_shift mode: " + bias_shift)
log.info(
f"Change energy bias of {origin_type_map!s} "
f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(bias_atom_e[idx_type_map]).reshape(-1)!s}."
self.fitting_net["bias_atom_e"] = (
out_bias + self.fitting_net["bias_atom_e"] if add else out_bias
)
self.fitting_net["bias_atom_e"] = bias_atom_e

def get_out_bias(self) -> torch.Tensor:
"""Return the output bias of the atomic model."""
return self.fitting_net["bias_atom_e"]

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand Down
21 changes: 21 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,27 @@
for _ in range(nmodels)
]

def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None:
"""
Modify the output bias for all the models in the linear atomic model.

Parameters
----------
out_bias : torch.Tensor
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""
for model in self.models:
model.set_out_bias(out_bias, add=add)

Check warning on line 306 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L305-L306

Added lines #L305 - L306 were not covered by tests

def get_out_bias(self) -> torch.Tensor:
"""Return the weighted output bias of the linear atomic model."""
# TODO add get_out_bias for linear atomic model
raise NotImplementedError

Check warning on line 311 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L311

Added line #L311 was not covered by tests

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
# tricky...
Expand Down
23 changes: 18 additions & 5 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,24 @@
torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1])
)

def change_out_bias(
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
# need to implement
pass
def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None:
"""
Modify the output bias for the atomic model.

Parameters
----------
out_bias : torch.Tensor
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""
self.bias_atom_e = out_bias + self.bias_atom_e if add else out_bias

Check warning on line 250 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L250

Added line #L250 was not covered by tests

def get_out_bias(self) -> torch.Tensor:
"""Return the output bias of the atomic model."""
return self.bias_atom_e

Check warning on line 254 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L254

Added line #L254 was not covered by tests

def forward_atomic(
self,
Expand Down
Loading