-
Notifications
You must be signed in to change notification settings - Fork 523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] pd: add CINN compiler for dpa2 training #4514
base: devel
Are you sure you want to change the base?
[WIP] pd: add CINN compiler for dpa2 training #4514
Conversation
Warning Rate limit exceeded@HydrogenSulfate has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 25 minutes and 9 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThe pull request introduces modifications to the Changes
Sequence DiagramsequenceDiagram
participant Env as Environment
participant Training as Training Module
participant Model as Model
Env->>Env: Configure JIT and CINN settings
Env->>Training: Set precision and compilation strategy
Training->>Model: Apply JIT compilation
Model-->>Training: Optimize forward pass
Training->>Training: Profile performance
The sequence diagram illustrates the flow of configuration and optimization process, showing how environment settings influence the training module and model compilation. Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/pd/train/training.py (2)
402-405
: Remove or clarify commented-out code.These lines comment out a previously raised
NotImplementedError
and a potentialpaddle.jit.to_static
call. If you no longer need this logic, removing it altogether might reduce confusion and keep the file tidy. Otherwise, add a comment explaining why these lines are kept for future reference.-# if JIT: -# raise NotImplementedError("JIT is not supported yet when training with Paddle") -# self.model = paddle.jit.to_static(self.model)Also applies to: 406-406
925-926
: Consider removing the extra commented-out code.This snippet appears to comment out a JIT debugging break. If it's no longer needed, removing it can avoid potential confusion.
-# if JIT: -# break🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 925-926: Commented-out code
This comment appears to contain commented-out code.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/train/training.py
(8 hunks)deepmd/pd/utils/env.py
(2 hunks)
🧰 Additional context used
🪛 GitHub Check: CodeQL
deepmd/pd/utils/env.py
[notice] 61-61: Unused global variable
The global variable 'CINN' is not used.
deepmd/pd/train/training.py
[notice] 925-926: Commented-out code
This comment appears to contain commented-out code.
🔇 Additional comments (7)
deepmd/pd/train/training.py (4)
56-57
: Check the new imports for consistency.
The addition of CINN
and DEFAULT_PRECISION
is consistent with improvements to JIT compilation and precision handling. Nothing problematic is observed here; just ensure that CINN
is successfully imported where used and that DEFAULT_PRECISION
is consistently applied.
636-648
: JIT and CINN integration logic looks good.
You're conditionally enabling JIT using jit.to_static
and passing the build_strategy.build_cinn_pass = CINN
. This is a clean approach, ensuring that CINN is only used if enabled. Just verify upstream usage to avoid unexpected behaviors if CINN
is disabled at runtime.
688-692
: Validate data fetching performance within profiling context.
Wrapping the data loading with nvprof_context
helps profile overhead. Ensure that exceptions thrown within this block are properly handled so that the profiler is closed gracefully.
706-706
: Precision usage for learning rate.
Using paddle.full([], pref_lr, DEFAULT_PRECISION)
enforces consistent floating-point precision for the learning rate. This is beneficial for uniformity, especially in multi-task scenarios. Good practice!
Also applies to: 765-765, 815-815
deepmd/pd/utils/env.py (3)
36-50
: Robust input validation in to_bool
.
Excellent job handling integers, booleans, and string values thoroughly, with clear error messages and suitable lowercasing of strings. This ensures minimal confusion for environment variable parsing.
60-61
: CINN usage is valid despite static analysis hint.
While static analysis tools suggest that CINN
is “unused,” it is in fact used within training.py
(via build_strategy.build_cinn_pass
). You can safely disregard that hint.
🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 61-61: Unused global variable
The global variable 'CINN' is not used.
167-182
: Flexible prim enabling for JIT and eager modes.
Switching between JIT mode and eager mode, while setting blacklists only in eager mode, is a neat approach to leverage performance improvements. This design offers more control and clarity to developers using either mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
deepmd/pd/train/training.py (3)
709-709
: Confirm precision alignment for learning rate.
Replacing a float withpaddle.full([], pref_lr, DEFAULT_PRECISION)
ensures consistent floating-point precision. Verify that using a 0-dimensional tensor is intentional and that the rest of the forward pass expects a scalar.
768-768
: Avoid code duplication in validation logic.
This pattern repeats the same approach for creating a scalar tensor frompref_lr
. Consider encapsulating the logic to reduce repetitive calls and simplify maintenance.
818-818
: Maintain consistent approach for multi-task calls.
Reusingpaddle.full([], pref_lr, DEFAULT_PRECISION)
in multiple task keys is valid. However, centralizing this snippet could prevent future errors in multi-task scenarios.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/train/training.py
(6 hunks)deepmd/pd/utils/env.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pd/utils/env.py
🔇 Additional comments (2)
deepmd/pd/train/training.py (2)
56-57
: Use environment-driven feature flags carefully.
Importing CINN
and DEFAULT_PRECISION
helps toggle experimental optimizations and ensure matching data types throughout training. This is acceptable, but confirm that these environment variables or configurations are properly documented and tested to avoid accidental misconfigurations.
691-695
: Validate data fetching within profiler context.
The usage of nvprof_context
around get_data()
is beneficial for profiling. Make sure exceptions (e.g., StopIteration
) in data fetching are handled gracefully to preserve consistent profiling measurements.
if CINN: | ||
from paddle import ( | ||
jit, | ||
static, | ||
) | ||
|
||
build_strategy = static.BuildStrategy() | ||
build_strategy.build_cinn_pass: bool = CINN | ||
self.wrapper.forward = jit.to_static( | ||
full_graph=True, build_strategy=build_strategy | ||
)(self.wrapper.forward) | ||
log.info( | ||
"Enable CINN during training, there may be some additional " | ||
"compilation time in the first traning step." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refine the CINN build strategy assignment.
build_strategy.build_cinn_pass: bool = CINN
Here, the colon syntax is typically a type hint, but does not assign the field at runtime. Consider switching to
build_strategy.build_cinn_pass = CINN
to ensure CINN is actually enabled. Otherwise, it may silently fail to apply the intended optimization.
- build_strategy.build_cinn_pass: bool = CINN
+ build_strategy.build_cinn_pass = CINN
95b201d
to
5cdd421
Compare
5cdd421
to
7ca2a9e
Compare
We verified paddle CINN compiler in DPA-2 example(single A100-SXM (40G), cada11.8, Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz x 160).
To enable CINN compiler in training, add one flag:
CINN=1
before command line, e.g.CINN=1 dp --pd train input_torch_medium.json
.Curves:
Performance
We tested with torch==2.6.0.dev20241219+cu118
Accuracy details:
Pytorch:
Paddle(eager mode):
Paddle(CINN compiler)
TODO:
Summary by CodeRabbit
Summary by CodeRabbit
New Features
CINN
andDEFAULT_PRECISION
.Improvements
Bug Fixes