Skip to content

Commit

Permalink
Algebraic Operations on Functions (#725)
Browse files Browse the repository at this point in the history
* `Function.__{neg,rmul,add,sub}__`

* Switch function arithmetic to singledispatch

* `arithmetic` -> `algebra`

* Additional tests

* Add input checks and error messages

* Add documentation

* Improvements to the tests

* Reorder algebra fallbacks

* Module docstrings

* `pn._function` -> `pn.functions`

* `black`` fix

* Bugfixes

* Reconfigure coverage report

* Docfixes

* Additional tests

* Overload algebraic ops for the Zero function

* black fix

* Apply suggestions from code review

Co-authored-by: Jonathan Wenger <[email protected]>

* Update docs/source/api.rst

Co-authored-by: Jonathan Wenger <[email protected]>

* Bugfix in test

* Add `Function.__mul__` operatot

* PyLint fix

Co-authored-by: Jonathan Wenger <[email protected]>
  • Loading branch information
marvinpfoertner and JonathanWenger authored Sep 1, 2022
1 parent 798c23a commit c9957f3
Show file tree
Hide file tree
Showing 23 changed files with 520 additions and 48 deletions.
3 changes: 3 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ API Reference
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.filtsmooth` | Bayesian filtering and smoothing. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.functions` | Callables with in- and output shape information. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.linalg` | Probabilistic numerical linear algebra. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.linops` | Finite-dimensional linear operators. |
Expand All @@ -39,6 +41,7 @@ API Reference
api/config
api/diffeq
api/filtsmooth
api/functions
api/linalg
api/linops
api/problems
Expand Down
7 changes: 7 additions & 0 deletions docs/source/api/functions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
*****************
probnum.functions
*****************

.. automodapi:: probnum.functions
:no-heading:
:headings: "="
1 change: 0 additions & 1 deletion docs/source/api/randprocs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ probnum.randprocs
:hidden:

randprocs/markov
randprocs/mean_fns
randprocs/kernels
7 changes: 0 additions & 7 deletions docs/source/api/randprocs/mean_fns.rst

This file was deleted.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ exclude_lines = [
# Don't complain if non-runnable code isn't run:
'if 0:',
'if __name__ == .__main__.:',
# Don't complain if operator's are not overloaded
'return NotImplemented'
]

################################################################################
Expand Down
6 changes: 1 addition & 5 deletions src/probnum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import (
diffeq,
filtsmooth,
functions,
linalg,
linops,
problems,
Expand All @@ -34,23 +35,18 @@
randvars,
utils,
)
from ._function import Function, LambdaFunction
from ._version import version as __version__
from .randvars import asrandvar

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"asrandvar",
"Function",
"LambdaFunction",
"ProbabilisticNumericalMethod",
"StoppingCriterion",
"LambdaStoppingCriterion",
]

# Set correct module paths. Corrects links and module paths in documentation.
Function.__module__ = "probnum"
LambdaFunction.__module__ = "probnum"
ProbabilisticNumericalMethod.__module__ = "probnum"
StoppingCriterion.__module__ = "probnum"
LambdaStoppingCriterion.__module__ = "probnum"
6 changes: 6 additions & 0 deletions src/probnum/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Callables with in- and output shape information supporting algebraic operations."""

from . import _algebra
from ._algebra_fallbacks import ScaledFunction, SumFunction
from ._function import Function, LambdaFunction
from ._zero import Zero
69 changes: 69 additions & 0 deletions src/probnum/functions/_algebra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
r"""Algebraic operations on :class:`Function`\ s."""

from ._algebra_fallbacks import SumFunction
from ._function import Function
from ._zero import Zero

############
# Function #
############


@Function.__add__.register # pylint: disable=no-member
def _(self, other: Function) -> SumFunction:
return SumFunction(self, other)


@Function.__add__.register # pylint: disable=no-member
def _(self, other: SumFunction) -> SumFunction:
return SumFunction(self, *other.summands)


@Function.__add__.register # pylint: disable=no-member
def _(self, other: Zero) -> Function: # pylint: disable=unused-argument
return self


@Function.__sub__.register # pylint: disable=no-member
def _(self, other: Function) -> SumFunction:
return SumFunction(self, -other)


@Function.__sub__.register # pylint: disable=no-member
def _(self, other: Zero) -> Function: # pylint: disable=unused-argument
return self


###############
# SumFunction #
###############


@SumFunction.__add__.register # pylint: disable=no-member
def _(self, other: Function) -> SumFunction:
return SumFunction(*self.summands, other)


@SumFunction.__add__.register # pylint: disable=no-member
def _(self, other: SumFunction) -> SumFunction:
return SumFunction(*self.summands, *other.summands)


@SumFunction.__sub__.register # pylint: disable=no-member
def _(self, other: Function) -> SumFunction:
return SumFunction(*self.summands, -other)


########
# Zero #
########


@Zero.__add__.register # pylint: disable=no-member
def _(self, other: Function) -> Function: # pylint: disable=unused-argument
return other


@Zero.__sub__.register # pylint: disable=no-member
def _(self, other: Function) -> Function: # pylint: disable=unused-argument
return -other
141 changes: 141 additions & 0 deletions src/probnum/functions/_algebra_fallbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
r"""Fallback implementation for algebraic operations on :class:`Function`\ s."""

from __future__ import annotations

import functools
import operator

import numpy as np

from probnum import utils
from probnum.typing import ScalarLike, ScalarType

from ._function import Function


class SumFunction(Function):
r"""Pointwise sum of :class:`Function`\ s.
Given functions :math:`f_1, \dotsc, f_n \colon \mathbb{R}^n \to \mathbb{R}^m`, this
defines a new function
.. math::
\sum_{i = 1}^n f_i \colon \mathbb{R}^n \to \mathbb{R}^m,
x \mapsto \sum_{i = 1}^n f_i(x).
Parameters
----------
*summands
The functions :math:`f_1, \dotsc, f_n`.
"""

def __init__(self, *summands: Function) -> None:
if not all(isinstance(summand, Function) for summand in summands):
raise TypeError(
"The functions to be added must be objects of type `Function`."
)

if not all(
summand.input_shape == summands[0].input_shape for summand in summands
):
raise ValueError(
"The functions to be added must all have the same input shape."
)

if not all(
summand.output_shape == summands[0].output_shape for summand in summands
):
raise ValueError(
"The functions to be added must all have the same output shape."
)

self._summands = summands

super().__init__(
input_shape=summands[0].input_shape,
output_shape=summands[0].output_shape,
)

@property
def summands(self) -> tuple[SumFunction, ...]:
r"""The functions :math:`f_1, \dotsc, f_n` to be added."""
return self._summands

def _evaluate(self, x: np.ndarray) -> np.ndarray:
return functools.reduce(
operator.add, (summand(x) for summand in self._summands)
)

@functools.singledispatchmethod
def __add__(self, other):
return super().__add__(other)

@functools.singledispatchmethod
def __sub__(self, other):
return super().__sub__(other)


class ScaledFunction(Function):
r"""Function multiplied pointwise with a scalar.
Given a function :math:`f \colon \mathbb{R}^n \to \mathbb{R}^m` and a scalar
:math:`\alpha \in \mathbb{R}`, this defines a new function
.. math::
\alpha f \colon \mathbb{R}^n \to \mathbb{R}^m,
x \mapsto (\alpha f)(x) = \alpha f(x).
Parameters
----------
function
The function :math:`f`.
scalar
The scalar :math:`\alpha`.
"""

def __init__(self, function: Function, scalar: ScalarLike):
if not isinstance(function, Function):
raise TypeError(
"The function to be scaled must be an object of type `Function`."
)

self._function = function
self._scalar = utils.as_numpy_scalar(scalar)

super().__init__(
input_shape=self._function.input_shape,
output_shape=self._function.output_shape,
)

@property
def function(self) -> Function:
r"""The function :math:`f`."""
return self._function

@property
def scalar(self) -> ScalarType:
r"""The scalar :math:`\alpha`."""
return self._scalar

def _evaluate(self, x: np.ndarray) -> np.ndarray:
return self._scalar * self._function(x)

@functools.singledispatchmethod
def __mul__(self, other):
if np.ndim(other) == 0:
return ScaledFunction(
function=self._function,
scalar=self._scalar * np.asarray(other),
)

return super().__mul__(other)

@functools.singledispatchmethod
def __rmul__(self, other):
if np.ndim(other) == 0:
return ScaledFunction(
function=self._function,
scalar=np.asarray(other) * self._scalar,
)

return super().__rmul__(other)
44 changes: 40 additions & 4 deletions src/probnum/_function.py → src/probnum/functions/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from __future__ import annotations

import abc
import functools
from typing import Callable

import numpy as np

from . import utils
from .typing import ArrayLike, ShapeLike, ShapeType
from probnum import utils
from probnum.typing import ArrayLike, ShapeLike, ShapeType


class Function(abc.ABC):
Expand All @@ -17,6 +18,8 @@ class Function(abc.ABC):
This class represents a, uni- or multivariate, scalar- or tensor-valued,
mathematical function. Hence, the call method should not have any observable
side-effects.
Instances of this class can be added and multiplied by a scalar, which means that
they are elements of a vector space.
Parameters
----------
Expand All @@ -29,7 +32,7 @@ class Function(abc.ABC):
See Also
--------
LambdaFunction : Define a :class:`Function` from an anonymous function.
~probnum.randprocs.mean_fns.Zero : Zero mean function of a random process.
~probnum.functions.Zero : Zero function.
"""

def __init__(self, input_shape: ShapeLike, output_shape: ShapeLike = ()) -> None:
Expand Down Expand Up @@ -112,6 +115,39 @@ def __call__(self, x: ArrayLike) -> np.ndarray:
def _evaluate(self, x: np.ndarray) -> np.ndarray:
pass

def __neg__(self):
return -1.0 * self

@functools.singledispatchmethod
def __add__(self, other):
return NotImplemented

@functools.singledispatchmethod
def __sub__(self, other):
return NotImplemented

@functools.singledispatchmethod
def __mul__(self, other):
if np.ndim(other) == 0:
from ._algebra_fallbacks import ( # pylint: disable=import-outside-toplevel
ScaledFunction,
)

return ScaledFunction(function=self, scalar=other)

return NotImplemented

@functools.singledispatchmethod
def __rmul__(self, other):
if np.ndim(other) == 0:
from ._algebra_fallbacks import ( # pylint: disable=import-outside-toplevel
ScaledFunction,
)

return ScaledFunction(function=self, scalar=other)

return NotImplemented


class LambdaFunction(Function):
"""Define a :class:`Function` from a given :class:`callable`.
Expand All @@ -131,7 +167,7 @@ class LambdaFunction(Function):
Examples
--------
>>> import numpy as np
>>> from probnum import LambdaFunction
>>> from probnum.functions import LambdaFunction
>>> fn = LambdaFunction(fn=lambda x: 2 * x + 1, input_shape=(2,), output_shape=(2,))
>>> fn(np.array([[1, 2], [4, 5]]))
array([[ 3, 5],
Expand Down
Loading

0 comments on commit c9957f3

Please sign in to comment.