Skip to content

Commit

Permalink
support CINN compiler for DPA2 example
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Dec 27, 2024
1 parent bf79cc6 commit 55f7ef6
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 25 deletions.
47 changes: 30 additions & 17 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
get_sampler_from_params,
)
from deepmd.pd.utils.env import (
CINN,
DEFAULT_PRECISION,
DEVICE,
JIT,
NUM_WORKERS,
Expand Down Expand Up @@ -397,11 +399,11 @@ def get_lr(lr_params):
self.lr_exp = get_lr(config["learning_rate"])

# JIT
if JIT:
raise NotImplementedError(
"JIT is not supported yet when training with Paddle"
)
self.model = paddle.jit.to_static(self.model)
# if JIT:
# raise NotImplementedError(
# "JIT is not supported yet when training with Paddle"
# )
# self.model = paddle.jit.to_static(self.model)

# Model Wrapper
self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params)
Expand Down Expand Up @@ -631,6 +633,19 @@ def warm_up_linear(step, warmup_steps):
self.profiling_file = training_params.get("profiling_file", "timeline.json")

def run(self):
if JIT:
from paddle import (
jit,
static,
)

build_strategy = static.BuildStrategy()
build_strategy.build_cinn_pass: bool = CINN
self.wrapper.forward = jit.to_static(
full_graph=True, build_strategy=build_strategy
)(self.wrapper.forward)
log.info(f"{'*' * 20} Using Jit {'*' * 20}")

fout = (
open(
self.disp_file,
Expand Down Expand Up @@ -670,9 +685,11 @@ def step(_step_id, task_key="Default") -> None:
cur_lr = _lr.value(_step_id)
pref_lr = cur_lr
self.optimizer.clear_grad(set_to_zero=False)
input_dict, label_dict, log_dict = self.get_data(
is_train=True, task_key=task_key
)

with nvprof_context(enable_profiling, "Fetching data"):
input_dict, label_dict, log_dict = self.get_data(
is_train=True, task_key=task_key
)
if SAMPLER_RECORD:
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
fout1.write(print_str)
Expand All @@ -686,7 +703,7 @@ def step(_step_id, task_key="Default") -> None:
with nvprof_context(enable_profiling, "Forward pass"):
model_pred, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=task_key,
)
Expand Down Expand Up @@ -745,7 +762,7 @@ def log_loss_valid(_task_key="Default"):
return {}
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=_task_key,
)
Expand Down Expand Up @@ -795,7 +812,7 @@ def log_loss_valid(_task_key="Default"):
)
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=_key,
)
Expand Down Expand Up @@ -905,8 +922,8 @@ def log_loss_valid(_task_key="Default"):
else:
model_key = "Default"
step(step_id, model_key)
if JIT:
break
# if JIT:
# break

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0):
if not self.multi_task:
Expand Down Expand Up @@ -961,10 +978,6 @@ def log_loss_valid(_task_key="Default"):
/ (elapsed_batch // self.disp_freq * self.disp_freq),
)

if JIT:
raise NotImplementedError(
"Paddle JIT saving during training is not supported yet."
)
log.info(f"Trained model has been saved to: {self.save_ckpt}")

if fout:
Expand Down
51 changes: 43 additions & 8 deletions deepmd/pd/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,33 @@

paddle.device.set_device(DEVICE)

JIT = False

def to_bool(flag: int | bool | str) -> bool:
if isinstance(flag, int):
if flag not in [0, 1]:
raise ValueError(f"flag must be either 0 or 1, but received {flag}")
return bool(flag)

elif isinstance(flag, str):
flag = flag.lower()
if flag not in ["1", "0", "true", "false"]:
raise ValueError(
"flag must be either '0', '1', 'true', 'false', "
f"but received '{flag}'"
)
return flag in ["1", "true"]

elif isinstance(flag, bool):
return flag

else:
raise ValueError(
f"flag must be either int, bool, or str, but received {type(flag).__name__}"
)


JIT = to_bool(os.environ.get("JIT", False))
CINN = to_bool(os.environ.get("CINN", False))

Check notice

Code scanning / CodeQL

Unused global variable Note

The global variable 'CINN' is not used.
CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
ENERGY_BIAS_TRAINABLE = True

Expand Down Expand Up @@ -138,14 +164,23 @@ def enable_prim(enable: bool = True):
]
EAGER_COMP_OP_BLACK_LIST = list(set(EAGER_COMP_OP_BLACK_LIST))

"""Enable running program in primitive C++ API in eager/static mode."""
from paddle.framework import (
core,
)
"""Enable running program with primitive operators in eager/static mode."""
if JIT:
# jit mode
paddle.framework.core._set_prim_all_enabled(enable)
if enable:
# No need to set a blacklist for now in JIT mode.
pass
else:
# eager mode
paddle.framework.core.set_prim_eager_enabled(enable)
if enable:
# Set a blacklist (i.e., disable several composite operators) in eager mode
# to enhance computational performance.
paddle.framework.core._set_prim_backward_blacklist(
*EAGER_COMP_OP_BLACK_LIST
)

core.set_prim_eager_enabled(enable)
if enable:
paddle.framework.core._set_prim_backward_blacklist(*EAGER_COMP_OP_BLACK_LIST)
log = logging.getLogger(__name__)
log.info(f"{'Enable' if enable else 'Disable'} prim in eager and static mode.")

Expand Down

0 comments on commit 55f7ef6

Please sign in to comment.