Skip to content

Commit

Permalink
feat(pt): Add command to check the available model branches in multi-…
Browse files Browse the repository at this point in the history
…task pre-trained model(Issue #3742) (#3796)

Solve #3742
1. Situation one(The right way to use it):
`dp --pt show multitask_model.pt model-branch type-map descriptor
fitting-net`
`[2024-05-22 10:38:16,678] DEEPMD INFO    This is a multitask model`
`[2024-05-22 10:38:16,678] DEEPMD INFO Available model branches are
['MPtraj_v026_01-mix-Utype', 'MPtraj_v026_02-mix-Utype',
'MPtraj_v026_03-mix-Utype', 'MPtraj_v026_04-mix-Utype',
'MPtraj_v026_05-mix-Utype', 'MPtraj_v026_06-mix-Utype',
'MPtraj_v026_07-mix-Utype', 'MPtraj_v026_08-mix-Utype',
'MPtraj_v026_09-mix-Utype', 'MPtraj_v026_10-mix-Utype',
'MPtraj_v026_11-mix-Utype']`
`[2024-05-22 10:38:16,679] DEEPMD INFO The type_map of branch
MPtraj_v026_01-mix-Utype is ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O',
'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca',
'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge',
'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru',
'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba',
'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er',
'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg',
'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U',
'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf',
'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv',
'Ts', 'Og', 'Co_U', 'Cr_U', 'Fe_U', 'Mn_U', 'Mo_U', 'Ni_U', 'V_U',
'W_U']`
(skip other branches' output)
`[2024-05-22 10:38:16,679] DEEPMD INFO The descriptor parameter of
branch MPtraj_v026_04-mix-Utype is {'type': 'dpa2', 'repinit':
{'tebd_dim': 256, 'rcut': 9.0, 'rcut_smth': 8.0, 'nsel': 120, 'neuron':
[25, 50, 100], 'axis_neuron': 12, 'activation_function': 'tanh'},
'repformer': {'rcut': 4.0, 'rcut_smth': 3.5, 'nsel': 40, 'nlayers': 12,
'g1_dim': 128, 'g2_dim': 32, 'attn2_hidden': 32, 'attn2_nhead': 4,
'attn1_hidden': 128, 'attn1_nhead': 4, 'axis_neuron': 4,
'activation_function': 'tanh', 'update_h2': False, 'update_g1_has_conv':
True, 'update_g1_has_grrg': True, 'update_g1_has_drrd': True,
'update_g1_has_attn': True, 'update_g2_has_g1g1': False,
'update_g2_has_attn': True, 'update_style': 'res_residual',
'update_residual': 0.01, 'update_residual_init': 'norm',
'attn2_has_gate': True}, 'add_tebd_to_repinit_out': False}`
(skip other branches' output)
`[2024-05-22 10:38:16,679] DEEPMD INFO The fitting_net parameter of
branch MPtraj_v026_01-mix-Utype is {'neuron': [240, 240, 240],
'activation_function': 'tanh', 'resnet_dt': True, 'seed': 1, '_comment':
" that's all"}`
(skip other branches' output)

2. Situation two (`singletask_model.pt` is not a multi-task pre-trained
model)
`dp --pt show singletask_model.pt model-branch type-map descriptor
fitting-net`
`[2024-05-22 10:43:11,642] DEEPMD INFO    This is a singletask model`
`RuntimeError: The 'model-branch' option requires a multitask model. The
provided model does not meet this criterion.`

3. Situation three(using tf backend)
`dp show multitask_model.pt model-branch`
`RuntimeError: unknown command list-model-branch`

4. Frozen model file with a .pth extension are used in the same way as
checkpoint file with a .pt extension.
`dp --pt show frozen_model.pth type-map descriptor fitting-net`
`[2024-05-22 10:46:26,365] DEEPMD INFO    This is a singletask model`
`[2024-05-22 10:46:26,365] DEEPMD INFO The type_map is ['H', 'He', 'Li',
'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S',
'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni',
'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr',
'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te',
'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd',
'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os',
'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra',
'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm',
'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn',
'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og', 'Co_U', 'Cr_U', 'Fe_U', 'Mn_U',
'Mo_U', 'Ni_U', 'V_U', 'W_U']`
`[2024-05-22 10:46:26,365] DEEPMD INFO The descriptor parameter is
{'type': 'dpa2', 'repinit': {'tebd_dim': 256, 'rcut': 9.0, 'rcut_smth':
8.0, 'nsel': 120, 'neuron': [25, 50, 100], 'axis_neuron': 12,
'activation_function': 'tanh'}, 'repformer': {'rcut': 4.0, 'rcut_smth':
3.5, 'nsel': 40, 'nlayers': 12, 'g1_dim': 128, 'g2_dim': 32,
'attn2_hidden': 32, 'attn2_nhead': 4, 'attn1_hidden': 128,
'attn1_nhead': 4, 'axis_neuron': 4, 'activation_function': 'tanh',
'update_h2': False, 'update_g1_has_conv': True, 'update_g1_has_grrg':
True, 'update_g1_has_drrd': True, 'update_g1_has_attn': True,
'update_g2_has_g1g1': False, 'update_g2_has_attn': True, 'update_style':
'res_residual', 'update_residual': 0.01, 'update_residual_init': 'norm',
'attn2_has_gate': True}, 'add_tebd_to_repinit_out': False}`
`[2024-05-22 10:46:26,365] DEEPMD INFO The fitting_net parameter is
{'neuron': [240, 240, 240], 'activation_function': 'tanh', 'resnet_dt':
True, 'seed': 1, '_comment': " that's all"}`

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced `--list-model-branch` command-line argument for listing
model branches of a multitask pretrained model.
- Added functionality to display model information based on specified
attributes.

- **Documentation**
- Updated documentation to include new command `--list-model-branch` for
checking available model branches in a multitask pre-trained model.

- **Tests**
- Added test cases for single-task and multi-task models, including
model configurations, training, and displaying model information for
checkpointed and frozen models.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Chenqqian Zhang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Duo <[email protected]>
  • Loading branch information
5 people authored May 29, 2024
1 parent d4cceda commit a5d8a21
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 4 deletions.
24 changes: 24 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,29 @@ def main_parser() -> argparse.ArgumentParser:
)
parser_convert_backend.add_argument("INPUT", help="The input model file.")
parser_convert_backend.add_argument("OUTPUT", help="The output model file.")

# * show model ******************************************************************
parser_show = subparsers.add_parser(
"show",
parents=[parser_log],
help="(Supported backend: PyTorch) Show the information of a model",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp --pt show model.pt model-branch type-map descriptor fitting-net
dp --pt show frozen_model.pth type-map descriptor fitting-net
"""
),
)
parser_show.add_argument(
"INPUT", help="The input checkpoint file or frozen model file"
)
parser_show.add_argument(
"ATTRIBUTES",
choices=["model-branch", "type-map", "descriptor", "fitting-net"],
nargs="+",
)
return parser


Expand Down Expand Up @@ -802,6 +825,7 @@ def main():
"compress",
"convert-from",
"train-nvnmd",
"show",
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
elif args.command is None:
Expand Down
66 changes: 66 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
from deepmd.pt.train import (
training,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.dataloader import (
DpLoaderSet,
)
Expand Down Expand Up @@ -297,6 +300,67 @@ def freeze(FLAGS):
)


def show(FLAGS):
if FLAGS.INPUT.split(".")[-1] == "pt":
state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
if "model" in state_dict:
state_dict = state_dict["model"]
model_params = state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.split(".")[-1] == "pth":
model_params_string = torch.jit.load(
FLAGS.INPUT, map_location=env.DEVICE
).model_def_script
model_params = json.loads(model_params_string)
else:
raise RuntimeError(
"The model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pth extension"
)
model_is_multi_task = "model_dict" in model_params
log.info("This is a multitask model") if model_is_multi_task else log.info(
"This is a singletask model"
)

if "model-branch" in FLAGS.ATTRIBUTES:
# The model must be multitask mode
if not model_is_multi_task:
raise RuntimeError(
"The 'model-branch' option requires a multitask model."
" The provided model does not meet this criterion."
)
model_branches = list(model_params["model_dict"].keys())
log.info(f"Available model branches are {model_branches}")
if "type-map" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
type_map = model_params["model_dict"][branch]["type_map"]
log.info(f"The type_map of branch {branch} is {type_map}")
else:
type_map = model_params["type_map"]
log.info(f"The type_map is {type_map}")
if "descriptor" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
descriptor = model_params["model_dict"][branch]["descriptor"]
log.info(f"The descriptor parameter of branch {branch} is {descriptor}")
else:
descriptor = model_params["descriptor"]
log.info(f"The descriptor parameter is {descriptor}")
if "fitting-net" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
fitting_net = model_params["model_dict"][branch]["fitting_net"]
log.info(
f"The fitting_net parameter of branch {branch} is {fitting_net}"
)
else:
fitting_net = model_params["fitting_net"]
log.info(f"The fitting_net parameter is {fitting_net}")


@record
def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if not isinstance(args, argparse.Namespace):
Expand All @@ -319,6 +383,8 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
freeze(FLAGS)
elif FLAGS.command == "show":
show(FLAGS)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
9 changes: 5 additions & 4 deletions doc/train/finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ $ dp --pt train input.json --finetune multitask_pretrained.pt --model-branch CHO
```

:::{note}
To check the available model branches, you can typically refer to the documentation of the pre-trained model.
If you're still unsure about the available branches, you can try inputting an arbitrary branch name.
This will prompt an error message that displays a list of all the available model branches.
One can check the available model branches in multi-task pre-trained model by refering to the documentation of the pre-trained model or by using the following command:

```bash
$ dp --pt show multitask_pretrained.pt model-branch
```

Please note that this feature will be improved in the upcoming version to provide a more user-friendly experience.
:::

This command will start fine-tuning based on the pre-trained model's descriptor and the selected branch's fitting net.
Expand Down
29 changes: 29 additions & 0 deletions source/tests/pt/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.pt.entrypoints.main import (
main,
)


def run_dp(cmd: str) -> int:
"""Run DP directly from the entry point instead of the subprocess.
It is quite slow to start DeePMD-kit with subprocess.
Parameters
----------
cmd : str
The command to run.
Returns
-------
int
Always returns 0.
"""
cmds = cmd.split()
if cmds[0] == "dp":
cmds = cmds[1:]
else:
raise RuntimeError("The command is not dp")

main(cmds)
return 0
199 changes: 199 additions & 0 deletions source/tests/pt/test_dp_show.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import io
import json
import os
import shutil
import unittest
from contextlib import (
redirect_stderr,
)
from copy import (
deepcopy,
)
from pathlib import (
Path,
)

from deepmd.pt.entrypoints.main import (
get_trainer,
)
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)

from .common import (
run_dp,
)
from .model.test_permutation import (
model_se_e2_a,
)


class TestSingleTaskModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
data_file = [str(Path(__file__).parent / "water/data/single")]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_se_e2_a)
self.config["model"]["type_map"] = ["O", "H", "Au"]
trainer = get_trainer(deepcopy(self.config))
trainer.run()
run_dp("dp --pt freeze")

def test_checkpoint(self):
INPUT = "model.pt"
ATTRIBUTES = "type-map descriptor fitting-net"
with redirect_stderr(io.StringIO()) as f:
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")
results = f.getvalue().split("\n")[:-1]
assert "This is a singletask model" in results[-4]
assert "The type_map is ['O', 'H', 'Au']" in results[-3]
assert (
"{'type': 'se_e2_a'" and "'sel': [46, 92, 4]" and "'rcut': 4.0"
) in results[-2]
assert (
"The fitting_net parameter is {'neuron': [24, 24, 24], 'resnet_dt': True, 'seed': 1}"
in results[-1]
)

def test_frozen_model(self):
INPUT = "frozen_model.pth"
ATTRIBUTES = "type-map descriptor fitting-net"
with redirect_stderr(io.StringIO()) as f:
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")
results = f.getvalue().split("\n")[:-1]
assert "This is a singletask model" in results[-4]
assert "The type_map is ['O', 'H', 'Au']" in results[-3]
assert (
"{'type': 'se_e2_a'" and "'sel': [46, 92, 4]" and "'rcut': 4.0"
) in results[-2]
assert (
"The fitting_net parameter is {'neuron': [24, 24, 24], 'resnet_dt': True, 'seed': 1}"
in results[-1]
)

def test_checkpoint_error(self):
INPUT = "model.pt"
ATTRIBUTES = "model-branch type-map descriptor fitting-net"
with self.assertRaisesRegex(
RuntimeError, "The 'model-branch' option requires a multitask model"
):
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith("pt"):
os.remove(f)
if f in ["lcurve.out", "frozen_model.pth", "output.txt", "checkpoint"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)


class TestMultiTaskModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/multitask.json")
with open(input_json) as f:
self.config = json.load(f)
self.config["model"]["shared_dict"]["my_descriptor"] = model_se_e2_a[
"descriptor"
]
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.stat_files = "se_e2_a"
os.makedirs(self.stat_files, exist_ok=True)
self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = (
data_file
)
self.config["training"]["data_dict"]["model_1"]["validation_data"][
"systems"
] = data_file
self.config["training"]["data_dict"]["model_1"]["stat_file"] = (
f"{self.stat_files}/model_1"
)
self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = (
data_file
)
self.config["training"]["data_dict"]["model_2"]["validation_data"][
"systems"
] = data_file
self.config["training"]["data_dict"]["model_2"]["stat_file"] = (
f"{self.stat_files}/model_2"
)
self.config["model"]["model_dict"]["model_1"]["fitting_net"] = {
"neuron": [1, 2, 3],
"seed": 678,
}
self.config["model"]["model_dict"]["model_2"]["fitting_net"] = {
"neuron": [9, 8, 7],
"seed": 1111,
}
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.origin_config = deepcopy(self.config)
self.config["model"], self.shared_links = preprocess_shared_params(
self.config["model"]
)
trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links)
trainer.run()
run_dp("dp --pt freeze --head model_1")

def test_checkpoint(self):
INPUT = "model.ckpt.pt"
ATTRIBUTES = "model-branch type-map descriptor fitting-net"
with redirect_stderr(io.StringIO()) as f:
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")
results = f.getvalue().split("\n")[:-1]
assert "This is a multitask model" in results[-8]
assert "Available model branches are ['model_1', 'model_2']" in results[-7]
assert "The type_map of branch model_1 is ['O', 'H', 'B']" in results[-6]
assert "The type_map of branch model_2 is ['O', 'H', 'B']" in results[-5]
assert (
"model_1"
and "'type': 'se_e2_a'"
and "'sel': [46, 92, 4]"
and "'rcut_smth': 0.5"
) in results[-4]
assert (
"model_2"
and "'type': 'se_e2_a'"
and "'sel': [46, 92, 4]"
and "'rcut_smth': 0.5"
) in results[-3]
assert (
"The fitting_net parameter of branch model_1 is {'neuron': [1, 2, 3], 'seed': 678}"
in results[-2]
)
assert (
"The fitting_net parameter of branch model_2 is {'neuron': [9, 8, 7], 'seed': 1111}"
in results[-1]
)

def test_frozen_model(self):
INPUT = "frozen_model.pth"
ATTRIBUTES = "type-map descriptor fitting-net"
with redirect_stderr(io.StringIO()) as f:
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")
results = f.getvalue().split("\n")[:-1]
assert "This is a singletask model" in results[-4]
assert "The type_map is ['O', 'H', 'B']" in results[-3]
assert (
"'type': 'se_e2_a'" and "'sel': [46, 92, 4]" and "'rcut_smth': 0.5"
) in results[-2]
assert (
"The fitting_net parameter is {'neuron': [1, 2, 3], 'seed': 678}"
in results[-1]
)

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith("pt"):
os.remove(f)
if f in ["lcurve.out", "frozen_model.pth", "checkpoint", "output.txt"]:
os.remove(f)
if f in ["stat_files", self.stat_files]:
shutil.rmtree(f)

0 comments on commit a5d8a21

Please sign in to comment.