From 9038bb2664a6d5249025c4590feb7d430e509253 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 22 Oct 2024 08:41:58 -0700 Subject: [PATCH] Better documentation for jnp.indices --- jax/_src/numpy/lax_numpy.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 714f21577b28..b92e5e250d1c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6913,6 +6913,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, A length-N list of grid arrays. See also: + - :func:`jax.numpy.indices`: generate a grid of indices. - :obj:`jax.numpy.mgrid`: create a meshgrid using indexing syntax. - :obj:`jax.numpy.ogrid`: create an open meshgrid using indexing syntax. @@ -7085,9 +7086,38 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, @overload def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: ... -@util.implements(np.indices) def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: + """Generate arrays of grid indices. + + JAX implementation of :func:`numpy.indices`. + + Args: + dimensions: the shape of the grid. + dtype: the dtype of the indices (defaults to integer). + sparse: if True, then return sparse indices. Default is False, which + returns dense indices. + + Returns: + An array of shape ``(len(dimensions), *dimensions)`` If ``sparse`` is False, + or a sequence of arrays of the same length as ``dimensions`` if ``sparse`` is True. + + See also: + - :func:`jax.numpy.meshgrid`: generate a grid from arbitrary input arrays. + - :obj:`jax.numpy.mgrid`: generate dense indices using a slicing syntax. + - :obj:`jax.numpy.ogrid`: generate sparse indices using a slicing syntax. + + Examples: + >>> jnp.indices((2, 3)) + Array([[[0, 0, 0], + [1, 1, 1]], + + [[0, 1, 2], + [0, 1, 2]]], dtype=int32) + >>> jnp.indices((2, 3), sparse=True) + (Array([[0], + [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32)) + """ dtypes.check_user_dtype_supported(dtype, "indices") dtype = dtype or dtypes.canonicalize_dtype(int_) dimensions = tuple(