From 0d4bdb965e3bd9eed6df8f5a4f024f9ff310601c Mon Sep 17 00:00:00 2001 From: dtamayo <119006120+dtamayo-nlp@users.noreply.github.com> Date: Sat, 7 Sep 2024 06:17:14 +0200 Subject: [PATCH] Add `intermediate_size` to GPT-NeoX models (#1212) * Update transformer.py -> Add `intermediate_size` * add support for rwkv and mamba and add todos about swiglu * refactor activations and mlps * change llama config to swiglu * fixes gelu fusion * pre-commit run * add assert message to mamba linear * Update 1-3B.yml revert accidental change * Update 1-3B.yml * fixes various issues * add back swiglu check --------- Co-authored-by: jahatef Co-authored-by: Quentin Anthony Co-authored-by: Jacob Hatef <74274091+jahatef@users.noreply.github.com> --- configs/llama/13B.yml | 2 +- configs/llama/30B.yml | 2 +- configs/llama/65B.yml | 2 +- configs/llama/7B.yml | 2 +- megatron/data/helpers.cpp | 12 +- megatron/model/activations.py | 38 +++--- megatron/model/gmlp.py | 2 +- megatron/model/mamba/mamba.py | 9 +- megatron/model/rwkv/v6/rwkv.py | 17 ++- megatron/model/transformer.py | 167 +++++++++------------------ megatron/neox_arguments/neox_args.py | 23 +++- 11 files changed, 117 insertions(+), 159 deletions(-) diff --git a/configs/llama/13B.yml b/configs/llama/13B.yml index 305567be1..7a823a43c 100644 --- a/configs/llama/13B.yml +++ b/configs/llama/13B.yml @@ -22,5 +22,5 @@ "use_bias_in_norms": false, "use_bias_in_attn_linear": false, "mlp_type": "llama", - "activation": "silu", + "activation": "swiglu", } diff --git a/configs/llama/30B.yml b/configs/llama/30B.yml index 450f8da38..2c356cea2 100644 --- a/configs/llama/30B.yml +++ b/configs/llama/30B.yml @@ -22,5 +22,5 @@ "use_bias_in_norms": false, "use_bias_in_attn_linear": false, "mlp_type": "llama", - "activation": "silu", + "activation": "swiglu", } diff --git a/configs/llama/65B.yml b/configs/llama/65B.yml index 85f199ce2..cc22d3734 100644 --- a/configs/llama/65B.yml +++ b/configs/llama/65B.yml @@ -22,5 +22,5 @@ "use_bias_in_norms": false, "use_bias_in_attn_linear": false, "mlp_type": "llama", - "activation": "silu", + "activation": "swiglu", } diff --git a/configs/llama/7B.yml b/configs/llama/7B.yml index ecbf187a8..0b134ae27 100644 --- a/configs/llama/7B.yml +++ b/configs/llama/7B.yml @@ -22,5 +22,5 @@ "use_bias_in_norms": false, "use_bias_in_attn_linear": false, "mlp_type": "llama", - "activation": "silu", + "activation": "swiglu", } diff --git a/megatron/data/helpers.cpp b/megatron/data/helpers.cpp index aca290854..9b062b050 100644 --- a/megatron/data/helpers.cpp +++ b/megatron/data/helpers.cpp @@ -428,9 +428,9 @@ py::array build_mapping_impl(const py::array_t& docs_, } } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { if (!second) { if (verbose) { @@ -660,9 +660,9 @@ py::array build_blocks_mapping_impl(const py::array_t& docs_, num_sent = 0; } } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { if (!second) { if (verbose) { diff --git a/megatron/model/activations.py b/megatron/model/activations.py index 7a29b0716..c0b825261 100644 --- a/megatron/model/activations.py +++ b/megatron/model/activations.py @@ -25,9 +25,23 @@ def get_activation(neox_args): - """retrieves the activation function specified in neox_args""" + """retrieves the activation function specified in neox_args and whether or not the activation is gated""" + is_gated = False if neox_args.activation == "geglu": - activation_func = GEGLU(neox_args=neox_args) + is_gated = True + activation_func = F.gelu + elif neox_args.activation == "reglu": + is_gated = True + activation_func = F.relu + elif neox_args.activation == "bilinear": + is_gated = True + activation_func = lambda x: x + elif neox_args.activation == "swiglu": + is_gated = True + activation_func = swish + elif neox_args.activation == "glu": + is_gated = True + activation_func = F.sigmoid elif neox_args.activation == "gelu": if neox_args.onnx_safe and neox_args.bias_gelu_fusion: raise ValueError("onnx_safe + bias_gelu_fusion not compatible") @@ -49,7 +63,7 @@ def get_activation(neox_args): activation_func = F.silu else: raise ValueError(f"Activation function {neox_args.activation} not recognized") - return activation_func + return activation_func, is_gated ###### BIAS GELU FUSION/ NO AUTOGRAD ################ @@ -119,21 +133,3 @@ def swish(x, beta: float = 1.0): @torch.jit.script def mish(x): return x * torch.tanh(F.softplus(x)) - - -class GEGLU(torch.nn.Module): - def __init__(self, neox_args): - super(GEGLU, self).__init__() - if neox_args.onnx_safe: - self.activation_func = erf_gelu - else: - self.activation_func = F.gelu - - def forward(self, x, bias=None): - x, gate = x.chunk(2, dim=-1) - if bias is not None: - bias_1, bias_2 = bias.chunk(2, dim=-1) - x = x + bias_1 - gate = gate + bias_2 - intermediate_parallel = self.activation_func(gate) - return intermediate_parallel * x diff --git a/megatron/model/gmlp.py b/megatron/model/gmlp.py index c3462c651..6400640bd 100644 --- a/megatron/model/gmlp.py +++ b/megatron/model/gmlp.py @@ -112,7 +112,7 @@ def __init__( init_method=init_method, skip_bias_add=True, ) - self.activation_func = get_activation(neox_args) + self.activation_func, _ = get_activation(neox_args) ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size()) if neox_args.attention_config[layer_number] == "amlp": d_attn = neox_args.gmlp_attn_dim diff --git a/megatron/model/mamba/mamba.py b/megatron/model/mamba/mamba.py index 3177267cb..b3d9e1549 100644 --- a/megatron/model/mamba/mamba.py +++ b/megatron/model/mamba/mamba.py @@ -44,12 +44,17 @@ def __init__( neox_args.mamba_use_bias_in_linears and neox_args.mamba_inner_func_fusion ), "Mamba fused inner fn and bias in x_proj not compatible!" + assert neox_args.intermediate_size == None or neox_args.expansion_factor == None, "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" + # set variables, mostly following mamba defaults self.d_model = neox_args.hidden_size self.d_state = 16 # state dimensions per channel self.d_conv = 4 # convolution width - self.expand = 2 # linear projection expansion factors - self.d_inner = int(self.expand * self.d_model) + if neox_args.intermediate_size: + self.d_inner = neox_args.intermediate_size + else: + self.expand = neox_args.expansion_factor if neox_args.expansion_factor else 2 + self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) # rank of dt / Delta parameter self.dt_scale = 1.0 diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 5d4e0d144..ec8cc1aa6 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -247,11 +247,11 @@ def __init__(self, neox_args, layer_number): self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False) self.receptance = nn.Linear( neox_args.hidden_size, neox_args.hidden_size, bias=False ) - self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) + self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False) def forward(self, x): xx = self.time_shift(x) - x @@ -275,14 +275,19 @@ def __init__(self, neox_args, layer_number): self.layer_number = layer_number self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" + assert neox_args.intermediate_size == None or neox_args.expansion_factor == None, "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" if not hasattr(neox_args, "dim_att"): neox_args.dim_att = neox_args.hidden_size - if not hasattr(neox_args, "dim_ffn"): - # Make hidden size 3.5x. Round to nearest multiple of 32 until we add hdim rounding logic - neox_args.dim_ffn = int((neox_args.hidden_size * 3.5) // 32 * 32) + if neox_args.intermediate_size: + neox_args.ffn_dim = neox_args.intermediate_size + else: + self.expand = neox_args.expansion_factor if neox_args.expansion_factor else 3.5 + neox_args.ffn_dim = int(self.expand * neox_args.hidden_size) + # Make hidden size 3.5x by default. Round to nearest multiple of 32 until we add hdim rounding logic + neox_args.ffn_dim = int(neox_args.ffn_dim // 32 * 32) assert neox_args.hidden_size % 32 == 0 assert neox_args.dim_att % 32 == 0 - assert neox_args.dim_ffn % 32 == 0 + assert neox_args.ffn_dim % 32 == 0 self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads self.head_size = self.neox_args.head_size self.num_attention_heads = neox_args.num_attention_heads diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 62e7d3a9c..119676c54 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -93,37 +93,55 @@ def __init__( init_method, output_layer_init_method, parallel_output=False, + multiple_of=256, MOE=False, MoE_mp_size=1, ): super().__init__() + assert neox_args.intermediate_size == None or neox_args.expansion_factor == None, "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" - self.activation_func = get_activation(neox_args) + self.activation_func, self.is_gated = get_activation(neox_args) self.activation_type = neox_args.activation self.bias_gelu_fusion = neox_args.bias_gelu_fusion + self.multiple_of = multiple_of - # auto scale so geglu has equal parameters - ff_mult = int(4 * 2 / 3) if self.activation_type == "geglu" else 4 - ff_dim = ( - int(ff_mult * neox_args.hidden_size) * 2 - if self.activation_type == "geglu" - else ff_mult * neox_args.hidden_size + if neox_args.intermediate_size: + ffn_dim = neox_args.intermediate_size + elif neox_args.expansion_factor: + ffn_dim = int(neox_args.expansion_factor * neox_args.hidden_size) + else: + # 4h is default for ffn_dim + ffn_dim = 4 * neox_args.hidden_size + ffn_dim_in = ffn_dim + if self.is_gated: + # set activation function to be gated implementation + self.activation_func = Gated_Activation(self.activation_func) + # auto scale so gated activations has equal parameters + ffn_dim = int(ffn_dim * 2 / 3) + ffn_dim_in = ffn_dim // 2 + # set multiple + ffn_dim = int( + (2 * self.multiple_of) + * ((ffn_dim + (2 * multiple_of) - 1) // (2 * multiple_of)) + ) + ffn_dim_in = int( + self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) ) - self.dense_h_to_4h = mpu.ColumnParallelLinear( + + self.linear1 = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, - output_size=ff_dim, + output_size=ffn_dim, gather_output=False, init_method=init_method, skip_bias_add=True, MOE=MOE, MoE_mp_size=MoE_mp_size, ) - ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim # Project back to h. - self.dense_4h_to_h = mpu.RowParallelLinear( + self.linear2 = mpu.RowParallelLinear( neox_args=neox_args, - input_size=ff_dim_in, + input_size=ffn_dim_in, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, @@ -134,13 +152,10 @@ def __init__( ) def forward(self, hidden_states): + # [s, b, intermediate_size] + intermediate_parallel, bias_parallel = self.linear1(hidden_states) - # [s, b, 4hp] - intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) - - if ( - self.activation_type == "gelu" and self.bias_gelu_fusion - ) or self.activation_type == "geglu": + if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): intermediate_parallel = self.activation_func( intermediate_parallel, bias_parallel ) @@ -150,84 +165,23 @@ def forward(self, hidden_states): ) # [s, b, h] - output, output_bias = self.dense_4h_to_h(intermediate_parallel) + output, output_bias = self.linear2(intermediate_parallel) return output, output_bias -class LLaMAParallelMLP(nn.Module): - """LLaMA's MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. At the end, dropout is also - applied. - - Note: multiple_of is used to compute the hidden dimension of the MLP - """ - - def __init__( - self, - neox_args, - init_method, - output_layer_init_method, - parallel_output=False, - multiple_of=256, - MOE=False, - MoE_mp_size=1, - ): +class Gated_Activation(torch.nn.Module): + def __init__(self, activation_func): super().__init__() + self.activation_func = activation_func - self.activation_func = get_activation(neox_args) - self.activation_type = neox_args.activation - - self.multiple_of = multiple_of - - # Allow custom intermediate size, e.g. for Mistral - if neox_args.intermediate_size is not None: - ff_dim = neox_args.intermediate_size - else: - ff_dim = int(2 * neox_args.hidden_size * 4 / 3) - ff_dim = self.multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) - - self.w1 = mpu.ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=ff_dim, - gather_output=False, - init_method=init_method, - skip_bias_add=True, - bias=False, - MOE=MOE, - MoE_mp_size=MoE_mp_size, - ) - self.w3 = mpu.ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=ff_dim, - gather_output=False, - init_method=init_method, - skip_bias_add=True, - bias=False, - MOE=MOE, - MoE_mp_size=MoE_mp_size, - ) - self.w2 = mpu.RowParallelLinear( - neox_args=neox_args, - input_size=ff_dim, - output_size=neox_args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True, - parallel_output=parallel_output, - bias=False, - MOE=MOE, - MoE_mp_size=MoE_mp_size, - ) - - def forward(self, hidden_states): - w1_out, _ = self.w1(hidden_states) - w3_out, _ = self.w3(hidden_states) - return self.w2(self.activation_func(w1_out) * w3_out) + def forward(self, x, bias=None): + x, gate = x.chunk(2, dim=-1) + if bias is not None: + bias_1, bias_2 = bias.chunk(2, dim=-1) + x = x + bias_1 + gate = gate + bias_2 + intermediate_parallel = self.activation_func(gate) + return intermediate_parallel * x class ParallelLinear(nn.Module): @@ -1054,24 +1008,13 @@ def __init__( # MLP def get_mlp(mlp_type, **kw): - if mlp_type == "regular": - return ParallelMLP( - neox_args=neox_args, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - parallel_output=self.gpt_j_residual, - **kw, - ) - elif mlp_type == "llama": - return LLaMAParallelMLP( - neox_args=neox_args, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - parallel_output=self.gpt_j_residual, - **kw, - ) - else: - raise KeyError(mlp_type) + return ParallelMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + **kw, + ) self.num_experts = ( neox_args.moe_num_experts @@ -1287,11 +1230,7 @@ def forward(self, x, attention_mask, layer_past=None): raise KeyError(self.moe_type) with torch.enable_grad(): - if ( - self.mlp_type == "llama" - or self.num_experts > 1 - and self.moe_type == "deepspeed" - ): + if self.activation == "swiglu" or self.num_experts > 1 and self.moe_type == "deepspeed": # No dropout either assert mlp_bias is None output = mlp_output + attention_output diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index dd51c7778..818c86d31 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -121,9 +121,12 @@ class NeoXArgsModel(NeoXArgsTemplate): intermediate_size: int = None """ - Transformer intermediate size. Currently only used for "mlp_type": "llama". + Transformer intermediate size. Default = 4h + """ - If not passed, will be set to a reasonable default. + expansion_factor: float = None + """ + Transformer intermediate size. Default = 4 """ num_attention_heads: int = None @@ -278,10 +281,20 @@ class NeoXArgsModel(NeoXArgsTemplate): """ activation: Literal[ - "gelu", "geglu", "relu", "softsign", "swish", "mish", "silu" + "gelu", + "geglu", + "relu", + "softsign", + "swish", + "mish", + "silu", + "reglu", + "swiglu", + "bilinear", + "glu", ] = "gelu" """ - Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu"] + Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "reglu", "swiglu", "bilinear", "glu"] """ scaled_upper_triang_masked_softmax_fusion: bool = False @@ -421,9 +434,9 @@ class NeoXArgsModel(NeoXArgsTemplate): mlp_type: str = "regular" """ + Currently, the only mlp_type is "regular." This behavior is currently deprecated. Types: regular: Megatron implementation - llama: LLaMA MLP (SiLU-gated MLP) """ soft_prompt_tuning: dict = None