Skip to content

Commit

Permalink
Add intermediate_size to GPT-NeoX models (#1212)
Browse files Browse the repository at this point in the history
* Update transformer.py -> Add `intermediate_size`

* add support for rwkv and mamba and add todos about swiglu

* refactor activations and mlps

* change llama config to swiglu

* fixes gelu fusion

* pre-commit run

* add assert message to mamba linear

* Update 1-3B.yml

revert accidental change

* Update 1-3B.yml

* fixes various issues

* add back swiglu check

---------

Co-authored-by: jahatef <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
Co-authored-by: Jacob Hatef <[email protected]>
  • Loading branch information
4 people authored Sep 7, 2024
1 parent 7548a8b commit 0d4bdb9
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 159 deletions.
2 changes: 1 addition & 1 deletion configs/llama/13B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
"activation": "swiglu",
}
2 changes: 1 addition & 1 deletion configs/llama/30B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
"activation": "swiglu",
}
2 changes: 1 addition & 1 deletion configs/llama/65B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
"activation": "swiglu",
}
2 changes: 1 addition & 1 deletion configs/llama/7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
"activation": "swiglu",
}
12 changes: 6 additions & 6 deletions megatron/data/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
}

} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down Expand Up @@ -660,9 +660,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down
38 changes: 17 additions & 21 deletions megatron/model/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,23 @@


def get_activation(neox_args):
"""retrieves the activation function specified in neox_args"""
"""retrieves the activation function specified in neox_args and whether or not the activation is gated"""
is_gated = False
if neox_args.activation == "geglu":
activation_func = GEGLU(neox_args=neox_args)
is_gated = True
activation_func = F.gelu
elif neox_args.activation == "reglu":
is_gated = True
activation_func = F.relu
elif neox_args.activation == "bilinear":
is_gated = True
activation_func = lambda x: x
elif neox_args.activation == "swiglu":
is_gated = True
activation_func = swish
elif neox_args.activation == "glu":
is_gated = True
activation_func = F.sigmoid
elif neox_args.activation == "gelu":
if neox_args.onnx_safe and neox_args.bias_gelu_fusion:
raise ValueError("onnx_safe + bias_gelu_fusion not compatible")
Expand All @@ -49,7 +63,7 @@ def get_activation(neox_args):
activation_func = F.silu
else:
raise ValueError(f"Activation function {neox_args.activation} not recognized")
return activation_func
return activation_func, is_gated


###### BIAS GELU FUSION/ NO AUTOGRAD ################
Expand Down Expand Up @@ -119,21 +133,3 @@ def swish(x, beta: float = 1.0):
@torch.jit.script
def mish(x):
return x * torch.tanh(F.softplus(x))


class GEGLU(torch.nn.Module):
def __init__(self, neox_args):
super(GEGLU, self).__init__()
if neox_args.onnx_safe:
self.activation_func = erf_gelu
else:
self.activation_func = F.gelu

def forward(self, x, bias=None):
x, gate = x.chunk(2, dim=-1)
if bias is not None:
bias_1, bias_2 = bias.chunk(2, dim=-1)
x = x + bias_1
gate = gate + bias_2
intermediate_parallel = self.activation_func(gate)
return intermediate_parallel * x
2 changes: 1 addition & 1 deletion megatron/model/gmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
init_method=init_method,
skip_bias_add=True,
)
self.activation_func = get_activation(neox_args)
self.activation_func, _ = get_activation(neox_args)
ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size())
if neox_args.attention_config[layer_number] == "amlp":
d_attn = neox_args.gmlp_attn_dim
Expand Down
9 changes: 7 additions & 2 deletions megatron/model/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ def __init__(
neox_args.mamba_use_bias_in_linears and neox_args.mamba_inner_func_fusion
), "Mamba fused inner fn and bias in x_proj not compatible!"

assert neox_args.intermediate_size == None or neox_args.expansion_factor == None, "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections"

# set variables, mostly following mamba defaults
self.d_model = neox_args.hidden_size
self.d_state = 16 # state dimensions per channel
self.d_conv = 4 # convolution width
self.expand = 2 # linear projection expansion factors
self.d_inner = int(self.expand * self.d_model)
if neox_args.intermediate_size:
self.d_inner = neox_args.intermediate_size
else:
self.expand = neox_args.expansion_factor if neox_args.expansion_factor else 2
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) # rank of dt / Delta parameter
self.dt_scale = 1.0

Expand Down
17 changes: 11 additions & 6 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def __init__(self, neox_args, layer_number):
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))

self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False)
self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False)
self.receptance = nn.Linear(
neox_args.hidden_size, neox_args.hidden_size, bias=False
)
self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False)
self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False)

def forward(self, x):
xx = self.time_shift(x) - x
Expand All @@ -275,14 +275,19 @@ def __init__(self, neox_args, layer_number):
self.layer_number = layer_number
self.fp16 = neox_args.precision == "fp16"
self.bf16 = neox_args.precision == "bfloat16"
assert neox_args.intermediate_size == None or neox_args.expansion_factor == None, "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections"
if not hasattr(neox_args, "dim_att"):
neox_args.dim_att = neox_args.hidden_size
if not hasattr(neox_args, "dim_ffn"):
# Make hidden size 3.5x. Round to nearest multiple of 32 until we add hdim rounding logic
neox_args.dim_ffn = int((neox_args.hidden_size * 3.5) // 32 * 32)
if neox_args.intermediate_size:
neox_args.ffn_dim = neox_args.intermediate_size
else:
self.expand = neox_args.expansion_factor if neox_args.expansion_factor else 3.5
neox_args.ffn_dim = int(self.expand * neox_args.hidden_size)
# Make hidden size 3.5x by default. Round to nearest multiple of 32 until we add hdim rounding logic
neox_args.ffn_dim = int(neox_args.ffn_dim // 32 * 32)
assert neox_args.hidden_size % 32 == 0
assert neox_args.dim_att % 32 == 0
assert neox_args.dim_ffn % 32 == 0
assert neox_args.ffn_dim % 32 == 0
self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads
self.head_size = self.neox_args.head_size
self.num_attention_heads = neox_args.num_attention_heads
Expand Down
Loading

0 comments on commit 0d4bdb9

Please sign in to comment.