diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 889f8b18c..23be28936 100755 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -16,5 +16,8 @@ # limitations under the License. from .gpt2_model import GPT2ModelPipe -from .utils import get_params_for_weight_decay_optimization, mark_norms_for_sequence_parallel_grad_sync +from .utils import ( + get_params_for_weight_decay_optimization, + mark_norms_for_sequence_parallel_grad_sync, +) from .word_embeddings import SoftEmbedding diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 3af7f94c4..d59edab94 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -418,7 +418,7 @@ def __init__( MOE=False, MoE_mp_size=1, mup_rescale_parameters=False, - seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout. + seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout. ): super(ColumnParallelLinear, self).__init__() @@ -430,7 +430,7 @@ def __init__( world_size = MoE_mp_size if MOE else get_model_parallel_world_size() self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add - + self.sequence_parallel = neox_args.sequence_parallel self.seq_dim = seq_dim @@ -558,7 +558,7 @@ def set_parallel_output(self, value: bool): def forward(self, input_): if self.use_mup and self.mup_rescale_parameters: input_ /= self.width_mult() - + if self.sequence_parallel: input_parallel = input_ else: @@ -570,13 +570,17 @@ def forward(self, input_): # do an AG in the fwd pass, RS in bwd pass. # gather / scatter portion happens across the sequence dim (self.seq_dim)-- # almost always is [s, b, h] and so dim 0, but for lm_head ParallelLinear it is seq_dim=1 and [b, s, h] - input_parallel = gather_from_sequence_parallel_region(input_parallel, seq_dim=self.seq_dim) + input_parallel = gather_from_sequence_parallel_region( + input_parallel, seq_dim=self.seq_dim + ) bias = self.bias if not self.skip_bias_add else None output_parallel = F.linear(input_parallel, self.weight, bias) if self.gather_output: # All-gather across the partitions. - assert not self.sequence_parallel, "sequence_parallel=True and gather_output=True are incompatible!" + assert ( + not self.sequence_parallel + ), "sequence_parallel=True and gather_output=True are incompatible!" output = gather_from_model_parallel_region(output_parallel) else: output = output_parallel @@ -643,8 +647,9 @@ def __init__( self.parallel_output = parallel_output self.sequence_parallel = neox_args.sequence_parallel - assert not (self.sequence_parallel and not self.input_is_parallel), "Cannot have self.input_is_parallel=False and self.sequence_parallel=True." - + assert not ( + self.sequence_parallel and not self.input_is_parallel + ), "Cannot have self.input_is_parallel=False and self.sequence_parallel=True." self.init_method = init_method self.stride = stride diff --git a/megatron/mpu/utils.py b/megatron/mpu/utils.py index 9e9bb6151..1f97e0e76 100644 --- a/megatron/mpu/utils.py +++ b/megatron/mpu/utils.py @@ -53,7 +53,9 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks= return tensor_list -def split_tensor_along_any_dim(tensor, num_partitions, seq_dim, contiguous_split_chunks=False): +def split_tensor_along_any_dim( + tensor, num_partitions, seq_dim, contiguous_split_chunks=False +): """Split a tensor along a user-specified dimension. Arguments: tensor: input tensor. diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 6803870a3..6a84df6c7 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1038,7 +1038,9 @@ def calculate_derived(self): assert self.zero_optimization["stage"] != 3, "MoE not compatible with zero3" assert self.mlp_type == "regular", "MoE not compatible with LLaMA" - assert self.sequence_parallel is False, "MoE not compatible with Sequence Parallel" + assert ( + self.sequence_parallel is False + ), "MoE not compatible with Sequence Parallel" # Attention config if self.attention_config is None: