-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Algebraic Operations on
Functions
(#725)
* `Function.__{neg,rmul,add,sub}__` * Switch function arithmetic to singledispatch * `arithmetic` -> `algebra` * Additional tests * Add input checks and error messages * Add documentation * Improvements to the tests * Reorder algebra fallbacks * Module docstrings * `pn._function` -> `pn.functions` * `black`` fix * Bugfixes * Reconfigure coverage report * Docfixes * Additional tests * Overload algebraic ops for the Zero function * black fix * Apply suggestions from code review Co-authored-by: Jonathan Wenger <[email protected]> * Update docs/source/api.rst Co-authored-by: Jonathan Wenger <[email protected]> * Bugfix in test * Add `Function.__mul__` operatot * PyLint fix Co-authored-by: Jonathan Wenger <[email protected]>
- Loading branch information
1 parent
798c23a
commit c9957f3
Showing
23 changed files
with
520 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
***************** | ||
probnum.functions | ||
***************** | ||
|
||
.. automodapi:: probnum.functions | ||
:no-heading: | ||
:headings: "=" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,5 +10,4 @@ probnum.randprocs | |
:hidden: | ||
|
||
randprocs/markov | ||
randprocs/mean_fns | ||
randprocs/kernels |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""Callables with in- and output shape information supporting algebraic operations.""" | ||
|
||
from . import _algebra | ||
from ._algebra_fallbacks import ScaledFunction, SumFunction | ||
from ._function import Function, LambdaFunction | ||
from ._zero import Zero |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
r"""Algebraic operations on :class:`Function`\ s.""" | ||
|
||
from ._algebra_fallbacks import SumFunction | ||
from ._function import Function | ||
from ._zero import Zero | ||
|
||
############ | ||
# Function # | ||
############ | ||
|
||
|
||
@Function.__add__.register # pylint: disable=no-member | ||
def _(self, other: Function) -> SumFunction: | ||
return SumFunction(self, other) | ||
|
||
|
||
@Function.__add__.register # pylint: disable=no-member | ||
def _(self, other: SumFunction) -> SumFunction: | ||
return SumFunction(self, *other.summands) | ||
|
||
|
||
@Function.__add__.register # pylint: disable=no-member | ||
def _(self, other: Zero) -> Function: # pylint: disable=unused-argument | ||
return self | ||
|
||
|
||
@Function.__sub__.register # pylint: disable=no-member | ||
def _(self, other: Function) -> SumFunction: | ||
return SumFunction(self, -other) | ||
|
||
|
||
@Function.__sub__.register # pylint: disable=no-member | ||
def _(self, other: Zero) -> Function: # pylint: disable=unused-argument | ||
return self | ||
|
||
|
||
############### | ||
# SumFunction # | ||
############### | ||
|
||
|
||
@SumFunction.__add__.register # pylint: disable=no-member | ||
def _(self, other: Function) -> SumFunction: | ||
return SumFunction(*self.summands, other) | ||
|
||
|
||
@SumFunction.__add__.register # pylint: disable=no-member | ||
def _(self, other: SumFunction) -> SumFunction: | ||
return SumFunction(*self.summands, *other.summands) | ||
|
||
|
||
@SumFunction.__sub__.register # pylint: disable=no-member | ||
def _(self, other: Function) -> SumFunction: | ||
return SumFunction(*self.summands, -other) | ||
|
||
|
||
######## | ||
# Zero # | ||
######## | ||
|
||
|
||
@Zero.__add__.register # pylint: disable=no-member | ||
def _(self, other: Function) -> Function: # pylint: disable=unused-argument | ||
return other | ||
|
||
|
||
@Zero.__sub__.register # pylint: disable=no-member | ||
def _(self, other: Function) -> Function: # pylint: disable=unused-argument | ||
return -other |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
r"""Fallback implementation for algebraic operations on :class:`Function`\ s.""" | ||
|
||
from __future__ import annotations | ||
|
||
import functools | ||
import operator | ||
|
||
import numpy as np | ||
|
||
from probnum import utils | ||
from probnum.typing import ScalarLike, ScalarType | ||
|
||
from ._function import Function | ||
|
||
|
||
class SumFunction(Function): | ||
r"""Pointwise sum of :class:`Function`\ s. | ||
Given functions :math:`f_1, \dotsc, f_n \colon \mathbb{R}^n \to \mathbb{R}^m`, this | ||
defines a new function | ||
.. math:: | ||
\sum_{i = 1}^n f_i \colon \mathbb{R}^n \to \mathbb{R}^m, | ||
x \mapsto \sum_{i = 1}^n f_i(x). | ||
Parameters | ||
---------- | ||
*summands | ||
The functions :math:`f_1, \dotsc, f_n`. | ||
""" | ||
|
||
def __init__(self, *summands: Function) -> None: | ||
if not all(isinstance(summand, Function) for summand in summands): | ||
raise TypeError( | ||
"The functions to be added must be objects of type `Function`." | ||
) | ||
|
||
if not all( | ||
summand.input_shape == summands[0].input_shape for summand in summands | ||
): | ||
raise ValueError( | ||
"The functions to be added must all have the same input shape." | ||
) | ||
|
||
if not all( | ||
summand.output_shape == summands[0].output_shape for summand in summands | ||
): | ||
raise ValueError( | ||
"The functions to be added must all have the same output shape." | ||
) | ||
|
||
self._summands = summands | ||
|
||
super().__init__( | ||
input_shape=summands[0].input_shape, | ||
output_shape=summands[0].output_shape, | ||
) | ||
|
||
@property | ||
def summands(self) -> tuple[SumFunction, ...]: | ||
r"""The functions :math:`f_1, \dotsc, f_n` to be added.""" | ||
return self._summands | ||
|
||
def _evaluate(self, x: np.ndarray) -> np.ndarray: | ||
return functools.reduce( | ||
operator.add, (summand(x) for summand in self._summands) | ||
) | ||
|
||
@functools.singledispatchmethod | ||
def __add__(self, other): | ||
return super().__add__(other) | ||
|
||
@functools.singledispatchmethod | ||
def __sub__(self, other): | ||
return super().__sub__(other) | ||
|
||
|
||
class ScaledFunction(Function): | ||
r"""Function multiplied pointwise with a scalar. | ||
Given a function :math:`f \colon \mathbb{R}^n \to \mathbb{R}^m` and a scalar | ||
:math:`\alpha \in \mathbb{R}`, this defines a new function | ||
.. math:: | ||
\alpha f \colon \mathbb{R}^n \to \mathbb{R}^m, | ||
x \mapsto (\alpha f)(x) = \alpha f(x). | ||
Parameters | ||
---------- | ||
function | ||
The function :math:`f`. | ||
scalar | ||
The scalar :math:`\alpha`. | ||
""" | ||
|
||
def __init__(self, function: Function, scalar: ScalarLike): | ||
if not isinstance(function, Function): | ||
raise TypeError( | ||
"The function to be scaled must be an object of type `Function`." | ||
) | ||
|
||
self._function = function | ||
self._scalar = utils.as_numpy_scalar(scalar) | ||
|
||
super().__init__( | ||
input_shape=self._function.input_shape, | ||
output_shape=self._function.output_shape, | ||
) | ||
|
||
@property | ||
def function(self) -> Function: | ||
r"""The function :math:`f`.""" | ||
return self._function | ||
|
||
@property | ||
def scalar(self) -> ScalarType: | ||
r"""The scalar :math:`\alpha`.""" | ||
return self._scalar | ||
|
||
def _evaluate(self, x: np.ndarray) -> np.ndarray: | ||
return self._scalar * self._function(x) | ||
|
||
@functools.singledispatchmethod | ||
def __mul__(self, other): | ||
if np.ndim(other) == 0: | ||
return ScaledFunction( | ||
function=self._function, | ||
scalar=self._scalar * np.asarray(other), | ||
) | ||
|
||
return super().__mul__(other) | ||
|
||
@functools.singledispatchmethod | ||
def __rmul__(self, other): | ||
if np.ndim(other) == 0: | ||
return ScaledFunction( | ||
function=self._function, | ||
scalar=np.asarray(other) * self._scalar, | ||
) | ||
|
||
return super().__rmul__(other) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.