Skip to content

Commit

Permalink
Now using jaxtyping.Real for prettier documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jan 13, 2025
1 parent d98338e commit dd9ac56
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 16 deletions.
19 changes: 3 additions & 16 deletions diffrax/_custom_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import typing
from typing import Any, TYPE_CHECKING, Union

import equinox as eqx
Expand All @@ -13,6 +12,7 @@
Float,
Int,
PyTree,
Real,
Shaped,
)

Expand All @@ -21,27 +21,14 @@
BoolScalarLike = Union[bool, Array, np.ndarray]
FloatScalarLike = Union[float, Array, np.ndarray]
IntScalarLike = Union[int, Array, np.ndarray]
elif getattr(typing, "GENERATING_DOCUMENTATION", False):
# Skip the union with Array in docs.
BoolScalarLike = bool
FloatScalarLike = float
IntScalarLike = int

#
# Because they appear in our docstrings, we also monkey-patch some non-Diffrax
# types that have similar defined-in-one-place, exported-in-another behaviour.
#

jtu.Partial.__module__ = "jax.tree_util"

RealScalarLike = Union[bool, int, float, Array, np.ndarray]
else:
BoolScalarLike = Bool[ArrayLike, ""]
FloatScalarLike = Float[ArrayLike, ""]
IntScalarLike = Int[ArrayLike, ""]
RealScalarLike = Real[ArrayLike, ""]


RealScalarLike = Union[FloatScalarLike, IntScalarLike]

Y = PyTree[Shaped[ArrayLike, "?*y"], "Y"]
VF = PyTree[Shaped[ArrayLike, "?*vf"], "VF"]
Control = PyTree[Shaped[ArrayLike, "?*control"], "C"]
Expand Down
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ plugins:
setup_commands:
- import pytkdocs_tweaks
- pytkdocs_tweaks.main()
- import jax.tree_util
- jax.tree_util.Partial.__module__ = "jax.tree_util"

selection:
inherited_members: true # Allow looking up inherited methods
Expand Down

0 comments on commit dd9ac56

Please sign in to comment.