diff --git a/deepmd/main.py b/deepmd/main.py index 5dab029d83..870a04a088 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -240,7 +240,7 @@ def main_parser() -> argparse.ArgumentParser: "--output", type=str, default="out.json", - help="(Supported backend: TensorFlow) The output file of the parameters used in training.", + help="The output file of the parameters used in training.", ) parser_train.add_argument( "--skip-neighbor-stat", diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 023bc5305e..736e8dde09 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -78,10 +78,6 @@ def get_trainer( shared_links=None, ): multi_task = "model_dict" in config.get("model", {}) - # argcheck - if not multi_task: - config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") - config = normalize(config) # Initialize DDP local_rank = os.environ.get("LOCAL_RANK") @@ -236,6 +232,11 @@ def train(FLAGS): if multi_task: config["model"], shared_links = preprocess_shared_params(config["model"]) + # argcheck + if not multi_task: + config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") + config = normalize(config) + # do neighbor stat if not FLAGS.skip_neighbor_stat: log.info( @@ -257,6 +258,9 @@ def train(FLAGS): fake_global_jdata, config["model"]["model_dict"][model_item] ) + with open(FLAGS.output, "w") as fp: + json.dump(config, fp, indent=4) + trainer = get_trainer( config, FLAGS.init_model,