From 8e7400f48d7e39199c7e29fe6553cf2ff42161ae Mon Sep 17 00:00:00 2001 From: Brandon Yang Date: Thu, 22 Aug 2024 00:58:28 -0700 Subject: [PATCH 1/3] Improve performance of sequence parallel gather, scatter, and reduce --- megatron/mpu/mappings.py | 56 ++++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index d95c5a3df..2d648c959 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -101,16 +101,30 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim): if get_fp32_allreduce(): input_ = input_.float() - assert input_.shape[seq_dim] % world_size == 0 - tensor_list = list( - torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim) - ) - output = torch.empty_like(tensor_list[0]) - torch.distributed.reduce_scatter(output, tensor_list) + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + assert dim_size[seq_dim] % world_size == 0 + + if seq_dim == 0: + dim_size[seq_dim] = dim_size[seq_dim] // world_size + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + tensor_list = list( + torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim) + ) + output = torch.empty_like(tensor_list[0]) + torch.distributed.reduce_scatter(output, tensor_list) # reconvert to original Bf16/Fp16 dtype if get_fp32_allreduce(): - input_ = input_.to(dt) + output = output.to(dt) return output @@ -123,12 +137,28 @@ def _gather_along_seq_dim(input_, seq_dim): if world_size == 1: return input_ - input_ = input_.contiguous() - rank = get_model_parallel_rank() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group()) - output = torch.cat(tensor_list, dim=seq_dim) + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + dim_size[seq_dim] = dim_size[seq_dim] * world_size + + if seq_dim == 0: + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + input_ = input_.contiguous() + rank = get_model_parallel_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather( + tensor_list, input_, group=get_model_parallel_group() + ) + output = torch.cat(tensor_list, dim=seq_dim) return output From 53d0ae8b79b08881025aca8a67400d8ae345c684 Mon Sep 17 00:00:00 2001 From: Brandon Yang Date: Thu, 22 Aug 2024 01:07:40 -0700 Subject: [PATCH 2/3] Add comment --- megatron/mpu/mappings.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index 2d648c959..f11d9e6ab 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -108,6 +108,7 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim): assert dim_size[seq_dim] % world_size == 0 if seq_dim == 0: + # reduce_scatter_tensor is faster but only works correctly on dimension 0 dim_size[seq_dim] = dim_size[seq_dim] // world_size output = torch.empty( dim_size, dtype=input_.dtype, device=torch.cuda.current_device() @@ -144,6 +145,7 @@ def _gather_along_seq_dim(input_, seq_dim): dim_size[seq_dim] = dim_size[seq_dim] * world_size if seq_dim == 0: + # reduce_gather_tensor is faster but only works correctly on dimension 0 output = torch.empty( dim_size, dtype=input_.dtype, device=torch.cuda.current_device() ) From 05f5cec2ce2a00ace8343236dd5af9ce497b5a4b Mon Sep 17 00:00:00 2001 From: github-actions Date: Thu, 22 Aug 2024 08:08:11 +0000 Subject: [PATCH 3/3] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 5fe833642..413138597 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 28a5a62 + Default = 53d0ae8 current git hash of repository @@ -1060,9 +1060,9 @@ Parallelism Arguments Default = False - flag to determine whether Megatron-style (https://arxiv.org/abs/2205.05198) Sequence-Parallelism - (sharding acts along seq. dim among TP group for LNs) will be used. Has no effect when model_parallel_size is 1. - **Set by user.** + flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198) + (Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1. + **Set by user, in contrast to neox_args.is_pipe_parallel.**