Skip to content

Commit

Permalink
Merge pull request #1260 from EleutherAI/fix-seq-dim-reducegatherdactter
Browse files Browse the repository at this point in the history
Fix gather and reduce scatter ops on sequence dimension
  • Loading branch information
haileyschoelkopf authored Aug 19, 2024
2 parents 2c5dc5a + 5427d9d commit b0d9398
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 34 deletions.
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 9945910
Default = 28a5a62

current git hash of repository

Expand Down
5 changes: 4 additions & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ def _logits_helper(embedding, lm_output):
)

logits = parallel_lm_logits(
lm_output, embedding.word_embeddings_weight, self.parallel_output
lm_output,
embedding.word_embeddings_weight,
self.parallel_output,
seq_parallel=self.neox_args.sequence_parallel,
)
return logits

Expand Down
29 changes: 23 additions & 6 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False, # TODO: add a parallel-input check? need to AG in fwd pass to reshard, here?
gather_output=False, # TODO: add a parallel-input check? need to AG in fwd pass to reshard, here?
init_method=init_method,
skip_bias_add=True,
MOE=MOE,
Expand All @@ -127,7 +127,7 @@ def __init__(
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
parallel_output=parallel_output, # seqpar should do parallel_output?
parallel_output=parallel_output, # seqpar should do parallel_output?
skip_bias_add=True,
MOE=MOE,
MoE_mp_size=MoE_mp_size,
Expand Down Expand Up @@ -254,7 +254,7 @@ def __init__(
gather_output=not parallel_output,
skip_bias_add=False,
mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here
seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1
seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1
)

# else:
Expand Down Expand Up @@ -1028,7 +1028,11 @@ def __init__(
# GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers.
# the reduction we use is a simple allreduce for pure Tensor Parallel,
# but needs to be a reduce-scatter when using sequence parallel (LN sharding.)
self.reduce = mpu.mappings.reduce_from_model_parallel_region if not neox_args.sequence_parallel else mpu.mappings.reduce_scatter_to_sequence_parallel_region
self.reduce = (
mpu.mappings.reduce_from_model_parallel_region
if not neox_args.sequence_parallel
else mpu.mappings.reduce_scatter_to_sequence_parallel_region
)

# Self attention.
self.attention = ParallelSelfAttention(
Expand Down Expand Up @@ -1343,10 +1347,23 @@ def forward(self, args):
return self.norm(args)


def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
def parallel_lm_logits(
input_,
word_embeddings_weight,
parallel_output,
seq_parallel=False,
seq_dim=1,
bias=None,
):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_model_parallel_region(input_)
if seq_parallel:
input_parallel = mpu.gather_from_sequence_parallel_region(
input_, seq_dim=seq_dim
)
else:
# Set up backprop all-reduce.
input_parallel = mpu.copy_to_model_parallel_region(input_)

# Matrix multiply.
if bias is None:
Expand Down
51 changes: 26 additions & 25 deletions megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,18 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim):
if dt == torch.bfloat16 and get_fp32_allreduce():
input_ = input_.float()


dim_size = list(input_.size())
assert isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0, "seq_dim must be a valid tensor dim"
assert (
dim_size[seq_dim] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
dim_size[seq_dim] = dim_size[seq_dim] // world_size

output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=get_model_parallel_group()
assert input_.shape[seq_dim] % world_size == 0
tensor_list = list(
torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim)
)
output = torch.empty_like(tensor_list[0])
torch.distributed.reduce_scatter(output, tensor_list)

# Bf16 convert
if dt == torch.bfloat16 and get_fp32_allreduce():
output = output.bfloat16() # TODO: this might screw up if we wanna do async comms w/ this
output = (
output.bfloat16()
) # TODO: this might screw up if we wanna do async comms w/ this

return output

Expand All @@ -144,18 +140,18 @@ def _gather_along_seq_dim(input_, seq_dim):
if dt == torch.bfloat16 and get_fp32_allreduce():
input_ = input_.float()

dim_size = list(input_.size())
assert isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0, "seq_dim must be a valid tensor dim"
dim_size[seq_dim] = dim_size[seq_dim] * world_size

output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=get_model_parallel_group()
)
input_ = input_.contiguous()
rank = get_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())
output = torch.cat(tensor_list, dim=seq_dim)

# Bf16 convert
if dt == torch.bfloat16 and get_fp32_allreduce():
output = output.bfloat16() # TODO: this might screw up if we wanna do async comms w/ this
output = (
output.bfloat16()
) # TODO: this might screw up if we wanna do async comms w/ this

return output

Expand Down Expand Up @@ -244,7 +240,7 @@ def backward(ctx, grad_output):


class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce-Scatter across sequence parallel region (same as model parallel region.)
"""Reduce-Scatter across sequence parallel region (same as model parallel region.)
TODO: rename to use ModelParallelRegion? There is not really a separate "SequenceParallelRegion" vs. "ModelParallelRegion"
"""

Expand Down Expand Up @@ -275,7 +271,7 @@ def symbolic(graph, input_, seq_dim):
@staticmethod
def forward(ctx, input_, seq_dim):
ctx.seq_dim = seq_dim
return _gather_along_seq_dim(input_, seq_dim=seq_dim) # TODO: check this
return _gather_along_seq_dim(input_, seq_dim=seq_dim) # TODO: check this

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -298,7 +294,10 @@ def forward(ctx, input_, seq_dim):
@staticmethod
def backward(ctx, grad_output):
seq_dim = ctx.seq_dim
return _gather_along_seq_dim(grad_output, seq_dim=seq_dim), None # TODO: triple-check this is the right bwd
return (
_gather_along_seq_dim(grad_output, seq_dim=seq_dim),
None,
) # TODO: triple-check this is the right bwd


# -----------------
Expand Down Expand Up @@ -330,5 +329,7 @@ def gather_from_sequence_parallel_region(input_, seq_dim=0):
return _GatherFromSequenceParallelRegion.apply(input_, seq_dim)


def scatter_to_sequence_parallel_region(input_, seq_dim=1): # use this fn in scattering input embeds across TP ranks. There, shape of inps is [b, s, h]
def scatter_to_sequence_parallel_region(
input_, seq_dim=1
): # use this fn in scattering input embeds across TP ranks. There, shape of inps is [b, s, h]
return _ScatterToSequenceParallelRegion.apply(input_, seq_dim)
1 change: 0 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,6 @@ def backward_step(neox_args, timers, optimizer, model, loss):

def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler):
"""Single training step."""
# FIXME: Layer norm weights are stuck at 1.0 when sequence_parallel=True
modules_dict = dict(model.named_modules())

# Pipeline parallelism schedules forward/backward/step
Expand Down

0 comments on commit b0d9398

Please sign in to comment.