From 651e24eba5fd21eca00450b95f29b2b819007cf0 Mon Sep 17 00:00:00 2001 From: Brandon Yang Date: Mon, 12 Aug 2024 00:32:25 -0700 Subject: [PATCH 1/3] Fix LayerNorm all reduce gradient hook --- megatron/model/utils.py | 32 ++++++-------------------------- megatron/training.py | 9 +++------ 2 files changed, 9 insertions(+), 32 deletions(-) diff --git a/megatron/model/utils.py b/megatron/model/utils.py index e2539245b..65e4202ec 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -355,28 +355,23 @@ def get_fusion_type(neox_args): def reduce_weight_grads_from_model_parallel_region(input_): - """A hook that can be applied to any weight tensor via .register_hook(). + """A hook that can be applied to any weight tensor via .register_hook(). Allreduces grads for e.g. LN weights, across the model parallel group. Needed to keep LNs in sync, despite them getting diff data -> diff gradients when using sequence parallel. """ - print("TRIPPED") - # Bypass the function if no TP -> no comm needed. if mpu.get_model_parallel_world_size() == 1: return input_ - print("TRIPPED HOOK") # these never get printed. We *should* see them go off if we actually ever trigger these hooks.... - # a print in mark_norms_... confirms that that function runs and initially adds these hooks, oddly. - # Bf16 convert dt = input_.dtype if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): input_ = input_.float() # All-reduce. - torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group()) + torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group()) # average grads - input_ = (input_ / mpu.get_model_parallel_world_size()) + input_ = input_ / mpu.get_model_parallel_world_size() # Bf16 convert if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): @@ -386,8 +381,8 @@ def reduce_weight_grads_from_model_parallel_region(input_): def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): - """Iterate through the modules in our model, and for any "...Norm" classnames, - register a hook on each parameter which will allreduce norms' weights' grads across + """Iterate through the modules in our model, and for any "...Norm" classnames, + register a hook on each parameter which will allreduce norms' weights' grads across the model (sequence) parallel region. """ @@ -399,20 +394,5 @@ def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): if "norm" in type(module_).__name__.lower(): # this is a norm, we want to allreduce its grads across sequence parallel region for name, param in module_.named_parameters(): - if param.requires_grad: - - # copying the helper fn that DeepSpeed uses to add hooks, to see if that fixes our issues. - # it did not seem to. - def wrapper(param): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - - grad_acc.register_hook(lambda grad: reduce_weight_grads_from_model_parallel_region(grad)) - - wrapper(param) - - # hook which will allreduce weight grads across model parallel (seq. parallel) region - # param.register_hook(reduce_weight_grads_from_model_parallel_region) - - return + param.register_hook(reduce_weight_grads_from_model_parallel_region) diff --git a/megatron/training.py b/megatron/training.py index 8cf6e4897..d522435c9 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -89,8 +89,6 @@ def save_base_shapes(neox_args, base_shapes, use_cache): use_cache=use_cache, ) - mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region - if not neox_args.is_pipe_parallel: base_model = base_model.to_sequential() @@ -114,8 +112,6 @@ def save_base_shapes(neox_args, base_shapes, use_cache): topology=mpu.get_topology(), use_cache=use_cache, ) - - mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region if not neox_args.is_pipe_parallel: delta_model = delta_model.to_sequential() @@ -692,8 +688,6 @@ def get_optimizer(model, neox_args): else: raise ValueError(f"Optimizer type {neox_args.optimizer_type} not recognized") - mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region - if neox_args.deepspeed: # fp16 wrapper is not required for DeepSpeed. return optimizer, param_groups @@ -773,6 +767,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) + mark_norms_for_sequence_parallel_grad_sync(model, neox_args) if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": # We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE. model.has_moe_layers = True @@ -843,6 +838,8 @@ def backward_step(neox_args, timers, optimizer, model, loss): def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" + # FIXME: Layer norm weights are stuck at 1.0 when sequence_parallel=True + modules_dict = dict(model.named_modules()) # Pipeline parallelism schedules forward/backward/step if neox_args.is_pipe_parallel: From 9a4331868a0441ee8684e5a44c064f82d6c28ffe Mon Sep 17 00:00:00 2001 From: Brandon Yang Date: Wed, 14 Aug 2024 02:23:25 -0700 Subject: [PATCH 2/3] Sum instead of average for LayerNorm gradient all reduce --- megatron/model/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 65e4202ec..162b1e218 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -370,8 +370,6 @@ def reduce_weight_grads_from_model_parallel_region(input_): # All-reduce. torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group()) - # average grads - input_ = input_ / mpu.get_model_parallel_world_size() # Bf16 convert if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): From c0561d6031394e87f18044f8e0c6344fc3c3c2bc Mon Sep 17 00:00:00 2001 From: github-actions Date: Wed, 14 Aug 2024 09:24:44 +0000 Subject: [PATCH 3/3] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 4db753b16..4c60c1c2f 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = c93c1b4 + Default = 9a43318 current git hash of repository