Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidLandup0 committed Nov 10, 2024
1 parent bc1879a commit c6e20f6
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 125 deletions.
1 change: 1 addition & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from keras_hub.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.layers.modeling.sine_position_encoding import (
SinePositionEncoding,
Expand Down
34 changes: 34 additions & 0 deletions keras_hub/src/layers/modeling/rms_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export


@keras_hub_export("keras_hub.layers.RMSNormalization")
class RMSNormalization(keras.layers.Layer):
"""
Root Mean Square (RMS) Normalization layer.
This layer normalizes the input tensor based on its RMS value and applies
a learned scaling factor.
Args:
input_dim: int. The dimensionality of the input tensor.
"""

def __init__(self, input_dim):
super().__init__()
self.scale = self.add_weight(
name="scale", shape=(input_dim,), initializer="ones"
)

def call(self, x):
"""
Applies RMS normalization to the input tensor.
Args:
x: KerasTensor. Input tensor of shape (batch_size, input_dim).
Returns:
KerasTensor: The RMS-normalized tensor of the same shape (batch_size, input_dim),
scaled by the learned `scale` parameter.
"""
x = ops.cast(x, float)
rrms = ops.rsqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + 1e-6)
return (x * rrms) * self.scale
Loading

0 comments on commit c6e20f6

Please sign in to comment.