Skip to content

Commit

Permalink
GemmaMLP uses 'tanh` approximation for GeLU activation (Lightning-AI#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov authored Mar 5, 2024
1 parent cc17394 commit f241d94
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@ def norm_class(self) -> Type:
bias=False,
_norm_class="RMSNorm",
_mlp_class="GemmaMLP",
gelu_approximate="tanh",
intermediate_size=16384,
),
# https://huggingface.co/google/gemma-7b/blob/main/config.json
Expand All @@ -884,6 +885,7 @@ def norm_class(self) -> Type:
bias=False,
_norm_class="RMSNorm",
_mlp_class="GemmaMLP",
gelu_approximate="tanh",
intermediate_size=24576,
),
]
Expand Down
4 changes: 3 additions & 1 deletion lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def __init__(self, config: Config) -> None:
self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)

self.config = config

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
Expand All @@ -298,7 +300,7 @@ class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.gelu(x_fc_1) * x_fc_2
x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2
return self.proj(x)


Expand Down
1 change: 1 addition & 0 deletions tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def test_against_original_gemma(model_name, device, dtype):
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

Expand Down
1 change: 1 addition & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def test_against_original_gemma(model_name, device, dtype):
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

Expand Down

0 comments on commit f241d94

Please sign in to comment.