diff --git a/flax/nnx/module.py b/flax/nnx/module.py index 795bb9a088..82b1dd993c 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -358,12 +358,13 @@ def train(self, **attributes): ) def eval(self, **attributes): - """Sets the Module to evaluation mode. + """Sets the :class:`flax.nnx.Module` to evaluation mode. - ``eval`` uses ``set_attributes`` to recursively set attributes ``deterministic=True`` - and ``use_running_average=True`` of all nested Modules that have these attributes. - Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm`` - Modules. + ``nnx.Module.eval`` uses :func:`flax.nnx.Module.set_attributes` to recursively set + attributes ``deterministic=True`` and ``use_running_average=True`` of all nested + ``nnx.Module``'s that have these attributes. + It is primarily used to control the runtime behavior of the :class:`flax.nnx.Dropout` + and :class:`flax.nnx.BatchNorm` ``nnx.Module``'s. Example:: @@ -383,7 +384,7 @@ def eval(self, **attributes): (True, True) Args: - **attributes: additional attributes passed to ``set_attributes``. + **attributes: Additional attributes passed to ``set_attributes``. """ return self.set_attributes( deterministic=True,