-
Notifications
You must be signed in to change notification settings - Fork 242
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bc1879a
commit c6e20f6
Showing
5 changed files
with
149 additions
and
125 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.RMSNormalization") | ||
class RMSNormalization(keras.layers.Layer): | ||
""" | ||
Root Mean Square (RMS) Normalization layer. | ||
This layer normalizes the input tensor based on its RMS value and applies | ||
a learned scaling factor. | ||
Args: | ||
input_dim: int. The dimensionality of the input tensor. | ||
""" | ||
|
||
def __init__(self, input_dim): | ||
super().__init__() | ||
self.scale = self.add_weight( | ||
name="scale", shape=(input_dim,), initializer="ones" | ||
) | ||
|
||
def call(self, x): | ||
""" | ||
Applies RMS normalization to the input tensor. | ||
Args: | ||
x: KerasTensor. Input tensor of shape (batch_size, input_dim). | ||
Returns: | ||
KerasTensor: The RMS-normalized tensor of the same shape (batch_size, input_dim), | ||
scaled by the learned `scale` parameter. | ||
""" | ||
x = ops.cast(x, float) | ||
rrms = ops.rsqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + 1e-6) | ||
return (x * rrms) * self.scale |
Oops, something went wrong.