diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 5fe833642..413138597 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 28a5a62 + Default = 53d0ae8 current git hash of repository @@ -1060,9 +1060,9 @@ Parallelism Arguments Default = False - flag to determine whether Megatron-style (https://arxiv.org/abs/2205.05198) Sequence-Parallelism - (sharding acts along seq. dim among TP group for LNs) will be used. Has no effect when model_parallel_size is 1. - **Set by user.** + flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198) + (Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1. + **Set by user, in contrast to neox_args.is_pipe_parallel.** diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index d95c5a3df..f11d9e6ab 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -101,16 +101,31 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim): if get_fp32_allreduce(): input_ = input_.float() - assert input_.shape[seq_dim] % world_size == 0 - tensor_list = list( - torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim) - ) - output = torch.empty_like(tensor_list[0]) - torch.distributed.reduce_scatter(output, tensor_list) + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + assert dim_size[seq_dim] % world_size == 0 + + if seq_dim == 0: + # reduce_scatter_tensor is faster but only works correctly on dimension 0 + dim_size[seq_dim] = dim_size[seq_dim] // world_size + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + tensor_list = list( + torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim) + ) + output = torch.empty_like(tensor_list[0]) + torch.distributed.reduce_scatter(output, tensor_list) # reconvert to original Bf16/Fp16 dtype if get_fp32_allreduce(): - input_ = input_.to(dt) + output = output.to(dt) return output @@ -123,12 +138,29 @@ def _gather_along_seq_dim(input_, seq_dim): if world_size == 1: return input_ - input_ = input_.contiguous() - rank = get_model_parallel_rank() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group()) - output = torch.cat(tensor_list, dim=seq_dim) + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + dim_size[seq_dim] = dim_size[seq_dim] * world_size + + if seq_dim == 0: + # reduce_gather_tensor is faster but only works correctly on dimension 0 + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + input_ = input_.contiguous() + rank = get_model_parallel_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather( + tensor_list, input_, group=get_model_parallel_group() + ) + output = torch.cat(tensor_list, dim=seq_dim) return output