From 55f7ef66486a6122f710496b23267a9f4d5d33ad Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 27 Dec 2024 12:05:21 +0800 Subject: [PATCH 1/2] support CINN compiler for DPA2 example --- deepmd/pd/train/training.py | 47 +++++++++++++++++++++------------- deepmd/pd/utils/env.py | 51 +++++++++++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 25 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 0f3c7a9732..64dc0f9ed9 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -53,6 +53,8 @@ get_sampler_from_params, ) from deepmd.pd.utils.env import ( + CINN, + DEFAULT_PRECISION, DEVICE, JIT, NUM_WORKERS, @@ -397,11 +399,11 @@ def get_lr(lr_params): self.lr_exp = get_lr(config["learning_rate"]) # JIT - if JIT: - raise NotImplementedError( - "JIT is not supported yet when training with Paddle" - ) - self.model = paddle.jit.to_static(self.model) + # if JIT: + # raise NotImplementedError( + # "JIT is not supported yet when training with Paddle" + # ) + # self.model = paddle.jit.to_static(self.model) # Model Wrapper self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params) @@ -631,6 +633,19 @@ def warm_up_linear(step, warmup_steps): self.profiling_file = training_params.get("profiling_file", "timeline.json") def run(self): + if JIT: + from paddle import ( + jit, + static, + ) + + build_strategy = static.BuildStrategy() + build_strategy.build_cinn_pass: bool = CINN + self.wrapper.forward = jit.to_static( + full_graph=True, build_strategy=build_strategy + )(self.wrapper.forward) + log.info(f"{'*' * 20} Using Jit {'*' * 20}") + fout = ( open( self.disp_file, @@ -670,9 +685,11 @@ def step(_step_id, task_key="Default") -> None: cur_lr = _lr.value(_step_id) pref_lr = cur_lr self.optimizer.clear_grad(set_to_zero=False) - input_dict, label_dict, log_dict = self.get_data( - is_train=True, task_key=task_key - ) + + with nvprof_context(enable_profiling, "Fetching data"): + input_dict, label_dict, log_dict = self.get_data( + is_train=True, task_key=task_key + ) if SAMPLER_RECORD: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) @@ -686,7 +703,7 @@ def step(_step_id, task_key="Default") -> None: with nvprof_context(enable_profiling, "Forward pass"): model_pred, loss, more_loss = self.wrapper( **input_dict, - cur_lr=pref_lr, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), label=label_dict, task_key=task_key, ) @@ -745,7 +762,7 @@ def log_loss_valid(_task_key="Default"): return {} _, loss, more_loss = self.wrapper( **input_dict, - cur_lr=pref_lr, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), label=label_dict, task_key=_task_key, ) @@ -795,7 +812,7 @@ def log_loss_valid(_task_key="Default"): ) _, loss, more_loss = self.wrapper( **input_dict, - cur_lr=pref_lr, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), label=label_dict, task_key=_key, ) @@ -905,8 +922,8 @@ def log_loss_valid(_task_key="Default"): else: model_key = "Default" step(step_id, model_key) - if JIT: - break + # if JIT: + # break if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0): if not self.multi_task: @@ -961,10 +978,6 @@ def log_loss_valid(_task_key="Default"): / (elapsed_batch // self.disp_freq * self.disp_freq), ) - if JIT: - raise NotImplementedError( - "Paddle JIT saving during training is not supported yet." - ) log.info(f"Trained model has been saved to: {self.save_ckpt}") if fout: diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index e2abe9a6e5..27f5b2a479 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -32,7 +32,33 @@ paddle.device.set_device(DEVICE) -JIT = False + +def to_bool(flag: int | bool | str) -> bool: + if isinstance(flag, int): + if flag not in [0, 1]: + raise ValueError(f"flag must be either 0 or 1, but received {flag}") + return bool(flag) + + elif isinstance(flag, str): + flag = flag.lower() + if flag not in ["1", "0", "true", "false"]: + raise ValueError( + "flag must be either '0', '1', 'true', 'false', " + f"but received '{flag}'" + ) + return flag in ["1", "true"] + + elif isinstance(flag, bool): + return flag + + else: + raise ValueError( + f"flag must be either int, bool, or str, but received {type(flag).__name__}" + ) + + +JIT = to_bool(os.environ.get("JIT", False)) +CINN = to_bool(os.environ.get("CINN", False)) CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True @@ -138,14 +164,23 @@ def enable_prim(enable: bool = True): ] EAGER_COMP_OP_BLACK_LIST = list(set(EAGER_COMP_OP_BLACK_LIST)) - """Enable running program in primitive C++ API in eager/static mode.""" - from paddle.framework import ( - core, - ) + """Enable running program with primitive operators in eager/static mode.""" + if JIT: + # jit mode + paddle.framework.core._set_prim_all_enabled(enable) + if enable: + # No need to set a blacklist for now in JIT mode. + pass + else: + # eager mode + paddle.framework.core.set_prim_eager_enabled(enable) + if enable: + # Set a blacklist (i.e., disable several composite operators) in eager mode + # to enhance computational performance. + paddle.framework.core._set_prim_backward_blacklist( + *EAGER_COMP_OP_BLACK_LIST + ) - core.set_prim_eager_enabled(enable) - if enable: - paddle.framework.core._set_prim_backward_blacklist(*EAGER_COMP_OP_BLACK_LIST) log = logging.getLogger(__name__) log.info(f"{'Enable' if enable else 'Disable'} prim in eager and static mode.") From 7ca2a9ecd203b4c6088dd3d97c92e1cef348441a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 27 Dec 2024 13:17:34 +0800 Subject: [PATCH 2/2] refine CINN flag --- deepmd/pd/train/training.py | 25 ++++++++++++++++--------- deepmd/pd/utils/env.py | 9 ++++++++- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 64dc0f9ed9..a0328942e4 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -399,11 +399,11 @@ def get_lr(lr_params): self.lr_exp = get_lr(config["learning_rate"]) # JIT - # if JIT: - # raise NotImplementedError( - # "JIT is not supported yet when training with Paddle" - # ) - # self.model = paddle.jit.to_static(self.model) + if JIT: + raise NotImplementedError( + "JIT is not supported yet when training with Paddle" + ) + self.model = paddle.jit.to_static(self.model) # Model Wrapper self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params) @@ -633,7 +633,7 @@ def warm_up_linear(step, warmup_steps): self.profiling_file = training_params.get("profiling_file", "timeline.json") def run(self): - if JIT: + if CINN: from paddle import ( jit, static, @@ -644,7 +644,10 @@ def run(self): self.wrapper.forward = jit.to_static( full_graph=True, build_strategy=build_strategy )(self.wrapper.forward) - log.info(f"{'*' * 20} Using Jit {'*' * 20}") + log.info( + "Enable CINN during training, there may be some additional " + "compilation time in the first traning step." + ) fout = ( open( @@ -922,8 +925,8 @@ def log_loss_valid(_task_key="Default"): else: model_key = "Default" step(step_id, model_key) - # if JIT: - # break + if JIT: + break if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0): if not self.multi_task: @@ -978,6 +981,10 @@ def log_loss_valid(_task_key="Default"): / (elapsed_batch // self.disp_freq * self.disp_freq), ) + if JIT: + raise NotImplementedError( + "Paddle JIT saving during training is not supported yet." + ) log.info(f"Trained model has been saved to: {self.save_ckpt}") if fout: diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index 27f5b2a479..87b69c5676 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -59,6 +59,13 @@ def to_bool(flag: int | bool | str) -> bool: JIT = to_bool(os.environ.get("JIT", False)) CINN = to_bool(os.environ.get("CINN", False)) +if CINN: + assert paddle.device.is_compiled_with_cinn(), ( + "CINN is set to True, but PaddlePaddle is not compiled with CINN support. " + "Ensure that your PaddlePaddle installation supports CINN by checking your " + "installation or recompiling with CINN enabled." + ) + CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True @@ -165,7 +172,7 @@ def enable_prim(enable: bool = True): EAGER_COMP_OP_BLACK_LIST = list(set(EAGER_COMP_OP_BLACK_LIST)) """Enable running program with primitive operators in eager/static mode.""" - if JIT: + if JIT or CINN: # jit mode paddle.framework.core._set_prim_all_enabled(enable) if enable: