From a0cad459172957c2621ec96a87de8979de6457ea Mon Sep 17 00:00:00 2001 From: Steven Kothen-Hill Date: Tue, 9 Jul 2024 15:43:44 -0700 Subject: [PATCH] remove hardcoded gelu from BERT models. --- megatron/core/models/bert/bert_lm_head.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/models/bert/bert_lm_head.py b/megatron/core/models/bert/bert_lm_head.py index 548c0460dc..56e7052ecc 100644 --- a/megatron/core/models/bert/bert_lm_head.py +++ b/megatron/core/models/bert/bert_lm_head.py @@ -50,10 +50,10 @@ def __init__( eps=config.layernorm_epsilon, ) - self.gelu = torch.nn.functional.gelu + self.activation = config.activation_func def forward(self, hidden_states: Tensor) -> Tensor: hidden_states = self.dense(hidden_states) - hidden_states = self.gelu(hidden_states) + hidden_states = self.activation(hidden_states) hidden_states = self.layer_norm(hidden_states) return hidden_states