Skip to content

Commit

Permalink
Better documentation for jnp.indices
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 22, 2024
1 parent f8a1f02 commit 9038bb2
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6913,6 +6913,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
A length-N list of grid arrays.
See also:
- :func:`jax.numpy.indices`: generate a grid of indices.
- :obj:`jax.numpy.mgrid`: create a meshgrid using indexing syntax.
- :obj:`jax.numpy.ogrid`: create an open meshgrid using indexing syntax.
Expand Down Expand Up @@ -7085,9 +7086,38 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
sparse: bool = False) -> Array | tuple[Array, ...]: ...
@util.implements(np.indices)
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
sparse: bool = False) -> Array | tuple[Array, ...]:
"""Generate arrays of grid indices.
JAX implementation of :func:`numpy.indices`.
Args:
dimensions: the shape of the grid.
dtype: the dtype of the indices (defaults to integer).
sparse: if True, then return sparse indices. Default is False, which
returns dense indices.
Returns:
An array of shape ``(len(dimensions), *dimensions)`` If ``sparse`` is False,
or a sequence of arrays of the same length as ``dimensions`` if ``sparse`` is True.
See also:
- :func:`jax.numpy.meshgrid`: generate a grid from arbitrary input arrays.
- :obj:`jax.numpy.mgrid`: generate dense indices using a slicing syntax.
- :obj:`jax.numpy.ogrid`: generate sparse indices using a slicing syntax.
Examples:
>>> jnp.indices((2, 3))
Array([[[0, 0, 0],
[1, 1, 1]],
<BLANKLINE>
[[0, 1, 2],
[0, 1, 2]]], dtype=int32)
>>> jnp.indices((2, 3), sparse=True)
(Array([[0],
[1]], dtype=int32), Array([[0, 1, 2]], dtype=int32))
"""
dtypes.check_user_dtype_supported(dtype, "indices")
dtype = dtype or dtypes.canonicalize_dtype(int_)
dimensions = tuple(
Expand Down

0 comments on commit 9038bb2

Please sign in to comment.