Skip to content

Commit

Permalink
run linter on the rest of the files
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Aug 19, 2024
1 parent ab11a6a commit f26b886
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
5 changes: 4 additions & 1 deletion megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@
# limitations under the License.

from .gpt2_model import GPT2ModelPipe
from .utils import get_params_for_weight_decay_optimization, mark_norms_for_sequence_parallel_grad_sync
from .utils import (
get_params_for_weight_decay_optimization,
mark_norms_for_sequence_parallel_grad_sync,
)
from .word_embeddings import SoftEmbedding
19 changes: 12 additions & 7 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def __init__(
MOE=False,
MoE_mp_size=1,
mup_rescale_parameters=False,
seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout.
seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout.
):
super(ColumnParallelLinear, self).__init__()

Expand All @@ -430,7 +430,7 @@ def __init__(
world_size = MoE_mp_size if MOE else get_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add

self.sequence_parallel = neox_args.sequence_parallel
self.seq_dim = seq_dim

Expand Down Expand Up @@ -558,7 +558,7 @@ def set_parallel_output(self, value: bool):
def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
input_ /= self.width_mult()

if self.sequence_parallel:
input_parallel = input_
else:
Expand All @@ -570,13 +570,17 @@ def forward(self, input_):
# do an AG in the fwd pass, RS in bwd pass.
# gather / scatter portion happens across the sequence dim (self.seq_dim)--
# almost always is [s, b, h] and so dim 0, but for lm_head ParallelLinear it is seq_dim=1 and [b, s, h]
input_parallel = gather_from_sequence_parallel_region(input_parallel, seq_dim=self.seq_dim)
input_parallel = gather_from_sequence_parallel_region(
input_parallel, seq_dim=self.seq_dim
)

bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
assert not self.sequence_parallel, "sequence_parallel=True and gather_output=True are incompatible!"
assert (
not self.sequence_parallel
), "sequence_parallel=True and gather_output=True are incompatible!"
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
Expand Down Expand Up @@ -643,8 +647,9 @@ def __init__(
self.parallel_output = parallel_output

self.sequence_parallel = neox_args.sequence_parallel
assert not (self.sequence_parallel and not self.input_is_parallel), "Cannot have self.input_is_parallel=False and self.sequence_parallel=True."

assert not (
self.sequence_parallel and not self.input_is_parallel
), "Cannot have self.input_is_parallel=False and self.sequence_parallel=True."

self.init_method = init_method
self.stride = stride
Expand Down
4 changes: 3 additions & 1 deletion megatron/mpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=
return tensor_list


def split_tensor_along_any_dim(tensor, num_partitions, seq_dim, contiguous_split_chunks=False):
def split_tensor_along_any_dim(
tensor, num_partitions, seq_dim, contiguous_split_chunks=False
):
"""Split a tensor along a user-specified dimension.
Arguments:
tensor: input tensor.
Expand Down
4 changes: 3 additions & 1 deletion megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,9 @@ def calculate_derived(self):
assert self.zero_optimization["stage"] != 3, "MoE not compatible with zero3"
assert self.mlp_type == "regular", "MoE not compatible with LLaMA"

assert self.sequence_parallel is False, "MoE not compatible with Sequence Parallel"
assert (
self.sequence_parallel is False
), "MoE not compatible with Sequence Parallel"

# Attention config
if self.attention_config is None:
Expand Down

0 comments on commit f26b886

Please sign in to comment.