Skip to content

Commit

Permalink
fix(model): fix norm input dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Dec 3, 2024
1 parent 9d2d76a commit 8081e61
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _forward(self, hidden_states, *args, **kwargs):
def _dropout_and_norm_attn(_hidden_states):
_dropped = self.dropout1(_hidden_states)
_residual = _dropped
_hidden_states = self.norm1(_residual.float())
_hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype))
return _residual, _hidden_states

if self.dropout_selective_checkpoint:
Expand All @@ -212,7 +212,7 @@ def _dropout_and_norm_attn(_hidden_states):
def _dropout_and_norm_ffn(_residual, _hidden_states):
_dropped = self.dropout2(_hidden_states)
_residual = (_dropped + _residual) if _residual is not None else _dropped
_hidden_states = self.norm2(_residual.float())
_hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype))
return _residual, _hidden_states

if self.dropout_selective_checkpoint:
Expand Down
2 changes: 1 addition & 1 deletion internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _dropout_and_norm_attn(_residual, _hidden_states):
def _dropout_and_norm_ffn(_residual, _hidden_states):
_dropped = self.dropout2(_hidden_states)
_residual = (_dropped + _residual) if _residual is not None else _dropped
_hidden_states = self.ffn_norm(_residual.to(torch.float32))
_hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))

return _residual, _hidden_states

Expand Down
2 changes: 1 addition & 1 deletion internlm/model/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _dropout_and_norm_attn(_residual, _hidden_states):
def _dropout_and_norm_ffn(_residual, _hidden_states):
_dropped = self.dropout2(_hidden_states)
_residual = (_dropped + _residual) if _residual is not None else _dropped
_hidden_states = self.ffn_norm(_residual.to(torch.float32))
_hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))

return _residual, _hidden_states

Expand Down
4 changes: 2 additions & 2 deletions internlm/model/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _forward(self, hidden_states, *args, **kwargs):
def _dropout_and_norm_attn(_hidden_states):
_dropped = self.dropout1(_hidden_states)
_residual = _dropped
_hidden_states = self.norm1(_residual.float())
_hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype))
return _residual, _hidden_states

if self.dropout_selective_checkpoint:
Expand All @@ -231,7 +231,7 @@ def _dropout_and_norm_attn(_hidden_states):
def _dropout_and_norm_ffn(_residual, _hidden_states):
_dropped = self.dropout2(_hidden_states)
_residual = (_dropped + _residual) if _residual is not None else _dropped
_hidden_states = self.norm2(_residual.float())
_hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype))
return _residual, _hidden_states

if self.dropout_selective_checkpoint:
Expand Down
4 changes: 2 additions & 2 deletions internlm/model/modeling_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _forward(self, hidden_states, *args, **kwargs):
def _dropout_and_norm_attn(_hidden_states):
_dropped = self.dropout1(_hidden_states)
_residual = _dropped
_hidden_states = self.norm1(_residual.float())
_hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype))
return _residual, _hidden_states

if self.dropout_selective_checkpoint:
Expand All @@ -222,7 +222,7 @@ def _dropout_and_norm_attn(_hidden_states):
def _dropout_and_norm_ffn(_residual, _hidden_states):
_dropped = self.dropout2(_hidden_states)
_residual = (_dropped + _residual) if _residual is not None else _dropped
_hidden_states = self.norm2(_residual.float())
_hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype))
return _residual, _hidden_states

if self.dropout_selective_checkpoint:
Expand Down

0 comments on commit 8081e61

Please sign in to comment.