Skip to content

Commit

Permalink
Fix numeric instability in LayerNormalization and BatchNormalization.
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Sep 13, 2024
1 parent e7b5a5d commit c2f5651
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
15 changes: 6 additions & 9 deletions keras/src/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from keras.src import ops
from keras.src import regularizers
from keras.src.api_export import keras_export
from keras.src.backend import standardize_dtype
from keras.src.layers.input_spec import InputSpec
from keras.src.layers.layer import Layer

Expand Down Expand Up @@ -244,11 +243,11 @@ def call(self, inputs, training=None, mask=None):
f"mask.shape={mask.shape}, inputs.shape={inputs.shape}"
)

input_dtype = standardize_dtype(inputs.dtype)
if input_dtype in ("float16", "bfloat16"):
# BN is prone to overflowing for float16/bfloat16 inputs, so we opt
# out BN for mixed precision.
inputs = ops.cast(inputs, "float32")
input_dtype = backend.standardize_dtype(inputs.dtype)
compute_dtype = backend.result_type(input_dtype, "float32")
# BN is prone to overflow with float16/bfloat16 inputs, so we upcast to
# float32 for the subsequent computations.
inputs = ops.cast(inputs, compute_dtype)

moving_mean = ops.cast(self.moving_mean, inputs.dtype)
moving_variance = ops.cast(self.moving_variance, inputs.dtype)
Expand Down Expand Up @@ -286,9 +285,7 @@ def call(self, inputs, training=None, mask=None):
scale=gamma,
epsilon=self.epsilon,
)
if input_dtype in ("float16", "bfloat16"):
outputs = ops.cast(outputs, input_dtype)
return outputs
return ops.cast(outputs, input_dtype)

def get_config(self):
base_config = super().get_config()
Expand Down
13 changes: 6 additions & 7 deletions keras/src/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras.src import backend
from keras.src import constraints
from keras.src import initializers
from keras.src import ops
Expand Down Expand Up @@ -179,7 +180,6 @@ def build(self, input_shape):
self.built = True

def call(self, inputs):
inputs = ops.cast(inputs, self.compute_dtype)
# Compute the axes along which to reduce the mean / variance
input_shape = inputs.shape
ndims = len(input_shape)
Expand All @@ -199,11 +199,11 @@ def _broadcast(v):
return ops.reshape(v, broadcast_shape)
return v

input_dtype = inputs.dtype
if input_dtype in ("float16", "bfloat16") and self.dtype == "float32":
# If mixed precision is used, cast inputs to float32 so that
# this is at least as numerically stable as the fused version.
inputs = ops.cast(inputs, "float32")
input_dtype = backend.standardize_dtype(inputs.dtype)
compute_dtype = backend.result_type(input_dtype, "float32")
# LN is prone to overflow with float16/bfloat16 inputs, so we upcast to
# float32 for the subsequent computations.
inputs = ops.cast(inputs, compute_dtype)

if self.rms_scaling:
# Calculate outputs with only variance and gamma if rms scaling
Expand Down Expand Up @@ -231,7 +231,6 @@ def _broadcast(v):
res = res + beta

outputs = inputs * inv + res

return ops.cast(outputs, input_dtype)

def compute_output_shape(self, input_shape):
Expand Down

0 comments on commit c2f5651

Please sign in to comment.