Skip to content

Commit

Permalink
Fix LayerNorm all reduce gradient hook
Browse files Browse the repository at this point in the history
  • Loading branch information
bclyang committed Aug 12, 2024
1 parent 9c1e7b9 commit 817039b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 29 deletions.
27 changes: 4 additions & 23 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,22 +359,17 @@ def reduce_weight_grads_from_model_parallel_region(input_):
Allreduces grads for e.g. LN weights, across the model parallel group.
Needed to keep LNs in sync, despite them getting diff data -> diff gradients when using sequence parallel.
"""
print("TRIPPED")

# Bypass the function if no TP -> no comm needed.
if mpu.get_model_parallel_world_size() == 1:
return input_

print("TRIPPED HOOK") # these never get printed. We *should* see them go off if we actually ever trigger these hooks....
# a print in mark_norms_... confirms that that function runs and initially adds these hooks, oddly.

# Bf16 convert
dt = input_.dtype
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
input_ = input_.float()

# All-reduce.
torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group())
torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group())
# average grads
input_ = (input_ / mpu.get_model_parallel_world_size())

Expand All @@ -397,22 +392,8 @@ def mark_norms_for_sequence_parallel_grad_sync(module, neox_args):

for module_ in module.modules():
if "norm" in type(module_).__name__.lower():
# this is a norm, we want to allreduce its grads across sequence parallel region
# this is a norm, we want to allreduce its grads across sequence parallel region
for name, param in module_.named_parameters():

if param.requires_grad:

# copying the helper fn that DeepSpeed uses to add hooks, to see if that fixes our issues.
# it did not seem to.
def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]

grad_acc.register_hook(lambda grad: reduce_weight_grads_from_model_parallel_region(grad))

wrapper(param)

# hook which will allreduce weight grads across model parallel (seq. parallel) region
# param.register_hook(reduce_weight_grads_from_model_parallel_region)
param.register_hook(reduce_weight_grads_from_model_parallel_region)

return
15 changes: 9 additions & 6 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
use_cache=use_cache,
)

mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region

if not neox_args.is_pipe_parallel:
base_model = base_model.to_sequential()

Expand All @@ -114,8 +112,6 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
topology=mpu.get_topology(),
use_cache=use_cache,
)

mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region

if not neox_args.is_pipe_parallel:
delta_model = delta_model.to_sequential()
Expand Down Expand Up @@ -692,8 +688,6 @@ def get_optimizer(model, neox_args):
else:
raise ValueError(f"Optimizer type {neox_args.optimizer_type} not recognized")

mark_norms_for_sequence_parallel_grad_sync(model, neox_args) # ensure LN param states remain synced across sequence parallel region

if neox_args.deepspeed:
# fp16 wrapper is not required for DeepSpeed.
return optimizer, param_groups
Expand Down Expand Up @@ -773,6 +767,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
# config_params=neox_args.deepspeed_config,
mpu=mpu if not neox_args.is_pipe_parallel else None,
)
mark_norms_for_sequence_parallel_grad_sync(model, neox_args)
if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks":
# We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE.
model.has_moe_layers = True
Expand Down Expand Up @@ -843,6 +838,14 @@ def backward_step(neox_args, timers, optimizer, model, loss):

def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler):
"""Single training step."""
# FIXME: Layer norm weights are stuck at 1.0 when sequence_parallel=True
modules_dict = dict(model.named_modules())
print(f"========= START RANK: {torch.distributed.get_rank()} =========")
print(f"RANK: {torch.distributed.get_rank()}: module.sequential.8.input_layernorm.weight = {modules_dict['module.sequential.8.input_layernorm'].weight.detach().mean()}")
print(f"RANK: {torch.distributed.get_rank()}: module.sequential.8.input_layernorm.bias = {modules_dict['module.sequential.8.input_layernorm'].bias.detach().mean()}")
print(f"RANK: {torch.distributed.get_rank()}: module.sequential.8.post_attention_layernorm.weight = {modules_dict['module.sequential.8.post_attention_layernorm'].weight.detach().mean()}")
print(f"RANK: {torch.distributed.get_rank()}: module.sequential.8.post_attention_layernorm.bias = {modules_dict['module.sequential.8.post_attention_layernorm'].bias.detach().mean()}")
print(f"========= END RANK: {torch.distributed.get_rank()} =========")

# Pipeline parallelism schedules forward/backward/step
if neox_args.is_pipe_parallel:
Expand Down

0 comments on commit 817039b

Please sign in to comment.