Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] pd: add CINN compiler for dpa2 training #4514

Open
wants to merge 2 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
get_sampler_from_params,
)
from deepmd.pd.utils.env import (
CINN,
DEFAULT_PRECISION,
DEVICE,
JIT,
NUM_WORKERS,
Expand Down Expand Up @@ -631,6 +633,22 @@ def warm_up_linear(step, warmup_steps):
self.profiling_file = training_params.get("profiling_file", "timeline.json")

def run(self):
if CINN:
from paddle import (
jit,
static,
)

build_strategy = static.BuildStrategy()
build_strategy.build_cinn_pass: bool = CINN
self.wrapper.forward = jit.to_static(
full_graph=True, build_strategy=build_strategy
)(self.wrapper.forward)
log.info(
"Enable CINN during training, there may be some additional "
"compilation time in the first traning step."
)

Comment on lines +636 to +651
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Refine the CINN build strategy assignment.

build_strategy.build_cinn_pass: bool = CINN

Here, the colon syntax is typically a type hint, but does not assign the field at runtime. Consider switching to

build_strategy.build_cinn_pass = CINN

to ensure CINN is actually enabled. Otherwise, it may silently fail to apply the intended optimization.

- build_strategy.build_cinn_pass: bool = CINN
+ build_strategy.build_cinn_pass = CINN

fout = (
open(
self.disp_file,
Expand Down Expand Up @@ -670,9 +688,11 @@ def step(_step_id, task_key="Default") -> None:
cur_lr = _lr.value(_step_id)
pref_lr = cur_lr
self.optimizer.clear_grad(set_to_zero=False)
input_dict, label_dict, log_dict = self.get_data(
is_train=True, task_key=task_key
)

with nvprof_context(enable_profiling, "Fetching data"):
input_dict, label_dict, log_dict = self.get_data(
is_train=True, task_key=task_key
)
if SAMPLER_RECORD:
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
fout1.write(print_str)
Expand All @@ -686,7 +706,7 @@ def step(_step_id, task_key="Default") -> None:
with nvprof_context(enable_profiling, "Forward pass"):
model_pred, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=task_key,
)
Expand Down Expand Up @@ -745,7 +765,7 @@ def log_loss_valid(_task_key="Default"):
return {}
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=_task_key,
)
Expand Down Expand Up @@ -795,7 +815,7 @@ def log_loss_valid(_task_key="Default"):
)
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=_key,
)
Expand Down
58 changes: 50 additions & 8 deletions deepmd/pd/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,40 @@

paddle.device.set_device(DEVICE)

JIT = False

def to_bool(flag: int | bool | str) -> bool:
if isinstance(flag, int):
if flag not in [0, 1]:
raise ValueError(f"flag must be either 0 or 1, but received {flag}")
return bool(flag)

elif isinstance(flag, str):
flag = flag.lower()
if flag not in ["1", "0", "true", "false"]:
raise ValueError(
"flag must be either '0', '1', 'true', 'false', "
f"but received '{flag}'"
)
return flag in ["1", "true"]

elif isinstance(flag, bool):
return flag

else:
raise ValueError(
f"flag must be either int, bool, or str, but received {type(flag).__name__}"
)


JIT = to_bool(os.environ.get("JIT", False))
CINN = to_bool(os.environ.get("CINN", False))
Fixed Show fixed Hide fixed
if CINN:
assert paddle.device.is_compiled_with_cinn(), (
"CINN is set to True, but PaddlePaddle is not compiled with CINN support. "
"Ensure that your PaddlePaddle installation supports CINN by checking your "
"installation or recompiling with CINN enabled."
)

CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
ENERGY_BIAS_TRAINABLE = True

Expand Down Expand Up @@ -138,14 +171,23 @@ def enable_prim(enable: bool = True):
]
EAGER_COMP_OP_BLACK_LIST = list(set(EAGER_COMP_OP_BLACK_LIST))

"""Enable running program in primitive C++ API in eager/static mode."""
from paddle.framework import (
core,
)
"""Enable running program with primitive operators in eager/static mode."""
if JIT or CINN:
# jit mode
paddle.framework.core._set_prim_all_enabled(enable)
if enable:
# No need to set a blacklist for now in JIT mode.
pass
else:
# eager mode
paddle.framework.core.set_prim_eager_enabled(enable)
if enable:
# Set a blacklist (i.e., disable several composite operators) in eager mode
# to enhance computational performance.
paddle.framework.core._set_prim_backward_blacklist(
*EAGER_COMP_OP_BLACK_LIST
)

core.set_prim_eager_enabled(enable)
if enable:
paddle.framework.core._set_prim_backward_blacklist(*EAGER_COMP_OP_BLACK_LIST)
log = logging.getLogger(__name__)
log.info(f"{'Enable' if enable else 'Disable'} prim in eager and static mode.")

Expand Down
Loading