From 0e6172da128b7ae84ac0b99829bbbb9aed56ed85 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 27 Dec 2024 15:39:13 +0800 Subject: [PATCH] refine code --- deepmd/pd/utils/serialization.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/deepmd/pd/utils/serialization.py b/deepmd/pd/utils/serialization.py index 973df9f6a0..a345c87a1d 100644 --- a/deepmd/pd/utils/serialization.py +++ b/deepmd/pd/utils/serialization.py @@ -35,12 +35,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None: """ if not model_file.endswith(".json"): raise ValueError("Paddle backend only supports converting .json file") - model = BaseModel.deserialize(data["model"]) + model: paddle.nn.Layer = BaseModel.deserialize(data["model"]) # JIT will happy in this way... - # model.model_def_script = json.dumps(data["model_def_script"]) if "min_nbor_dist" in data.get("@variables", {}): - model.min_nbor_dist = float(data["@variables"]["min_nbor_dist"]) - # model = paddle.jit.to_static(model) + model.register_buffer( + "buffer_min_nbor_dist", + paddle.to_tensor( + float(data["@variables"]["min_nbor_dist"]), + ), + ) paddle.set_flags( { "FLAGS_save_cf_stack_op": 1,