Skip to content

Commit

Permalink
Merge pull request #1259 from EleutherAI/fix-ln-hooks
Browse files Browse the repository at this point in the history
Fix LayerNorm all reduce gradient hook
  • Loading branch information
haileyschoelkopf authored Aug 14, 2024
2 parents 9c1e7b9 + c0561d6 commit 9945910
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 34 deletions.
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = c93c1b4
Default = 9a43318

current git hash of repository

Expand Down
32 changes: 5 additions & 27 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,28 +355,21 @@ def get_fusion_type(neox_args):


def reduce_weight_grads_from_model_parallel_region(input_):
"""A hook that can be applied to any weight tensor via .register_hook().
"""A hook that can be applied to any weight tensor via .register_hook().
Allreduces grads for e.g. LN weights, across the model parallel group.
Needed to keep LNs in sync, despite them getting diff data -> diff gradients when using sequence parallel.
"""
print("TRIPPED")

# Bypass the function if no TP -> no comm needed.
if mpu.get_model_parallel_world_size() == 1:
return input_

print("TRIPPED HOOK") # these never get printed. We *should* see them go off if we actually ever trigger these hooks....
# a print in mark_norms_... confirms that that function runs and initially adds these hooks, oddly.

# Bf16 convert
dt = input_.dtype
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
input_ = input_.float()

# All-reduce.
torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group())
# average grads
input_ = (input_ / mpu.get_model_parallel_world_size())
torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group())

# Bf16 convert
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
Expand All @@ -386,8 +379,8 @@ def reduce_weight_grads_from_model_parallel_region(input_):


def mark_norms_for_sequence_parallel_grad_sync(module, neox_args):
"""Iterate through the modules in our model, and for any "...Norm" classnames,
register a hook on each parameter which will allreduce norms' weights' grads across
"""Iterate through the modules in our model, and for any "...Norm" classnames,
register a hook on each parameter which will allreduce norms' weights' grads across
the model (sequence) parallel region.
"""

Expand All @@ -399,20 +392,5 @@ def mark_norms_for_sequence_parallel_grad_sync(module, neox_args):
if "norm" in type(module_).__name__.lower():
# this is a norm, we want to allreduce its grads across sequence parallel region
for name, param in module_.named_parameters():

if param.requires_grad:

# copying the helper fn that DeepSpeed uses to add hooks, to see if that fixes our issues.
# it did not seem to.
def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]

grad_acc.register_hook(lambda grad: reduce_weight_grads_from_model_parallel_region(grad))

wrapper(param)

# hook which will allreduce weight grads across model parallel (seq. parallel) region
# param.register_hook(reduce_weight_grads_from_model_parallel_region)

return
param.register_hook(reduce_weight_grads_from_model_parallel_region)
9 changes: 3 additions & 6 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
use_cache=use_cache,
)

mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region

if not neox_args.is_pipe_parallel:
base_model = base_model.to_sequential()

Expand All @@ -114,8 +112,6 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
topology=mpu.get_topology(),
use_cache=use_cache,
)

mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region

if not neox_args.is_pipe_parallel:
delta_model = delta_model.to_sequential()
Expand Down Expand Up @@ -692,8 +688,6 @@ def get_optimizer(model, neox_args):
else:
raise ValueError(f"Optimizer type {neox_args.optimizer_type} not recognized")

mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region

if neox_args.deepspeed:
# fp16 wrapper is not required for DeepSpeed.
return optimizer, param_groups
Expand Down Expand Up @@ -773,6 +767,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
# config_params=neox_args.deepspeed_config,
mpu=mpu if not neox_args.is_pipe_parallel else None,
)
mark_norms_for_sequence_parallel_grad_sync(model, neox_args)
if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks":
# We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE.
model.has_moe_layers = True
Expand Down Expand Up @@ -843,6 +838,8 @@ def backward_step(neox_args, timers, optimizer, model, loss):

def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler):
"""Single training step."""
# FIXME: Layer norm weights are stuck at 1.0 when sequence_parallel=True
modules_dict = dict(model.named_modules())

# Pipeline parallelism schedules forward/backward/step
if neox_args.is_pipe_parallel:
Expand Down

0 comments on commit 9945910

Please sign in to comment.