Skip to content

Commit

Permalink
[Feature] Add exclude_frozen_parameters for DeepSpeedStrategy (#1415
Browse files Browse the repository at this point in the history
)
  • Loading branch information
LZHgrla authored Nov 2, 2023
1 parent 2a563f4 commit 27ab6a6
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions mmengine/_strategy/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
STRATEGIES)
from mmengine.utils import get_git_hash
from mmengine.utils import digit_version, get_git_hash
from .base import BaseStrategy


Expand Down Expand Up @@ -245,6 +245,8 @@ class DeepSpeedStrategy(BaseStrategy):
gradient_accumulation_steps (int, optional): Number of training steps
to accumulate gradients before averaging and applying them.
Defaults to None.
exclude_frozen_parameters (bool, optional): Exclude frozen parameters
from saved checkpoint.
"""

def __init__(
Expand All @@ -265,6 +267,7 @@ def __init__(
# disable the log printed by deepseed
steps_per_print: int = 10000000000000,
# the following args are for BaseStrategy
exclude_frozen_parameters: Optional[bool] = None,
**kwargs,
):
assert deepspeed is not None, \
Expand Down Expand Up @@ -298,6 +301,11 @@ def __init__(
self.config.setdefault('gradient_accumulation_steps', 1)
self.config['steps_per_print'] = steps_per_print
self._inputs_to_half = inputs_to_half
assert (exclude_frozen_parameters is None or
digit_version(deepspeed.__version__) >= digit_version('0.10.1')
), ('DeepSpeed >= 0.10.1 is required to enable '
'exclude_frozen_parameters')
self.exclude_frozen_parameters = exclude_frozen_parameters

register_deepspeed_optimizers()

Expand Down Expand Up @@ -413,8 +421,15 @@ def load_checkpoint(
self.logger.info(f'Load checkpoint from {filename}')

dirname, basename = osp.split(filename)
_, extra_ckpt = self.model.load_checkpoint(
dirname, tag=basename, load_optimizer_states=False)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
_, extra_ckpt = self.model.load_checkpoint(
dirname,
tag=basename,
load_optimizer_states=False,
load_module_strict=not self.exclude_frozen_parameters)
else:
_, extra_ckpt = self.model.load_checkpoint(
dirname, tag=basename, load_optimizer_states=False)

return extra_ckpt

Expand Down Expand Up @@ -510,5 +525,16 @@ def save_checkpoint(
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()

dirname, basename = osp.split(filename)
self.model.save_checkpoint(
dirname, tag=basename, client_state=extra_ckpt, save_latest=False)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
exclude_frozen_parameters=self.exclude_frozen_parameters)
else:
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False)

0 comments on commit 27ab6a6

Please sign in to comment.