From 7d3faf4eb7aa4353b43989da884ce7c1864dc51c Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 16 Dec 2024 23:55:12 +0000 Subject: [PATCH] Update NNX Module eval docs in module.py --- flax/nnx/module.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/flax/nnx/module.py b/flax/nnx/module.py index 795bb9a088..82b1dd993c 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -358,12 +358,13 @@ def train(self, **attributes): ) def eval(self, **attributes): - """Sets the Module to evaluation mode. + """Sets the :class:`flax.nnx.Module` to evaluation mode. - ``eval`` uses ``set_attributes`` to recursively set attributes ``deterministic=True`` - and ``use_running_average=True`` of all nested Modules that have these attributes. - Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm`` - Modules. + ``nnx.Module.eval`` uses :func:`flax.nnx.Module.set_attributes` to recursively set + attributes ``deterministic=True`` and ``use_running_average=True`` of all nested + ``nnx.Module``'s that have these attributes. + It is primarily used to control the runtime behavior of the :class:`flax.nnx.Dropout` + and :class:`flax.nnx.BatchNorm` ``nnx.Module``'s. Example:: @@ -383,7 +384,7 @@ def eval(self, **attributes): (True, True) Args: - **attributes: additional attributes passed to ``set_attributes``. + **attributes: Additional attributes passed to ``set_attributes``. """ return self.set_attributes( deterministic=True,