Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add zero3 module_granularity_threshold to zero optimization. #6649

Merged
merged 36 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a2610d8
z3 coalesced fetch
inkcherry Oct 21, 2024
4e8be08
fix format
inkcherry Oct 21, 2024
7641994
fix default value
inkcherry Oct 21, 2024
805a820
fix default
inkcherry Oct 21, 2024
ce7dfb7
Merge branch 'master' into z3_coalesced_fetch
delock Oct 23, 2024
810353b
fix ut
inkcherry Oct 23, 2024
a8dd8fe
fix ut
inkcherry Oct 23, 2024
53584ca
Merge branch 'master' into z3_coalesced_fetch
loadams Oct 25, 2024
4d86198
Merge branch 'master' into z3_coalesced_fetch
tjruwase Oct 31, 2024
7b94377
add ut(usage)
inkcherry Nov 4, 2024
cd31a0d
use int type config
inkcherry Nov 4, 2024
ea50964
fix format
inkcherry Nov 4, 2024
b068118
Merge remote-tracking branch 'origin/z3_coalesced_fetch' into z3_coal…
inkcherry Nov 4, 2024
600d9c7
fix note
inkcherry Nov 4, 2024
4477077
Merge branch 'master' into z3_coalesced_fetch
tjruwase Nov 4, 2024
c2c434b
refine code
inkcherry Nov 5, 2024
e5f9430
remove debug code
inkcherry Nov 5, 2024
c2b020a
update
inkcherry Nov 5, 2024
511ace0
Merge remote-tracking branch 'origin/z3_coalesced_fetch' into z3_coal…
inkcherry Nov 5, 2024
3680109
don't set leaf for container module
inkcherry Nov 5, 2024
f2752f8
Merge branch 'master' into z3_coalesced_fetch
inkcherry Nov 5, 2024
22c0f81
update ut
inkcherry Nov 6, 2024
f773258
udpate
inkcherry Nov 6, 2024
c31ad02
change config name, refine doc
inkcherry Nov 6, 2024
40ceeac
fix rjust size
inkcherry Nov 6, 2024
73e5bd5
fix merge
inkcherry Nov 6, 2024
c31c8d2
format
inkcherry Nov 6, 2024
619cbe6
always print info if the config is enabled
inkcherry Nov 7, 2024
3c0a183
Merge branch 'master' into z3_coalesced_fetch
inkcherry Nov 7, 2024
a6e5a39
update
inkcherry Nov 7, 2024
e7e5cdf
Merge branch 'z3_coalesced_fetch' of https://github.com/inkcherry/Dee…
inkcherry Nov 7, 2024
00ac4eb
Merge remote-tracking branch 'upstream/master' into z3_coalesced_fetch
inkcherry Nov 11, 2024
25df962
use mark parametrize for test
inkcherry Nov 11, 2024
fdd6fbf
Merge branch 'master' into z3_coalesced_fetch
tjruwase Nov 11, 2024
b860961
Merge branch 'master' into z3_coalesced_fetch
tjruwase Nov 12, 2024
dfb2b45
Merge branch 'master' into z3_coalesced_fetch
delock Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,9 @@ def zero_max_reuse_distance(self):
def zero_prefetch_bucket_size(self):
return self._config.zero_config.prefetch_bucket_size

def zero_module_granularity_threshold(self):
return self._config.zero_config.module_granularity_threshold

def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold

Expand Down Expand Up @@ -1611,6 +1614,7 @@ def _configure_zero_optimizer(self, optimizer):
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=self.zero_quantized_weights(),
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
)
else:
log_dist(
Expand Down Expand Up @@ -1657,6 +1661,7 @@ def _configure_zero_optimizer(self, optimizer):
zero_hpz_partition_size=self.zero_hpz_partition_size(),
zero_quantized_weights=self.zero_quantized_weights(),
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
)

else:
Expand Down
9 changes: 9 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"stage3_max_live_parameters" : 1000000000,
"stage3_max_reuse_distance" : 1000000000,
"stage3_use_all_reduce_for_fetch_params": [true|false],
"stage3_module_granularity_threshold": 0,
"allgather_partitions": [true|false],
"use_multi_rank_bucket_allreduce": [true|false],
"allgather_bucket_size": 500000000,
Expand Down Expand Up @@ -245,6 +246,14 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
this option is enabled and then saves the fp16 model weights.
"""

module_granularity_threshold: int = Field(pp_int(0), alias="stage3_module_granularity_threshold")
"""
The granularity of a module is determined by the ratio of "parameter_count / (1 + descendant count)".
ZeRO3 classifies modules with a granularity below the threshold as fine-grained,
which are treated as integral units during parameter fetching. This reduces host overhead
and the separate allgather overhead introduced by hooks for fine-grained layers when fetching parameters.
"""

use_all_reduce_for_fetch_params: bool = Field(False, alias="stage3_use_all_reduce_for_fetch_params")
"""
Use all_reduce op when fetching module parameters at stage3. This improves performance by reducing
Expand Down
91 changes: 90 additions & 1 deletion deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import sys
import torch
from collections import OrderedDict
from deepspeed.utils import z3_leaf_module
from deepspeed.utils import z3_leaf_module, set_z3_leaf_module
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.partition_parameters import _init_external_params
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params
from deepspeed.accelerator import get_accelerator
from deepspeed import utils

FWD_MODULE_STACK = list()

Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
zero_param_parallel_group=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
):

see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
Expand Down Expand Up @@ -155,8 +157,16 @@ def __init__(
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
)

if zero_module_granularity_threshold > 0:
self.min_granularity_value = sys.maxsize
self.min_granularity_layer = None
self.granularity_info = set()
self.z3_leaf_layers = []
self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold)

self.forward_hooks = []
self.backward_hooks = []

self.setup_zero_stage3_hooks()
print_rank_0(
f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}',
Expand Down Expand Up @@ -482,3 +492,82 @@ def post_sub_module_backward_function(self, sub_module):
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)

def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold):

self._get_granularity_recursively(module)
print_rank_0(f"{'MODULE NAME'.ljust(30)}|{'GRANULARITY VALUE'.rjust(20)}", force=True)
for granularity in self.granularity_info:
print_rank_0(granularity, force=True)

if self.min_granularity_value <= zero_module_granularity_threshold:
self._set_leaf_by_threshold_preorder(module, zero_module_granularity_threshold)
utils.logger.info(
f"z3_leaf_module was set by stage3_module_granularity_threshold:{zero_module_granularity_threshold}")
for layer in self.z3_leaf_layers:
print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True)
else:
utils.logger.warning(
f"The smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}]. "\
f"To make stage3_module_granularity_threshold effective, you need to set stage3_module_granularity_threshold >= {self.min_granularity_value}. "\
f"Current Value:{zero_module_granularity_threshold}"
)

def _get_granularity_recursively(self, module):
"""This function is used to recursively obtain the granularity of each module."""

# avoid setting as leaf for particularly large models, even if the granularity is very small
# an oversized leaf module increases the number of live parameters, introducing memory overhead
Z3_MAX_LEAF_SIZE = 1e9

if not list(module.parameters()):
# skip Modules without parameters, such as GELU, etc.
module.ds_model_granularity = sys.maxsize
return 0, 0

num_layers = 0
num_params = 0
num_params += sum(p.ds_numel for p in module.parameters(recurse=False))
if not any(module.children()):
# torch leaf module
module.ds_model_granularity = sys.maxsize
return 1, num_params

for child in module.children():
layers_in_child, params_in_child = self._get_granularity_recursively(child)
num_layers += layers_in_child
num_params += params_in_child

if module.__class__.__name__ in torch.nn.modules.container.__all__:
# Do not set container modules like ModuleList as leaf modules
# as this will prevent hooks from being set on their children
# and they may do not invoke the forward method
module.ds_model_granularity = sys.maxsize
return num_layers, num_params

num_layers += 1
ds_model_granularity = (num_params // num_layers) if num_params <= Z3_MAX_LEAF_SIZE else sys.maxsize
module.ds_model_granularity = ds_model_granularity
# module.ds_model_num_layers = num_layers
# module.ds_model_num_params = num_params
if self.min_granularity_value > ds_model_granularity:
self.min_granularity_value = ds_model_granularity
self.min_granularity_layer = module.__class__.__name__
self.granularity_info.add(f"{module.__class__.__name__.ljust(30)}|{str(ds_model_granularity).rjust(20)}")

return num_layers, num_params

def _set_leaf_by_threshold_preorder(self, module, granularity_treshhold):
'''Set modules as leaf modules based on the threshold, prioritizing parent nodes.'''

num_params = sum(p.ds_numel for p in module.parameters())
if num_params == 0:
# skip Modules without parameters, such as GELU, etc.
return
if module.ds_model_granularity <= granularity_treshhold:
set_z3_leaf_module(module, True)
self.z3_leaf_layers.append(module)
return

for sub_module in module.children():
self._set_leaf_by_threshold_preorder(sub_module, granularity_treshhold)
8 changes: 6 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
zero_hpz_partition_size=1,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
):
see_memory_usage("Stage 3 initialize beginning", force=True)

Expand Down Expand Up @@ -227,7 +228,8 @@ def __init__(
mpu=mpu,
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights)
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
zero_module_granularity_threshold=zero_module_granularity_threshold)

self.persistent_parameters = self.parameter_offload.persistent_parameters
self._configure_offloading(offload_optimizer_config, offload_param_config)
Expand Down Expand Up @@ -458,6 +460,7 @@ def initialize_ds_offload(
zero_param_parallel_group,
zero_quantized_weights,
zero_quantized_nontrainable_weights,
zero_module_granularity_threshold,
):
return DeepSpeedZeRoOffload(module=module,
timers=timers,
Expand All @@ -473,7 +476,8 @@ def initialize_ds_offload(
mpu=mpu,
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights)
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
zero_module_granularity_threshold=zero_module_granularity_threshold)

def _get_trainable_parameter_groups(self):
param_groups = []
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state, safe_set_full_grad
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd
27 changes: 17 additions & 10 deletions deepspeed/utils/z3_leaf_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team

import torch
from typing import List, Type
from typing import List, Type, Union


def z3_leaf_module(model: torch.nn.Module) -> bool:
Expand Down Expand Up @@ -40,18 +40,24 @@ def get_z3_leaf_modules(model: torch.nn.Module) -> List[torch.nn.Module]:
return [module for module in model.modules() if z3_leaf_module(module)]


def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type],
def set_z3_leaf_module(model: torch.nn.Module, flag: bool):
model._z3_leaf = flag


def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]],
flag: bool) -> List[torch.nn.Module]:
assert all(isinstance(module_class, type) for module_class in leaf_module_classes), \
f'leaf_module_classes must be a list of types, got {leaf_module_classes}'
assert all(isinstance(module_class, (type, str) ) for module_class in leaf_module_classes), \
f'leaf_module_classes must be a list of types or names, got {leaf_module_classes}'

leaf_modules = []

def _set_z3_leaf_flag(model: torch.nn.Module):
nonlocal leaf_modules
if model.__class__ in leaf_module_classes:
model._z3_leaf = flag
leaf_modules.append(model)
for module in leaf_module_classes:
if (isinstance(module, type) and model.__class__ == module) or \
(isinstance(module, str) and model.__class__.__name__ == module):
model._z3_leaf = flag
leaf_modules.append(model)

model.apply(_set_z3_leaf_flag)

Expand All @@ -61,13 +67,14 @@ def _set_z3_leaf_flag(model: torch.nn.Module):
return leaf_modules


def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]:
def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type],
List[str]]) -> List[torch.nn.Module]:
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules.
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
Returns:
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
"""
Expand All @@ -79,7 +86,7 @@ def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules.
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
Returns:
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
"""
Expand Down
5 changes: 5 additions & 0 deletions docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,11 @@ Enabling and configuring ZeRO memory optimizations
|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- |
| Consolidate the weights before saving the model by `save_16bit_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights. | `False` |

***stage3_module_granularity_threshold***: [integer]
| Description | Default |
|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- |
| The granularity of a module is determined by the ratio of `parameter_count` / `(1 + descendant_count)`. ZeRO3 classifies modules with a granularity below the threshold as fine-grained, treating them as integral units during parameter fetching. This reduces host and communication overhead from separate hooks. | `0` |

***zero_hpz_partition_size***: [integer]

| Description | Default |
Expand Down
Loading