From c2f5c532ea30cc74555e6435bb275318358bd013 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 12 Jan 2025 21:46:13 +0100 Subject: [PATCH] Now using jaxtyping.Real for prettier documentation. --- diffrax/_custom_types.py | 19 +++---------------- mkdocs.yml | 2 ++ 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index 7e08aa1b..a16b4d61 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -1,4 +1,3 @@ -import typing from typing import Any, TYPE_CHECKING, Union import equinox as eqx @@ -13,6 +12,7 @@ Float, Int, PyTree, + Real, Shaped, ) @@ -21,27 +21,14 @@ BoolScalarLike = Union[bool, Array, np.ndarray] FloatScalarLike = Union[float, Array, np.ndarray] IntScalarLike = Union[int, Array, np.ndarray] -elif getattr(typing, "GENERATING_DOCUMENTATION", False): - # Skip the union with Array in docs. - BoolScalarLike = bool - FloatScalarLike = float - IntScalarLike = int - - # - # Because they appear in our docstrings, we also monkey-patch some non-Diffrax - # types that have similar defined-in-one-place, exported-in-another behaviour. - # - - jtu.Partial.__module__ = "jax.tree_util" - + RealScalarLike = Union[bool, int, float, Array, np.ndarray] else: BoolScalarLike = Bool[ArrayLike, ""] FloatScalarLike = Float[ArrayLike, ""] IntScalarLike = Int[ArrayLike, ""] + RealScalarLike = Real[ArrayLike, ""] -RealScalarLike = Union[FloatScalarLike, IntScalarLike] - Y = PyTree[Shaped[ArrayLike, "?*y"], "Y"] VF = PyTree[Shaped[ArrayLike, "?*vf"], "VF"] Control = PyTree[Shaped[ArrayLike, "?*control"], "C"] diff --git a/mkdocs.yml b/mkdocs.yml index b399fbd8..067cd458 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,6 +75,8 @@ plugins: setup_commands: - import pytkdocs_tweaks - pytkdocs_tweaks.main() + - import jax.tree_util + - jax.tree_util.Partial.__module__ = "jax.tree_util" selection: inherited_members: true # Allow looking up inherited methods