Skip to content

Commit

Permalink
refine CINN flag
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Dec 27, 2024
1 parent 55f7ef6 commit 95b201d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
21 changes: 12 additions & 9 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,11 +399,11 @@ def get_lr(lr_params):
self.lr_exp = get_lr(config["learning_rate"])

# JIT
# if JIT:
# raise NotImplementedError(
# "JIT is not supported yet when training with Paddle"
# )
# self.model = paddle.jit.to_static(self.model)
if JIT:
raise NotImplementedError(
"JIT is not supported yet when training with Paddle"
)
self.model = paddle.jit.to_static(self.model)

# Model Wrapper
self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params)
Expand Down Expand Up @@ -633,7 +633,7 @@ def warm_up_linear(step, warmup_steps):
self.profiling_file = training_params.get("profiling_file", "timeline.json")

def run(self):
if JIT:
if CINN:
from paddle import (
jit,
static,
Expand All @@ -644,7 +644,10 @@ def run(self):
self.wrapper.forward = jit.to_static(
full_graph=True, build_strategy=build_strategy
)(self.wrapper.forward)
log.info(f"{'*' * 20} Using Jit {'*' * 20}")
log.info(
"Enable CINN during training, there may be some additional "
"compilation time in the first traning step."
)

fout = (
open(
Expand Down Expand Up @@ -922,8 +925,8 @@ def log_loss_valid(_task_key="Default"):
else:
model_key = "Default"
step(step_id, model_key)
# if JIT:
# break
if JIT:
break

if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0):
if not self.multi_task:
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pd/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def to_bool(flag: int | bool | str) -> bool:

JIT = to_bool(os.environ.get("JIT", False))
CINN = to_bool(os.environ.get("CINN", False))
if CINN:
assert paddle.device.is_compiled_with_cinn(), (
"CINN is set to True, but PaddlePaddle is not compiled with CINN support. "
"Ensure that your PaddlePaddle installation supports CINN by checking your "
"installation or recompiling with CINN enabled."
)

CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
ENERGY_BIAS_TRAINABLE = True

Expand Down

0 comments on commit 95b201d

Please sign in to comment.