diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 64dc0f9ed9..0ef95b2e1b 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -399,11 +399,11 @@ def get_lr(lr_params): self.lr_exp = get_lr(config["learning_rate"]) # JIT - # if JIT: - # raise NotImplementedError( - # "JIT is not supported yet when training with Paddle" - # ) - # self.model = paddle.jit.to_static(self.model) + if JIT: + raise NotImplementedError( + "JIT is not supported yet when training with Paddle" + ) + self.model = paddle.jit.to_static(self.model) # Model Wrapper self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params) @@ -633,7 +633,7 @@ def warm_up_linear(step, warmup_steps): self.profiling_file = training_params.get("profiling_file", "timeline.json") def run(self): - if JIT: + if CINN: from paddle import ( jit, static, @@ -644,7 +644,10 @@ def run(self): self.wrapper.forward = jit.to_static( full_graph=True, build_strategy=build_strategy )(self.wrapper.forward) - log.info(f"{'*' * 20} Using Jit {'*' * 20}") + log.info( + "Enable CINN during training, there may be some additional " + "compilation time in the first traning step." + ) fout = ( open( @@ -922,8 +925,8 @@ def log_loss_valid(_task_key="Default"): else: model_key = "Default" step(step_id, model_key) - # if JIT: - # break + if JIT: + break if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0): if not self.multi_task: diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index 27f5b2a479..8c52f6b143 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -59,6 +59,13 @@ def to_bool(flag: int | bool | str) -> bool: JIT = to_bool(os.environ.get("JIT", False)) CINN = to_bool(os.environ.get("CINN", False)) +if CINN: + assert paddle.device.is_compiled_with_cinn(), ( + "CINN is set to True, but PaddlePaddle is not compiled with CINN support. " + "Ensure that your PaddlePaddle installation supports CINN by checking your " + "installation or recompiling with CINN enabled." + ) + CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True