diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 4752a9b7bd..11e3f33d01 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -54,10 +54,10 @@ class VariableMetadata(tp.Generic[A]): class Variable(tp.Generic[A], reprlib.Representable): - """The base class for all ``Variable`` types. Create custom ``Variable`` - types by subclassing this class. Numerous NNX graph functions can filter - for specific ``Variable`` types, for example, :func:`split`, :func:`state`, - :func:`pop`, and :func:`State.filter`. + """The base class for all :class:`flax.nnx.Variable` types. Creates custom + ``nnx.Variable`` types by subclassing this class. Numerous NNX graph functions + can filter for specific ``nnx.Variable`` types, for example, :func:`flax.nnx.split`, + :func:`flax.nnx.state`, :func:`flax.nnx.pop`, and :func:`flax.nnx.State.filter`. Example usage::