Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pd: support dpa2 #4418

Merged
merged 81 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from 77 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
48f77f3
add core modules of paddle backend and water/se_e2_a example
HydrogenSulfate Nov 2, 2024
2082a59
add paddle code in consistent test
HydrogenSulfate Nov 2, 2024
2ae45b8
clean env and training
HydrogenSulfate Nov 2, 2024
7f03a04
add more test files
HydrogenSulfate Nov 2, 2024
4d1c44c
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 2, 2024
72c9b4e
fix pt->pd
HydrogenSulfate Nov 2, 2024
3b1c348
update test_python.yml
HydrogenSulfate Nov 2, 2024
a46dcb5
restore .pre-commit-config.yaml
HydrogenSulfate Nov 3, 2024
90f9ff9
remove redundant file
HydrogenSulfate Nov 3, 2024
0a6baa6
Skip bfloat16 for some cases
HydrogenSulfate Nov 3, 2024
4b77e55
enable prim by default in unitest
HydrogenSulfate Nov 3, 2024
6e139a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2024
9437957
fix env code
HydrogenSulfate Nov 5, 2024
f1d762f
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 5, 2024
8534597
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 6, 2024
c22b45d
update test_ener.py
HydrogenSulfate Nov 6, 2024
39842ff
add missing pd_class
HydrogenSulfate Nov 6, 2024
07cd98e
use paddle Tensor instead of numpy array in pd/test_auto_batch_size.p…
HydrogenSulfate Nov 6, 2024
bb2d547
add training test and remove ase_calc.py
HydrogenSulfate Nov 7, 2024
5fb6d8e
add training test and remove ase_calc.py
HydrogenSulfate Nov 7, 2024
91066f8
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 7, 2024
90c9c03
upload missing json
HydrogenSulfate Nov 7, 2024
eb7384e
restore pt/test_auto_batch_size.py
HydrogenSulfate Nov 7, 2024
9faf54f
rerun CI for network problem
HydrogenSulfate Nov 7, 2024
4e3a121
add multitask unitest
HydrogenSulfate Nov 7, 2024
18333ab
add more unitest
HydrogenSulfate Nov 7, 2024
f9c6da8
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 7, 2024
3fd979d
remove redundant file and fix typo
HydrogenSulfate Nov 7, 2024
5922e84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
f5a17a9
update unitest
HydrogenSulfate Nov 8, 2024
8bea1bf
delete record
HydrogenSulfate Nov 8, 2024
8a7875f
remove more unused code and files
HydrogenSulfate Nov 8, 2024
df9f887
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 8, 2024
67b81e1
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 8, 2024
b0bf733
Merge branch 'add_paddle_backend_core_and_water_se_e2_a' of https://g…
HydrogenSulfate Nov 8, 2024
71a3c0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 8, 2024
ede5047
remove redundant annotations
HydrogenSulfate Nov 8, 2024
d11bf4d
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 8, 2024
b7a8cec
add nvtx profiler code in training, which is more accurate and detailed
HydrogenSulfate Nov 8, 2024
7567cf8
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 8, 2024
416fec8
update code as devel and fix typo
HydrogenSulfate Nov 9, 2024
1c0161c
fix pth -> json
HydrogenSulfate Nov 9, 2024
02a6f84
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 9, 2024
3354e5c
update unitest and training
HydrogenSulfate Nov 9, 2024
0d3f8cf
install paddle when test_cuda
HydrogenSulfate Nov 9, 2024
18215ff
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 9, 2024
859b94d
fix unitest
HydrogenSulfate Nov 9, 2024
74ee1c2
add eta in logging message for convenient
HydrogenSulfate Nov 9, 2024
f176309
remove hybrid code and enable one unitest
HydrogenSulfate Nov 9, 2024
4935e7b
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 9, 2024
fac51d3
add pd/__init__.py
HydrogenSulfate Nov 11, 2024
d3ca1f0
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 11, 2024
db1cd76
fix enable_prim
HydrogenSulfate Nov 11, 2024
36512fd
remove unused layernorm
HydrogenSulfate Nov 11, 2024
351bf7a
update dpa1 code
HydrogenSulfate Nov 25, 2024
01d4179
Merge branch 'devel' into add_dpa1
HydrogenSulfate Nov 25, 2024
701926a
update dpa2 code
HydrogenSulfate Nov 25, 2024
bc1cb38
update code of dpa1
HydrogenSulfate Nov 28, 2024
b4bc9db
Merge branch 'devel' into add_dpa1
HydrogenSulfate Nov 28, 2024
c944b82
restore decomp to paddle function
HydrogenSulfate Nov 28, 2024
4c925f9
remove redundant files
HydrogenSulfate Nov 28, 2024
dd3191a
update unitest and codes
HydrogenSulfate Nov 29, 2024
7df0e2f
fix
HydrogenSulfate Nov 29, 2024
ac479ed
update code
HydrogenSulfate Nov 29, 2024
3e64196
Merge branch 'devel' into add_dpa1
HydrogenSulfate Nov 29, 2024
56e079c
update typos
HydrogenSulfate Nov 29, 2024
3d70e7c
update code
HydrogenSulfate Nov 29, 2024
8b5b4a8
fix coverage
HydrogenSulfate Nov 29, 2024
e74d272
update consistent check of dpa1
HydrogenSulfate Nov 30, 2024
1a79a06
Merge branch 'add_dpa1' into add_dpa2_v2
HydrogenSulfate Dec 1, 2024
4520e97
add unittests code of dpa2 and replace several decomp. API to paddle.…
HydrogenSulfate Dec 1, 2024
79ed0f0
fix typos
HydrogenSulfate Dec 1, 2024
ed80c6d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 2, 2024
89d82f9
update UT code in test_se_t_tebd
HydrogenSulfate Dec 2, 2024
984b242
Merge branch 'add_dpa2_v22' into add_dpa2_v2
HydrogenSulfate Dec 2, 2024
617a258
update __init__
HydrogenSulfate Dec 2, 2024
6e5ebb3
solve code QL
HydrogenSulfate Dec 3, 2024
3cfb90f
Merge branch 'devel' into add_dpa2_v2
HydrogenSulfate Dec 18, 2024
575726a
fix unitest and typo
HydrogenSulfate Dec 19, 2024
befc111
Merge branch 'devel' into add_dpa2_v2
HydrogenSulfate Dec 23, 2024
abeae6d
update reference in dpa2.py
HydrogenSulfate Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@
# Initialize DDP
world_size = dist.get_world_size()
if world_size > 1:
assert paddle.version.nccl() != "0"
fleet.init(is_collective=True)

Check warning on line 96 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L95-L96

Added lines #L95 - L96 were not covered by tests

def prepare_trainer_input_single(
model_params_single, data_dict_single, rank=0, seed=None
Expand All @@ -111,15 +111,15 @@
# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
if rank != 0:
stat_file_path_single = None

Check warning on line 114 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L114

Added line #L114 was not covered by tests
elif stat_file_path_single is not None:
if not Path(stat_file_path_single).exists():
if stat_file_path_single.endswith((".h5", ".hdf5")):
with h5py.File(stat_file_path_single, "w") as f:
pass

Check warning on line 119 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L116-L119

Added lines #L116 - L119 were not covered by tests
else:
Path(stat_file_path_single).mkdir()
stat_file_path_single = DPPath(stat_file_path_single, "a")

Check warning on line 122 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L121-L122

Added lines #L121 - L122 were not covered by tests

# validation and training data
# avoid the same batch sequence among devices
Expand Down Expand Up @@ -160,9 +160,9 @@
seed=data_seed,
)
else:
train_data, validation_data, stat_file_path = {}, {}, {}
for model_key in config["model"]["model_dict"]:
(

Check warning on line 165 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L163-L165

Added lines #L163 - L165 were not covered by tests
train_data[model_key],
validation_data[model_key],
stat_file_path[model_key],
Expand Down Expand Up @@ -194,24 +194,24 @@

def is_built_with_cuda(self) -> bool:
"""Check if the backend is built with CUDA."""
return paddle.device.is_compiled_with_cuda()

Check warning on line 197 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L197

Added line #L197 was not covered by tests

def is_built_with_rocm(self) -> bool:
"""Check if the backend is built with ROCm."""
return paddle.device.is_compiled_with_rocm()

Check warning on line 201 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L201

Added line #L201 was not covered by tests

def get_compute_device(self) -> str:
"""Get Compute device."""
return str(DEVICE)

Check warning on line 205 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L205

Added line #L205 was not covered by tests

def get_ngpus(self) -> int:
"""Get the number of GPUs."""
return paddle.device.cuda.device_count()

Check warning on line 209 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L209

Added line #L209 was not covered by tests

def get_backend_info(self) -> dict:
"""Get backend information."""
op_info = {}
return {

Check warning on line 214 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L213-L214

Added lines #L213 - L214 were not covered by tests
"Backend": "Paddle",
"PD ver": f"v{paddle.__version__}-g{paddle.version.commit[:11]}",
"Enable custom OP": False,
Expand All @@ -230,85 +230,85 @@
use_pretrain_script: bool = False,
force_load: bool = False,
output: str = "out.json",
):
) -> None:
log.info("Configuration path: %s", input_file)
SummaryPrinter()()
with open(input_file) as fin:
config = json.load(fin)

Check warning on line 237 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L234-L237

Added lines #L234 - L237 were not covered by tests
# ensure suffix, as in the command line help, we say "path prefix of checkpoint files"
if init_model is not None and not init_model.endswith(".pd"):
init_model += ".pd"
if restart is not None and not restart.endswith(".pd"):
restart += ".pd"

Check warning on line 242 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L239-L242

Added lines #L239 - L242 were not covered by tests

# update multitask config
multi_task = "model_dict" in config["model"]
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

Check warning on line 248 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L245-L248

Added lines #L245 - L248 were not covered by tests
# handle the special key
assert (

Check warning on line 250 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L250

Added line #L250 was not covered by tests
"RANDOM" not in config["model"]["model_dict"]
), "Model name can not be 'RANDOM' in multi-task mode!"

# update fine-tuning config
finetune_links = None
if finetune is not None:
config["model"], finetune_links = get_finetune_rules(

Check warning on line 257 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L255-L257

Added lines #L255 - L257 were not covered by tests
finetune,
config["model"],
model_branch=model_branch,
change_model_params=use_pretrain_script,
)
# update init_model or init_frz_model config if necessary
if (init_model is not None or init_frz_model is not None) and use_pretrain_script:
if init_model is not None:
init_state_dict = paddle.load(init_model)
if "model" in init_state_dict:
init_state_dict = init_state_dict["model"]
config["model"] = init_state_dict["_extra_state"]["model_params"]

Check warning on line 269 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L264-L269

Added lines #L264 - L269 were not covered by tests
else:
raise NotImplementedError("init_frz_model is not supported yet")

Check warning on line 271 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L271

Added line #L271 was not covered by tests

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)

Check warning on line 275 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L274-L275

Added lines #L274 - L275 were not covered by tests

# do neighbor stat
min_nbor_dist = None
if not skip_neighbor_stat:
log.info(

Check warning on line 280 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L278-L280

Added lines #L278 - L280 were not covered by tests
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)

if not multi_task:
type_map = config["model"].get("type_map")
train_data = get_data(

Check warning on line 286 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L284-L286

Added lines #L284 - L286 were not covered by tests
config["training"]["training_data"], 0, type_map, None
)
config["model"], min_nbor_dist = BaseModel.update_sel(

Check warning on line 289 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L289

Added line #L289 was not covered by tests
train_data, type_map, config["model"]
)
else:
min_nbor_dist = {}
for model_item in config["model"]["model_dict"]:
type_map = config["model"]["model_dict"][model_item].get("type_map")
train_data = get_data(

Check warning on line 296 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L293-L296

Added lines #L293 - L296 were not covered by tests
config["training"]["data_dict"][model_item]["training_data"],
0,
type_map,
None,
)
config["model"]["model_dict"][model_item], min_nbor_dist[model_item] = (

Check warning on line 302 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L302

Added line #L302 was not covered by tests
BaseModel.update_sel(
train_data, type_map, config["model"]["model_dict"][model_item]
)
)

with open(output, "w") as fp:
json.dump(config, fp, indent=4)

Check warning on line 309 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L308-L309

Added lines #L308 - L309 were not covered by tests

trainer = get_trainer(

Check warning on line 311 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L311

Added line #L311 was not covered by tests
config,
init_model,
restart,
Expand All @@ -319,41 +319,49 @@
finetune_links=finetune_links,
)
# save min_nbor_dist
if min_nbor_dist is not None:
if not multi_task:
trainer.model.min_nbor_dist = min_nbor_dist
trainer.model.min_nbor_dist = paddle.to_tensor(

Check warning on line 324 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L322-L324

Added lines #L322 - L324 were not covered by tests
min_nbor_dist,
dtype=paddle.float64,
place=DEVICE,
)
else:
for model_item in min_nbor_dist:
trainer.model[model_item].min_nbor_dist = min_nbor_dist[model_item]
trainer.model[model_item].min_nbor_dist = paddle.to_tensor(

Check warning on line 331 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L330-L331

Added lines #L330 - L331 were not covered by tests
min_nbor_dist[model_item],
dtype=paddle.float64,
place=DEVICE,
)
trainer.run()

Check warning on line 336 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L336

Added line #L336 was not covered by tests


def freeze(
model: str,
output: str = "frozen_model.json",
head: Optional[str] = None,
):
) -> None:
paddle.set_flags(

Check warning on line 344 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L344

Added line #L344 was not covered by tests
{
"FLAGS_save_cf_stack_op": 1,
"FLAGS_prim_enable_dynamic": 1,
"FLAGS_enable_pir_api": 1,
}
)
model = inference.Tester(model, head=head).model
model.eval()
from paddle.static import (

Check warning on line 353 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L351-L353

Added lines #L351 - L353 were not covered by tests
InputSpec,
)

"""

Check warning on line 357 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L357

Added line #L357 was not covered by tests
** coord [None, natoms, 3] paddle.float64
** atype [None, natoms] paddle.int64
** nlist [None, natoms, nnei] paddle.int32
"""
# NOTE: 'FLAGS_save_cf_stack_op', 'FLAGS_prim_enable_dynamic' and
# 'FLAGS_enable_pir_api' shoule be enabled when freezing model.
jit_model = paddle.jit.to_static(

Check warning on line 364 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L364

Added line #L364 was not covered by tests
model.forward_lower,
full_graph=True,
input_spec=[
Expand All @@ -362,14 +370,14 @@
InputSpec([-1, -1, -1], dtype="int32", name="nlist"),
],
)
if output.endswith(".json"):
output = output[:-5]
paddle.jit.save(

Check warning on line 375 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L373-L375

Added lines #L373 - L375 were not covered by tests
jit_model,
path=output,
skip_prune_program=True,
)
log.info(

Check warning on line 380 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L380

Added line #L380 was not covered by tests
f"Paddle inference model has been exported to: {output}.json and {output}.pdiparams"
)

Expand All @@ -383,27 +391,27 @@
numb_batch: int = 0,
model_branch: Optional[str] = None,
output: Optional[str] = None,
):
) -> None:
if input_file.endswith(".pd"):
old_state_dict = paddle.load(input_file)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
model_params = model_state_dict["_extra_state"]["model_params"]
else:
raise RuntimeError(

Check warning on line 400 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L400

Added line #L400 was not covered by tests
"Paddle now do not support change bias directly from a freezed model file"
"Please provided a checkpoint file with a .pd extension"
)
multi_task = "model_dict" in model_params
bias_adjust_mode = "change-by-statistic" if mode == "change" else "set-by-statistic"
if multi_task:
assert (

Check warning on line 407 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L407

Added line #L407 was not covered by tests
model_branch is not None
), "For multitask model, the model branch must be set!"
assert model_branch in model_params["model_dict"], (

Check warning on line 410 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L410

Added line #L410 was not covered by tests
f"For multitask model, the model branch must be in the 'model_dict'! "
f"Available options are : {list(model_params['model_dict'].keys())}."
)
log.info(f"Changing out bias for model {model_branch}.")

Check warning on line 414 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L414

Added line #L414 was not covered by tests
model = training.get_model_for_wrapper(model_params)
type_map = (
model_params["type_map"]
Expand All @@ -415,7 +423,7 @@
wrapper = ModelWrapper(model)
wrapper.set_state_dict(old_state_dict["model"])
else:
raise NotImplementedError("Only support .pd file")

Check warning on line 426 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L426

Added line #L426 was not covered by tests

if bias_value is not None:
# use user-defined bias
Expand Down Expand Up @@ -468,7 +476,7 @@
if not multi_task:
model = updated_model
else:
model[model_branch] = updated_model

Check warning on line 479 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L479

Added line #L479 was not covered by tests

if input_file.endswith(".pd"):
output_path = (
Expand All @@ -479,18 +487,18 @@
old_state_dict["model"] = wrapper.state_dict()
old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"]
else:
old_state_dict = wrapper.state_dict()
old_state_dict["_extra_state"] = model_state_dict["_extra_state"]

Check warning on line 491 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L490-L491

Added lines #L490 - L491 were not covered by tests
paddle.save(old_state_dict, output_path)
else:
raise NotImplementedError("Only support .pd file now")

Check warning on line 494 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L494

Added line #L494 was not covered by tests

log.info(f"Saved model to {output_path}")


def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
if not isinstance(args, argparse.Namespace):
FLAGS = parse_args(args=args)

Check warning on line 501 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L501

Added line #L501 was not covered by tests
else:
FLAGS = args

Expand All @@ -503,7 +511,7 @@
log.info("DeePMD version: %s", __version__)

if FLAGS.command == "train":
train(

Check warning on line 514 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L514

Added line #L514 was not covered by tests
input_file=FLAGS.INPUT,
init_model=FLAGS.init_model,
restart=FLAGS.restart,
Expand All @@ -516,14 +524,14 @@
output=FLAGS.output,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
latest_ckpt_file = (checkpoint_path / "checkpoint").read_text()
FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file))

Check warning on line 530 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L527-L530

Added lines #L527 - L530 were not covered by tests
else:
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".json"))
freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head)

Check warning on line 534 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L532-L534

Added lines #L532 - L534 were not covered by tests
elif FLAGS.command == "change-bias":
change_bias(
input_file=FLAGS.INPUT,
Expand All @@ -536,8 +544,8 @@
output=FLAGS.output,
)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Check warning on line 547 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L547

Added line #L547 was not covered by tests


if __name__ == "__main__":
main()

Check warning on line 551 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L551

Added line #L551 was not covered by tests
6 changes: 1 addition & 5 deletions deepmd/pd/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
TaskLoss,
)
from deepmd.pd.utils import (
decomp,
env,
)
from deepmd.pd.utils.env import (
Expand Down Expand Up @@ -114,7 +113,7 @@
self.enable_atom_ener_coeff = enable_atom_ener_coeff
self.numb_generalized_coord = numb_generalized_coord
if self.has_gf and self.numb_generalized_coord < 1:
raise RuntimeError(

Check warning on line 116 in deepmd/pd/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/loss/ener.py#L116

Added line #L116 was not covered by tests
"When generalized force loss is used, the dimension of generalized coordinates should be larger than 0"
)
self.use_l1_all = use_l1_all
Expand Down Expand Up @@ -188,13 +187,13 @@
)
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
l1_ener_loss = F.l1_loss(

Check warning on line 190 in deepmd/pd/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/loss/ener.py#L190

Added line #L190 was not covered by tests
energy_pred.reshape([-1]),
energy_label.reshape([-1]),
reduction="sum",
)
loss += pref_e * l1_ener_loss
more_loss["mae_e"] = self.display_if_exist(

Check warning on line 196 in deepmd/pd/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/loss/ener.py#L195-L196

Added lines #L195 - L196 were not covered by tests
F.l1_loss(
energy_pred.reshape([-1]),
energy_label.reshape([-1]),
Expand All @@ -204,10 +203,10 @@
)
# more_loss['log_keys'].append('rmse_e')
if mae:
mae_e = paddle.mean(paddle.abs(energy_pred - energy_label)) * atom_norm
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
mae_e_all = paddle.mean(paddle.abs(energy_pred - energy_label))
more_loss["mae_e_all"] = self.display_if_exist(

Check warning on line 209 in deepmd/pd/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/loss/ener.py#L206-L209

Added lines #L206 - L209 were not covered by tests
mae_e_all.detach(), find_energy
)

Expand All @@ -224,10 +223,7 @@

if self.relative_f is not None:
force_label_3 = force_label.reshape([-1, 3])
# norm_f = force_label_3.norm(axis=1, keepdim=True) + self.relative_f
norm_f = (
decomp.norm(force_label_3, axis=1, keepdim=True) + self.relative_f
)
norm_f = force_label_3.norm(axis=1, keepdim=True) + self.relative_f
diff_f_3 = diff_f.reshape([-1, 3])
diff_f_3 = diff_f_3 / norm_f
diff_f = diff_f_3.reshape([-1])
Expand All @@ -245,15 +241,15 @@
rmse_f.detach(), find_force
)
else:
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none")
more_loss["mae_f"] = self.display_if_exist(

Check warning on line 245 in deepmd/pd/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/loss/ener.py#L244-L245

Added lines #L244 - L245 were not covered by tests
l1_force_loss.mean().detach(), find_force
)
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
loss += (pref_f * l1_force_loss).to(GLOBAL_PD_FLOAT_PRECISION)

Check warning on line 249 in deepmd/pd/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/loss/ener.py#L248-L249

Added lines #L248 - L249 were not covered by tests
if mae:
mae_f = paddle.mean(paddle.abs(diff_f))
more_loss["mae_f"] = self.display_if_exist(

Check warning on line 252 in deepmd/pd/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/loss/ener.py#L251-L252

Added lines #L251 - L252 were not covered by tests
mae_f.detach(), find_force
)

Expand Down Expand Up @@ -322,8 +318,8 @@
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
mae_v = paddle.mean(paddle.abs(diff_v)) * atom_norm
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)

Check warning on line 322 in deepmd/pd/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/loss/ener.py#L321-L322

Added lines #L321 - L322 were not covered by tests

if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label:
atom_ener = model_pred["atom_energy"]
Expand Down Expand Up @@ -405,7 +401,7 @@
)
)
if self.has_gf > 0:
coderabbitai[bot] marked this conversation as resolved.
Show resolved Hide resolved
label_requirement.append(

Check warning on line 404 in deepmd/pd/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/loss/ener.py#L404

Added line #L404 was not covered by tests
DataRequirementItem(
"drdq",
ndof=self.numb_generalized_coord * 3,
Expand All @@ -415,7 +411,7 @@
)
)
if self.enable_atom_ener_coeff:
label_requirement.append(

Check warning on line 414 in deepmd/pd/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/loss/ener.py#L414

Added line #L414 was not covered by tests
DataRequirementItem(
"atom_ener_coeff",
ndof=1,
Expand Down
38 changes: 34 additions & 4 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import functools
import logging
from typing import (
Expand Down Expand Up @@ -52,7 +51,7 @@
fitting,
type_map: list[str],
**kwargs,
):
) -> None:
super().__init__(type_map, **kwargs)
ntypes = len(type_map)
self.type_map = type_map
Expand Down Expand Up @@ -116,12 +115,12 @@

def set_eval_descriptor_hook(self, enable: bool) -> None:
"""Set the hook for evaluating descriptor and clear the cache for descriptor list."""
self.enable_eval_descriptor_hook = enable
self.eval_descriptor_list = []

Check warning on line 119 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L118-L119

Added lines #L118 - L119 were not covered by tests

def eval_descriptor(self) -> paddle.Tensor:
"""Evaluate the descriptor."""
return paddle.concat(self.eval_descriptor_list)

Check warning on line 123 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L123

Added line #L123 was not covered by tests

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand All @@ -144,7 +143,7 @@
Set the case embedding of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.fitting_net.set_case_embd(case_idx)

Check warning on line 146 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L146

Added line #L146 was not covered by tests

def mixed_types(self) -> bool:
"""If true, the model
Expand Down Expand Up @@ -179,7 +178,7 @@

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

Check warning on line 181 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L181

Added line #L181 was not covered by tests

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
Expand All @@ -201,7 +200,7 @@

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
Expand All @@ -212,6 +211,37 @@
obj = super().deserialize(data)
return obj

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call descriptor enable_compression().

Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
self.descriptor.enable_compression(

Check warning on line 237 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L237

Added line #L237 was not covered by tests
min_nbor_dist,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def forward_atomic(
self,
extended_coord,
Expand Down Expand Up @@ -258,7 +288,7 @@
)
assert descriptor is not None
if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor)

Check warning on line 291 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L291

Added line #L291 was not covered by tests
# energy, force
fit_ret = self.fitting_net(
descriptor,
Expand All @@ -278,7 +308,7 @@
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
) -> None:
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
Expand All @@ -297,19 +327,19 @@
if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)

Check warning on line 330 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L330

Added line #L330 was not covered by tests

@functools.lru_cache
def wrapped_sampler():
sampled = sampled_func()
if self.pair_excl is not None:
pair_exclude_types = self.pair_excl.get_exclude_types()
for sample in sampled:
sample["pair_exclude_types"] = list(pair_exclude_types)

Check warning on line 338 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L336-L338

Added lines #L336 - L338 were not covered by tests
if self.atom_excl is not None:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)

Check warning on line 342 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L340-L342

Added lines #L340 - L342 were not covered by tests
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
Expand All @@ -330,11 +360,11 @@
to the result of the model.
If returning an empty list, all atom types are selected.
"""
return self.fitting_net.get_sel_type()

Check warning on line 363 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L363

Added line #L363 was not covered by tests

def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

If False, the shape is (nframes, nloc, ndim).
"""
return False

Check warning on line 370 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L370

Added line #L370 was not covered by tests
20 changes: 20 additions & 0 deletions deepmd/pd/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,38 @@
from .descriptor import (
DescriptorBlock,
)
from .dpa1 import (
DescrptBlockSeAtten,
DescrptDPA1,
)
from .dpa2 import (
DescrptDPA2,
)
from .env_mat import (
prod_env_mat,
)
from .repformers import (
DescrptBlockRepformers,
)
from .se_a import (
DescrptBlockSeA,
DescrptSeA,
)
from .se_t_tebd import (
DescrptBlockSeTTebd,
DescrptSeTTebd,
)

__all__ = [
"BaseDescriptor",
"DescriptorBlock",
"DescrptBlockRepformers",
"DescrptBlockSeA",
"DescrptBlockSeAtten",
"DescrptBlockSeTTebd",
"DescrptDPA1",
"DescrptDPA2",
"DescrptSeA",
"DescrptSeTTebd",
"prod_env_mat",
]
Loading
Loading