-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'microsoft:master' into features/fp6_compile_err
- Loading branch information
Showing
18 changed files
with
386 additions
and
107 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import os | ||
import torch | ||
|
||
from deepspeed.utils import logger | ||
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states | ||
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank | ||
|
||
|
||
class DeepSpeedOptimizer(object): | ||
pass | ||
|
||
|
||
class ZeROOptimizer(DeepSpeedOptimizer): | ||
|
||
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: | ||
checkpoint_dir = os.path.join(checkpoint_dir, "zero") | ||
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") | ||
assert os.path.isfile( | ||
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' | ||
optim_sd = torch.load(optim_state_path) | ||
|
||
self._load_global_state(optim_sd) | ||
|
||
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) | ||
if self.mpu is None: | ||
logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.") | ||
tp_world_size = 1 | ||
else: | ||
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ | ||
else self.mpu.get_tensor_model_parallel_world_size() | ||
|
||
for i, (param_group, | ||
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])): | ||
# We have an assumption that all params in the same param_group have the same keys | ||
opt_keys = set() | ||
steps = [] | ||
|
||
lp_groups = getattr(self, lp_groups_name) | ||
for lp in lp_groups[i]: | ||
if lp._hp_mapping is not None: | ||
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") | ||
step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, | ||
tp_world_size) | ||
for key in lp._hp_mapping.get_optim_state_keys(): | ||
opt_keys.add(key) | ||
steps.append(step) | ||
|
||
hp_param = param_group['params'][0] | ||
assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal" | ||
if steps[0] is not None: | ||
self.optimizer.state[hp_param]['step'] = steps[0] | ||
|
||
map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys) | ||
|
||
for key, value in loaded_param_group.items(): | ||
if key == 'params': | ||
continue | ||
param_group[key] = value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.