From c9957f3537a97d32ff1f70b59cbb3752d3f37a1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marvin=20Pf=C3=B6rtner?= Date: Thu, 1 Sep 2022 10:22:29 +0200 Subject: [PATCH] Algebraic Operations on `Functions` (#725) * `Function.__{neg,rmul,add,sub}__` * Switch function arithmetic to singledispatch * `arithmetic` -> `algebra` * Additional tests * Add input checks and error messages * Add documentation * Improvements to the tests * Reorder algebra fallbacks * Module docstrings * `pn._function` -> `pn.functions` * `black`` fix * Bugfixes * Reconfigure coverage report * Docfixes * Additional tests * Overload algebraic ops for the Zero function * black fix * Apply suggestions from code review Co-authored-by: Jonathan Wenger * Update docs/source/api.rst Co-authored-by: Jonathan Wenger * Bugfix in test * Add `Function.__mul__` operatot * PyLint fix Co-authored-by: Jonathan Wenger --- docs/source/api.rst | 3 + docs/source/api/functions.rst | 7 + docs/source/api/randprocs.rst | 1 - docs/source/api/randprocs/mean_fns.rst | 7 - pyproject.toml | 2 + src/probnum/__init__.py | 6 +- src/probnum/functions/__init__.py | 6 + src/probnum/functions/_algebra.py | 69 +++++++++ src/probnum/functions/_algebra_fallbacks.py | 141 ++++++++++++++++++ src/probnum/{ => functions}/_function.py | 44 +++++- .../mean_fns.py => functions/_zero.py} | 16 +- src/probnum/randprocs/__init__.py | 2 +- src/probnum/randprocs/_gaussian_process.py | 8 +- src/probnum/randprocs/_random_process.py | 8 +- src/probnum/randprocs/markov/_markov.py | 4 +- tests/test_functions/__init__.py | 0 tests/test_functions/conftest.py | 6 + tests/test_functions/test_algebra.py | 118 +++++++++++++++ .../test_functions/test_algebra_fallbacks.py | 86 +++++++++++ tests/{ => test_functions}/test_function.py | 2 +- tests/test_randprocs/conftest.py | 12 +- tests/test_randprocs/test_gaussian_process.py | 14 +- tests/test_randprocs/test_random_process.py | 6 +- 23 files changed, 520 insertions(+), 48 deletions(-) create mode 100644 docs/source/api/functions.rst delete mode 100644 docs/source/api/randprocs/mean_fns.rst create mode 100644 src/probnum/functions/__init__.py create mode 100644 src/probnum/functions/_algebra.py create mode 100644 src/probnum/functions/_algebra_fallbacks.py rename src/probnum/{ => functions}/_function.py (76%) rename src/probnum/{randprocs/mean_fns.py => functions/_zero.py} (53%) create mode 100644 tests/test_functions/__init__.py create mode 100644 tests/test_functions/conftest.py create mode 100644 tests/test_functions/test_algebra.py create mode 100644 tests/test_functions/test_algebra_fallbacks.py rename tests/{ => test_functions}/test_function.py (86%) diff --git a/docs/source/api.rst b/docs/source/api.rst index f634b0dbf..0bfbf8d68 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -13,6 +13,8 @@ API Reference +-------------------------------------------------+--------------------------------------------------------------+ | :mod:`~probnum.filtsmooth` | Bayesian filtering and smoothing. | +-------------------------------------------------+--------------------------------------------------------------+ + | :mod:`~probnum.functions` | Callables with in- and output shape information. | + +-------------------------------------------------+--------------------------------------------------------------+ | :mod:`~probnum.linalg` | Probabilistic numerical linear algebra. | +-------------------------------------------------+--------------------------------------------------------------+ | :mod:`~probnum.linops` | Finite-dimensional linear operators. | @@ -39,6 +41,7 @@ API Reference api/config api/diffeq api/filtsmooth + api/functions api/linalg api/linops api/problems diff --git a/docs/source/api/functions.rst b/docs/source/api/functions.rst new file mode 100644 index 000000000..cc32b7dbd --- /dev/null +++ b/docs/source/api/functions.rst @@ -0,0 +1,7 @@ +***************** +probnum.functions +***************** + +.. automodapi:: probnum.functions + :no-heading: + :headings: "=" diff --git a/docs/source/api/randprocs.rst b/docs/source/api/randprocs.rst index 3d8cb9f2a..bb3bc7c2e 100644 --- a/docs/source/api/randprocs.rst +++ b/docs/source/api/randprocs.rst @@ -10,5 +10,4 @@ probnum.randprocs :hidden: randprocs/markov - randprocs/mean_fns randprocs/kernels diff --git a/docs/source/api/randprocs/mean_fns.rst b/docs/source/api/randprocs/mean_fns.rst deleted file mode 100644 index ba329c6a2..000000000 --- a/docs/source/api/randprocs/mean_fns.rst +++ /dev/null @@ -1,7 +0,0 @@ -************************** -probnum.randprocs.mean_fns -************************** - -.. automodapi:: probnum.randprocs.mean_fns - :no-heading: - :headings: "=" diff --git a/pyproject.toml b/pyproject.toml index cc3db6bb2..3b05cbd90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,8 @@ exclude_lines = [ # Don't complain if non-runnable code isn't run: 'if 0:', 'if __name__ == .__main__.:', + # Don't complain if operator's are not overloaded + 'return NotImplemented' ] ################################################################################ diff --git a/src/probnum/__init__.py b/src/probnum/__init__.py index 5fc963d9e..64bd49d24 100644 --- a/src/probnum/__init__.py +++ b/src/probnum/__init__.py @@ -26,6 +26,7 @@ from . import ( diffeq, filtsmooth, + functions, linalg, linops, problems, @@ -34,23 +35,18 @@ randvars, utils, ) -from ._function import Function, LambdaFunction from ._version import version as __version__ from .randvars import asrandvar # Public classes and functions. Order is reflected in documentation. __all__ = [ "asrandvar", - "Function", - "LambdaFunction", "ProbabilisticNumericalMethod", "StoppingCriterion", "LambdaStoppingCriterion", ] # Set correct module paths. Corrects links and module paths in documentation. -Function.__module__ = "probnum" -LambdaFunction.__module__ = "probnum" ProbabilisticNumericalMethod.__module__ = "probnum" StoppingCriterion.__module__ = "probnum" LambdaStoppingCriterion.__module__ = "probnum" diff --git a/src/probnum/functions/__init__.py b/src/probnum/functions/__init__.py new file mode 100644 index 000000000..8bb8aec6d --- /dev/null +++ b/src/probnum/functions/__init__.py @@ -0,0 +1,6 @@ +"""Callables with in- and output shape information supporting algebraic operations.""" + +from . import _algebra +from ._algebra_fallbacks import ScaledFunction, SumFunction +from ._function import Function, LambdaFunction +from ._zero import Zero diff --git a/src/probnum/functions/_algebra.py b/src/probnum/functions/_algebra.py new file mode 100644 index 000000000..4b757e802 --- /dev/null +++ b/src/probnum/functions/_algebra.py @@ -0,0 +1,69 @@ +r"""Algebraic operations on :class:`Function`\ s.""" + +from ._algebra_fallbacks import SumFunction +from ._function import Function +from ._zero import Zero + +############ +# Function # +############ + + +@Function.__add__.register # pylint: disable=no-member +def _(self, other: Function) -> SumFunction: + return SumFunction(self, other) + + +@Function.__add__.register # pylint: disable=no-member +def _(self, other: SumFunction) -> SumFunction: + return SumFunction(self, *other.summands) + + +@Function.__add__.register # pylint: disable=no-member +def _(self, other: Zero) -> Function: # pylint: disable=unused-argument + return self + + +@Function.__sub__.register # pylint: disable=no-member +def _(self, other: Function) -> SumFunction: + return SumFunction(self, -other) + + +@Function.__sub__.register # pylint: disable=no-member +def _(self, other: Zero) -> Function: # pylint: disable=unused-argument + return self + + +############### +# SumFunction # +############### + + +@SumFunction.__add__.register # pylint: disable=no-member +def _(self, other: Function) -> SumFunction: + return SumFunction(*self.summands, other) + + +@SumFunction.__add__.register # pylint: disable=no-member +def _(self, other: SumFunction) -> SumFunction: + return SumFunction(*self.summands, *other.summands) + + +@SumFunction.__sub__.register # pylint: disable=no-member +def _(self, other: Function) -> SumFunction: + return SumFunction(*self.summands, -other) + + +######## +# Zero # +######## + + +@Zero.__add__.register # pylint: disable=no-member +def _(self, other: Function) -> Function: # pylint: disable=unused-argument + return other + + +@Zero.__sub__.register # pylint: disable=no-member +def _(self, other: Function) -> Function: # pylint: disable=unused-argument + return -other diff --git a/src/probnum/functions/_algebra_fallbacks.py b/src/probnum/functions/_algebra_fallbacks.py new file mode 100644 index 000000000..d511cdef5 --- /dev/null +++ b/src/probnum/functions/_algebra_fallbacks.py @@ -0,0 +1,141 @@ +r"""Fallback implementation for algebraic operations on :class:`Function`\ s.""" + +from __future__ import annotations + +import functools +import operator + +import numpy as np + +from probnum import utils +from probnum.typing import ScalarLike, ScalarType + +from ._function import Function + + +class SumFunction(Function): + r"""Pointwise sum of :class:`Function`\ s. + + Given functions :math:`f_1, \dotsc, f_n \colon \mathbb{R}^n \to \mathbb{R}^m`, this + defines a new function + + .. math:: + \sum_{i = 1}^n f_i \colon \mathbb{R}^n \to \mathbb{R}^m, + x \mapsto \sum_{i = 1}^n f_i(x). + + Parameters + ---------- + *summands + The functions :math:`f_1, \dotsc, f_n`. + """ + + def __init__(self, *summands: Function) -> None: + if not all(isinstance(summand, Function) for summand in summands): + raise TypeError( + "The functions to be added must be objects of type `Function`." + ) + + if not all( + summand.input_shape == summands[0].input_shape for summand in summands + ): + raise ValueError( + "The functions to be added must all have the same input shape." + ) + + if not all( + summand.output_shape == summands[0].output_shape for summand in summands + ): + raise ValueError( + "The functions to be added must all have the same output shape." + ) + + self._summands = summands + + super().__init__( + input_shape=summands[0].input_shape, + output_shape=summands[0].output_shape, + ) + + @property + def summands(self) -> tuple[SumFunction, ...]: + r"""The functions :math:`f_1, \dotsc, f_n` to be added.""" + return self._summands + + def _evaluate(self, x: np.ndarray) -> np.ndarray: + return functools.reduce( + operator.add, (summand(x) for summand in self._summands) + ) + + @functools.singledispatchmethod + def __add__(self, other): + return super().__add__(other) + + @functools.singledispatchmethod + def __sub__(self, other): + return super().__sub__(other) + + +class ScaledFunction(Function): + r"""Function multiplied pointwise with a scalar. + + Given a function :math:`f \colon \mathbb{R}^n \to \mathbb{R}^m` and a scalar + :math:`\alpha \in \mathbb{R}`, this defines a new function + + .. math:: + \alpha f \colon \mathbb{R}^n \to \mathbb{R}^m, + x \mapsto (\alpha f)(x) = \alpha f(x). + + Parameters + ---------- + function + The function :math:`f`. + scalar + The scalar :math:`\alpha`. + """ + + def __init__(self, function: Function, scalar: ScalarLike): + if not isinstance(function, Function): + raise TypeError( + "The function to be scaled must be an object of type `Function`." + ) + + self._function = function + self._scalar = utils.as_numpy_scalar(scalar) + + super().__init__( + input_shape=self._function.input_shape, + output_shape=self._function.output_shape, + ) + + @property + def function(self) -> Function: + r"""The function :math:`f`.""" + return self._function + + @property + def scalar(self) -> ScalarType: + r"""The scalar :math:`\alpha`.""" + return self._scalar + + def _evaluate(self, x: np.ndarray) -> np.ndarray: + return self._scalar * self._function(x) + + @functools.singledispatchmethod + def __mul__(self, other): + if np.ndim(other) == 0: + return ScaledFunction( + function=self._function, + scalar=self._scalar * np.asarray(other), + ) + + return super().__mul__(other) + + @functools.singledispatchmethod + def __rmul__(self, other): + if np.ndim(other) == 0: + return ScaledFunction( + function=self._function, + scalar=np.asarray(other) * self._scalar, + ) + + return super().__rmul__(other) diff --git a/src/probnum/_function.py b/src/probnum/functions/_function.py similarity index 76% rename from src/probnum/_function.py rename to src/probnum/functions/_function.py index 051f2e2e1..12573c6ec 100644 --- a/src/probnum/_function.py +++ b/src/probnum/functions/_function.py @@ -3,12 +3,13 @@ from __future__ import annotations import abc +import functools from typing import Callable import numpy as np -from . import utils -from .typing import ArrayLike, ShapeLike, ShapeType +from probnum import utils +from probnum.typing import ArrayLike, ShapeLike, ShapeType class Function(abc.ABC): @@ -17,6 +18,8 @@ class Function(abc.ABC): This class represents a, uni- or multivariate, scalar- or tensor-valued, mathematical function. Hence, the call method should not have any observable side-effects. + Instances of this class can be added and multiplied by a scalar, which means that + they are elements of a vector space. Parameters ---------- @@ -29,7 +32,7 @@ class Function(abc.ABC): See Also -------- LambdaFunction : Define a :class:`Function` from an anonymous function. - ~probnum.randprocs.mean_fns.Zero : Zero mean function of a random process. + ~probnum.functions.Zero : Zero function. """ def __init__(self, input_shape: ShapeLike, output_shape: ShapeLike = ()) -> None: @@ -112,6 +115,39 @@ def __call__(self, x: ArrayLike) -> np.ndarray: def _evaluate(self, x: np.ndarray) -> np.ndarray: pass + def __neg__(self): + return -1.0 * self + + @functools.singledispatchmethod + def __add__(self, other): + return NotImplemented + + @functools.singledispatchmethod + def __sub__(self, other): + return NotImplemented + + @functools.singledispatchmethod + def __mul__(self, other): + if np.ndim(other) == 0: + from ._algebra_fallbacks import ( # pylint: disable=import-outside-toplevel + ScaledFunction, + ) + + return ScaledFunction(function=self, scalar=other) + + return NotImplemented + + @functools.singledispatchmethod + def __rmul__(self, other): + if np.ndim(other) == 0: + from ._algebra_fallbacks import ( # pylint: disable=import-outside-toplevel + ScaledFunction, + ) + + return ScaledFunction(function=self, scalar=other) + + return NotImplemented + class LambdaFunction(Function): """Define a :class:`Function` from a given :class:`callable`. @@ -131,7 +167,7 @@ class LambdaFunction(Function): Examples -------- >>> import numpy as np - >>> from probnum import LambdaFunction + >>> from probnum.functions import LambdaFunction >>> fn = LambdaFunction(fn=lambda x: 2 * x + 1, input_shape=(2,), output_shape=(2,)) >>> fn(np.array([[1, 2], [4, 5]])) array([[ 3, 5], diff --git a/src/probnum/randprocs/mean_fns.py b/src/probnum/functions/_zero.py similarity index 53% rename from src/probnum/randprocs/mean_fns.py rename to src/probnum/functions/_zero.py index f1ec572ea..2c6ae01ea 100644 --- a/src/probnum/randprocs/mean_fns.py +++ b/src/probnum/functions/_zero.py @@ -1,10 +1,10 @@ -"""Mean functions of random processes.""" +"""The zero function.""" -import numpy as np +import functools -from .. import _function +import numpy as np -__all__ = ["Zero"] +from . import _function class Zero(_function.Function): @@ -15,3 +15,11 @@ def _evaluate(self, x: np.ndarray) -> np.ndarray: x, shape=x.shape[: x.ndim - self._input_ndim] + self._output_shape, ) + + @functools.singledispatchmethod + def __add__(self, other): + return super().__add__(other) + + @functools.singledispatchmethod + def __sub__(self, other): + return super().__sub__(other) diff --git a/src/probnum/randprocs/__init__.py b/src/probnum/randprocs/__init__.py index c5a2ace73..99c66126a 100644 --- a/src/probnum/randprocs/__init__.py +++ b/src/probnum/randprocs/__init__.py @@ -6,7 +6,7 @@ functions with stochastic output. """ -from . import kernels, mean_fns +from . import kernels from ._gaussian_process import GaussianProcess from ._random_process import RandomProcess diff --git a/src/probnum/randprocs/_gaussian_process.py b/src/probnum/randprocs/_gaussian_process.py index 866114c28..4844e2cb3 100644 --- a/src/probnum/randprocs/_gaussian_process.py +++ b/src/probnum/randprocs/_gaussian_process.py @@ -8,7 +8,7 @@ from probnum.typing import ArrayLike from . import _random_process, kernels -from .. import _function +from .. import functions class GaussianProcess(_random_process.RandomProcess[ArrayLike, np.ndarray]): @@ -35,7 +35,7 @@ class GaussianProcess(_random_process.RandomProcess[ArrayLike, np.ndarray]): Define a Gaussian process with a zero mean function and RBF kernel. >>> import numpy as np - >>> from probnum.randprocs.mean_fns import Zero + >>> from probnum.functions import Zero >>> from probnum.randprocs.kernels import ExpQuad >>> from probnum.randprocs import GaussianProcess >>> mu = Zero(input_shape=()) # zero-mean function @@ -58,10 +58,10 @@ class GaussianProcess(_random_process.RandomProcess[ArrayLike, np.ndarray]): def __init__( self, - mean: _function.Function, + mean: functions.Function, cov: kernels.Kernel, ): - if not isinstance(mean, _function.Function): + if not isinstance(mean, functions.Function): raise TypeError("The mean function must have type `probnum.Function`.") super().__init__( diff --git a/src/probnum/randprocs/_random_process.py b/src/probnum/randprocs/_random_process.py index c87d18f04..559c25e00 100644 --- a/src/probnum/randprocs/_random_process.py +++ b/src/probnum/randprocs/_random_process.py @@ -7,7 +7,7 @@ import numpy as np -from probnum import _function, randvars, utils as _utils +from probnum import functions, randvars, utils as _utils from probnum.randprocs import kernels from probnum.typing import DTypeLike, ShapeLike, ShapeType @@ -56,7 +56,7 @@ def __init__( input_shape: ShapeLike, output_shape: ShapeLike, dtype: DTypeLike, - mean: Optional[_function.Function] = None, + mean: Optional[functions.Function] = None, cov: Optional[kernels.Kernel] = None, ): self._input_shape = _utils.as_shape(input_shape) @@ -75,7 +75,7 @@ def __init__( # Mean function if mean is not None: - if not isinstance(mean, _function.Function): + if not isinstance(mean, functions.Function): raise TypeError("The mean function must have type `probnum.Function`.") if mean.input_shape != self._input_shape: @@ -177,7 +177,7 @@ def marginal(self, args: InputType) -> randvars._RandomVariableList: raise NotImplementedError @property - def mean(self) -> _function.Function: + def mean(self) -> functions.Function: r"""Mean function :math:`m(x) := \mathbb{E}[f(x)]` of the random process.""" if self._mean is None: raise NotImplementedError diff --git a/src/probnum/randprocs/markov/_markov.py b/src/probnum/randprocs/markov/_markov.py index f0320e0f5..bb0d89dbf 100644 --- a/src/probnum/randprocs/markov/_markov.py +++ b/src/probnum/randprocs/markov/_markov.py @@ -5,7 +5,7 @@ import numpy as np import scipy.stats -from probnum import _function, randvars, utils +from probnum import functions, randvars, utils from probnum.randprocs import _random_process, kernels from probnum.randprocs.markov import _transition, continuous, discrete from probnum.typing import ShapeLike @@ -31,7 +31,7 @@ def __init__( input_shape=input_shape, output_shape=output_shape, dtype=np.dtype(np.float_), - mean=_function.LambdaFunction( + mean=functions.LambdaFunction( lambda x: self.__call__(args=x).mean, input_shape=input_shape, output_shape=output_shape, diff --git a/tests/test_functions/__init__.py b/tests/test_functions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_functions/conftest.py b/tests/test_functions/conftest.py new file mode 100644 index 000000000..ff86e4bf2 --- /dev/null +++ b/tests/test_functions/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture(scope="module") +def seed() -> int: + return 234 diff --git a/tests/test_functions/test_algebra.py b/tests/test_functions/test_algebra.py new file mode 100644 index 000000000..873d8c646 --- /dev/null +++ b/tests/test_functions/test_algebra.py @@ -0,0 +1,118 @@ +import numpy as np +import pytest +from pytest_cases import param_fixture, param_fixtures + +from probnum import functions +from probnum.typing import ScalarLike, ShapeType + +lambda_fn_0 = functions.LambdaFunction( + lambda xs: ( + np.sin( + np.linspace(0.5, 2.0, 6).reshape((3, 2)) + * np.sum(xs**2, axis=-1)[..., None, None] + ) + ), + input_shape=(2,), + output_shape=(3, 2), +) + +lambda_fn_1 = functions.LambdaFunction( + lambda xs: ( + np.linspace(0.5, 2.0, 6).reshape((3, 2)) + * np.exp(-0.5 * np.sum(xs**2, axis=-1))[..., None, None] + ), + input_shape=(2,), + output_shape=(3, 2), +) + +op0, op1 = param_fixtures( + "op0, op1", + ( + pytest.param( + lambda_fn_0, + lambda_fn_1, + id="LambdaFunction-LambdaFunction", + ), + pytest.param( + lambda_fn_0, + functions.Zero(lambda_fn_0.input_shape, lambda_fn_1.output_shape), + id="LambdaFunction-Zero", + ), + pytest.param( + functions.Zero(lambda_fn_0.input_shape, lambda_fn_1.output_shape), + lambda_fn_0, + id="Zero-LambdaFunction", + ), + pytest.param( + functions.Zero((3, 3), ()), + functions.Zero((3, 3), ()), + id="Zero-Zero", + ), + ), +) + +batch_shape = param_fixture("batch_shape", ((), (3,), (2, 1, 2))) + + +def test_add_evaluation( + op0: functions.Function, op1: functions.Function, batch_shape: ShapeType, seed: int +): + fn_add = op0 + op1 + + rng = np.random.default_rng(seed) + xs = rng.uniform(-1.0, 1.0, batch_shape + op0.input_shape) + + np.testing.assert_array_equal( + fn_add(xs), + op0(xs) + op1(xs), + ) + + +def test_sub_evaluation( + op0: functions.Function, op1: functions.Function, batch_shape: ShapeType, seed: int +): + fn_sub = op0 - op1 + + rng = np.random.default_rng(seed) + xs = rng.uniform(-1.0, 1.0, batch_shape + op0.input_shape) + + np.testing.assert_array_equal( + fn_sub(xs), + op0(xs) - op1(xs), + ) + + +@pytest.mark.parametrize("scalar", [1.0, 3, 1000.0]) +def test_mul_scalar_evaluation( + op0: functions.Function, + scalar: ScalarLike, + batch_shape: ShapeType, + seed: int, +): + fn_scaled = op0 * scalar + + rng = np.random.default_rng(seed) + xs = rng.uniform(-1.0, 1.0, batch_shape + op0.input_shape) + + np.testing.assert_array_equal( + fn_scaled(xs), + op0(xs) * scalar, + ) + + +@pytest.mark.parametrize("scalar", [1.0, 3, 1000.0]) +def test_rmul_scalar_evaluation( + op0: functions.Function, + scalar: ScalarLike, + batch_shape: ShapeType, + seed: int, +): + fn_scaled = scalar * op0 + + rng = np.random.default_rng(seed) + xs = rng.uniform(-1.0, 1.0, batch_shape + op0.input_shape) + + np.testing.assert_array_equal( + fn_scaled(xs), + scalar * op0(xs), + ) diff --git a/tests/test_functions/test_algebra_fallbacks.py b/tests/test_functions/test_algebra_fallbacks.py new file mode 100644 index 000000000..b3bc3f51b --- /dev/null +++ b/tests/test_functions/test_algebra_fallbacks.py @@ -0,0 +1,86 @@ +import numpy as np +import pytest + +from probnum import functions + + +@pytest.fixture(scope="module") +def fn0() -> functions.LambdaFunction: + return functions.LambdaFunction( + lambda xs: ((np.linspace(0.5, 2.0, 6).reshape(3, 2) @ xs[..., None])[..., 0]), + input_shape=2, + output_shape=3, + ) + + +@pytest.fixture(scope="module") +def fn1() -> functions.LambdaFunction: + return functions.LambdaFunction( + lambda xs: np.zeros(shape=xs.shape[:-1]), + input_shape=2, + output_shape=3, + ) + + +def test_scaling_lambda_raises_error(): + with pytest.raises(TypeError): + functions.ScaledFunction(lambda x: 2.0 * x, scalar=2.0) + + +def test_sum_lambda_raises_error(fn1: functions.Function): + with pytest.raises(TypeError): + functions.SumFunction(lambda x: 2.0 * x, fn1) + + +def test_sum_function_contracts(fn0: functions.Function, fn1: functions.Function): + sum_fn = (fn0 + (fn1 + fn0)) - fn1 + fn0 + (fn0 + fn1) + + assert isinstance(sum_fn, functions.SumFunction) + assert len(sum_fn.summands) == 7 + assert sum_fn.summands[0] is fn0 + assert sum_fn.summands[1] is fn1 + assert sum_fn.summands[2] is fn0 + assert ( + isinstance(sum_fn.summands[3], functions.ScaledFunction) + and sum_fn.summands[3].function is fn1 + and sum_fn.summands[3].scalar == -1 + ) + assert sum_fn.summands[4] is fn0 + assert sum_fn.summands[5] is fn0 + assert sum_fn.summands[6] is fn1 + + +def test_sum_function_input_shape_mismatch_raises_error(fn0: functions.Function): + fn_err = functions.LambdaFunction( + lambda x: np.zeros(fn0.output_shape), + input_shape=(), + output_shape=fn0.output_shape, + ) + + with pytest.raises(ValueError): + fn0 + fn_err # pylint: disable=pointless-statement + + +def test_sum_function_output_shape_mismatch_raises_error(fn0: functions.Function): + fn_err = functions.LambdaFunction( + lambda x: np.zeros(()), + input_shape=fn0.input_shape, + output_shape=(), + ) + + with pytest.raises(ValueError): + fn0 + fn_err # pylint: disable=pointless-statement + + +def test_scaled_function_contracts(fn0: functions.Function): + scaled_fn_mul = -fn0 * 2.0 + + assert isinstance(scaled_fn_mul, functions.ScaledFunction) + assert scaled_fn_mul.function is fn0 + assert scaled_fn_mul.scalar == -2.0 + + scaled_fn_rmul = 2.0 * -fn0 + + assert isinstance(scaled_fn_rmul, functions.ScaledFunction) + assert scaled_fn_rmul.function is fn0 + assert scaled_fn_rmul.scalar == -2.0 diff --git a/tests/test_function.py b/tests/test_functions/test_function.py similarity index 86% rename from tests/test_function.py rename to tests/test_functions/test_function.py index ee39878b1..e0242ace1 100644 --- a/tests/test_function.py +++ b/tests/test_functions/test_function.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from probnum import LambdaFunction +from probnum.functions import LambdaFunction def test_input_shape_mismatch_raises_error(): diff --git a/tests/test_randprocs/conftest.py b/tests/test_randprocs/conftest.py index f38c0b62f..a0f56efa0 100644 --- a/tests/test_randprocs/conftest.py +++ b/tests/test_randprocs/conftest.py @@ -6,8 +6,8 @@ import numpy as np import pytest -from probnum import LambdaFunction, randprocs -from probnum.randprocs import kernels, mean_fns +from probnum import functions, randprocs +from probnum.randprocs import kernels @pytest.fixture( @@ -44,10 +44,12 @@ def output_dim(request) -> int: params=[ pytest.param(mu, id=mu[0]) for mu in [ - ("zero", mean_fns.Zero), + ("zero", functions.Zero), ( "lin", - functools.partial(LambdaFunction, lambda x: 2 * x.sum(axis=1) + 1.0), + functools.partial( + functions.LambdaFunction, lambda x: 2 * x.sum(axis=1) + 1.0 + ), ), ] ], @@ -82,7 +84,7 @@ def fixture_cov(request, input_dim: int) -> kernels.Kernel: ( "gp", randprocs.GaussianProcess( - mean=mean_fns.Zero(input_shape=(1,)), + mean=functions.Zero(input_shape=(1,)), cov=kernels.Matern(input_shape=(1,)), ), ), diff --git a/tests/test_randprocs/test_gaussian_process.py b/tests/test_randprocs/test_gaussian_process.py index 61e8723ff..885457047 100644 --- a/tests/test_randprocs/test_gaussian_process.py +++ b/tests/test_randprocs/test_gaussian_process.py @@ -3,8 +3,8 @@ import numpy as np import pytest -from probnum import randprocs, randvars -from probnum.randprocs import kernels, mean_fns +from probnum import functions, randprocs, randvars +from probnum.randprocs import kernels def test_mean_not_function_raises_error(): @@ -20,20 +20,20 @@ def test_cov_not_kernel_raises_error(): TypeError.""" with pytest.raises(TypeError): randprocs.GaussianProcess( - mean=mean_fns.Zero(input_shape=(1,), output_shape=(1,)), cov=np.dot + mean=functions.Zero(input_shape=(1,), output_shape=(1,)), cov=np.dot ) def test_mean_kernel_shape_mismatch_raises_error(): with pytest.raises(ValueError): randprocs.GaussianProcess( - mean=mean_fns.Zero(input_shape=(2,), output_shape=(1,)), + mean=functions.Zero(input_shape=(2,), output_shape=(1,)), cov=kernels.ExpQuad(input_shape=(3,)), ) with pytest.raises(ValueError): randprocs.GaussianProcess( - mean=mean_fns.Zero(input_shape=(2,), output_shape=(2,)), + mean=functions.Zero(input_shape=(2,), output_shape=(2,)), cov=kernels.ExpQuad(input_shape=(2,)), ) @@ -41,13 +41,13 @@ def test_mean_kernel_shape_mismatch_raises_error(): def test_mean_wrong_input_shape_raises_error(): with pytest.raises(ValueError): randprocs.GaussianProcess( - mean=mean_fns.Zero(input_shape=(2, 2), output_shape=(1,)), + mean=functions.Zero(input_shape=(2, 2), output_shape=(1,)), cov=kernels.ExpQuad(input_shape=(2,)), ) with pytest.raises(ValueError): randprocs.GaussianProcess( - mean=mean_fns.Zero(input_shape=(2,), output_shape=(2, 1)), + mean=functions.Zero(input_shape=(2,), output_shape=(2, 1)), cov=kernels.ExpQuad(input_shape=(2,)), ) diff --git a/tests/test_randprocs/test_random_process.py b/tests/test_randprocs/test_random_process.py index d502cac87..f7d76a2d2 100644 --- a/tests/test_randprocs/test_random_process.py +++ b/tests/test_randprocs/test_random_process.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from probnum import randprocs, randvars +from probnum import functions, randprocs, randvars # pylint: disable=invalid-name @@ -129,7 +129,7 @@ def test_inconsistent_mean_shape_errors(): input_shape=(42,), output_shape=(), dtype=np.double, - mean=randprocs.mean_fns.Zero( + mean=functions.Zero( input_shape=(3,), output_shape=(3,), ), @@ -140,7 +140,7 @@ def test_inconsistent_mean_shape_errors(): input_shape=(), output_shape=(1,), dtype=np.double, - mean=randprocs.mean_fns.Zero( + mean=functions.Zero( input_shape=(), output_shape=(3,), ),