Skip to content

Commit

Permalink
Merge pull request #24537 from jakevdp:doc-examples
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 690156005
  • Loading branch information
Google-ML-Automation committed Oct 26, 2024
2 parents 56dc89f + adf1492 commit 2b01aff
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
32 changes: 31 additions & 1 deletion jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,14 +651,44 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
def promote_types(a: DTypeLike, b: DTypeLike) -> DType:
"""Returns the type to which a binary operation should cast its arguments.
For details of JAX's type promotion semantics, see :ref:`type-promotion`.
JAX implementation of :func:`numpy.promote_types`. For details of JAX's
type promotion semantics, see :ref:`type-promotion`.
Args:
a: a :class:`numpy.dtype` or a dtype specifier.
b: a :class:`numpy.dtype` or a dtype specifier.
Returns:
A :class:`numpy.dtype` object.
Examples:
Type specifiers may be strings, dtypes, or scalar types, and the return
value is always a dtype:
>>> jnp.promote_types('int32', 'float32') # strings
dtype('float32')
>>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32')) # dtypes
dtype('float32')
>>> jnp.promote_types(jnp.int32, jnp.float32) # scalar types
dtype('float32')
Built-in scalar types (:type:`int`, :type:`float`, or :type:`complex`) are
treated as weakly-typed and will not change the bit width of a strongly-typed
counterpart (see discussion in :ref:`type-promotion`):
>>> jnp.promote_types('uint8', int)
dtype('uint8')
>>> jnp.promote_types('float16', float)
dtype('float16')
This differs from the NumPy version of this function, which treats built-in scalar
types as equivalent to 64-bit types:
>>> import numpy
>>> numpy.promote_types('uint8', int)
dtype('int64')
>>> numpy.promote_types('float16', float)
dtype('float64')
"""
# Note: we deliberately avoid `if a in _weak_types` here because we want to check
# object identity, not object equality, due to the behavior of np.dtype.__eq__
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
meta = _ScalarMeta(np_scalar_type.__name__, (object,),
{"dtype": np.dtype(np_scalar_type)})
meta.__module__ = _PUBLIC_MODULE_NAME
meta.__doc__ =\
f"""A JAX scalar constructor of type {np_scalar_type.__name__}.
While NumPy defines scalar types for each data type, JAX represents
scalars as zero-dimensional arrays.
"""
return meta

bool_ = _make_scalar_type(np.bool_)
Expand Down
23 changes: 23 additions & 0 deletions jax/_src/numpy/ufunc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,5 +598,28 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
Returns:
wrapped : jax.numpy.ufunc wrapper of func.
Examples:
Here is an example of creating a ufunc similar to :obj:`jax.numpy.add`:
>>> import operator
>>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)
Now all the standard :class:`jax.numpy.ufunc` methods are available:
>>> x = jnp.arange(4)
>>> add(x, 10)
Array([10, 11, 12, 13], dtype=int32)
>>> add.outer(x, x)
Array([[0, 1, 2, 3],
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]], dtype=int32)
>>> add.reduce(x)
Array(6, dtype=int32)
>>> add.accumulate(x)
Array([0, 1, 3, 6], dtype=int32)
>>> add.at(x, 1, 10, inplace=False)
Array([ 0, 11, 2, 3], dtype=int32)
"""
return ufunc(func, nin, nout, identity=identity)

0 comments on commit 2b01aff

Please sign in to comment.