Skip to content

Commit

Permalink
Merge branch 'main' into yuxiang/te
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean authored Oct 17, 2023
2 parents e605841 + 5b31b70 commit f8f3de4
Showing 1 changed file with 53 additions and 4 deletions.
57 changes: 53 additions & 4 deletions msamp/megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from megatron.core import mpu, tensor_parallel
from megatron.optimizer.optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
from megatron.optimizer.distrib_optimizer import DistributedOptimizer, Range
from megatron.utils import print_rank_0

from msamp.common.dtype import Dtypes
from msamp.common.tensor import ScalingTensor, ScalingMeta
Expand Down Expand Up @@ -350,12 +351,60 @@ def get_model_parallel_group(self):
return None

def state_dict(self):
"""Return the optimizer's state dict."""
raise NotImplementedError
"""The state dict must contain the fp32-from-float16 and fp16-from-fp8 shards."""
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
# shared master weight
state_dict['shard_fp32_from_float16_groups'] = \
self.shard_fp32_from_float16_groups
state_dict['shard_hp_from_fp8_groups'] = \
self.shard_hp_from_fp8_groups
return state_dict

def load_state_dict(self, state_dict):
"""Load the optimizer's state dict."""
raise NotImplementedError
"""Load the state dict."""
optimizer_key = 'optimizer'
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
print_rank_0('***WARNING*** loading optimizer from '
'an old checkpoint ...')
# convert optimizer states
ckpt_state_dict = state_dict[optimizer_key]
self.optimizer.load_state_dict(ckpt_state_dict)

# Grad scaler.
if 'grad_scaler' not in state_dict:
if self.fp16:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
print_rank_0(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)

# Copy data for the main params.
for current_group, saved_group in zip(
self.shard_fp32_from_float16_groups, state_dict['shard_fp32_from_float16_groups']
):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)

for current_group, saved_group in zip(self.shard_hp_from_fp8_groups, state_dict['shard_hp_from_fp8_groups']):
for current_param, saved_param in zip(current_group, saved_group):
if current_param.data.qtype == saved_param.data.qtype:
current_param.data.copy_(saved_param.data)
else:
# when the data type of optimizer's master weight and checkpoint's is different
current_param.data.copy_(
saved_param.data.to(current_param.data.device).cast(current_param.data.qtype)
)

def zero_grad(self, set_to_none=True):
"""Zero grads.
Expand Down

0 comments on commit f8f3de4

Please sign in to comment.