-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoTuner]Add memory model #147
Conversation
844c126
to
0a33cba
Compare
flagscale/auto_tuner/memory_model.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please try to reuse this impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The impl has been reused and the activation section has been refined.
flagscale/auto_tuner/prune/memory.py
Outdated
|
||
|
||
def prune_by_memory_model_util(config, strategy, history=[]): | ||
if "modeling_memory" in strategy: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to rename "modeling_memory" to "memory_model" as the other places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx, done
flagscale/auto_tuner/tuner.py
Outdated
if os.environ.get("AIRS_ACCELERATOR_COUNT", None): | ||
self.config.experiment.auto_tuner.nproc_per_node = int( | ||
os.environ["AIRS_ACCELERATOR_COUNT"] | ||
# Set config | ||
self.config.experiment.auto_tuner.nproc_per_node = ( | ||
int(os.environ["AIRS_ACCELERATOR_COUNT"]) * 2 | ||
if "luvatar_BI" in os.environ["AIRS_ACCELERATOR_MODEL"] | ||
else int(os.environ["AIRS_ACCELERATOR_COUNT"]) | ||
) | ||
# Set original config | ||
self.orig_config.experiment.runner.nproc_per_node = int( | ||
os.environ["AIRS_ACCELERATOR_COUNT"] | ||
self.orig_config.experiment.runner.nproc_per_node = ( | ||
int(os.environ["AIRS_ACCELERATOR_COUNT"]) * 2 | ||
if "luvatar_BI" in os.environ["AIRS_ACCELERATOR_MODEL"] | ||
else int(os.environ["AIRS_ACCELERATOR_COUNT"]) | ||
) | ||
# Set config | ||
self.config.experiment.runner.nproc_per_node = int( | ||
os.environ["AIRS_ACCELERATOR_COUNT"] | ||
self.config.experiment.runner.nproc_per_node = ( | ||
int(os.environ["AIRS_ACCELERATOR_COUNT"]) * 2 | ||
if "luvatar_BI" in os.environ["AIRS_ACCELERATOR_MODEL"] | ||
else int(os.environ["AIRS_ACCELERATOR_COUNT"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to move these platform related code into a standalone place? We may support differnt cloud platform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The platform code has been removed to platform.py and different platforms code will be in this file.
d052639
to
69510a0
Compare
69510a0
to
e56b149
Compare
@@ -0,0 +1,351 @@ | |||
""" | |||
Computes theoretical memory footprint for model training referring to megatron. | |||
Activation memory is optimized with adding block recompute formula. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add reference to megatron original impl
@@ -161,3 +167,114 @@ def compare_by_recompute(strategy1, strategy2): | |||
result = True | |||
|
|||
return result | |||
|
|||
|
|||
def convert_config_to_megatron_args(config, strategy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any simpler way to impl this?
9d49a2a
to
9f6ad90
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR adds memory model, which can be used to speed up pruning and filter out OOM memory strategies and strategies with low memory usage.