Skip to content

Commit

Permalink
Fix gather and reduce scatter ops on sequence dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
bclyang committed Aug 17, 2024
1 parent 2c5dc5a commit 9d883de
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
51 changes: 26 additions & 25 deletions megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,18 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim):
if dt == torch.bfloat16 and get_fp32_allreduce():
input_ = input_.float()


dim_size = list(input_.size())
assert isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0, "seq_dim must be a valid tensor dim"
assert (
dim_size[seq_dim] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
dim_size[seq_dim] = dim_size[seq_dim] // world_size

output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=get_model_parallel_group()
assert input_.shape[seq_dim] % world_size == 0
tensor_list = list(
torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim)
)
output = torch.empty_like(tensor_list[0])
torch.distributed.reduce_scatter(output, tensor_list)

# Bf16 convert
if dt == torch.bfloat16 and get_fp32_allreduce():
output = output.bfloat16() # TODO: this might screw up if we wanna do async comms w/ this
output = (
output.bfloat16()
) # TODO: this might screw up if we wanna do async comms w/ this

return output

Expand All @@ -144,18 +140,18 @@ def _gather_along_seq_dim(input_, seq_dim):
if dt == torch.bfloat16 and get_fp32_allreduce():
input_ = input_.float()

dim_size = list(input_.size())
assert isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0, "seq_dim must be a valid tensor dim"
dim_size[seq_dim] = dim_size[seq_dim] * world_size

output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=get_model_parallel_group()
)
input_ = input_.contiguous()
rank = get_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())
output = torch.cat(tensor_list, dim=seq_dim)

# Bf16 convert
if dt == torch.bfloat16 and get_fp32_allreduce():
output = output.bfloat16() # TODO: this might screw up if we wanna do async comms w/ this
output = (
output.bfloat16()
) # TODO: this might screw up if we wanna do async comms w/ this

return output

Expand Down Expand Up @@ -244,7 +240,7 @@ def backward(ctx, grad_output):


class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce-Scatter across sequence parallel region (same as model parallel region.)
"""Reduce-Scatter across sequence parallel region (same as model parallel region.)
TODO: rename to use ModelParallelRegion? There is not really a separate "SequenceParallelRegion" vs. "ModelParallelRegion"
"""

Expand Down Expand Up @@ -275,7 +271,7 @@ def symbolic(graph, input_, seq_dim):
@staticmethod
def forward(ctx, input_, seq_dim):
ctx.seq_dim = seq_dim
return _gather_along_seq_dim(input_, seq_dim=seq_dim) # TODO: check this
return _gather_along_seq_dim(input_, seq_dim=seq_dim) # TODO: check this

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -298,7 +294,10 @@ def forward(ctx, input_, seq_dim):
@staticmethod
def backward(ctx, grad_output):
seq_dim = ctx.seq_dim
return _gather_along_seq_dim(grad_output, seq_dim=seq_dim), None # TODO: triple-check this is the right bwd
return (
_gather_along_seq_dim(grad_output, seq_dim=seq_dim),
None,
) # TODO: triple-check this is the right bwd


# -----------------
Expand Down Expand Up @@ -330,5 +329,7 @@ def gather_from_sequence_parallel_region(input_, seq_dim=0):
return _GatherFromSequenceParallelRegion.apply(input_, seq_dim)


def scatter_to_sequence_parallel_region(input_, seq_dim=1): # use this fn in scattering input embeds across TP ranks. There, shape of inps is [b, s, h]
def scatter_to_sequence_parallel_region(
input_, seq_dim=1
): # use this fn in scattering input embeds across TP ranks. There, shape of inps is [b, s, h]
return _ScatterToSequenceParallelRegion.apply(input_, seq_dim)
1 change: 0 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,6 @@ def backward_step(neox_args, timers, optimizer, model, loss):

def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler):
"""Single training step."""
# FIXME: Layer norm weights are stuck at 1.0 when sequence_parallel=True
modules_dict = dict(model.named_modules())

# Pipeline parallelism schedules forward/backward/step
Expand Down

0 comments on commit 9d883de

Please sign in to comment.