From 817039b331b6245fe2c32b28b04dafa7f78a53ef Mon Sep 17 00:00:00 2001 From: Brandon Yang Date: Mon, 12 Aug 2024 00:32:25 -0700 Subject: [PATCH] Fix LayerNorm all reduce gradient hook --- megatron/model/utils.py | 27 ++++----------------------- megatron/training.py | 15 +++++++++------ 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/megatron/model/utils.py b/megatron/model/utils.py index e2539245b..fc41d3107 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -359,22 +359,17 @@ def reduce_weight_grads_from_model_parallel_region(input_): Allreduces grads for e.g. LN weights, across the model parallel group. Needed to keep LNs in sync, despite them getting diff data -> diff gradients when using sequence parallel. """ - print("TRIPPED") - # Bypass the function if no TP -> no comm needed. if mpu.get_model_parallel_world_size() == 1: return input_ - print("TRIPPED HOOK") # these never get printed. We *should* see them go off if we actually ever trigger these hooks.... - # a print in mark_norms_... confirms that that function runs and initially adds these hooks, oddly. - # Bf16 convert dt = input_.dtype if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): input_ = input_.float() - + # All-reduce. - torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group()) + torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group()) # average grads input_ = (input_ / mpu.get_model_parallel_world_size()) @@ -397,22 +392,8 @@ def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): for module_ in module.modules(): if "norm" in type(module_).__name__.lower(): - # this is a norm, we want to allreduce its grads across sequence parallel region + # this is a norm, we want to allreduce its grads across sequence parallel region for name, param in module_.named_parameters(): - if param.requires_grad: - - # copying the helper fn that DeepSpeed uses to add hooks, to see if that fixes our issues. - # it did not seem to. - def wrapper(param): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - - grad_acc.register_hook(lambda grad: reduce_weight_grads_from_model_parallel_region(grad)) - - wrapper(param) - - # hook which will allreduce weight grads across model parallel (seq. parallel) region - # param.register_hook(reduce_weight_grads_from_model_parallel_region) + param.register_hook(reduce_weight_grads_from_model_parallel_region) - return diff --git a/megatron/training.py b/megatron/training.py index 8cf6e4897..578ef3d79 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -89,8 +89,6 @@ def save_base_shapes(neox_args, base_shapes, use_cache): use_cache=use_cache, ) - mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region - if not neox_args.is_pipe_parallel: base_model = base_model.to_sequential() @@ -114,8 +112,6 @@ def save_base_shapes(neox_args, base_shapes, use_cache): topology=mpu.get_topology(), use_cache=use_cache, ) - - mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region if not neox_args.is_pipe_parallel: delta_model = delta_model.to_sequential() @@ -692,8 +688,6 @@ def get_optimizer(model, neox_args): else: raise ValueError(f"Optimizer type {neox_args.optimizer_type} not recognized") - mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region - if neox_args.deepspeed: # fp16 wrapper is not required for DeepSpeed. return optimizer, param_groups @@ -773,6 +767,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) + mark_norms_for_sequence_parallel_grad_sync(model, neox_args) if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": # We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE. model.has_moe_layers = True @@ -843,6 +838,14 @@ def backward_step(neox_args, timers, optimizer, model, loss): def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" + # FIXME: Layer norm weights are stuck at 1.0 when sequence_parallel=True + modules_dict = dict(model.named_modules()) + print(f"========= START RANK: {torch.distributed.get_rank()} =========") + print(f"RANK: {torch.distributed.get_rank()}: module.sequential.8.input_layernorm.weight = {modules_dict['module.sequential.8.input_layernorm'].weight.detach().mean()}") + print(f"RANK: {torch.distributed.get_rank()}: module.sequential.8.input_layernorm.bias = {modules_dict['module.sequential.8.input_layernorm'].bias.detach().mean()}") + print(f"RANK: {torch.distributed.get_rank()}: module.sequential.8.post_attention_layernorm.weight = {modules_dict['module.sequential.8.post_attention_layernorm'].weight.detach().mean()}") + print(f"RANK: {torch.distributed.get_rank()}: module.sequential.8.post_attention_layernorm.bias = {modules_dict['module.sequential.8.post_attention_layernorm'].bias.detach().mean()}") + print(f"========= END RANK: {torch.distributed.get_rank()} =========") # Pipeline parallelism schedules forward/backward/step if neox_args.is_pipe_parallel: