Skip to content

Commit

Permalink
test each tp dim individually set to 2
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Nov 28, 2023
1 parent 8b1a9b2 commit 20d4228
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 44 deletions.
8 changes: 7 additions & 1 deletion configs/125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe_parallel_size": 1,
"model_parallel_size": 1,
"model_parallel_size": 2,

"use_axonn_model_parallelism": true,
## these are the 3 dimensions of AxoNN's TP
"depth_model_parallel_size": 1,
"row_model_parallel_size": 1,
"column_model_parallel_size": 2,

# model settings
"num_layers": 12,
Expand Down
22 changes: 21 additions & 1 deletion megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import deepspeed
import inspect

from axonn import axonn as ax

def initialize_megatron(neox_args, allow_no_cuda=False):
"""Set initialize distributed and set autoresume and random seeds.
Expand Down Expand Up @@ -188,6 +188,26 @@ def _initialize_distributed(neox_args):
fp32_allreduce=neox_args.fp32_allreduce,
)



if neox_args.use_axonn_model_parallelism:
row_mp = neox_args.row_model_parallel_size
column_mp = neox_args.column_model_parallel_size
depth_mp = neox_args.depth_model_parallel_size
assert row_mp * column_mp * depth_mp == neox_args.model_parallel_size, "product of row-model-parallel-size, column-model-parallel-sizem and depth-model-parallel-size should equal model-parallel-size"
ax.init(
G_inter= pp,
G_data = dp,
G_intra_r = neox_args.row_model_parallel_size,
G_intra_c = neox_args.column_model_parallel_size,
G_intra_d = neox_args.depth_model_parallel_size,
)
print(
f"> initialized AxoNN with G_intra_r={neox_args.row_model_parallel_size},"
f"G_intra_c={neox_args.column_model_parallel_size}",
f"G_intra_d={neox_args.depth_model_parallel_size}",
)

# Init DeepSpeed Activation Checkpointing Features
setup_deepspeed_random_and_activation_checkpointing(neox_args=neox_args)

Expand Down
165 changes: 123 additions & 42 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
)
from megatron.model.utils import configure_sparse_attention

from axonn.intra_layer import Linear, drop, gather

# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
Expand Down Expand Up @@ -93,30 +95,57 @@ def __init__(
if self.activation_type == "geglu"
else ff_mult * neox_args.hidden_size
)
self.dense_h_to_4h = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
)
if neox_args.use_axonn_model_parallelism:
self.dense_h_to_4h = Linear(
in_features = neox_args.hidden_size,
out_features = ff_dim,
init_method = init_method,
skip_bias_add = True
)
else:
self.dense_h_to_4h = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
)
ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim_in,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
)

if neox_args.use_axonn_model_parallelism:
self.dense_4h_to_h = Linear(
in_features = ff_dim_in,
out_features = neox_args.hidden_size,
init_method = output_layer_init_method,
skip_bias_add = True,
transpose=True
)
assert not parallel_output, "ToDO: Implement axonn support for parallel_output=True (gpt j residual)"

else:
self.dense_4h_to_h = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim_in,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
)


self.use_axonn_model_parallelism = neox_args.use_axonn_model_parallelism

def forward(self, hidden_states):

# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.use_axonn_model_parallelism:
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states,
scatter_input=False, gather_output=False)
else:
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

if (
self.activation_type == "gelu" and self.bias_gelu_fusion
Expand All @@ -130,7 +159,11 @@ def forward(self, hidden_states):
)

# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
if self.use_axonn_model_parallelism:
output, output_bias = self.dense_4h_to_h(intermediate_parallel,
scatter_input=False, gather_output=False)
else:
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias


Expand Down Expand Up @@ -162,6 +195,9 @@ def __init__(

ff_dim = int(2 * neox_args.hidden_size * 4 / 3)
ff_dim = self.multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)

assert not neox_args.use_axonn_model_parallelism, "ToDo: Implement AxoNN TP for LLaMAParallelMLP"

self.w1 = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
Expand Down Expand Up @@ -275,7 +311,10 @@ def __init__(
self.attention_softmax_in_fp32 = True
self.layer_number = layer_number
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
if neox_args.use_axonn_model_parallelism:
world_size = neox_args.row_model_parallel_size
else:
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size)
self.hidden_size_per_attention_head = mpu.divide(
neox_args.hidden_size, neox_args.num_attention_heads
Expand All @@ -286,14 +325,24 @@ def __init__(
self.pos_emb = neox_args.pos_emb

# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=3 * neox_args.hidden_size,
gather_output=False,
init_method=init_method,
bias=neox_args.use_bias_in_attn_linear,
)
self.use_axonn_model_parallelism = neox_args.use_axonn_model_parallelism
if neox_args.use_axonn_model_parallelism:
self.query_key_value = Linear(
in_features=neox_args.hidden_size,
out_features=3 * neox_args.hidden_size,
init_method=init_method,
bias=neox_args.use_bias_in_attn_linear,
skip_bias_add=True
)
else:
self.query_key_value = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=3 * neox_args.hidden_size,
gather_output=False,
init_method=init_method,
bias=neox_args.use_bias_in_attn_linear,
)

coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
Expand Down Expand Up @@ -377,16 +426,27 @@ def __init__(
self.attention_dropout = nn.Dropout(self.dropout_p)

# Output.
self.dense = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
bias=neox_args.use_bias_in_attn_linear,
)
if neox_args.use_axonn_model_parallelism:
self.dense = Linear(
in_features=neox_args.hidden_size,
out_features=neox_args.hidden_size,
init_method=output_layer_init_method,
skip_bias_add=True,
bias=neox_args.use_bias_in_attn_linear,
transpose=True
)
assert not parallel_output, "ToDO: Implement axonn support for parallel_output=True (gpt j residual)"
else:
self.dense = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
bias=neox_args.use_bias_in_attn_linear,
)

def attention(
self, query_layer, key_layer, value_layer, layer_past, attention_mask
Expand Down Expand Up @@ -625,7 +685,10 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
# =====================

# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
if self.use_axonn_model_parallelism:
mixed_x_layer, _ = self.query_key_value(hidden_states, scatter_input=False, gather_output=False)
else:
mixed_x_layer, _ = self.query_key_value(hidden_states)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
Expand Down Expand Up @@ -710,7 +773,10 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
# Output. [sq, b, h]
# =================

output, bias = self.dense(context_layer)
if self.use_axonn_model_parallelism:
output, bias = self.dense(context_layer, scatter_input=False, gather_output=False)
else:
output, bias = self.dense(context_layer)

if self.use_cache:
output = [output, present]
Expand Down Expand Up @@ -739,11 +805,17 @@ def __init__(

super().__init__()
self.layer_number = layer_number
self.is_first_layer = ( layer_number == 0 )
self.is_last_layer = ( layer_number == neox_args.num_layers - 1 )

norm, eps = get_norm(neox_args)

# Layernorm on the input data.
self.input_layernorm = norm(neox_args.hidden_size, eps=eps)
if neox_args.use_axonn_model_parallelism:
self.input_layernorm = norm(mpu.divide(neox_args.hidden_size,
neox_args.column_model_parallel_size), eps=eps)
else:
self.input_layernorm = norm(neox_args.hidden_size, eps=eps)
self.use_cache = use_cache

self.hidden_dropout = neox_args.hidden_dropout
Expand Down Expand Up @@ -771,7 +843,11 @@ def __init__(
# Layernorm on the output of the attention layer.
# If GPT-J residuals are used, this is surpurfulous but leaving it in
# leads to cleaner code
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps)
if neox_args.use_axonn_model_parallelism:
self.post_attention_layernorm = norm(mpu.divide(neox_args.hidden_size,
neox_args.column_model_parallel_size), eps=eps)
else:
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps)

# MLP
if neox_args.mlp_type == "regular":
Expand Down Expand Up @@ -807,6 +883,9 @@ def _get_bias_dropout(self):
def forward(self, x, attention_mask, layer_past=None):
layer_past = layer_past if layer_past is not None else self.layer_past
bias_dropout_fn = self._get_bias_dropout()

if self.is_first_layer:
x = drop(x, batch_dim=1)
# x: [b, s, h]
if self.gpt_j_residual:
# pseudocode:
Expand Down Expand Up @@ -904,6 +983,8 @@ def forward(self, x, attention_mask, layer_past=None):
prob=self.hidden_dropout,
)

if self.is_last_layer:
output = gather(output, batch_dim=1)
return output


Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class NeoXArgsParallelism(NeoXArgsTemplate):
"""

model_parallel_size: int = 1
use_axonn_model_parallelism: bool = False
row_model_parallel_size: int = 1
column_model_parallel_size: int = 1
depth_model_parallel_size: int = 1

"""
Size of the model parallelism.
"""
Expand Down

0 comments on commit 20d4228

Please sign in to comment.