Skip to content

Commit

Permalink
little fix
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com committed Dec 26, 2024
1 parent 8aaddaf commit 4a04298
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
10 changes: 6 additions & 4 deletions configs/7B_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
# control universal ckpt. INFO: Not compatible with the original ckpt
# Default to use async_save and not use broadcast_load
# as broadcast_load may cause loading performance degradation
universal_ckpt=dict(enable=False, aysnc_save=True, broadcast_load=False),
# INFO: Universal ckpt is not compatible with the original ckpt.
# Default is to use async_save and not use broadcast_load
# as broadcast_load may cause loading performance degradation.
# NOTE: If using aysnc_save, there is a risk of losing the latest ckpt
# when there is a sudden training interruption.
universal_ckpt=dict(enable=True, aysnc_save=True, broadcast_load=False),
)

TRAIN_FOLDER = None
Expand Down
8 changes: 5 additions & 3 deletions internlm/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None,
universal_load(
load_ckpt_folder, checkpoint_state, broadcast_checkpoint=gpc.config.ckpt.universal_ckpt.broadcast_load
)
if gpc.is_rank_for_log():
logger.warning("Finsh loading universal model checkpoint and optimizer checkpoint.")

if not universal_ckpt and load_content.need_load(CheckpointLoadContent.MODEL):
load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
Expand All @@ -107,9 +109,9 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None,
if not universal_ckpt and load_content.need_load(CheckpointLoadContent.OPIMIZER):
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
else:
if gpc.is_rank_for_log():
logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!")

if not load_content.need_load(CheckpointLoadContent.OPIMIZER) and gpc.is_rank_for_log():
logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!")

# load lr scheduler states.
if load_content.need_load(CheckpointLoadContent.SCHEDULAER):
Expand Down
6 changes: 3 additions & 3 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def set_param_unique_tracking_name(model):
assert global_fqn not in map_layer_attr, f"{map_layer_attr} exists"
map_layer_attr[global_fqn] = {
"offset": getattr(child, "offset", [0] * len(child.weight.size())),
"complete_size": getattr(child, "complete_size", child.weight.size()),
"complete_size": getattr(child, "complete_size", list(child.weight.size())),
}

elif isinstance(child, (RMSNorm)) and uc_enable:
Expand All @@ -188,7 +188,7 @@ def set_param_unique_tracking_name(model):
)
map_layer_attr[global_fqn] = {
"offset": getattr(child, "offset", [0] * len(child.weight.size())),
"complete_size": getattr(child, "complete_size", child.weight.size()),
"complete_size": getattr(child, "complete_size", list(child.weight.size())),
}

else:
Expand Down Expand Up @@ -226,7 +226,7 @@ def set_param_unique_tracking_name(model):

map_layer_attr[local_fqn] = {
"offset": getattr(children, "offset", [0] * len(children.weight.size())),
"complete_size": getattr(children, "complete_size", children.weight.size()),
"complete_size": getattr(children, "complete_size", list(children.weight.size())),
}


Expand Down

0 comments on commit 4a04298

Please sign in to comment.