diff --git a/docs/source/development/adding_to_the_api_documentation.ipynb b/docs/source/development/adding_to_the_api_documentation.ipynb index ccfa012ea6..4dc1a8b143 100644 --- a/docs/source/development/adding_to_the_api_documentation.ipynb +++ b/docs/source/development/adding_to_the_api_documentation.ipynb @@ -45,27 +45,27 @@ "import probnum # pylint: disable=unused-import\n", "from probnum import linops, randvars, utils\n", "from probnum.linalg.solvers.matrixbased import SymmetricMatrixBasedSolver\n", - "from probnum.typing import LinearOperatorArgType\n", + "from probnum.typing import LinearOperatorLike\n", "\n", "# pylint: disable=too-many-branches\n", "\n", "\n", "def problinsolve(\n", " A: Union[\n", - " LinearOperatorArgType,\n", - " \"randvars.RandomVariable[LinearOperatorArgType]\",\n", + " LinearOperatorLike,\n", + " \"randvars.RandomVariable[LinearOperatorLike]\",\n", " ],\n", " b: Union[np.ndarray, \"randvars.RandomVariable[np.ndarray]\"],\n", " A0: Optional[\n", " Union[\n", - " LinearOperatorArgType,\n", - " \"randvars.RandomVariable[LinearOperatorArgType]\",\n", + " LinearOperatorLike,\n", + " \"randvars.RandomVariable[LinearOperatorLike]\",\n", " ]\n", " ] = None,\n", " Ainv0: Optional[\n", " Union[\n", - " LinearOperatorArgType,\n", - " \"randvars.RandomVariable[LinearOperatorArgType]\",\n", + " LinearOperatorLike,\n", + " \"randvars.RandomVariable[LinearOperatorLike]\",\n", " ]\n", " ] = None,\n", " x0: Optional[Union[np.ndarray, \"randvars.RandomVariable[np.ndarray]\"]] = None,\n", diff --git a/docs/source/development/implementing_a_probnum_method.ipynb b/docs/source/development/implementing_a_probnum_method.ipynb index b55da12496..8f27e95496 100644 --- a/docs/source/development/implementing_a_probnum_method.ipynb +++ b/docs/source/development/implementing_a_probnum_method.ipynb @@ -51,7 +51,7 @@ "source": [ "### Method `probsolve_qp`\n", "\n", - "We will now take a closer look at the interface of our 1D noisy quadratic optimization method. At a basic level `probsolve_qp` takes a function of the type `Callable[[FloatArgType], FloatArgType]`. This hints that the optimization objective is a 1D function. Our prior knowledge about the parameters $(a,b,c)$ is encoded in the random variable `fun_params0`. However, we want to also give a user the option to not specify any prior knowledge or just a guess about the parameter values, hence this argument is optional or can be an `np.ndarray`. \n", + "We will now take a closer look at the interface of our 1D noisy quadratic optimization method. At a basic level `probsolve_qp` takes a function of the type `Callable[[FloatLike], FloatLike]`. This hints that the optimization objective is a 1D function. Our prior knowledge about the parameters $(a,b,c)$ is encoded in the random variable `fun_params0`. However, we want to also give a user the option to not specify any prior knowledge or just a guess about the parameter values, hence this argument is optional or can be an `np.ndarray`. \n", "\n", "The interface also has an `assume_fun` argument, which allows specification of the variant of the probabilistic numerical method to use based on the assumptions about the problem. For convenience, this can be inferred from the problem itself. The actual implementation of the PN method variant which is initialized in a modular fashion is separate from the interface and will be explained later. Finally, the actual optimization routine is called and the result is returned." ] @@ -67,7 +67,7 @@ "\n", "import probnum as pn\n", "from probnum import randvars, linops\n", - "from probnum.typing import FloatArgType, IntArgType\n", + "from probnum.typing import FloatLike, IntLike\n", "\n", "rng = np.random.default_rng(seed=123)" ] @@ -83,14 +83,14 @@ "# %load -s probsolve_qp quadopt_example/_probsolve_qp\n", "def probsolve_qp(\n", " rng: np.random.Generator,\n", - " fun: Callable[[FloatArgType], FloatArgType],\n", + " fun: Callable[[FloatLike], FloatLike],\n", " fun_params0: Optional[Union[np.ndarray, randvars.RandomVariable]] = None,\n", " assume_fun: Optional[str] = None,\n", - " tol: FloatArgType = 10 ** -5,\n", - " maxiter: IntArgType = 10 ** 4,\n", + " tol: FloatLike = 10 ** -5,\n", + " maxiter: IntLike = 10 ** 4,\n", " noise_cov: Optional[Union[np.ndarray, linops.LinearOperator]] = None,\n", " callback: Optional[\n", - " Callable[[FloatArgType, FloatArgType, randvars.RandomVariable], None]\n", + " Callable[[FloatLike, FloatLike, randvars.RandomVariable], None]\n", " ] = None,\n", ") -> Tuple[float, randvars.RandomVariable, randvars.RandomVariable, Dict]:\n", " \"\"\"Probabilistic 1D Quadratic Optimization.\n", @@ -316,24 +316,24 @@ "# Type aliases for quadratic optimization\n", "QuadOptPolicyType = Callable[\n", " [\n", - " Callable[[FloatArgType], FloatArgType],\n", + " Callable[[FloatLike], FloatLike],\n", " randvars.RandomVariable,\n", " ],\n", - " FloatArgType,\n", + " FloatLike,\n", "]\n", "QuadOptObservationOperatorType = Callable[\n", - " [Callable[[FloatArgType], FloatArgType], FloatArgType], FloatArgType\n", + " [Callable[[FloatLike], FloatLike], FloatLike], FloatLike\n", "]\n", "QuadOptBeliefUpdateType = Callable[\n", " [\n", " randvars.RandomVariable,\n", - " FloatArgType,\n", - " FloatArgType,\n", + " FloatLike,\n", + " FloatLike,\n", " ],\n", " randvars.RandomVariable,\n", "]\n", "QuadOptStoppingCriterionType = Callable[\n", - " [Callable[[FloatArgType], FloatArgType], randvars.RandomVariable, IntArgType],\n", + " [Callable[[FloatLike], FloatLike], randvars.RandomVariable, IntLike],\n", " Tuple[bool, Union[str, None]],\n", "]\n", "\n", @@ -430,7 +430,7 @@ " self.stopping_criteria = stopping_criteria\n", "\n", " def has_converged(\n", - " self, fun: Callable[[FloatArgType], FloatArgType], iteration: IntArgType\n", + " self, fun: Callable[[FloatLike], FloatLike], iteration: IntLike\n", " ) -> Tuple[bool, Union[str, None]]:\n", " \"\"\"Check whether the optimizer has converged.\n", "\n", @@ -451,7 +451,7 @@ "\n", " def optim_iterator(\n", " self,\n", - " fun: Callable[[FloatArgType], FloatArgType],\n", + " fun: Callable[[FloatLike], FloatLike],\n", " ) -> Tuple[float, float, randvars.RandomVariable]:\n", " \"\"\"Generator implementing the optimization iteration.\n", "\n", @@ -486,7 +486,7 @@ "\n", " def optimize(\n", " self,\n", - " fun: Callable[[FloatArgType], FloatArgType],\n", + " fun: Callable[[FloatLike], FloatLike],\n", " callback: Optional[\n", " Callable[[float, float, randvars.RandomVariable], None]\n", " ] = None,\n", @@ -584,10 +584,10 @@ "internal representation of those same objects. Canonical examples are different kinds of integer or float types, which might be passed by a user. These are all unified internally.\n", "\n", "```python\n", - "IntArgType = Union[int, numbers.Integral, np.integer]\n", - "FloatArgType = Union[float, numbers.Real, np.floating]\n", + "IntLike = Union[int, numbers.Integral, np.integer]\n", + "FloatLike = Union[float, numbers.Real, np.floating]\n", "\n", - "ShapeArgType = Union[IntArgType, Iterable[IntArgType]]\n", + "ShapeLike = Union[IntLike, Iterable[IntLike]]\n", "\"\"\"Type of a public API argument for supplying a shape. Values of this type should\n", "always be converted into :class:`ShapeType` using the function\n", ":func:`probnum.utils.as_shape` before further internal processing.\"\"\"\n", @@ -602,11 +602,11 @@ "metadata": {}, "outputs": [], "source": [ - "from probnum.typing import ShapeType, IntArgType, ShapeArgType\n", + "from probnum.typing import ShapeType, IntLike, ShapeLike\n", "from probnum.utils import as_shape\n", "\n", "\n", - "def extend_shape(shape: ShapeArgType, extension: IntArgType) -> ShapeType:\n", + "def extend_shape(shape: ShapeLike, extension: IntLike) -> ShapeType:\n", " return as_shape(shape) + as_shape(extension)" ] }, @@ -674,7 +674,7 @@ "source": [ "# %load -s explore_exploit_policy quadopt_example/policies\n", "def explore_exploit_policy(\n", - " fun: Callable[[FloatArgType], FloatArgType],\n", + " fun: Callable[[FloatLike], FloatLike],\n", " fun_params0: randvars.RandomVariable,\n", " rng: np.random.Generator,\n", ") -> float:\n", @@ -704,16 +704,16 @@ "```python\n", "QuadOptPolicyType = Callable[\n", " [\n", - " Callable[[FloatArgType], FloatArgType],\n", + " Callable[[FloatLike], FloatLike],\n", " randvars.RandomVariable\n", " ],\n", - " FloatArgType,\n", + " FloatLike,\n", "]\n", "```\n", "The observation process for this problem is very simple. It just evaluates the objective function. \n", "```python\n", "QuadOptObservationOperatorType = Callable[\n", - " [Callable[[FloatArgType], FloatArgType], FloatArgType], FloatArgType\n", + " [Callable[[FloatLike], FloatLike], FloatLike], FloatLike\n", "]\n", "```\n", "One can imagine a different probabilistic optimization method which evaluates the gradient as well. In this case the different observation processes would all get the function, its gradient and an evaluation point / action as arguments." @@ -727,7 +727,7 @@ "source": [ "# %load -s function_evaluation quadopt_example/observation_operators\n", "def function_evaluation(\n", - " fun: Callable[[FloatArgType], FloatArgType], action: FloatArgType\n", + " fun: Callable[[FloatLike], FloatLike], action: FloatLike\n", ") -> np.float_:\n", " \"\"\"Observe a (noisy) function evaluation of the quadratic objective.\n", "\n", @@ -758,8 +758,8 @@ "QuadOptBeliefUpdateType = Callable[\n", " [\n", " randvars.RandomVariable,\n", - " FloatArgType,\n", - " FloatArgType,\n", + " FloatLike,\n", + " FloatLike,\n", " ],\n", " randvars.RandomVariable,\n", "]\n", @@ -776,8 +776,8 @@ "# %load -s gaussian_belief_update quadopt_example/belief_updates\n", "def gaussian_belief_update(\n", " fun_params0: randvars.RandomVariable,\n", - " action: FloatArgType,\n", - " observation: FloatArgType,\n", + " action: FloatLike,\n", + " observation: FloatLike,\n", " noise_cov: Union[np.ndarray, linops.LinearOperator],\n", ") -> randvars.RandomVariable:\n", " \"\"\"Update the belief over the parameters with an observation.\n", @@ -823,7 +823,7 @@ "The stopping criteria are also implemented as simple methods, which return a `bool` determining convergence and a string giving the name of the criterion.\n", "```python\n", "QuadOptStoppingCriterionType = Callable[\n", - " [Callable[[FloatArgType], FloatArgType], randvars.RandomVariable, IntArgType],\n", + " [Callable[[FloatLike], FloatLike], randvars.RandomVariable, IntLike],\n", " Tuple[bool, Union[str, None]],\n", "]\n", "```\n", @@ -838,11 +838,11 @@ "source": [ "# %load -s parameter_uncertainty quadopt_example/stopping_criteria\n", "def parameter_uncertainty(\n", - " fun: Callable[[FloatArgType], FloatArgType],\n", + " fun: Callable[[FloatLike], FloatLike],\n", " fun_params0: randvars.RandomVariable,\n", - " current_iter: IntArgType,\n", - " abstol: FloatArgType,\n", - " reltol: FloatArgType,\n", + " current_iter: IntLike,\n", + " abstol: FloatLike,\n", + " reltol: FloatLike,\n", ") -> Tuple[bool, Union[str, None]]:\n", " \"\"\"Termination based on numerical uncertainty about the parameters.\n", "\n", diff --git a/docs/source/development/quadopt_example/_probsolve_qp.py b/docs/source/development/quadopt_example/_probsolve_qp.py index 22b022f78e..0ea7e69383 100644 --- a/docs/source/development/quadopt_example/_probsolve_qp.py +++ b/docs/source/development/quadopt_example/_probsolve_qp.py @@ -6,7 +6,7 @@ import probnum as pn import probnum.utils as _utils from probnum import linops, randvars -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike from .belief_updates import gaussian_belief_update from .observation_operators import function_evaluation @@ -17,14 +17,14 @@ def probsolve_qp( rng: np.random.Generator, - fun: Callable[[FloatArgType], FloatArgType], + fun: Callable[[FloatLike], FloatLike], fun_params0: Optional[Union[np.ndarray, randvars.RandomVariable]] = None, assume_fun: Optional[str] = None, - tol: FloatArgType = 10 ** -5, - maxiter: IntArgType = 10 ** 4, + tol: FloatLike = 10 ** -5, + maxiter: IntLike = 10 ** 4, noise_cov: Optional[Union[np.ndarray, linops.LinearOperator]] = None, callback: Optional[ - Callable[[FloatArgType, FloatArgType, randvars.RandomVariable], None] + Callable[[FloatLike, FloatLike, randvars.RandomVariable], None] ] = None, ) -> Tuple[float, randvars.RandomVariable, randvars.RandomVariable, Dict]: """Probabilistic 1D Quadratic Optimization. diff --git a/docs/source/development/quadopt_example/belief_updates.py b/docs/source/development/quadopt_example/belief_updates.py index 1071f0392e..d2572ea263 100644 --- a/docs/source/development/quadopt_example/belief_updates.py +++ b/docs/source/development/quadopt_example/belief_updates.py @@ -7,13 +7,13 @@ import probnum as pn from probnum import linops, randvars -from probnum.typing import FloatArgType +from probnum.typing import FloatLike def gaussian_belief_update( fun_params0: randvars.RandomVariable, - action: FloatArgType, - observation: FloatArgType, + action: FloatLike, + observation: FloatLike, noise_cov: Union[np.ndarray, linops.LinearOperator], ) -> randvars.RandomVariable: """Update the belief over the parameters with an observation. diff --git a/docs/source/development/quadopt_example/observation_operators.py b/docs/source/development/quadopt_example/observation_operators.py index 74b684ceed..a08e25cf4c 100644 --- a/docs/source/development/quadopt_example/observation_operators.py +++ b/docs/source/development/quadopt_example/observation_operators.py @@ -5,11 +5,11 @@ import numpy as np from probnum import utils -from probnum.typing import FloatArgType +from probnum.typing import FloatLike def function_evaluation( - fun: Callable[[FloatArgType], FloatArgType], action: FloatArgType + fun: Callable[[FloatLike], FloatLike], action: FloatLike ) -> np.float_: """Observe a (noisy) function evaluation of the quadratic objective. diff --git a/docs/source/development/quadopt_example/policies.py b/docs/source/development/quadopt_example/policies.py index 2b1b2621e5..45e95adbee 100644 --- a/docs/source/development/quadopt_example/policies.py +++ b/docs/source/development/quadopt_example/policies.py @@ -5,11 +5,11 @@ import numpy as np from probnum import randvars -from probnum.typing import FloatArgType +from probnum.typing import FloatLike def explore_exploit_policy( - fun: Callable[[FloatArgType], FloatArgType], + fun: Callable[[FloatLike], FloatLike], fun_params0: randvars.RandomVariable, rng: np.random.Generator, ) -> float: @@ -31,7 +31,7 @@ def explore_exploit_policy( def stochastic_policy( - fun: Callable[[FloatArgType], FloatArgType], + fun: Callable[[FloatLike], FloatLike], fun_params0: randvars.RandomVariable, rng: np.random.Generator, ) -> float: diff --git a/docs/source/development/quadopt_example/probabilistic_quadratic_optimizer.py b/docs/source/development/quadopt_example/probabilistic_quadratic_optimizer.py index 60b2e980f0..20d76fad47 100644 --- a/docs/source/development/quadopt_example/probabilistic_quadratic_optimizer.py +++ b/docs/source/development/quadopt_example/probabilistic_quadratic_optimizer.py @@ -7,7 +7,7 @@ import probnum as pn import probnum.utils as _utils from probnum import linops, randvars -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike from .belief_updates import gaussian_belief_update from .observation_operators import function_evaluation @@ -17,24 +17,24 @@ # Type aliases for quadratic optimization QuadOptPolicyType = Callable[ [ - Callable[[FloatArgType], FloatArgType], + Callable[[FloatLike], FloatLike], randvars.RandomVariable, ], - FloatArgType, + FloatLike, ] QuadOptObservationOperatorType = Callable[ - [Callable[[FloatArgType], FloatArgType], FloatArgType], FloatArgType + [Callable[[FloatLike], FloatLike], FloatLike], FloatLike ] QuadOptBeliefUpdateType = Callable[ [ randvars.RandomVariable, - FloatArgType, - FloatArgType, + FloatLike, + FloatLike, ], randvars.RandomVariable, ] QuadOptStoppingCriterionType = Callable[ - [Callable[[FloatArgType], FloatArgType], randvars.RandomVariable, IntArgType], + [Callable[[FloatLike], FloatLike], randvars.RandomVariable, IntLike], Tuple[bool, Union[str, None]], ] @@ -131,7 +131,7 @@ def __init__( self.stopping_criteria = stopping_criteria def has_converged( - self, fun: Callable[[FloatArgType], FloatArgType], iteration: IntArgType + self, fun: Callable[[FloatLike], FloatLike], iteration: IntLike ) -> Tuple[bool, Union[str, None]]: """Check whether the optimizer has converged. @@ -152,7 +152,7 @@ def has_converged( def optim_iterator( self, - fun: Callable[[FloatArgType], FloatArgType], + fun: Callable[[FloatLike], FloatLike], ) -> Tuple[float, float, randvars.RandomVariable]: """Generator implementing the optimization iteration. @@ -187,7 +187,7 @@ def optim_iterator( def optimize( self, - fun: Callable[[FloatArgType], FloatArgType], + fun: Callable[[FloatLike], FloatLike], callback: Optional[ Callable[[float, float, randvars.RandomVariable], None] ] = None, diff --git a/docs/source/development/quadopt_example/stopping_criteria.py b/docs/source/development/quadopt_example/stopping_criteria.py index 1fc50af0d9..dad3bfc047 100644 --- a/docs/source/development/quadopt_example/stopping_criteria.py +++ b/docs/source/development/quadopt_example/stopping_criteria.py @@ -5,15 +5,15 @@ import numpy as np from probnum import randvars -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike def parameter_uncertainty( - fun: Callable[[FloatArgType], FloatArgType], + fun: Callable[[FloatLike], FloatLike], fun_params0: randvars.RandomVariable, - current_iter: IntArgType, - abstol: FloatArgType, - reltol: FloatArgType, + current_iter: IntLike, + abstol: FloatLike, + reltol: FloatLike, ) -> Tuple[bool, Union[str, None]]: """Termination based on numerical uncertainty about the parameters. @@ -41,10 +41,10 @@ def parameter_uncertainty( def maximum_iterations( - fun: Callable[[FloatArgType], FloatArgType], + fun: Callable[[FloatLike], FloatLike], fun_params0: randvars.RandomVariable, - current_iter: IntArgType, - maxiter: IntArgType, + current_iter: IntLike, + maxiter: IntLike, ) -> Tuple[bool, Union[str, None]]: """Termination based on maximum number of iterations. @@ -66,11 +66,11 @@ def maximum_iterations( def residual( - fun: Callable[[FloatArgType], FloatArgType], + fun: Callable[[FloatLike], FloatLike], fun_params0: randvars.RandomVariable, - current_iter: IntArgType, - abstol: FloatArgType, - reltol: FloatArgType, + current_iter: IntLike, + abstol: FloatLike, + reltol: FloatLike, ) -> Tuple[bool, Union[str, None]]: """Termination based on the residual. diff --git a/docs/source/development/styleguide.md b/docs/source/development/styleguide.md index c5eb6e09c2..9ec31cecd6 100644 --- a/docs/source/development/styleguide.md +++ b/docs/source/development/styleguide.md @@ -41,7 +41,7 @@ An exception from these rules are type-related modules, which include `typing` a Types are always imported directly. - `from typing import Optional, Callable` -- `from probnum.typing import FloatArgType` +- `from probnum.typing import FloatLike` Please do not abbreviate import paths unnecessarily. We do **not** use the following imports: - `import probnum.random_variables as pnrv` or `import probnum.filtsmooth as pnfs` (correct would be `from probnum import randvars, filtsmooth`) diff --git a/src/probnum/backend/_core/__init__.py b/src/probnum/backend/_core/__init__.py index 6cf990c267..02cc9fe70b 100644 --- a/src/probnum/backend/_core/__init__.py +++ b/src/probnum/backend/_core/__init__.py @@ -1,5 +1,5 @@ from probnum import backend as _backend -from probnum.typing import ArrayType, DTypeArgType, ScalarArgType +from probnum.typing import ArrayType, DTypeArgType, ScalarLike if _backend.BACKEND is _backend.Backend.NUMPY: from . import _numpy as _core @@ -79,7 +79,7 @@ jit_method = _core.jit_method -def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType: +def as_scalar(x: ScalarLike, dtype: DTypeArgType = None) -> ArrayType: """Convert a scalar into a NumPy scalar. Parameters diff --git a/src/probnum/backend/random/_jax.py b/src/probnum/backend/random/_jax.py index 8759a1754a..d98bcb7046 100644 --- a/src/probnum/backend/random/_jax.py +++ b/src/probnum/backend/random/_jax.py @@ -5,7 +5,7 @@ import jax from jax import numpy as jnp -from probnum.typing import DTypeArgType, FloatArgType, ShapeArgType +from probnum.typing import DTypeArgType, FloatLike, ShapeLike def seed(seed: Optional[int]) -> jnp.ndarray: @@ -28,9 +28,9 @@ def standard_normal(seed: jnp.ndarray, shape=(), dtype=jnp.double): def gamma( seed: jnp.ndarray, - shape_param: FloatArgType, - scale_param: FloatArgType = 1.0, - shape: ShapeArgType = (), + shape_param: FloatLike, + scale_param: FloatLike = 1.0, + shape: ShapeLike = (), dtype: DTypeArgType = jnp.double, ): return ( @@ -43,7 +43,7 @@ def gamma( def uniform_so_group( seed: jnp.ndarray, n: int, - shape: ShapeArgType = (), + shape: ShapeLike = (), dtype: DTypeArgType = jnp.double, ) -> jnp.ndarray: if n == 1: diff --git a/src/probnum/backend/random/_numpy.py b/src/probnum/backend/random/_numpy.py index ec11e98827..dae2971f12 100644 --- a/src/probnum/backend/random/_numpy.py +++ b/src/probnum/backend/random/_numpy.py @@ -3,7 +3,7 @@ import numpy as np -from probnum.typing import DTypeArgType, FloatArgType, ShapeArgType +from probnum.typing import DTypeArgType, FloatLike, ShapeLike def seed(seed: Optional[int]) -> np.random.SeedSequence: @@ -21,7 +21,7 @@ def split( def standard_normal( seed: np.random.SeedSequence, - shape: ShapeArgType = (), + shape: ShapeLike = (), dtype: DTypeArgType = np.double, ) -> np.ndarray: return _make_rng(seed).standard_normal(size=shape, dtype=dtype) @@ -29,9 +29,9 @@ def standard_normal( def gamma( seed: np.random.SeedSequence, - shape_param: FloatArgType, - scale_param: FloatArgType = 1.0, - shape: ShapeArgType = (), + shape_param: FloatLike, + scale_param: FloatLike = 1.0, + shape: ShapeLike = (), dtype: DTypeArgType = np.double, ) -> np.ndarray: return ( @@ -43,7 +43,7 @@ def gamma( def uniform_so_group( seed: np.random.SeedSequence, n: int, - shape: ShapeArgType = (), + shape: ShapeLike = (), dtype: DTypeArgType = np.double, ) -> np.ndarray: if n == 1: diff --git a/src/probnum/backend/random/_torch.py b/src/probnum/backend/random/_torch.py index 4e85d5c90f..968885ffbb 100644 --- a/src/probnum/backend/random/_torch.py +++ b/src/probnum/backend/random/_torch.py @@ -4,7 +4,7 @@ import torch from torch.distributions.utils import broadcast_all -from probnum.typing import DTypeArgType, ShapeArgType +from probnum.typing import DTypeArgType, ShapeLike _RNG_STATE_SIZE = torch.Generator().get_state().shape[0] @@ -51,7 +51,7 @@ def gamma( def uniform_so_group( seed: np.random.SeedSequence, n: int, - shape: ShapeArgType = (), + shape: ShapeLike = (), dtype: DTypeArgType = torch.double, ) -> torch.Tensor: if n == 1: diff --git a/src/probnum/diffeq/_odesolution.py b/src/probnum/diffeq/_odesolution.py index 0c7e913cfd..3077adf58d 100644 --- a/src/probnum/diffeq/_odesolution.py +++ b/src/probnum/diffeq/_odesolution.py @@ -10,8 +10,8 @@ import numpy as np from probnum import filtsmooth, randvars -from probnum.filtsmooth._timeseriesposterior import DenseOutputLocationArgType -from probnum.typing import FloatArgType, IntArgType, ShapeArgType +from probnum.filtsmooth._timeseriesposterior import ArrayLike +from probnum.typing import FloatLike, IntLike, ShapeLike class ODESolution(filtsmooth.TimeSeriesPosterior): @@ -43,9 +43,9 @@ def __init__( def interpolate( self, - t: FloatArgType, - previous_index: Optional[IntArgType] = None, - next_index: Optional[IntArgType] = None, + t: FloatLike, + previous_index: Optional[IntLike] = None, + next_index: Optional[IntLike] = None, ) -> randvars.RandomVariable: raise NotImplementedError("Dense output is not implemented.") @@ -60,8 +60,8 @@ def __getitem__(self, idx: int) -> randvars.RandomVariable: def sample( self, rng: np.random.Generator, - t: Optional[DenseOutputLocationArgType] = None, - size: Optional[ShapeArgType] = (), + t: Optional[ArrayLike] = None, + size: Optional[ShapeLike] = (), ) -> np.ndarray: """Sample from the ODE solution. @@ -83,7 +83,7 @@ def sample( def transform_base_measure_realizations( self, base_measure_realizations: np.ndarray, - t: DenseOutputLocationArgType, + t: ArrayLike, ) -> np.ndarray: raise NotImplementedError( "Transforming base measure realizations is not implemented." diff --git a/src/probnum/diffeq/_odesolver.py b/src/probnum/diffeq/_odesolver.py index d32fe18370..777430dbd9 100644 --- a/src/probnum/diffeq/_odesolver.py +++ b/src/probnum/diffeq/_odesolver.py @@ -8,7 +8,7 @@ from probnum import problems from probnum.diffeq import callbacks -from probnum.typing import FloatArgType +from probnum.typing import FloatLike CallbackType = Union[callbacks.ODESolverCallback, Iterable[callbacks.ODESolverCallback]] """Callback interface type.""" @@ -29,7 +29,7 @@ def __init__( def solve( self, ivp: problems.InitialValueProblem, - stop_at: Iterable[FloatArgType] = None, + stop_at: Iterable[FloatLike] = None, callbacks: Optional[CallbackType] = None, ): """Solve an IVP. @@ -54,7 +54,7 @@ def solve( def solution_generator( self, ivp: problems.InitialValueProblem, - stop_at: Iterable[FloatArgType] = None, + stop_at: Iterable[FloatLike] = None, callbacks: Optional[CallbackType] = None, ): """Generate ODE solver steps.""" diff --git a/src/probnum/diffeq/odefilter/_odefilter_solution.py b/src/probnum/diffeq/odefilter/_odefilter_solution.py index 80c2b24b01..239cd989bf 100644 --- a/src/probnum/diffeq/odefilter/_odefilter_solution.py +++ b/src/probnum/diffeq/odefilter/_odefilter_solution.py @@ -6,8 +6,8 @@ from probnum import filtsmooth, randvars, utils from probnum.diffeq import _odesolution -from probnum.filtsmooth._timeseriesposterior import DenseOutputLocationArgType -from probnum.typing import FloatArgType, IntArgType, ShapeArgType +from probnum.filtsmooth._timeseriesposterior import ArrayLike +from probnum.typing import FloatLike, IntLike, ShapeLike class ODEFilterSolution(_odesolution.ODESolution): @@ -90,9 +90,9 @@ def __init__(self, kalman_posterior: filtsmooth.gaussian.KalmanPosterior): def interpolate( self, - t: FloatArgType, - previous_index: Optional[IntArgType] = None, - next_index: Optional[IntArgType] = None, + t: FloatLike, + previous_index: Optional[IntLike] = None, + next_index: Optional[IntLike] = None, ) -> randvars.RandomVariable: out_rv = self.kalman_posterior.interpolate( t, previous_index=previous_index, next_index=next_index @@ -102,8 +102,8 @@ def interpolate( def sample( self, rng: np.random.Generator, - t: Optional[DenseOutputLocationArgType] = None, - size: Optional[ShapeArgType] = (), + t: Optional[ArrayLike] = None, + size: Optional[ShapeLike] = (), ) -> np.ndarray: samples = self.kalman_posterior.sample(rng=rng, t=t, size=size) @@ -116,7 +116,7 @@ def sample( def transform_base_measure_realizations( self, base_measure_realizations: np.ndarray, - t: DenseOutputLocationArgType = None, + t: ArrayLike = None, ) -> np.ndarray: errormsg = ( "The ODEFilterSolution does not implement transformation of realizations of a base measure." diff --git a/src/probnum/diffeq/odefilter/information_operators/_information_operator.py b/src/probnum/diffeq/odefilter/information_operators/_information_operator.py index 05cbe74f88..33e11d8936 100644 --- a/src/probnum/diffeq/odefilter/information_operators/_information_operator.py +++ b/src/probnum/diffeq/odefilter/information_operators/_information_operator.py @@ -6,7 +6,7 @@ import numpy as np from probnum import problems, randprocs -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike __all__ = ["InformationOperator", "ODEInformationOperator"] @@ -36,22 +36,22 @@ class InformationOperator(abc.ABC): Therefore, they are one important component in a probabilistic ODE solver. """ - def __init__(self, input_dim: IntArgType, output_dim: IntArgType): + def __init__(self, input_dim: IntLike, output_dim: IntLike): self.input_dim = input_dim self.output_dim = output_dim @abc.abstractmethod - def __call__(self, t: FloatArgType, x: np.ndarray) -> np.ndarray: + def __call__(self, t: FloatLike, x: np.ndarray) -> np.ndarray: raise NotImplementedError - def jacobian(self, t: FloatArgType, x: np.ndarray) -> np.ndarray: + def jacobian(self, t: FloatLike, x: np.ndarray) -> np.ndarray: raise NotImplementedError def as_transition( self, - measurement_cov_fun: Optional[Callable[[FloatArgType], np.ndarray]] = None, + measurement_cov_fun: Optional[Callable[[FloatLike], np.ndarray]] = None, measurement_cov_cholesky_fun: Optional[ - Callable[[FloatArgType], np.ndarray] + Callable[[FloatLike], np.ndarray] ] = None, ): @@ -84,7 +84,7 @@ class ODEInformationOperator(InformationOperator): :class:`InitialValueProblem`. Not all information operators that are used in ODE solvers do. """ - def __init__(self, input_dim: IntArgType, output_dim: IntArgType): + def __init__(self, input_dim: IntLike, output_dim: IntLike): super().__init__(input_dim=input_dim, output_dim=output_dim) # Initialized once the ODE can be seen @@ -103,9 +103,9 @@ def ode_has_been_incorporated(self) -> bool: def as_transition( self, - measurement_cov_fun: Optional[Callable[[FloatArgType], np.ndarray]] = None, + measurement_cov_fun: Optional[Callable[[FloatLike], np.ndarray]] = None, measurement_cov_cholesky_fun: Optional[ - Callable[[FloatArgType], np.ndarray] + Callable[[FloatLike], np.ndarray] ] = None, ): if not self.ode_has_been_incorporated: diff --git a/src/probnum/diffeq/odefilter/information_operators/_ode_residual.py b/src/probnum/diffeq/odefilter/information_operators/_ode_residual.py index 01f953aba6..c77b65b9fd 100644 --- a/src/probnum/diffeq/odefilter/information_operators/_ode_residual.py +++ b/src/probnum/diffeq/odefilter/information_operators/_ode_residual.py @@ -6,7 +6,7 @@ from probnum import problems, randprocs from probnum.diffeq.odefilter.information_operators import _information_operator -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike __all__ = ["ODEResidual"] @@ -14,7 +14,7 @@ class ODEResidual(_information_operator.ODEInformationOperator): """Information operator that measures the residual of an explicit ODE.""" - def __init__(self, num_prior_derivatives: IntArgType, ode_dimension: IntArgType): + def __init__(self, num_prior_derivatives: IntLike, ode_dimension: IntLike): integrator_dimension = ode_dimension * (num_prior_derivatives + 1) super().__init__(input_dim=integrator_dimension, output_dim=ode_dimension) # Store remaining attributes @@ -47,7 +47,7 @@ def incorporate_ode(self, ode: problems.InitialValueProblem): self._residual, self._residual_jacobian = res, res_jac def _match_residual_and_jacobian_to_ode_order( - self, ode_order: IntArgType + self, ode_order: IntLike ) -> Tuple[Callable, Callable]: """Choose the correct residual (and Jacobian) implementation based on the order of the ODE.""" @@ -56,20 +56,20 @@ def _match_residual_and_jacobian_to_ode_order( } return choose_implementation[ode_order] - def __call__(self, t: FloatArgType, x: np.ndarray) -> np.ndarray: + def __call__(self, t: FloatLike, x: np.ndarray) -> np.ndarray: return self._residual(t, x) - def jacobian(self, t: FloatArgType, x: np.ndarray) -> np.ndarray: + def jacobian(self, t: FloatLike, x: np.ndarray) -> np.ndarray: return self._residual_jacobian(t, x) # Implementation of different residuals - def _residual_first_order_ode(self, t: FloatArgType, x: np.ndarray) -> np.ndarray: + def _residual_first_order_ode(self, t: FloatLike, x: np.ndarray) -> np.ndarray: h0, h1 = self.projection_matrices return h1 @ x - self.ode.f(t, h0 @ x) def _residual_first_order_ode_jacobian( - self, t: FloatArgType, x: np.ndarray + self, t: FloatLike, x: np.ndarray ) -> np.ndarray: h0, h1 = self.projection_matrices return h1 - self.ode.df(t, h0 @ x) @ h0 diff --git a/src/probnum/diffeq/odefilter/initialization_routines/_runge_kutta.py b/src/probnum/diffeq/odefilter/initialization_routines/_runge_kutta.py index f3d92a5a57..cd30fa4eb4 100644 --- a/src/probnum/diffeq/odefilter/initialization_routines/_runge_kutta.py +++ b/src/probnum/diffeq/odefilter/initialization_routines/_runge_kutta.py @@ -8,7 +8,7 @@ from probnum import filtsmooth, problems, randprocs, randvars from probnum.diffeq.odefilter.initialization_routines import _initialization_routine -from probnum.typing import FloatArgType +from probnum.typing import FloatLike class RungeKuttaInitialization(_initialization_routine.InitializationRoutine): @@ -52,7 +52,7 @@ class RungeKuttaInitialization(_initialization_routine.InitializationRoutine): """ def __init__( - self, dt: Optional[FloatArgType] = 1e-2, method: Optional[str] = "DOP853" + self, dt: Optional[FloatLike] = 1e-2, method: Optional[str] = "DOP853" ): self.dt = dt self.method = method diff --git a/src/probnum/diffeq/odefilter/utils/_problem_utils.py b/src/probnum/diffeq/odefilter/utils/_problem_utils.py index d81f54aa0d..00fde3c54f 100644 --- a/src/probnum/diffeq/odefilter/utils/_problem_utils.py +++ b/src/probnum/diffeq/odefilter/utils/_problem_utils.py @@ -6,7 +6,7 @@ from probnum import problems, randprocs from probnum.diffeq.odefilter import approx_strategies, information_operators -from probnum.typing import FloatArgType +from probnum.typing import FloatLike __all__ = ["ivp_to_regression_problem"] @@ -19,7 +19,7 @@ def ivp_to_regression_problem( locations: Union[Sequence, np.ndarray], ode_information_operator: information_operators.InformationOperator, approx_strategy: Optional[approx_strategies.ApproximationStrategy] = None, - ode_measurement_variance: Optional[FloatArgType] = 0.0, + ode_measurement_variance: Optional[FloatLike] = 0.0, exclude_initial_condition=False, ): """Transform an initial value problem into a regression problem. diff --git a/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_odesolution.py b/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_odesolution.py index 147dcc5657..768bf77f71 100644 --- a/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_odesolution.py +++ b/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_odesolution.py @@ -5,7 +5,7 @@ from probnum import randvars from probnum.diffeq import _odesolution from probnum.filtsmooth._timeseriesposterior import DenseOutputValueType -from probnum.typing import DenseOutputLocationArgType +from probnum.typing import ArrayLike class WrappedScipyODESolution(_odesolution.ODESolution): @@ -19,7 +19,7 @@ def __init__(self, scipy_solution: OdeSolution, rvs: list): rv_states = randvars._RandomVariableList(rvs) super().__init__(locations=scipy_solution.ts, states=rv_states) - def __call__(self, t: DenseOutputLocationArgType) -> DenseOutputValueType: + def __call__(self, t: ArrayLike) -> DenseOutputValueType: """Evaluate the time-continuous solution at time t. Parameters diff --git a/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_solver.py b/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_solver.py index d92c2347a2..32a4472120 100644 --- a/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_solver.py +++ b/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_solver.py @@ -10,7 +10,7 @@ from probnum import randvars from probnum.diffeq import _odesolver, _odesolver_state from probnum.diffeq.perturbed.scipy_wrapper import _wrapped_scipy_odesolution -from probnum.typing import FloatArgType +from probnum.typing import FloatLike class WrappedScipyRungeKutta(_odesolver.ODESolver): @@ -62,7 +62,7 @@ def initialize(self, ivp): ) return state - def attempt_step(self, state: _odesolver_state.ODESolverState, dt: FloatArgType): + def attempt_step(self, state: _odesolver_state.ODESolverState, dt: FloatLike): """Perform one ODE-step from start to stop and set variables to the corresponding values. diff --git a/src/probnum/diffeq/perturbed/step/_perturbation_functions.py b/src/probnum/diffeq/perturbed/step/_perturbation_functions.py index 6544635acc..2e56b0fd95 100644 --- a/src/probnum/diffeq/perturbed/step/_perturbation_functions.py +++ b/src/probnum/diffeq/perturbed/step/_perturbation_functions.py @@ -4,15 +4,15 @@ import numpy as np import scipy -from probnum.typing import FloatArgType, IntArgType, ShapeArgType +from probnum.typing import FloatLike, IntLike, ShapeLike def perturb_uniform( rng: np.random.Generator, - step: FloatArgType, - solver_order: IntArgType, - noise_scale: FloatArgType, - size: Optional[ShapeArgType] = (), + step: FloatLike, + solver_order: IntLike, + noise_scale: FloatLike, + size: Optional[ShapeLike] = (), ) -> Union[float, np.ndarray]: """Perturb the step with uniformly distributed noise. @@ -50,10 +50,10 @@ def perturb_uniform( def perturb_lognormal( rng: np.random.Generator, - step: FloatArgType, - solver_order: IntArgType, - noise_scale: FloatArgType, - size: Optional[ShapeArgType] = (), + step: FloatLike, + solver_order: IntLike, + noise_scale: FloatLike, + size: Optional[ShapeLike] = (), ) -> Union[float, np.ndarray]: """Perturb the step with log-normally distributed noise. diff --git a/src/probnum/diffeq/perturbed/step/_perturbedstepsolution.py b/src/probnum/diffeq/perturbed/step/_perturbedstepsolution.py index 3febdbed2c..bfe0fd22f3 100644 --- a/src/probnum/diffeq/perturbed/step/_perturbedstepsolution.py +++ b/src/probnum/diffeq/perturbed/step/_perturbedstepsolution.py @@ -7,7 +7,7 @@ from probnum import randvars from probnum.diffeq import _odesolution -from probnum.typing import FloatArgType +from probnum.typing import FloatLike class PerturbedStepSolution(_odesolution.ODESolution): @@ -26,9 +26,9 @@ def __init__( def interpolate( self, - t: FloatArgType, - previous_index: Optional[FloatArgType] = None, - next_index: Optional[FloatArgType] = None, + t: FloatLike, + previous_index: Optional[FloatLike] = None, + next_index: Optional[FloatLike] = None, ): # For the first state, no interpolation has to be performed. if t == self.locations[0]: diff --git a/src/probnum/diffeq/perturbed/step/_perturbedstepsolver.py b/src/probnum/diffeq/perturbed/step/_perturbedstepsolver.py index feb5634bcb..4ff9447da2 100644 --- a/src/probnum/diffeq/perturbed/step/_perturbedstepsolver.py +++ b/src/probnum/diffeq/perturbed/step/_perturbedstepsolver.py @@ -11,7 +11,7 @@ _perturbation_functions, _perturbedstepsolution, ) -from probnum.typing import FloatArgType +from probnum.typing import FloatLike class PerturbedStepSolver(_odesolver.ODESolver): @@ -44,7 +44,7 @@ def __init__( self, rng: np.random.Generator, solver: scipy_wrapper.WrappedScipyRungeKutta, - noise_scale: FloatArgType, + noise_scale: FloatLike, perturb_function: Callable, ): def perturb_step(rng, step): @@ -67,7 +67,7 @@ def construct_with_lognormal_perturbation( cls, rng: np.random.Generator, solver: scipy_wrapper.WrappedScipyRungeKutta, - noise_scale: FloatArgType, + noise_scale: FloatLike, ): pertfun = _perturbation_functions.perturb_lognormal return cls( @@ -82,7 +82,7 @@ def construct_with_uniform_perturbation( cls, rng: np.random.Generator, solver: scipy_wrapper.WrappedScipyRungeKutta, - noise_scale: FloatArgType, + noise_scale: FloatLike, ): pertfun = _perturbation_functions.perturb_uniform return cls( @@ -97,7 +97,7 @@ def initialize(self, ivp): self.scales = [] return self.solver.initialize(ivp) - def attempt_step(self, state: _odesolver_state.ODESolverState, dt: FloatArgType): + def attempt_step(self, state: _odesolver_state.ODESolverState, dt: FloatLike): """Perturb the original stopping point. Perform one perturbed step and project the solution back to the original diff --git a/src/probnum/diffeq/stepsize/_steprule.py b/src/probnum/diffeq/stepsize/_steprule.py index 9a0feb24fe..fd33f5c9f1 100644 --- a/src/probnum/diffeq/stepsize/_steprule.py +++ b/src/probnum/diffeq/stepsize/_steprule.py @@ -5,27 +5,27 @@ import numpy as np -from probnum.typing import FloatArgType, IntArgType, ToleranceDiffusionType +from probnum.typing import FloatLike, IntLike, ArrayLike class StepRule(ABC): """Step-size selection rules for ODE solvers.""" - def __init__(self, firststep: FloatArgType): + def __init__(self, firststep: FloatLike): self.firststep = firststep @abstractmethod def suggest( self, - laststep: FloatArgType, - scaled_error: FloatArgType, - localconvrate: Optional[IntArgType] = None, + laststep: FloatLike, + scaled_error: FloatLike, + localconvrate: Optional[IntLike] = None, ): """Suggest a new step h_{n+1} given error estimate e_n at step h_n.""" raise NotImplementedError @abstractmethod - def is_accepted(self, scaled_error: FloatArgType): + def is_accepted(self, scaled_error: FloatLike): """Check if the proposed step should be accepted or not. Variable "proposedstep" not used yet, but may be important in @@ -35,9 +35,7 @@ def is_accepted(self, scaled_error: FloatArgType): raise NotImplementedError @abstractmethod - def errorest_to_norm( - self, errorest: ToleranceDiffusionType, reference_state: np.ndarray - ): + def errorest_to_norm(self, errorest: ArrayLike, reference_state: np.ndarray): """Computes the norm of error per tolerance (usually referred to as 'E'). The norm is usually the current error estimate normalised with @@ -50,25 +48,23 @@ def errorest_to_norm( class ConstantSteps(StepRule): """Constant step-sizes.""" - def __init__(self, stepsize: FloatArgType): + def __init__(self, stepsize: FloatLike): self.step = stepsize super().__init__(firststep=stepsize) def suggest( self, - laststep: FloatArgType, - scaled_error: FloatArgType, - localconvrate: Optional[IntArgType] = None, + laststep: FloatLike, + scaled_error: FloatLike, + localconvrate: Optional[IntLike] = None, ): return self.step - def is_accepted(self, scaled_error: FloatArgType): + def is_accepted(self, scaled_error: FloatLike): """Always True.""" return True - def errorest_to_norm( - self, errorest: ToleranceDiffusionType, reference_state: np.ndarray - ): + def errorest_to_norm(self, errorest: ArrayLike, reference_state: np.ndarray): pass @@ -92,13 +88,13 @@ class AdaptiveSteps(StepRule): def __init__( self, - firststep: FloatArgType, - atol: ToleranceDiffusionType, - rtol: ToleranceDiffusionType, - limitchange: Optional[Tuple[FloatArgType]] = (0.2, 10.0), - safetyscale: Optional[FloatArgType] = 0.95, - minstep: Optional[FloatArgType] = 1e-15, - maxstep: Optional[FloatArgType] = 1e15, + firststep: FloatLike, + atol: ArrayLike, + rtol: ArrayLike, + limitchange: Optional[Tuple[FloatLike]] = (0.2, 10.0), + safetyscale: Optional[FloatLike] = 0.95, + minstep: Optional[FloatLike] = 1e-15, + maxstep: Optional[FloatLike] = 1e15, ): self.safetyscale = safetyscale self.limitchange = limitchange @@ -110,9 +106,9 @@ def __init__( def suggest( self, - laststep: FloatArgType, - scaled_error: FloatArgType, - localconvrate: Optional[IntArgType] = None, + laststep: FloatLike, + scaled_error: FloatLike, + localconvrate: Optional[IntLike] = None, ): small, large = self.limitchange @@ -133,12 +129,10 @@ def suggest( raise RuntimeError("Step-size larger than maximum step-size") return step - def is_accepted(self, scaled_error: FloatArgType): + def is_accepted(self, scaled_error: FloatLike): return scaled_error < 1 - def errorest_to_norm( - self, errorest: ToleranceDiffusionType, reference_state: np.ndarray - ): + def errorest_to_norm(self, errorest: ArrayLike, reference_state: np.ndarray): tolerance = self.atol + self.rtol * reference_state ratio = errorest / tolerance dim = len(ratio) if ratio.ndim > 0 else 1 diff --git a/src/probnum/filtsmooth/_timeseriesposterior.py b/src/probnum/filtsmooth/_timeseriesposterior.py index 5fb0ac51c1..b3a47771f9 100644 --- a/src/probnum/filtsmooth/_timeseriesposterior.py +++ b/src/probnum/filtsmooth/_timeseriesposterior.py @@ -7,11 +7,11 @@ from probnum import randvars from probnum.typing import ( - ArrayLikeGetitemArgType, - DenseOutputLocationArgType, - FloatArgType, - IntArgType, - ShapeArgType, + ArrayIndicesLike, + ArrayLike, + FloatLike, + IntLike, + ShapeLike, ) DenseOutputValueType = Union[randvars.RandomVariable, randvars._RandomVariableList] @@ -33,14 +33,14 @@ class TimeSeriesPosterior(abc.ABC): def __init__( self, - locations: Optional[Iterable[FloatArgType]] = None, + locations: Optional[Iterable[FloatLike]] = None, states: Optional[Iterable[randvars.RandomVariable]] = None, ) -> None: self._locations = list(locations) if locations is not None else [] self._states = list(states) if states is not None else [] self._frozen = False - def _check_location(self, location: FloatArgType) -> FloatArgType: + def _check_location(self, location: FloatLike) -> FloatLike: if len(self._locations) > 0 and location <= self._locations[-1]: _err_msg = "Locations have to be strictly ascending. " _err_msg += f"Received {location} <= {self._locations[-1]}." @@ -49,7 +49,7 @@ def _check_location(self, location: FloatArgType) -> FloatArgType: def append( self, - location: FloatArgType, + location: FloatLike, state: randvars.RandomVariable, ) -> None: @@ -81,10 +81,10 @@ def __len__(self) -> int: """ return len(self.locations) - def __getitem__(self, idx: ArrayLikeGetitemArgType) -> randvars.RandomVariable: + def __getitem__(self, idx: ArrayIndicesLike) -> randvars.RandomVariable: return self.states[idx] - def __call__(self, t: DenseOutputLocationArgType) -> DenseOutputValueType: + def __call__(self, t: ArrayLike) -> DenseOutputValueType: """Evaluate the time-continuous posterior at location `t` Algorithm: @@ -159,9 +159,9 @@ def __call__(self, t: DenseOutputLocationArgType) -> DenseOutputValueType: @abc.abstractmethod def interpolate( self, - t: FloatArgType, - previous_index: Optional[IntArgType] = None, - next_index: Optional[IntArgType] = None, + t: FloatLike, + previous_index: Optional[IntLike] = None, + next_index: Optional[IntLike] = None, ) -> randvars.RandomVariable: """Evaluate the posterior at a measurement-free point. @@ -176,8 +176,8 @@ def interpolate( def sample( self, rng: np.random.Generator, - t: Optional[DenseOutputLocationArgType] = None, - size: Optional[ShapeArgType] = (), + t: Optional[ArrayLike] = None, + size: Optional[ShapeLike] = (), ) -> np.ndarray: """Draw samples from the filtering/smoothing posterior. @@ -213,7 +213,7 @@ def sample( def transform_base_measure_realizations( self, base_measure_realizations: np.ndarray, - t: Optional[DenseOutputLocationArgType], + t: Optional[ArrayLike], ) -> np.ndarray: """Transform a set of realizations from a base measure into realizations from the posterior. diff --git a/src/probnum/filtsmooth/gaussian/_kalmanposterior.py b/src/probnum/filtsmooth/gaussian/_kalmanposterior.py index 1f41c21324..ed2d867a32 100644 --- a/src/probnum/filtsmooth/gaussian/_kalmanposterior.py +++ b/src/probnum/filtsmooth/gaussian/_kalmanposterior.py @@ -13,10 +13,10 @@ from probnum.filtsmooth import _timeseriesposterior from probnum.filtsmooth.gaussian import approx from probnum.typing import ( - DenseOutputLocationArgType, - FloatArgType, - IntArgType, - ShapeArgType, + ArrayLike, + FloatLike, + IntLike, + ShapeLike, ) GaussMarkovPriorTransitionArgType = Union[ @@ -46,7 +46,7 @@ class KalmanPosterior(_timeseriesposterior.TimeSeriesPosterior, abc.ABC): def __init__( self, transition: GaussMarkovPriorTransitionArgType, - locations: Optional[Iterable[FloatArgType]] = None, + locations: Optional[Iterable[FloatLike]] = None, states: Optional[Iterable[randvars.RandomVariable]] = None, diffusion_model=None, ) -> None: @@ -60,17 +60,17 @@ def __init__( @abc.abstractmethod def interpolate( self, - t: FloatArgType, - previous_index: Optional[IntArgType] = None, - next_index: Optional[IntArgType] = None, + t: FloatLike, + previous_index: Optional[IntLike] = None, + next_index: Optional[IntLike] = None, ) -> randvars.RandomVariable: raise NotImplementedError def sample( self, rng: np.random.Generator, - t: Optional[DenseOutputLocationArgType] = None, - size: Optional[ShapeArgType] = (), + t: Optional[ArrayLike] = None, + size: Optional[ShapeLike] = (), ) -> np.ndarray: size = utils.as_shape(size) @@ -108,7 +108,7 @@ def sample( def transform_base_measure_realizations( self, base_measure_realizations: np.ndarray, - t: DenseOutputLocationArgType, + t: ArrayLike, ) -> np.ndarray: """Transform samples from a base measure to samples from the KalmanPosterior. @@ -157,7 +157,7 @@ def __init__( self, filtering_posterior: _timeseriesposterior.TimeSeriesPosterior, transition: GaussMarkovPriorTransitionArgType, - locations: Iterable[FloatArgType], + locations: Iterable[FloatLike], states: Iterable[randvars.RandomVariable], diffusion_model=None, ): @@ -171,9 +171,9 @@ def __init__( def interpolate( self, - t: FloatArgType, - previous_index: Optional[IntArgType] = None, - next_index: Optional[IntArgType] = None, + t: FloatLike, + previous_index: Optional[IntLike] = None, + next_index: Optional[IntLike] = None, ) -> randvars.RandomVariable: # Assert either previous_location or next_location is not None @@ -364,9 +364,9 @@ class FilteringPosterior(KalmanPosterior): def interpolate( self, - t: FloatArgType, - previous_index: Optional[IntArgType] = None, - next_index: Optional[IntArgType] = None, + t: FloatLike, + previous_index: Optional[IntLike] = None, + next_index: Optional[IntLike] = None, ) -> randvars.RandomVariable: # Assert either previous_location or next_location is not None @@ -427,8 +427,8 @@ def interpolate( def sample( self, rng: np.random.Generator, - t: Optional[DenseOutputLocationArgType] = None, - size: Optional[ShapeArgType] = (), + t: Optional[ArrayLike] = None, + size: Optional[ShapeLike] = (), ) -> np.ndarray: # If this error would not be thrown here, trying to sample from a FilteringPosterior # would call FilteringPosterior.transform_base_measure_realizations which is not implemented. @@ -441,7 +441,7 @@ def sample( def transform_base_measure_realizations( self, base_measure_realizations: np.ndarray, - t: Optional[DenseOutputLocationArgType] = None, + t: Optional[ArrayLike] = None, ) -> np.ndarray: raise NotImplementedError( "Transforming base measure realizations is not implemented." diff --git a/src/probnum/filtsmooth/gaussian/approx/_unscentedkalman.py b/src/probnum/filtsmooth/gaussian/approx/_unscentedkalman.py index acc083f35f..6e1a373e1a 100644 --- a/src/probnum/filtsmooth/gaussian/approx/_unscentedkalman.py +++ b/src/probnum/filtsmooth/gaussian/approx/_unscentedkalman.py @@ -11,7 +11,7 @@ from probnum import randprocs, randvars from probnum.filtsmooth.gaussian.approx import _unscentedtransform -from probnum.typing import FloatArgType +from probnum.typing import FloatLike class UKFComponent: @@ -20,9 +20,9 @@ class UKFComponent: def __init__( self, non_linear_model, - spread: Optional[FloatArgType] = 1e-4, - priorpar: Optional[FloatArgType] = 2.0, - special_scale: Optional[FloatArgType] = 0.0, + spread: Optional[FloatLike] = 1e-4, + priorpar: Optional[FloatLike] = 2.0, + special_scale: Optional[FloatLike] = 0.0, ) -> None: self.non_linear_model = non_linear_model self.ut = _unscentedtransform.UnscentedTransform( @@ -57,11 +57,11 @@ class ContinuousUKFComponent(UKFComponent, randprocs.markov.continuous.SDE): def __init__( self, non_linear_model, - spread: Optional[FloatArgType] = 1e-4, - priorpar: Optional[FloatArgType] = 2.0, - special_scale: Optional[FloatArgType] = 0.0, - mde_atol: Optional[FloatArgType] = 1e-6, - mde_rtol: Optional[FloatArgType] = 1e-6, + spread: Optional[FloatLike] = 1e-4, + priorpar: Optional[FloatLike] = 2.0, + special_scale: Optional[FloatLike] = 0.0, + mde_atol: Optional[FloatLike] = 1e-6, + mde_rtol: Optional[FloatLike] = 1e-6, mde_solver: Optional[str] = "LSODA", ) -> None: @@ -153,9 +153,9 @@ class DiscreteUKFComponent(UKFComponent, randprocs.markov.discrete.NonlinearGaus def __init__( self, non_linear_model, - spread: Optional[FloatArgType] = 1e-4, - priorpar: Optional[FloatArgType] = 2.0, - special_scale: Optional[FloatArgType] = 0.0, + spread: Optional[FloatLike] = 1e-4, + priorpar: Optional[FloatLike] = 2.0, + special_scale: Optional[FloatLike] = 0.0, ) -> None: UKFComponent.__init__( self, diff --git a/src/probnum/filtsmooth/particle/_particle_filter.py b/src/probnum/filtsmooth/particle/_particle_filter.py index 4332dc14e7..1ce628d137 100644 --- a/src/probnum/filtsmooth/particle/_particle_filter.py +++ b/src/probnum/filtsmooth/particle/_particle_filter.py @@ -10,7 +10,7 @@ _importance_distributions, _particle_filter_posterior, ) -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike # Terribly long variable names, but internal only, so no worries. ParticleFilterMeasurementModelArgType = Union[ @@ -65,10 +65,10 @@ def __init__( self, prior_process: randprocs.markov.MarkovProcess, importance_distribution: _importance_distributions.ImportanceDistribution, - num_particles: IntArgType, + num_particles: IntLike, rng: np.random.Generator, with_resampling: bool = True, - resampling_percentage_threshold: FloatArgType = 0.1, + resampling_percentage_threshold: FloatLike = 0.1, ) -> None: super().__init__( prior_process=prior_process, diff --git a/src/probnum/filtsmooth/particle/_particle_filter_posterior.py b/src/probnum/filtsmooth/particle/_particle_filter_posterior.py index 2cc54125f6..f19c6fa52c 100644 --- a/src/probnum/filtsmooth/particle/_particle_filter_posterior.py +++ b/src/probnum/filtsmooth/particle/_particle_filter_posterior.py @@ -6,7 +6,7 @@ from probnum import randvars from probnum.filtsmooth import _timeseriesposterior -from probnum.typing import DenseOutputLocationArgType, FloatArgType, ShapeArgType +from probnum.typing import ArrayLike, FloatLike, ShapeLike class ParticleFilterPosterior(_timeseriesposterior.TimeSeriesPosterior): @@ -17,21 +17,21 @@ def __call__(self, t): # The methods below are not implemented (yet?). - def interpolate(self, t: FloatArgType) -> randvars.RandomVariable: + def interpolate(self, t: FloatLike) -> randvars.RandomVariable: raise NotImplementedError def sample( self, rng: np.random.Generator, - t: Optional[DenseOutputLocationArgType] = None, - size: Optional[ShapeArgType] = (), + t: Optional[ArrayLike] = None, + size: Optional[ShapeLike] = (), ) -> np.ndarray: raise NotImplementedError("Sampling is not implemented.") def transform_base_measure_realizations( self, base_measure_realizations: np.ndarray, - t: Optional[DenseOutputLocationArgType] = None, + t: Optional[ArrayLike] = None, ) -> np.ndarray: raise NotImplementedError( "Transforming base measure realizations is not implemented." diff --git a/src/probnum/linalg/_problinsolve.py b/src/probnum/linalg/_problinsolve.py index 9b53ffb0da..a4f7ac3e81 100644 --- a/src/probnum/linalg/_problinsolve.py +++ b/src/probnum/linalg/_problinsolve.py @@ -15,27 +15,27 @@ import probnum # pylint: disable=unused-import from probnum import linops, randvars from probnum.linalg.solvers.matrixbased import SymmetricMatrixBasedSolver -from probnum.typing import LinearOperatorArgType +from probnum.typing import LinearOperatorLike # pylint: disable=too-many-branches def problinsolve( A: Union[ - LinearOperatorArgType, - "randvars.RandomVariable[LinearOperatorArgType]", + LinearOperatorLike, + "randvars.RandomVariable[LinearOperatorLike]", ], b: Union[np.ndarray, "randvars.RandomVariable[np.ndarray]"], A0: Optional[ Union[ - LinearOperatorArgType, - "randvars.RandomVariable[LinearOperatorArgType]", + LinearOperatorLike, + "randvars.RandomVariable[LinearOperatorLike]", ] ] = None, Ainv0: Optional[ Union[ - LinearOperatorArgType, - "randvars.RandomVariable[LinearOperatorArgType]", + LinearOperatorLike, + "randvars.RandomVariable[LinearOperatorLike]", ] ] = None, x0: Optional[Union[np.ndarray, "randvars.RandomVariable[np.ndarray]"]] = None, diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py b/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py index 6327c38925..e926a88a27 100644 --- a/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py @@ -5,7 +5,7 @@ import probnum # pylint: disable="unused-import" from probnum import randvars from probnum.linalg.solvers.beliefs import LinearSystemBelief -from probnum.typing import FloatArgType +from probnum.typing import FloatLike from .._linear_solver_belief_update import LinearSolverBeliefUpdate @@ -35,7 +35,7 @@ class SolutionBasedProjectedRHSBeliefUpdate(LinearSolverBeliefUpdate): Analysis*, 2019, 14, 937-1012 """ - def __init__(self, noise_var: FloatArgType = 0.0) -> None: + def __init__(self, noise_var: FloatLike = 0.0) -> None: if noise_var < 0.0: raise ValueError(f"Noise variance {noise_var} must be non-negative.") self._noise_var = noise_var diff --git a/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py b/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py index cc51722245..401f4f115d 100644 --- a/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py +++ b/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py @@ -3,7 +3,7 @@ import numpy as np import probnum # pylint: disable="unused-import" -from probnum.typing import ScalarArgType +from probnum.typing import ScalarLike from ._linear_solver_stopping_criterion import LinearSolverStoppingCriterion @@ -29,8 +29,8 @@ class PosteriorContractionStoppingCriterion(LinearSolverStoppingCriterion): def __init__( self, qoi: str = "x", - atol: ScalarArgType = 10 ** -5, - rtol: ScalarArgType = 10 ** -5, + atol: ScalarLike = 10 ** -5, + rtol: ScalarLike = 10 ** -5, ): self.qoi = qoi self.atol = probnum.utils.as_numpy_scalar(atol) diff --git a/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py b/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py index 57026417dd..484db7a185 100644 --- a/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py +++ b/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py @@ -3,7 +3,7 @@ import numpy as np import probnum -from probnum.typing import ScalarArgType +from probnum.typing import ScalarLike from ._linear_solver_stopping_criterion import LinearSolverStoppingCriterion @@ -25,8 +25,8 @@ class ResidualNormStoppingCriterion(LinearSolverStoppingCriterion): def __init__( self, - atol: ScalarArgType = 10 ** -5, - rtol: ScalarArgType = 10 ** -5, + atol: ScalarLike = 10 ** -5, + rtol: ScalarLike = 10 ** -5, ): self.atol = probnum.utils.as_numpy_scalar(atol) self.rtol = probnum.utils.as_numpy_scalar(rtol) diff --git a/src/probnum/linops/_arithmetic.py b/src/probnum/linops/_arithmetic.py index 7135426731..7a88eb1a6d 100644 --- a/src/probnum/linops/_arithmetic.py +++ b/src/probnum/linops/_arithmetic.py @@ -5,7 +5,7 @@ import scipy.sparse from probnum import config, utils -from probnum.typing import NotImplementedType, ScalarArgType, ShapeArgType +from probnum.typing import NotImplementedType, ScalarLike, ShapeLike from ._arithmetic_fallbacks import ( NegatedLinearOperator, @@ -95,14 +95,14 @@ def matmul(op1: LinearOperator, op2: LinearOperator) -> LinearOperator: ######################################################################################## # Scaling -def _mul_scalar_scaling(scalar: ScalarArgType, scaling: Scaling) -> Scaling: +def _mul_scalar_scaling(scalar: ScalarLike, scaling: Scaling) -> Scaling: if scaling.is_isotropic: return Scaling(scalar * scaling.scalar, shape=scaling.shape) return Scaling(scalar * scaling.factors, shape=scaling.shape) -def _mul_scaling_scalar(scaling: Scaling, scalar: ScalarArgType) -> Scaling: +def _mul_scaling_scalar(scaling: Scaling, scalar: ScalarLike) -> Scaling: if scaling.is_isotropic: return Scaling(scalar * scaling.scalar, shape=scaling.shape) @@ -157,14 +157,14 @@ def _matmul_kronecker_scaling(kronecker: Kronecker, scaling: Scaling) -> Kroneck return NotImplemented -def _mul_scalar_kronecker(scalar: ScalarArgType, kronecker: Kronecker) -> Kronecker: +def _mul_scalar_kronecker(scalar: ScalarLike, kronecker: Kronecker) -> Kronecker: if scalar < 0.0: return NotImplemented sqrt_scalar = np.sqrt(scalar) return Kronecker(A=sqrt_scalar * kronecker.A, B=sqrt_scalar * kronecker.B) -def _mul_kronecker_scalar(kronecker: Kronecker, scalar: ScalarArgType) -> Kronecker: +def _mul_kronecker_scalar(kronecker: Kronecker, scalar: ScalarLike) -> Kronecker: if scalar < 0.0: return NotImplemented sqrt_scalar = np.sqrt(scalar) @@ -213,7 +213,7 @@ def _matmul_idkronecker_scaling( def _mul_scalar_idkronecker( - scalar: ScalarArgType, idkronecker: IdentityKronecker + scalar: ScalarLike, idkronecker: IdentityKronecker ) -> IdentityKronecker: return IdentityKronecker( @@ -222,7 +222,7 @@ def _mul_scalar_idkronecker( def _mul_idkronecker_scalar( - idkronecker: IdentityKronecker, scalar: ScalarArgType + idkronecker: IdentityKronecker, scalar: ScalarLike ) -> IdentityKronecker: return IdentityKronecker( @@ -425,7 +425,7 @@ def _apply( ######################################################################################## -def _operand_to_linop(operand: Any, shape: ShapeArgType) -> Optional[LinearOperator]: +def _operand_to_linop(operand: Any, shape: ShapeLike) -> Optional[LinearOperator]: if isinstance(operand, LinearOperator): pass elif np.ndim(operand) == 0: diff --git a/src/probnum/linops/_arithmetic_fallbacks.py b/src/probnum/linops/_arithmetic_fallbacks.py index 0b1a3c49e4..ea521c862f 100644 --- a/src/probnum/linops/_arithmetic_fallbacks.py +++ b/src/probnum/linops/_arithmetic_fallbacks.py @@ -6,7 +6,7 @@ import numpy as np import probnum.utils -from probnum.typing import NotImplementedType, ScalarArgType +from probnum.typing import NotImplementedType, ScalarLike from ._linear_operator import BinaryOperandType, LinearOperator @@ -18,7 +18,7 @@ class ScaledLinearOperator(LinearOperator): """Linear operator scaled with a scalar.""" - def __init__(self, linop: LinearOperator, scalar: ScalarArgType): + def __init__(self, linop: LinearOperator, scalar: ScalarLike): if not isinstance(linop, LinearOperator): raise TypeError("`linop` must be a `LinearOperator`") diff --git a/src/probnum/linops/_linear_operator.py b/src/probnum/linops/_linear_operator.py index 3c21da6cc6..6570593be0 100644 --- a/src/probnum/linops/_linear_operator.py +++ b/src/probnum/linops/_linear_operator.py @@ -8,10 +8,10 @@ import probnum.utils from probnum import config -from probnum.typing import DTypeArgType, ScalarArgType, ShapeArgType +from probnum.typing import DTypeArgType, ScalarLike, ShapeLike BinaryOperandType = Union[ - "LinearOperator", ScalarArgType, np.ndarray, scipy.sparse.spmatrix + "LinearOperator", ScalarLike, np.ndarray, scipy.sparse.spmatrix ] # pylint: disable="too-many-lines" @@ -101,7 +101,7 @@ class LinearOperator: def __init__( self, - shape: ShapeArgType, + shape: ShapeLike, dtype: DTypeArgType, *, matmul: Callable[[np.ndarray], np.ndarray], @@ -979,7 +979,7 @@ class Identity(LinearOperator): def __init__( self, - shape: ShapeArgType, + shape: ShapeLike, dtype: DTypeArgType = np.double, ): shape = probnum.utils.as_shape(shape) diff --git a/src/probnum/linops/_scaling.py b/src/probnum/linops/_scaling.py index 562d31e61a..0079f2c66f 100644 --- a/src/probnum/linops/_scaling.py +++ b/src/probnum/linops/_scaling.py @@ -4,7 +4,7 @@ import numpy as np import probnum.utils -from probnum.typing import DTypeArgType, ScalarArgType, ShapeArgType +from probnum.typing import DTypeArgType, ScalarLike, ShapeLike from . import _linear_operator @@ -37,8 +37,8 @@ class Scaling(_linear_operator.LinearOperator): def __init__( self, - factors: Union[np.ndarray, ScalarArgType], - shape: Optional[ShapeArgType] = None, + factors: Union[np.ndarray, ScalarLike], + shape: Optional[ShapeLike] = None, dtype: Optional[DTypeArgType] = None, ): self._factors = None diff --git a/src/probnum/problems/_problems.py b/src/probnum/problems/_problems.py index 4bf260e3a7..69f1727a1e 100644 --- a/src/probnum/problems/_problems.py +++ b/src/probnum/problems/_problems.py @@ -8,7 +8,7 @@ import scipy.sparse from probnum import linops, randvars -from probnum.typing import FloatArgType +from probnum.typing import FloatLike @dataclasses.dataclass @@ -147,7 +147,7 @@ class InitialValueProblem: f: Callable[[float, np.ndarray], np.ndarray] t0: float tmax: float - y0: Union[FloatArgType, np.ndarray] + y0: Union[FloatLike, np.ndarray] df: Optional[Callable[[float, np.ndarray], np.ndarray]] = None ddf: Optional[Callable[[float, np.ndarray], np.ndarray]] = None @@ -248,8 +248,8 @@ class QuadratureProblem: """ integrand: Callable[[np.ndarray], Union[float, np.ndarray]] - lower_bd: Union[FloatArgType, np.ndarray] - upper_bd: Union[FloatArgType, np.ndarray] + lower_bd: Union[FloatLike, np.ndarray] + upper_bd: Union[FloatLike, np.ndarray] output_dim: Optional[int] = 1 # For testing and benchmarking diff --git a/src/probnum/problems/zoo/filtsmooth/_filtsmooth_problems.py b/src/probnum/problems/zoo/filtsmooth/_filtsmooth_problems.py index 947f9f5d48..725622c687 100644 --- a/src/probnum/problems/zoo/filtsmooth/_filtsmooth_problems.py +++ b/src/probnum/problems/zoo/filtsmooth/_filtsmooth_problems.py @@ -4,7 +4,7 @@ from probnum import diffeq, filtsmooth, problems, randprocs, randvars from probnum.problems.zoo import diffeq as diffeq_zoo -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike __all__ = [ "benes_daum", @@ -17,11 +17,11 @@ def car_tracking( rng: np.random.Generator, - measurement_variance: FloatArgType = 0.5, - process_diffusion: FloatArgType = 1.0, - num_prior_derivatives: IntArgType = 1, - timespan: Tuple[FloatArgType, FloatArgType] = (0.0, 20.0), - step: FloatArgType = 0.2, + measurement_variance: FloatLike = 0.5, + process_diffusion: FloatLike = 1.0, + num_prior_derivatives: IntLike = 1, + timespan: Tuple[FloatLike, FloatLike] = (0.0, 20.0), + step: FloatLike = 0.2, initrv: Optional[randvars.RandomVariable] = None, forward_implementation: str = "classic", backward_implementation: str = "classic", @@ -149,9 +149,9 @@ def car_tracking( def ornstein_uhlenbeck( rng: np.random.Generator, - measurement_variance: FloatArgType = 0.1, - driftspeed: FloatArgType = 0.21, - process_diffusion: FloatArgType = 0.5, + measurement_variance: FloatLike = 0.1, + driftspeed: FloatLike = 0.21, + process_diffusion: FloatLike = 0.5, time_grid: Optional[np.ndarray] = None, initrv: Optional[randvars.RandomVariable] = None, forward_implementation: str = "classic", @@ -252,9 +252,9 @@ def ornstein_uhlenbeck( def pendulum( rng: np.random.Generator, - measurement_variance: FloatArgType = 0.1024, - timespan: Tuple[FloatArgType, FloatArgType] = (0.0, 4.0), - step: FloatArgType = 0.0075, + measurement_variance: FloatLike = 0.1024, + timespan: Tuple[FloatLike, FloatLike] = (0.0, 4.0), + step: FloatLike = 0.0075, initrv: Optional[randvars.RandomVariable] = None, initarg: Optional[float] = None, ): @@ -400,8 +400,8 @@ def dh(t, x): def benes_daum( rng: np.random.Generator, - measurement_variance: FloatArgType = 0.1, - process_diffusion: FloatArgType = 1.0, + measurement_variance: FloatLike = 0.1, + process_diffusion: FloatLike = 1.0, time_grid: Optional[np.ndarray] = None, initrv: Optional[randvars.RandomVariable] = None, ): @@ -506,15 +506,15 @@ def l(t, x): def logistic_ode( - y0: Optional[Union[np.ndarray, FloatArgType]] = None, - timespan: Tuple[FloatArgType, FloatArgType] = (0.0, 2.0), - step: FloatArgType = 0.1, - params: Tuple[FloatArgType, FloatArgType] = (6.0, 1.0), + y0: Optional[Union[np.ndarray, FloatLike]] = None, + timespan: Tuple[FloatLike, FloatLike] = (0.0, 2.0), + step: FloatLike = 0.1, + params: Tuple[FloatLike, FloatLike] = (6.0, 1.0), initrv: Optional[randvars.RandomVariable] = None, - evlvar: Optional[Union[np.ndarray, FloatArgType]] = None, - ek0_or_ek1: IntArgType = 1, + evlvar: Optional[Union[np.ndarray, FloatLike]] = None, + ek0_or_ek1: IntLike = 1, exclude_initial_condition: bool = True, - order: IntArgType = 3, + order: IntLike = 3, forward_implementation: str = "classic", backward_implementation: str = "classic", ): diff --git a/src/probnum/problems/zoo/linalg/_random_linear_system.py b/src/probnum/problems/zoo/linalg/_random_linear_system.py index ec0d81649d..e963ca758e 100644 --- a/src/probnum/problems/zoo/linalg/_random_linear_system.py +++ b/src/probnum/problems/zoo/linalg/_random_linear_system.py @@ -6,13 +6,13 @@ import scipy.sparse from probnum import backend, linops, problems, randvars -from probnum.typing import LinearOperatorArgType, SeedLike +from probnum.typing import LinearOperatorLike, SeedLike def random_linear_system( seed: SeedLike, matrix: Union[ - LinearOperatorArgType, + LinearOperatorLike, Callable[ [np.random.Generator, Optional[Any]], Union[np.ndarray, scipy.sparse.spmatrix, linops.LinearOperator], diff --git a/src/probnum/problems/zoo/linalg/_random_spd_matrix.py b/src/probnum/problems/zoo/linalg/_random_spd_matrix.py index 7c2f08c2d2..cf88543ddd 100644 --- a/src/probnum/problems/zoo/linalg/_random_spd_matrix.py +++ b/src/probnum/problems/zoo/linalg/_random_spd_matrix.py @@ -6,12 +6,12 @@ import scipy.stats from probnum import backend -from probnum.typing import IntArgType, SeedType +from probnum.typing import IntLike, SeedType def random_spd_matrix( seed: SeedType, - dim: IntArgType, + dim: IntLike, spectrum: Sequence = None, ) -> np.ndarray: r"""Random symmetric positive definite matrix. @@ -86,7 +86,7 @@ def random_spd_matrix( def random_sparse_spd_matrix( rng: np.random.Generator, - dim: IntArgType, + dim: IntLike, density: float, chol_entry_min: float = 0.1, chol_entry_max: float = 1.0, diff --git a/src/probnum/quad/_bayesquad.py b/src/probnum/quad/_bayesquad.py index f9d8e38635..27b227cedb 100644 --- a/src/probnum/quad/_bayesquad.py +++ b/src/probnum/quad/_bayesquad.py @@ -14,7 +14,7 @@ from probnum.randprocs.kernels import Kernel from probnum.randvars import Normal -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike from ._integration_measures import GaussianMeasure, IntegrationMeasure, LebesgueMeasure from .solvers import BayesianQuadrature @@ -26,14 +26,14 @@ def bayesquad( input_dim: int, kernel: Optional[Kernel] = None, domain: Optional[ - Union[Tuple[FloatArgType, FloatArgType], Tuple[np.ndarray, np.ndarray]] + Union[Tuple[FloatLike, FloatLike], Tuple[np.ndarray, np.ndarray]] ] = None, measure: Optional[IntegrationMeasure] = None, policy: Optional[str] = "bmc", - max_evals: Optional[IntArgType] = None, - var_tol: Optional[FloatArgType] = None, - rel_tol: Optional[FloatArgType] = None, - batch_size: Optional[IntArgType] = 1, + max_evals: Optional[IntLike] = None, + var_tol: Optional[FloatLike] = None, + rel_tol: Optional[FloatLike] = None, + batch_size: Optional[IntLike] = 1, rng: Optional[np.random.Generator] = np.random.default_rng(), ) -> Tuple[Normal, Dict]: r"""Infer the solution of the uni- or multivariate integral :math:`\int_\Omega f(x) d \mu(x)` @@ -162,7 +162,7 @@ def bayesquad_from_data( fun_evals: np.ndarray, kernel: Optional[Kernel] = None, domain: Optional[ - Tuple[Union[np.ndarray, FloatArgType], Union[np.ndarray, FloatArgType]] + Tuple[Union[np.ndarray, FloatLike], Union[np.ndarray, FloatLike]] ] = None, measure: Optional[IntegrationMeasure] = None, ) -> Tuple[Normal, Dict]: diff --git a/src/probnum/quad/_integration_measures.py b/src/probnum/quad/_integration_measures.py index 6d66a05dcd..b6693edfcc 100644 --- a/src/probnum/quad/_integration_measures.py +++ b/src/probnum/quad/_integration_measures.py @@ -7,7 +7,7 @@ import scipy.stats from probnum.randvars import Normal -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike class IntegrationMeasure(abc.ABC): @@ -27,13 +27,13 @@ class IntegrationMeasure(abc.ABC): def __init__( self, - domain: Union[Tuple[FloatArgType, FloatArgType], Tuple[np.ndarray, np.ndarray]], - input_dim: IntArgType, + domain: Union[Tuple[FloatLike, FloatLike], Tuple[np.ndarray, np.ndarray]], + input_dim: IntLike, ) -> None: self._set_dimension_domain(input_dim, domain) - def __call__(self, points: Union[FloatArgType, np.ndarray]) -> np.ndarray: + def __call__(self, points: Union[FloatLike, np.ndarray]) -> np.ndarray: """Evaluate the density function of the integration measure. Parameters @@ -51,7 +51,7 @@ def __call__(self, points: Union[FloatArgType, np.ndarray]) -> np.ndarray: def sample( self, - n_sample: IntArgType, + n_sample: IntLike, rng: Optional[np.random.Generator] = np.random.default_rng(), ) -> np.ndarray: """Sample ``n_sample`` points from the integration measure. @@ -76,8 +76,8 @@ def sample( def _set_dimension_domain( self, - input_dim: IntArgType, - domain: Union[Tuple[FloatArgType, FloatArgType], Tuple[np.ndarray, np.ndarray]], + input_dim: IntLike, + domain: Union[Tuple[FloatLike, FloatLike], Tuple[np.ndarray, np.ndarray]], ) -> None: """Sets the integration domain and input_dimension. @@ -150,8 +150,8 @@ class LebesgueMeasure(IntegrationMeasure): def __init__( self, - domain: Union[Tuple[FloatArgType, FloatArgType], Tuple[np.ndarray, np.ndarray]], - input_dim: Optional[IntArgType] = None, + domain: Union[Tuple[FloatLike, FloatLike], Tuple[np.ndarray, np.ndarray]], + input_dim: Optional[IntLike] = None, normalized: Optional[bool] = False, ) -> None: super().__init__(input_dim=input_dim, domain=domain) @@ -181,7 +181,7 @@ def __call__(self, points: np.ndarray) -> np.ndarray: def sample( self, - n_sample: IntArgType, + n_sample: IntLike, rng: Optional[np.random.Generator] = np.random.default_rng(), ) -> np.ndarray: return self.random_variable.rvs( @@ -211,7 +211,7 @@ def __init__( self, mean: Union[float, np.floating, np.ndarray], cov: Union[float, np.floating, np.ndarray], - input_dim: Optional[IntArgType] = None, + input_dim: Optional[IntLike] = None, ) -> None: # Extend scalar mean and covariance to higher dimensions if input_dim has been diff --git a/src/probnum/quad/solvers/bayesian_quadrature.py b/src/probnum/quad/solvers/bayesian_quadrature.py index bfe5042d70..9729ccc6c3 100644 --- a/src/probnum/quad/solvers/bayesian_quadrature.py +++ b/src/probnum/quad/solvers/bayesian_quadrature.py @@ -13,7 +13,7 @@ ) from probnum.randprocs.kernels import ExpQuad, Kernel from probnum.randvars import Normal -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike from .._integration_measures import IntegrationMeasure, LebesgueMeasure from ..kernel_embeddings import KernelEmbedding @@ -64,13 +64,13 @@ def from_problem( kernel: Optional[Kernel] = None, measure: Optional[IntegrationMeasure] = None, domain: Optional[ - Union[Tuple[FloatArgType, FloatArgType], Tuple[np.ndarray, np.ndarray]] + Union[Tuple[FloatLike, FloatLike], Tuple[np.ndarray, np.ndarray]] ] = None, policy: str = "bmc", - max_evals: Optional[IntArgType] = None, - var_tol: Optional[FloatArgType] = None, - rel_tol: Optional[FloatArgType] = None, - batch_size: IntArgType = 1, + max_evals: Optional[IntLike] = None, + var_tol: Optional[FloatLike] = None, + rel_tol: Optional[FloatLike] = None, + batch_size: IntLike = 1, rng: np.random.Generator = None, ) -> "BayesianQuadrature": diff --git a/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py b/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py index bd863ac111..5276892cd4 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py +++ b/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py @@ -2,7 +2,7 @@ from probnum.quad.solvers.bq_state import BQState from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion -from probnum.typing import FloatArgType +from probnum.typing import FloatLike # pylint: disable=too-few-public-methods, fixme @@ -16,7 +16,7 @@ class IntegralVarianceTolerance(BQStoppingCriterion): Tolerance value of the variance. """ - def __init__(self, var_tol: FloatArgType): + def __init__(self, var_tol: FloatLike): self.var_tol = var_tol def __call__(self, bq_state: BQState) -> bool: diff --git a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py index fb40e3f326..59c9a8ce1b 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py +++ b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py @@ -2,7 +2,7 @@ from probnum.quad.solvers.bq_state import BQState from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion -from probnum.typing import IntArgType +from probnum.typing import IntLike # pylint: disable=too-few-public-methods @@ -16,7 +16,7 @@ class MaxNevals(BQStoppingCriterion): Maximum number of integrand evaluations. """ - def __init__(self, max_nevals: IntArgType): + def __init__(self, max_nevals: IntLike): self.max_nevals = max_nevals def __call__(self, bq_state: BQState) -> bool: diff --git a/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py b/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py index f74fb6262e..cb32b89d3e 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py +++ b/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py @@ -4,7 +4,7 @@ from probnum.quad.solvers.bq_state import BQState from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion -from probnum.typing import FloatArgType +from probnum.typing import FloatLike # pylint: disable=too-few-public-methods @@ -23,7 +23,7 @@ class RelativeMeanChange(BQStoppingCriterion): Relative error tolerance on consecutive integral mean values. """ - def __init__(self, rel_tol: FloatArgType): + def __init__(self, rel_tol: FloatLike): self.rel_tol = rel_tol def __call__(self, bq_state: BQState) -> bool: diff --git a/src/probnum/randprocs/_gaussian_process.py b/src/probnum/randprocs/_gaussian_process.py index e7de2caff1..4ec885073a 100644 --- a/src/probnum/randprocs/_gaussian_process.py +++ b/src/probnum/randprocs/_gaussian_process.py @@ -5,7 +5,7 @@ import numpy as np from probnum import randvars -from probnum.typing import ShapeArgType +from probnum.typing import ShapeLike from . import _random_process, kernels @@ -105,7 +105,7 @@ def _sample_at_input( self, rng: np.random.Generator, args: _InputType, - size: ShapeArgType = (), + size: ShapeLike = (), ) -> _OutputType: gaussian_rv = self.__call__(args) return gaussian_rv.sample(rng=rng, size=size) diff --git a/src/probnum/randprocs/_random_process.py b/src/probnum/randprocs/_random_process.py index c6e570e1a9..3537554a35 100644 --- a/src/probnum/randprocs/_random_process.py +++ b/src/probnum/randprocs/_random_process.py @@ -6,7 +6,7 @@ import numpy as np from probnum import randvars, utils as _utils -from probnum.typing import DTypeArgType, IntArgType, ShapeArgType +from probnum.typing import DTypeArgType, IntLike, ShapeLike _InputType = TypeVar("InputType") _OutputType = TypeVar("OutputType") @@ -46,8 +46,8 @@ class RandomProcess(Generic[_InputType, _OutputType], abc.ABC): def __init__( self, - input_dim: IntArgType, - output_dim: Optional[IntArgType], + input_dim: IntLike, + output_dim: Optional[IntLike], dtype: DTypeArgType, ): self._input_dim = np.int_(_utils.as_numpy_scalar(input_dim)) @@ -295,7 +295,7 @@ def sample( self, rng: np.random.Generator, args: _InputType = None, - size: ShapeArgType = (), + size: ShapeLike = (), ) -> Union[Callable[[_InputType], _OutputType], _OutputType]: """Sample paths from the random process. @@ -324,7 +324,7 @@ def _sample_at_input( self, rng: np.random.Generator, args: _InputType, - size: ShapeArgType = (), + size: ShapeLike = (), ) -> _OutputType: """Evaluate a set of sample paths at the given inputs. diff --git a/src/probnum/randprocs/kernels/_exponentiated_quadratic.py b/src/probnum/randprocs/kernels/_exponentiated_quadratic.py index ccfebabe55..d7a7e9324c 100644 --- a/src/probnum/randprocs/kernels/_exponentiated_quadratic.py +++ b/src/probnum/randprocs/kernels/_exponentiated_quadratic.py @@ -4,7 +4,7 @@ from typing import Optional from probnum import backend -from probnum.typing import ArrayType, IntArgType, ScalarArgType +from probnum.typing import ArrayType, IntLike, ScalarLike from ._kernel import IsotropicMixin, Kernel @@ -45,7 +45,7 @@ class ExpQuad(Kernel, IsotropicMixin): [1.92874985e-22, 3.72665317e-06, 1.00000000e+00]]) """ - def __init__(self, input_dim: IntArgType, lengthscale: ScalarArgType = 1.0): + def __init__(self, input_dim: IntLike, lengthscale: ScalarLike = 1.0): self.lengthscale = backend.as_scalar(lengthscale) super().__init__(input_dim=input_dim) diff --git a/src/probnum/randprocs/kernels/_kernel.py b/src/probnum/randprocs/kernels/_kernel.py index 167d72a26f..3fc14d57c4 100644 --- a/src/probnum/randprocs/kernels/_kernel.py +++ b/src/probnum/randprocs/kernels/_kernel.py @@ -5,7 +5,7 @@ from typing import Optional from probnum import backend, utils as _pn_utils -from probnum.typing import ArrayLike, ArrayType, IntArgType, ShapeArgType, ShapeType +from probnum.typing import ArrayLike, ArrayType, IntLike, ShapeLike, ShapeType class Kernel(abc.ABC): @@ -133,8 +133,8 @@ class Kernel(abc.ABC): def __init__( self, - input_dim: IntArgType, - shape: ShapeArgType = (), + input_dim: IntLike, + shape: ShapeLike = (), ): self._input_dim = int(input_dim) diff --git a/src/probnum/randprocs/kernels/_linear.py b/src/probnum/randprocs/kernels/_linear.py index 968eaa35a1..47adffdfce 100644 --- a/src/probnum/randprocs/kernels/_linear.py +++ b/src/probnum/randprocs/kernels/_linear.py @@ -3,7 +3,7 @@ from typing import Optional from probnum import backend -from probnum.typing import ArrayType, IntArgType, ScalarArgType +from probnum.typing import ArrayType, IntLike, ScalarLike from ._kernel import Kernel @@ -38,7 +38,7 @@ class Linear(Kernel): [ 8., 13.]]) """ - def __init__(self, input_dim: IntArgType, constant: ScalarArgType = 0.0): + def __init__(self, input_dim: IntLike, constant: ScalarLike = 0.0): self.constant = backend.as_scalar(constant) super().__init__(input_dim=input_dim) diff --git a/src/probnum/randprocs/kernels/_matern.py b/src/probnum/randprocs/kernels/_matern.py index 14bf025fe1..26069eb389 100644 --- a/src/probnum/randprocs/kernels/_matern.py +++ b/src/probnum/randprocs/kernels/_matern.py @@ -3,7 +3,7 @@ from typing import Optional from probnum import backend -from probnum.typing import ArrayType, IntArgType, ScalarArgType +from probnum.typing import ArrayType, IntLike, ScalarLike from ._kernel import IsotropicMixin, Kernel @@ -59,8 +59,8 @@ class Matern(Kernel, IsotropicMixin): def __init__( self, - input_dim: IntArgType, - lengthscale: ScalarArgType = 1.0, + input_dim: IntLike, + lengthscale: ScalarLike = 1.0, nu: float = 1.5, ): self.lengthscale = backend.as_scalar(lengthscale) diff --git a/src/probnum/randprocs/kernels/_polynomial.py b/src/probnum/randprocs/kernels/_polynomial.py index a6fa60f010..ff976de34f 100644 --- a/src/probnum/randprocs/kernels/_polynomial.py +++ b/src/probnum/randprocs/kernels/_polynomial.py @@ -3,7 +3,7 @@ from typing import Optional from probnum import backend -from probnum.typing import ArrayType, IntArgType, ScalarArgType +from probnum.typing import ArrayType, IntLike, ScalarLike from ._kernel import Kernel @@ -42,9 +42,9 @@ class Polynomial(Kernel): def __init__( self, - input_dim: IntArgType, - constant: ScalarArgType = 0.0, - exponent: IntArgType = 1.0, + input_dim: IntLike, + constant: ScalarLike = 0.0, + exponent: IntLike = 1.0, ): self.constant = backend.as_scalar(constant) self.exponent = backend.as_scalar(exponent) diff --git a/src/probnum/randprocs/kernels/_rational_quadratic.py b/src/probnum/randprocs/kernels/_rational_quadratic.py index 15e0998254..0ed7479e29 100644 --- a/src/probnum/randprocs/kernels/_rational_quadratic.py +++ b/src/probnum/randprocs/kernels/_rational_quadratic.py @@ -3,7 +3,7 @@ from typing import Optional from probnum import backend -from probnum.typing import ArrayType, IntArgType, ScalarArgType +from probnum.typing import ArrayType, IntLike, ScalarLike from ._kernel import IsotropicMixin, Kernel @@ -56,9 +56,9 @@ class RatQuad(Kernel, IsotropicMixin): def __init__( self, - input_dim: IntArgType, - lengthscale: ScalarArgType = 1.0, - alpha: ScalarArgType = 1.0, + input_dim: IntLike, + lengthscale: ScalarLike = 1.0, + alpha: ScalarLike = 1.0, ): self.lengthscale = backend.as_scalar(lengthscale) self.alpha = backend.as_scalar(alpha) diff --git a/src/probnum/randprocs/kernels/_white_noise.py b/src/probnum/randprocs/kernels/_white_noise.py index 81c1d1dc72..733713d9a3 100644 --- a/src/probnum/randprocs/kernels/_white_noise.py +++ b/src/probnum/randprocs/kernels/_white_noise.py @@ -3,7 +3,7 @@ from typing import Optional from probnum import backend -from probnum.typing import ArrayType, IntArgType, ScalarArgType +from probnum.typing import ArrayType, IntLike, ScalarLike from ._kernel import Kernel @@ -24,7 +24,7 @@ class WhiteNoise(Kernel): Noise level :math:`\sigma`. """ - def __init__(self, input_dim: IntArgType, sigma: ScalarArgType = 1.0): + def __init__(self, input_dim: IntLike, sigma: ScalarLike = 1.0): self.sigma = backend.as_scalar(sigma) self._sigma_sq = self.sigma ** 2 super().__init__(input_dim=input_dim) diff --git a/src/probnum/randprocs/markov/_markov_process.py b/src/probnum/randprocs/markov/_markov_process.py index bf38f30553..9b588de7f4 100644 --- a/src/probnum/randprocs/markov/_markov_process.py +++ b/src/probnum/randprocs/markov/_markov_process.py @@ -8,7 +8,7 @@ from probnum import randvars, utils from probnum.randprocs import _random_process from probnum.randprocs.markov import _transition -from probnum.typing import ShapeArgType +from probnum.typing import ShapeLike _InputType = Union[np.floating, np.ndarray] _OutputType = Union[np.floating, np.ndarray] @@ -69,7 +69,7 @@ def _sample_at_input( self, rng: np.random.Generator, args: _InputType, - size: ShapeArgType = (), + size: ShapeLike = (), ) -> _OutputType: size = utils.as_shape(size) diff --git a/src/probnum/randprocs/markov/_transition.py b/src/probnum/randprocs/markov/_transition.py index 0135826a49..ed0bbf253a 100644 --- a/src/probnum/randprocs/markov/_transition.py +++ b/src/probnum/randprocs/markov/_transition.py @@ -5,7 +5,7 @@ import numpy as np from probnum import randvars -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike class Transition(abc.ABC): @@ -46,7 +46,7 @@ class Transition(abc.ABC): Markov-chains and general discrete-time transitions (likelihoods). """ - def __init__(self, input_dim: IntArgType, output_dim: IntArgType): + def __init__(self, input_dim: IntLike, output_dim: IntLike): self.input_dim = input_dim self.output_dim = output_dim @@ -300,7 +300,7 @@ def smooth_list( def jointly_transform_base_measure_realization_list_backward( self, base_measure_realizations: np.ndarray, - t: FloatArgType, + t: FloatLike, rv_list: randvars._RandomVariableList, _diffusion_list: np.ndarray, _previous_posterior=None, @@ -367,7 +367,7 @@ def jointly_transform_base_measure_realization_list_backward( def jointly_transform_base_measure_realization_list_forward( self, base_measure_realizations: np.ndarray, - t: FloatArgType, + t: FloatLike, initrv: randvars.RandomVariable, _diffusion_list: np.ndarray, _previous_posterior=None, diff --git a/src/probnum/randprocs/markov/continuous/_diffusions.py b/src/probnum/randprocs/markov/continuous/_diffusions.py index 45febd2517..99a73bfa6c 100644 --- a/src/probnum/randprocs/markov/continuous/_diffusions.py +++ b/src/probnum/randprocs/markov/continuous/_diffusions.py @@ -9,10 +9,10 @@ from probnum import randvars from probnum.typing import ( - ArrayLikeGetitemArgType, - DenseOutputLocationArgType, - FloatArgType, - ToleranceDiffusionType, + ArrayIndicesLike, + ArrayLike, + FloatLike, + ArrayLike, ) @@ -23,16 +23,12 @@ def __repr__(self): raise NotImplementedError @abc.abstractmethod - def __call__( - self, t: DenseOutputLocationArgType - ) -> Union[ToleranceDiffusionType, np.ndarray]: + def __call__(self, t: ArrayLike) -> Union[ArrayLike, np.ndarray]: r"""Evaluate the diffusion :math:`\sigma(t)` at :math:`t`.""" raise NotImplementedError @abc.abstractmethod - def __getitem__( - self, idx: ArrayLikeGetitemArgType - ) -> Union[ToleranceDiffusionType, np.ndarray]: + def __getitem__(self, idx: ArrayIndicesLike) -> Union[ArrayLike, np.ndarray]: raise NotImplementedError @abc.abstractmethod @@ -40,8 +36,8 @@ def estimate_locally( self, meas_rv: randvars.RandomVariable, meas_rv_assuming_zero_previous_cov: randvars.RandomVariable, - t: FloatArgType, - ) -> ToleranceDiffusionType: + t: FloatLike, + ) -> ArrayLike: r"""Estimate the (local) diffusion and update current (global) estimation in- place. @@ -64,18 +60,14 @@ def __init__(self): def __repr__(self): return f"ConstantDiffusion({self.diffusion})" - def __call__( - self, t: DenseOutputLocationArgType - ) -> Union[ToleranceDiffusionType, np.ndarray]: + def __call__(self, t: ArrayLike) -> Union[ArrayLike, np.ndarray]: if self.diffusion is None: raise NotImplementedError( "No diffusions seen yet. Call estimate_locally_and_update_in_place first." ) return self.diffusion * np.ones_like(t) - def __getitem__( - self, idx: ArrayLikeGetitemArgType - ) -> Union[ToleranceDiffusionType, np.ndarray]: + def __getitem__(self, idx: ArrayIndicesLike) -> Union[ArrayLike, np.ndarray]: if self.diffusion is None: raise NotImplementedError( "No diffusions seen yet. Call estimate_locally_and_update_in_place first." @@ -87,8 +79,8 @@ def estimate_locally( self, meas_rv: randvars.RandomVariable, meas_rv_assuming_zero_previous_cov: randvars.RandomVariable, - t: FloatArgType, - ) -> ToleranceDiffusionType: + t: FloatLike, + ) -> ArrayLike: new_increment = _compute_local_quasi_mle(meas_rv) return new_increment @@ -135,9 +127,7 @@ def __init__(self, t0): def __repr__(self): return f"PiecewiseConstantDiffusion({self.diffusions})" - def __call__( - self, t: DenseOutputLocationArgType - ) -> Union[ToleranceDiffusionType, np.ndarray]: + def __call__(self, t: ArrayLike) -> Union[ArrayLike, np.ndarray]: if len(self._locations) <= 1: raise NotImplementedError( "No diffusions seen yet. Call estimate_locally_and_update_in_place first." @@ -158,9 +148,7 @@ def __call__( return self[indices] - def __getitem__( - self, idx: ArrayLikeGetitemArgType - ) -> Union[ToleranceDiffusionType, np.ndarray]: + def __getitem__(self, idx: ArrayIndicesLike) -> Union[ArrayLike, np.ndarray]: if len(self._locations) <= 1: raise NotImplementedError( "No diffusions seen yet. Call estimate_locally_and_update_in_place first." @@ -171,8 +159,8 @@ def estimate_locally( self, meas_rv: randvars.RandomVariable, meas_rv_assuming_zero_previous_cov: randvars.RandomVariable, - t: FloatArgType, - ) -> ToleranceDiffusionType: + t: FloatLike, + ) -> ArrayLike: if not t >= self.tmax: raise ValueError( "This time-point is not right of the current rightmost time-point." diff --git a/src/probnum/randprocs/markov/continuous/_linear_sde.py b/src/probnum/randprocs/markov/continuous/_linear_sde.py index c0594fbe0d..ed3b82e9f3 100644 --- a/src/probnum/randprocs/markov/continuous/_linear_sde.py +++ b/src/probnum/randprocs/markov/continuous/_linear_sde.py @@ -8,7 +8,7 @@ from probnum import randvars from probnum.randprocs.markov.continuous import _sde -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike from probnum.utils.linalg import tril_to_positive_tril @@ -46,13 +46,13 @@ class LinearSDE(_sde.SDE): def __init__( self, - state_dimension: IntArgType, - wiener_process_dimension: IntArgType, - drift_matrix_function: Callable[[FloatArgType], np.ndarray], - force_vector_function: Callable[[FloatArgType], np.ndarray], - dispersion_matrix_function: Callable[[FloatArgType], np.ndarray], - mde_atol: Optional[FloatArgType] = 1e-6, - mde_rtol: Optional[FloatArgType] = 1e-6, + state_dimension: IntLike, + wiener_process_dimension: IntLike, + drift_matrix_function: Callable[[FloatLike], np.ndarray], + force_vector_function: Callable[[FloatLike], np.ndarray], + dispersion_matrix_function: Callable[[FloatLike], np.ndarray], + mde_atol: Optional[FloatLike] = 1e-6, + mde_rtol: Optional[FloatLike] = 1e-6, mde_solver: Optional[str] = "RK45", forward_implementation: Optional[str] = "classic", ): diff --git a/src/probnum/randprocs/markov/continuous/_sde.py b/src/probnum/randprocs/markov/continuous/_sde.py index d0e71bb095..06896d10a2 100644 --- a/src/probnum/randprocs/markov/continuous/_sde.py +++ b/src/probnum/randprocs/markov/continuous/_sde.py @@ -5,7 +5,7 @@ import numpy as np from probnum.randprocs.markov import _transition -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike class SDE(_transition.Transition): @@ -18,11 +18,11 @@ class SDE(_transition.Transition): def __init__( self, - state_dimension: IntArgType, - wiener_process_dimension: IntArgType, - drift_function: Callable[[FloatArgType, np.ndarray], np.ndarray], - dispersion_function: Callable[[FloatArgType, np.ndarray], np.ndarray], - drift_jacobian: Optional[Callable[[FloatArgType, np.ndarray], np.ndarray]], + state_dimension: IntLike, + wiener_process_dimension: IntLike, + drift_function: Callable[[FloatLike, np.ndarray], np.ndarray], + dispersion_function: Callable[[FloatLike, np.ndarray], np.ndarray], + drift_jacobian: Optional[Callable[[FloatLike, np.ndarray], np.ndarray]], ): super().__init__(input_dim=state_dimension, output_dim=state_dimension) diff --git a/src/probnum/randprocs/markov/discrete/_linear_gaussian.py b/src/probnum/randprocs/markov/discrete/_linear_gaussian.py index d6c3a3830e..0c21fbbad6 100644 --- a/src/probnum/randprocs/markov/discrete/_linear_gaussian.py +++ b/src/probnum/randprocs/markov/discrete/_linear_gaussian.py @@ -8,7 +8,7 @@ from probnum import config, linops, randvars from probnum.randprocs.markov.discrete import _nonlinear_gaussian -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike from probnum.utils.linalg import cholesky_update, tril_to_positive_tril @@ -38,14 +38,12 @@ class LinearGaussian(_nonlinear_gaussian.NonlinearGaussian): def __init__( self, - input_dim: IntArgType, - output_dim: IntArgType, - state_trans_mat_fun: Callable[[FloatArgType], np.ndarray], - shift_vec_fun: Callable[[FloatArgType], np.ndarray], - proc_noise_cov_mat_fun: Callable[[FloatArgType], np.ndarray], - proc_noise_cov_cholesky_fun: Optional[ - Callable[[FloatArgType], np.ndarray] - ] = None, + input_dim: IntLike, + output_dim: IntLike, + state_trans_mat_fun: Callable[[FloatLike], np.ndarray], + shift_vec_fun: Callable[[FloatLike], np.ndarray], + proc_noise_cov_mat_fun: Callable[[FloatLike], np.ndarray], + proc_noise_cov_cholesky_fun: Optional[Callable[[FloatLike], np.ndarray]] = None, forward_implementation="classic", backward_implementation="classic", ): diff --git a/src/probnum/randprocs/markov/discrete/_nonlinear_gaussian.py b/src/probnum/randprocs/markov/discrete/_nonlinear_gaussian.py index d118861d5c..ce676273c5 100644 --- a/src/probnum/randprocs/markov/discrete/_nonlinear_gaussian.py +++ b/src/probnum/randprocs/markov/discrete/_nonlinear_gaussian.py @@ -8,7 +8,7 @@ from probnum import randvars from probnum.randprocs.markov import _transition from probnum.randprocs.markov.discrete import _condition_state -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike class NonlinearGaussian(_transition.Transition): @@ -43,16 +43,14 @@ class NonlinearGaussian(_transition.Transition): def __init__( self, - input_dim: IntArgType, - output_dim: IntArgType, - state_trans_fun: Callable[[FloatArgType, np.ndarray], np.ndarray], - proc_noise_cov_mat_fun: Callable[[FloatArgType], np.ndarray], + input_dim: IntLike, + output_dim: IntLike, + state_trans_fun: Callable[[FloatLike, np.ndarray], np.ndarray], + proc_noise_cov_mat_fun: Callable[[FloatLike], np.ndarray], jacob_state_trans_fun: Optional[ - Callable[[FloatArgType, np.ndarray], np.ndarray] - ] = None, - proc_noise_cov_cholesky_fun: Optional[ - Callable[[FloatArgType], np.ndarray] + Callable[[FloatLike, np.ndarray], np.ndarray] ] = None, + proc_noise_cov_cholesky_fun: Optional[Callable[[FloatLike], np.ndarray]] = None, ): self.state_trans_fun = state_trans_fun self.proc_noise_cov_mat_fun = proc_noise_cov_mat_fun @@ -152,10 +150,10 @@ def proc_noise_cov_cholesky_fun(self, t): @classmethod def from_callable( cls, - input_dim: IntArgType, - output_dim: IntArgType, - state_trans_fun: Callable[[FloatArgType, np.ndarray], np.ndarray], - jacob_state_trans_fun: Callable[[FloatArgType, np.ndarray], np.ndarray], + input_dim: IntLike, + output_dim: IntLike, + state_trans_fun: Callable[[FloatLike, np.ndarray], np.ndarray], + jacob_state_trans_fun: Callable[[FloatLike, np.ndarray], np.ndarray], ): """Turn a callable into a deterministic transition.""" diff --git a/src/probnum/randprocs/markov/integrator/convert/_convert.py b/src/probnum/randprocs/markov/integrator/convert/_convert.py index d80b0da6de..97fa440dfc 100644 --- a/src/probnum/randprocs/markov/integrator/convert/_convert.py +++ b/src/probnum/randprocs/markov/integrator/convert/_convert.py @@ -3,11 +3,11 @@ import numpy as np from probnum.randprocs.markov.integrator import _integrator -from probnum.typing import IntArgType +from probnum.typing import IntLike def convert_derivwise_to_coordwise( - state: np.ndarray, num_derivatives: IntArgType, wiener_process_dimension: IntArgType + state: np.ndarray, num_derivatives: IntLike, wiener_process_dimension: IntLike ) -> np.ndarray: """Convert coordinate-wise representation to derivative-wise representation. @@ -29,7 +29,7 @@ def convert_derivwise_to_coordwise( def convert_coordwise_to_derivwise( - state: np.ndarray, num_derivatives: IntArgType, wiener_process_dimension: IntArgType + state: np.ndarray, num_derivatives: IntLike, wiener_process_dimension: IntLike ) -> np.ndarray: """Convert coordinate-wise representation to derivative-wise representation. diff --git a/src/probnum/randvars/_constant.py b/src/probnum/randvars/_constant.py index 7d3cc138ad..1177e0a14f 100644 --- a/src/probnum/randvars/_constant.py +++ b/src/probnum/randvars/_constant.py @@ -6,10 +6,10 @@ from probnum import backend, config, linops, utils as _utils from probnum.typing import ( - ArrayLikeGetitemArgType, + ArrayIndicesLike, ArrayType, SeedType, - ShapeArgType, + ShapeLike, ShapeType, ) @@ -120,7 +120,7 @@ def support(self) -> ArrayType: """Constant value taken by the random variable.""" return self._support - def __getitem__(self, key: ArrayLikeGetitemArgType) -> "Constant": + def __getitem__(self, key: ArrayIndicesLike) -> "Constant": """(Advanced) indexing, masking and slicing. This method supports all modes of array indexing presented in @@ -145,7 +145,7 @@ def transpose(self, *axes: int) -> "Constant": support=self._support.transpose(*axes), ) - def _sample(self, seed: SeedType, sample_shape: ShapeArgType = ()) -> ArrayType: + def _sample(self, seed: SeedType, sample_shape: ShapeLike = ()) -> ArrayType: # pylint: disable=unused-argument if sample_shape == (): diff --git a/src/probnum/randvars/_normal.py b/src/probnum/randvars/_normal.py index 604da0ae1f..8ee7065375 100644 --- a/src/probnum/randvars/_normal.py +++ b/src/probnum/randvars/_normal.py @@ -7,13 +7,13 @@ from probnum import backend, config, linops from probnum.typing import ( ArrayLike, - ArrayLikeGetitemArgType, + ArrayIndicesLike, ArrayType, - FloatArgType, + FloatLike, ScalarType, SeedLike, SeedType, - ShapeArgType, + ShapeLike, ShapeType, ) @@ -217,7 +217,7 @@ def _cov_op_cholesky(self) -> ArrayType: def compute_cov_cholesky( self, - damping_factor: Optional[FloatArgType] = None, + damping_factor: Optional[FloatLike] = None, ) -> None: """Compute Cholesky factor (careful: in-place operation!).""" if damping_factor is None: @@ -276,7 +276,7 @@ def cov_cholesky_is_precomputed(self) -> bool: """ return self._cov_cholesky is not None or self.__cov_op_cholesky is not None - def __getitem__(self, key: ArrayLikeGetitemArgType) -> "Normal": + def __getitem__(self, key: ArrayIndicesLike) -> "Normal": """Marginalization in multi- and matrixvariate normal random variables, expressed as (advanced) indexing, masking and slicing. @@ -313,7 +313,7 @@ def __getitem__(self, key: ArrayLikeGetitemArgType) -> "Normal": cov=cov, ) - def reshape(self, newshape: ShapeArgType) -> "Normal": + def reshape(self, newshape: ShapeLike) -> "Normal": try: reshaped_mean = self.dense_mean.reshape(newshape) except ValueError as exc: @@ -434,7 +434,7 @@ def _scalar_logcdf(self, x: ArrayType) -> ArrayType: return backend.log(self._scalar_cdf(x)) @backend.jit_method - def _scalar_quantile(self, p: FloatArgType) -> ArrayType: + def _scalar_quantile(self, p: FloatLike) -> ArrayType: return self.mean + self.std * backend.special.ndtri(p) @backend.jit_method diff --git a/src/probnum/randvars/_random_variable.py b/src/probnum/randvars/_random_variable.py index 4495e79c36..74435798bf 100644 --- a/src/probnum/randvars/_random_variable.py +++ b/src/probnum/randvars/_random_variable.py @@ -8,12 +8,12 @@ from probnum import backend, utils as _utils from probnum.typing import ( - ArrayLikeGetitemArgType, + ArrayIndicesLike, ArrayType, DTypeArgType, ScalarType, SeedType, - ShapeArgType, + ShapeLike, ShapeType, ) @@ -100,7 +100,7 @@ class RandomVariable: def __init__( self, - shape: ShapeArgType, + shape: ShapeLike, dtype: DTypeArgType, parameters: Optional[Dict[str, Any]] = None, sample: Optional[Callable[[SeedType, ShapeType], ArrayType]] = None, @@ -401,7 +401,7 @@ def in_support(self, x: ArrayType) -> ArrayType: return in_support - def sample(self, seed: SeedType, sample_shape: ShapeArgType = ()) -> ArrayType: + def sample(self, seed: SeedType, sample_shape: ShapeLike = ()) -> ArrayType: """Draw realizations from a random variable. Parameters @@ -516,7 +516,7 @@ def quantile(self, p: ArrayType) -> ArrayType: return quantile - def __getitem__(self, key: ArrayLikeGetitemArgType) -> "RandomVariable": + def __getitem__(self, key: ArrayIndicesLike) -> "RandomVariable": # Shape inference # For simplicity, this should not be computed using backend, but rather in numpy shape = np.broadcast_to(np.empty(()), self.shape)[key].shape @@ -532,7 +532,7 @@ def __getitem__(self, key: ArrayLikeGetitemArgType) -> "RandomVariable": entropy=lambda: self.entropy, ) - def reshape(self, newshape: ShapeArgType) -> "RandomVariable": + def reshape(self, newshape: ShapeLike) -> "RandomVariable": """Give a new shape to a random variable. Parameters @@ -894,7 +894,7 @@ class DiscreteRandomVariable(RandomVariable): def __init__( self, - shape: ShapeArgType, + shape: ShapeLike, dtype: DTypeArgType, parameters: Optional[Dict[str, Any]] = None, sample: Optional[Callable[[SeedType, ShapeType], ArrayType]] = None, @@ -1103,7 +1103,7 @@ class ContinuousRandomVariable(RandomVariable): def __init__( self, - shape: ShapeArgType, + shape: ShapeLike, dtype: DTypeArgType, parameters: Optional[Dict[str, Any]] = None, sample: Optional[Callable[[SeedType, ShapeType], ArrayType]] = None, diff --git a/src/probnum/typing.py b/src/probnum/typing.py index 28e84131ad..43f0cdfbd9 100644 --- a/src/probnum/typing.py +++ b/src/probnum/typing.py @@ -18,85 +18,97 @@ import numpy as np import scipy.sparse -from numpy.typing import ( # pylint: disable=unused-import - ArrayLike as _NumPyArrayLike, - DTypeLike as DTypeArgType, -) +from numpy.typing import ArrayLike as _NumPyArrayLike, DTypeLike as _NumPyDTypeLike ######################################################################################## # API Types ######################################################################################## +# Array Utilities ShapeType = Tuple[int, ...] -# Backend Types -ArrayType = Union[np.ndarray, "jax.numpy.ndarray", "torch.Tensor"] -ScalarType = Union[np.ndarray, "jax.numpy.ndarray", "torch.Tensor"] +ScalarType = "probnum.backend.ndarray" +MatrixType = Union["probnum.backend.ndarray", "probnum.linops.LinearOperator"] SeedType = Union[np.random.SeedSequence, "jax.random.PRNGKey"] -# ProbNum Types -MatrixType = Union[ArrayType, "probnum.linops.LinearOperator"] ######################################################################################## # Argument Types ######################################################################################## -# Backend Types -ArrayLike = Union[_NumPyArrayLike, "jax.numpy.ndarray", "torch.Tensor"] -ScalarLike = Union[ScalarType, int, float, complex, numbers.Number, np.number] +# Python Numbers +IntLike = Union[int, numbers.Integral, np.integer] +"""Type of a public API argument for supplying an integer. -SeedLike = Optional[int] +Values of this type should always be converted into :class:`int`\\ s before further +internal processing.""" -IntArgType = Union[int, numbers.Integral, np.integer] -FloatArgType = Union[float, numbers.Real, np.floating] +FloatLike = Union[float, numbers.Real, np.floating] +"""Type of a public API argument for supplying a float. -ShapeArgType = Union[IntArgType, Iterable[IntArgType]] -"""Type of a public API argument for supplying a shape. Values of this type should -always be converted into :class:`ShapeType` using the function -:func:`probnum.utils.as_shape` before further internal processing.""" +Values of this type should always be converteg into :class:`float`\\ s before further +internal processing.""" -ScalarArgType = Union[int, float, complex, numbers.Number, np.number] -"""Type of a public API argument for supplying a scalar value. Values of this type -should always be converted into :class:`np.generic` using the function -:func:`probnum.backend.as_scalar` before further internal processing.""" +# Array Utilities +ShapeLike = Union[IntLike, Iterable[IntLike]] +"""Type of a public API argument for supplying a shape. -LinearOperatorArgType = Union[ - np.ndarray, - scipy.sparse.spmatrix, - "probnum.linops.LinearOperator", -] -"""Type of a public API argument for supplying a matrix or finite-dimensional linear operator.""" +Values of this type should always be converted into :class:`ShapeType` using the +function :func:`probnum.backend.as_shape` before further internal processing.""" + +DTypeLike = Union[_NumPyDTypeLike, "jax.numpy.dtype", "torch.dtype"] +"""Type of a public API argument for supplying an array's dtype. -ArrayLikeGetitemArgType = Union[ +Values of this type should always be converted into :class:`backend.dtype`\\ s using the +function :func:`probnum.backend.as_dtype` before further internal processing.""" + +_ArrayIndexLike = Union[ int, slice, - np.ndarray, - np.newaxis, - None, type(Ellipsis), - Tuple[Union[int, slice, np.ndarray, np.newaxis, None, type(Ellipsis)], ...], + None, + "probnum.backend.newaxis", + "probnum.backend.ndarray", ] +ArrayIndicesLike = Union[_ArrayIndexLike, Tuple[_ArrayIndexLike, ...]] +"""Type of the argument to the :meth:`__getitem__` method of a NumPy-like array type +such as :class:`probnum.backend.ndarray`, :class:`probnum.linops.LinearOperator` or +:class:`probnum.randvars.RandomVariable`.""" -######################################################################################## -# Other Types -######################################################################################## +# Scalars, Arrays and Matrices +ScalarLike = Union[ScalarType, int, float, complex, numbers.Number, np.number] +"""Type of a public API argument for supplying a scalar value. -ToleranceDiffusionType = Union[FloatArgType, np.ndarray] -r"""Type of a quantity that describes tolerances, errors, and diffusions. +Values of this type should always be converted into :class:`backend.ndarray`\\ s using +the function :func:`probnum.backend.as_scalar` before further internal processing.""" -Used for absolute (atol) and relative tolerances (rtol), local error estimates, as well as -(the diagonal entries of diagonal matrices representing) diffusion models. -atol, rtol, and diffusion are usually floats, but can be generalized to arrays -- essentially, -to every :math:`\tau` that allows arithmetic operations such as +ArrayLike = Union[_NumPyArrayLike, "jax.numpy.ndarray", "torch.Tensor"] +"""Type of a public API argument for supplying an array. -.. math:: \tau + tau * \text{vec}, \text{ or } L \otimes \text{diag}(\tau) +Values of this type should always be converted into :class:`backend.ndarray`\\ s using +the function :func:`probnum.backend.as_array` before further internal processing.""" -respectively. Currently, the array-support for diffusions is experimental (at best). -""" +LinearOperatorLike = Union[ + ArrayLike, + scipy.sparse.spmatrix, + "probnum.linops.LinearOperator", +] +"""Type of a public API argument for supplying a finite-dimensional linear operator. + +Values of this type should always be converted into :class:`probnum.linops.\\ +LinearOperator`\\ s using the function :func:`probnum.backend.as_linop` before further +internal processing.""" -DenseOutputLocationArgType = Union[FloatArgType, np.ndarray] -"""TimeSeriesPosteriors and derived classes can be evaluated at a single location 't' -or an array of locations.""" +# Random Number Generation +SeedLike = Optional[int] +"""Type of a public API argument for supplying the seed of a random number generator. + +Values of this type should always be converted to :class:`SeedType` using the function +:func:`probnum.backend.random.seed` before further internal processing.""" + +######################################################################################## +# Other Types +######################################################################################## NotImplementedType = type(NotImplemented) diff --git a/src/probnum/utils/argutils.py b/src/probnum/utils/argutils.py index 25668536d8..24deaf9f8a 100644 --- a/src/probnum/utils/argutils.py +++ b/src/probnum/utils/argutils.py @@ -5,12 +5,12 @@ import numpy as np -from probnum.typing import DTypeArgType, ScalarArgType, ShapeArgType, ShapeType +from probnum.typing import DTypeArgType, ScalarLike, ShapeLike, ShapeType __all__ = ["as_shape", "as_numpy_scalar"] -def as_shape(x: ShapeArgType, ndim: Optional[numbers.Integral] = None) -> ShapeType: +def as_shape(x: ShapeLike, ndim: Optional[numbers.Integral] = None) -> ShapeType: """Convert a shape representation into a shape defined as a tuple of ints. Parameters @@ -42,7 +42,7 @@ def as_shape(x: ShapeArgType, ndim: Optional[numbers.Integral] = None) -> ShapeT return shape -def as_numpy_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> np.generic: +def as_numpy_scalar(x: ScalarLike, dtype: DTypeArgType = None) -> np.generic: """Convert a scalar into a NumPy scalar. Parameters diff --git a/tests/test_quad/util.py b/tests/test_quad/util.py index 3b9a7af6c5..51e14b0f66 100644 --- a/tests/test_quad/util.py +++ b/tests/test_quad/util.py @@ -5,15 +5,15 @@ from scipy.linalg import sqrtm from scipy.special import roots_legendre -from probnum.typing import FloatArgType, IntArgType +from probnum.typing import FloatLike, IntLike # Auxiliary functions for quadrature tests def gauss_hermite_tensor( - n_points: IntArgType, - input_dim: IntArgType, - mean: Union[np.ndarray, FloatArgType], - cov: Union[np.ndarray, FloatArgType], + n_points: IntLike, + input_dim: IntLike, + mean: Union[np.ndarray, FloatLike], + cov: Union[np.ndarray, FloatLike], ): """Returns the points and weights of a tensor-product Gauss-Hermite rule for integration w.r.t a Gaussian measure.""" @@ -31,9 +31,9 @@ def gauss_hermite_tensor( def gauss_legendre_tensor( - n_points: IntArgType, - input_dim: IntArgType, - domain: Tuple[Union[np.ndarray, FloatArgType], Union[np.ndarray, FloatArgType]], + n_points: IntLike, + input_dim: IntLike, + domain: Tuple[Union[np.ndarray, FloatLike], Union[np.ndarray, FloatLike]], normalized: Optional[bool] = False, ): """Returns the points and weights of a tensor-product Gauss-Legendre rule for diff --git a/tests/test_randvars/test_arithmetic/conftest.py b/tests/test_randvars/test_arithmetic/conftest.py index a6e3f84d8b..b5c2da1179 100644 --- a/tests/test_randvars/test_arithmetic/conftest.py +++ b/tests/test_randvars/test_arithmetic/conftest.py @@ -4,12 +4,12 @@ from probnum import backend, linops, randvars from probnum.problems.zoo.linalg import random_spd_matrix -from probnum.typing import ShapeArgType +from probnum.typing import ShapeLike from tests.testing import seed_from_args @pytest.fixture -def constant(shape_const: ShapeArgType) -> randvars.Constant: +def constant(shape_const: ShapeLike) -> randvars.Constant: seed = seed_from_args(shape_const, 19836) return randvars.Constant( @@ -19,7 +19,7 @@ def constant(shape_const: ShapeArgType) -> randvars.Constant: @pytest.fixture def multivariate_normal( - shape: ShapeArgType, precompute_cov_cholesky: bool + shape: ShapeLike, precompute_cov_cholesky: bool ) -> randvars.Normal: seed = seed_from_args(shape, precompute_cov_cholesky, 1908) seed_mean, seed_cov = backend.random.split(seed) @@ -35,7 +35,7 @@ def multivariate_normal( @pytest.fixture def matrixvariate_normal( - shape: ShapeArgType, precompute_cov_cholesky: bool + shape: ShapeLike, precompute_cov_cholesky: bool ) -> randvars.Normal: seed = seed_from_args(shape, precompute_cov_cholesky, 354) seed_mean, seed_cov_A, seed_cov_B = backend.random.split(seed, num=3) @@ -54,7 +54,7 @@ def matrixvariate_normal( @pytest.fixture def symmetric_matrixvariate_normal( - shape: ShapeArgType, precompute_cov_cholesky: bool + shape: ShapeLike, precompute_cov_cholesky: bool ) -> randvars.Normal: seed = seed_from_args(shape, precompute_cov_cholesky, 246) seed_mean, seed_cov = backend.random.split(seed) diff --git a/tests/test_randvars/test_arithmetic/test_generic.py b/tests/test_randvars/test_arithmetic/test_generic.py index 1e9c1b1fbd..b9c308492f 100644 --- a/tests/test_randvars/test_arithmetic/test_generic.py +++ b/tests/test_randvars/test_arithmetic/test_generic.py @@ -5,11 +5,11 @@ from numpy.typing import DTypeLike from probnum import randvars -from probnum.typing import ShapeArgType +from probnum.typing import ShapeLike @pytest.mark.parametrize("shape,dtype", [((5,), np.single), ((2, 3), np.double)]) -def test_generic_randvar_dtype_shape_inference(shape: ShapeArgType, dtype: DTypeLike): +def test_generic_randvar_dtype_shape_inference(shape: ShapeLike, dtype: DTypeLike): x = randvars.RandomVariable( shape=shape, dtype=dtype,