Skip to content

Commit

Permalink
Merge pull request #1263 from EleutherAI/improve-seq-parallel-perf
Browse files Browse the repository at this point in the history
Improve performance of sequence parallel gather, scatter, and reduce
  • Loading branch information
haileyschoelkopf authored Aug 22, 2024
2 parents f26b886 + 05f5cec commit 1661db6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
8 changes: 4 additions & 4 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 28a5a62
Default = 53d0ae8

current git hash of repository

Expand Down Expand Up @@ -1060,9 +1060,9 @@ Parallelism Arguments
Default = False
flag to determine whether Megatron-style (https://arxiv.org/abs/2205.05198) Sequence-Parallelism
(sharding acts along seq. dim among TP group for LNs) will be used. Has no effect when model_parallel_size is 1.
**Set by user.**
flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198)
(Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1.
**Set by user, in contrast to neox_args.is_pipe_parallel.**
Expand Down
58 changes: 45 additions & 13 deletions megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,31 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim):
if get_fp32_allreduce():
input_ = input_.float()

assert input_.shape[seq_dim] % world_size == 0
tensor_list = list(
torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim)
)
output = torch.empty_like(tensor_list[0])
torch.distributed.reduce_scatter(output, tensor_list)
dim_size = list(input_.size())
assert (
isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0
), "seq_dim must be a valid tensor dim"
assert dim_size[seq_dim] % world_size == 0

if seq_dim == 0:
# reduce_scatter_tensor is faster but only works correctly on dimension 0
dim_size[seq_dim] = dim_size[seq_dim] // world_size
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=get_model_parallel_group()
)
else:
tensor_list = list(
torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim)
)
output = torch.empty_like(tensor_list[0])
torch.distributed.reduce_scatter(output, tensor_list)

# reconvert to original Bf16/Fp16 dtype
if get_fp32_allreduce():
input_ = input_.to(dt)
output = output.to(dt)

return output

Expand All @@ -123,12 +138,29 @@ def _gather_along_seq_dim(input_, seq_dim):
if world_size == 1:
return input_

input_ = input_.contiguous()
rank = get_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())
output = torch.cat(tensor_list, dim=seq_dim)
dim_size = list(input_.size())
assert (
isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0
), "seq_dim must be a valid tensor dim"
dim_size[seq_dim] = dim_size[seq_dim] * world_size

if seq_dim == 0:
# reduce_gather_tensor is faster but only works correctly on dimension 0
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=get_model_parallel_group()
)
else:
input_ = input_.contiguous()
rank = get_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(
tensor_list, input_, group=get_model_parallel_group()
)
output = torch.cat(tensor_list, dim=seq_dim)

return output

Expand Down

0 comments on commit 1661db6

Please sign in to comment.