Skip to content

Commit

Permalink
Make quad policies stateless (#744)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmahsereci authored Nov 24, 2022
1 parent 92e59ad commit 8dde541
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 68 deletions.
15 changes: 8 additions & 7 deletions src/probnum/quad/_bayesquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def bayesquad(
var_tol: Optional[FloatLike] = None,
rel_tol: Optional[FloatLike] = None,
batch_size: IntLike = 1,
rng: Optional[np.random.Generator] = np.random.default_rng(),
rng: Optional[np.random.Generator] = None,
jitter: FloatLike = 1.0e-8,
) -> Tuple[Normal, BQIterInfo]:
r"""Infer the solution of the uni- or multivariate integral
Expand Down Expand Up @@ -100,7 +100,7 @@ def bayesquad(
Number of new observations at each update. Defaults to 1.
rng
Random number generator. Used by Bayesian Monte Carlo other random sampling
policies. Optional. Default is `np.random.default_rng()`.
policies.
jitter
Non-negative jitter to numerically stabilise kernel matrix inversion.
Defaults to 1e-8.
Expand Down Expand Up @@ -145,9 +145,9 @@ def bayesquad(
>>> input_dim = 1
>>> domain = (0, 1)
>>> def f(x):
>>> def fun(x):
... return x.reshape(-1, )
>>> F, info = bayesquad(fun=f, input_dim=input_dim, domain=domain)
>>> F, info = bayesquad(fun, input_dim, domain=domain, rng=np.random.default_rng(0))
>>> print(F.mean)
0.5
"""
Expand All @@ -167,12 +167,13 @@ def bayesquad(
var_tol=var_tol,
rel_tol=rel_tol,
batch_size=batch_size,
rng=rng,
jitter=jitter,
)

# Integrate
integral_belief, _, info = bq_method.integrate(fun=fun, nodes=None, fun_evals=None)
integral_belief, _, info = bq_method.integrate(
fun=fun, nodes=None, fun_evals=None, rng=rng
)

return integral_belief, info

Expand Down Expand Up @@ -261,7 +262,7 @@ def bayesquad_from_data(

# Integrate
integral_belief, _, info = bq_method.integrate(
fun=None, nodes=nodes, fun_evals=fun_evals
fun=None, nodes=nodes, fun_evals=fun_evals, rng=None
)

return integral_belief, info
Expand Down
43 changes: 23 additions & 20 deletions src/probnum/quad/solvers/_bayesian_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def from_problem(
var_tol: Optional[FloatLike] = None,
rel_tol: Optional[FloatLike] = None,
batch_size: IntLike = 1,
rng: np.random.Generator = None,
jitter: FloatLike = 1.0e-8,
) -> "BayesianQuadrature":

Expand Down Expand Up @@ -112,8 +111,6 @@ def from_problem(
Relative tolerance as stopping criterion.
batch_size
Batch size used in node acquisition. Defaults to 1.
rng
The random number generator.
jitter
Non-negative jitter to numerically stabilise kernel matrix inversion.
Defaults to 1e-8.
Expand All @@ -127,9 +124,6 @@ def from_problem(
------
ValueError
If neither a ``domain`` nor a ``measure`` are given.
ValueError
If Bayesian Monte Carlo ('bmc') is selected as ``policy`` and no random
number generator (``rng``) is given.
NotImplementedError
If an unknown ``policy`` is given.
"""
Expand All @@ -153,15 +147,9 @@ def from_problem(
# require an acquisition loop. The error handling is done in ``integrate``.
pass
elif policy == "bmc":
if rng is None:
errormsg = (
"Policy 'bmc' relies on random sampling, "
"thus requires a random number generator ('rng')."
)
raise ValueError(errormsg)
policy = RandomPolicy(measure.sample, batch_size=batch_size, rng=rng)
policy = RandomPolicy(batch_size, measure.sample)
elif policy == "vdc":
policy = VanDerCorputPolicy(measure=measure, batch_size=batch_size)
policy = VanDerCorputPolicy(batch_size, measure)
else:
raise NotImplementedError(f"The given policy ({policy}) is unknown.")

Expand Down Expand Up @@ -215,6 +203,7 @@ def bq_iterator(
bq_state: BQState,
info: Optional[BQIterInfo],
fun: Optional[Callable],
rng: Optional[np.random.Generator],
) -> Tuple[Normal, BQState, BQIterInfo]:
"""Generator that implements the iteration of the BQ method.
Expand All @@ -231,6 +220,8 @@ def bq_iterator(
fun
Function to be integrated. It needs to accept a shape=(n_eval, input_dim)
``np.ndarray`` and return a shape=(n_eval,) ``np.ndarray``.
rng
The random number generator used for random methods.
Yields
------
Expand Down Expand Up @@ -258,7 +249,7 @@ def bq_iterator(
break

# Select new nodes via policy
new_nodes = self.policy(bq_state=bq_state)
new_nodes = self.policy(bq_state, rng)

# Evaluate the integrand at new nodes
new_fun_evals = fun(new_nodes)
Expand All @@ -278,6 +269,7 @@ def integrate(
fun: Optional[Callable],
nodes: Optional[np.ndarray],
fun_evals: Optional[np.ndarray],
rng: Optional[np.random.Generator] = None,
) -> Tuple[Normal, BQState, BQIterInfo]:
"""Integrates the function ``fun``.
Expand All @@ -297,6 +289,8 @@ def integrate(
fun_evals
*shape=(n_eval,)* -- Optional function evaluations at ``nodes`` available
from the start.
rng
The random number generator used for random methods.
Returns
-------
Expand All @@ -308,14 +302,17 @@ def integrate(
Raises
------
ValueError
If neither the integrand function (``fun``) nor integrand evaluations
(``fun_evals``) are given.
If neither the integrand function ``fun`` nor integrand evaluations
``fun_evals`` are given.
ValueError
If ``nodes`` are not given and no policy is present.
If neither ``nodes`` nor ``policy`` is given.
ValueError
If dimension of ``nodes`` or ``fun_evals`` is incorrect, or if their
shapes do not match.
ValueError
If ``rng`` is not given but ``policy`` requires it.
"""

# no policy given: Integrate on fixed dataset.
if self.policy is None:
# nodes must be provided if no policy is given.
Expand All @@ -325,13 +322,19 @@ def integrate(
# Use fun_evals and disregard fun if both are given
if fun is not None and fun_evals is not None:
warnings.warn(
"No policy available: 'fun_eval' are used instead of 'fun'."
"No policy available: 'fun_evals' are used instead of 'fun'."
)
fun = None

# override stopping condition as no policy is given.
self.stopping_criterion = ImmediateStop()

elif self.policy.requires_rng and rng is None:
raise ValueError(
f"The policy '{self.policy.__class__.__name__}' requires a random "
f"number generator (rng) to be given."
)

# Check if integrand function is provided
if fun is None and fun_evals is None:
raise ValueError(
Expand Down Expand Up @@ -375,7 +378,7 @@ def integrate(
)

info = None
for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun):
for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun, rng):
pass

return bq_state.integral_belief, bq_state, info
21 changes: 18 additions & 3 deletions src/probnum/quad/solvers/policies/_policy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Abstract base class for BQ policies."""

from __future__ import annotations

import abc
from typing import Optional

import numpy as np

from probnum.quad.solvers._bq_state import BQState
from probnum.typing import IntLike

# pylint: disable=too-few-public-methods, fixme

Expand All @@ -18,17 +22,28 @@ class Policy(abc.ABC):
Size of batch of nodes when calling the policy once.
"""

def __init__(self, batch_size: int) -> None:
self.batch_size = batch_size
def __init__(self, batch_size: IntLike) -> None:
self.batch_size = int(batch_size)

@property
@abc.abstractmethod
def __call__(self, bq_state: BQState) -> np.ndarray:
def requires_rng(self) -> bool:
"""Whether the policy requires a random number generator when called."""
raise NotImplementedError

@abc.abstractmethod
def __call__(
self, bq_state: BQState, rng: Optional[np.random.Generator]
) -> np.ndarray:
"""Find nodes according to the policy.
Parameters
----------
bq_state
State of the BQ belief.
rng
A random number generator.
Returns
-------
nodes :
Expand Down
27 changes: 16 additions & 11 deletions src/probnum/quad/solvers/policies/_random_policy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Random policy for Bayesian Monte Carlo."""

from typing import Callable
from __future__ import annotations

from typing import Callable, Optional

import numpy as np

from probnum.quad.solvers._bq_state import BQState
from probnum.typing import IntLike

from ._policy import Policy

Expand All @@ -16,25 +19,27 @@ class RandomPolicy(Policy):
Parameters
----------
batch_size
Size of batch of nodes when calling the policy once.
sample_func
The sample function. Needs to have the following interface:
`sample_func(batch_size: int, rng: np.random.Generator)` and return an array of
shape (batch_size, n_dim).
batch_size
Size of batch of nodes when calling the policy once.
rng
A random number generator.
shape (batch_size, input_dim).
"""

def __init__(
self,
batch_size: IntLike,
sample_func: Callable,
batch_size: int,
rng: np.random.Generator = np.random.default_rng(),
) -> None:
super().__init__(batch_size=batch_size)
self.sample_func = sample_func
self.rng = rng

def __call__(self, bq_state: BQState) -> np.ndarray:
return self.sample_func(self.batch_size, rng=self.rng)
@property
def requires_rng(self) -> bool:
return True

def __call__(
self, bq_state: BQState, rng: Optional[np.random.Generator]
) -> np.ndarray:
return self.sample_func(self.batch_size, rng=rng)
17 changes: 13 additions & 4 deletions src/probnum/quad/solvers/policies/_van_der_corput_policy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Van der Corput points for integration on 1D intervals."""

from __future__ import annotations

from typing import Optional

import numpy as np

from probnum.quad.integration_measures import IntegrationMeasure
from probnum.quad.solvers._bq_state import BQState
from probnum.typing import IntLike

from ._policy import Policy

Expand All @@ -22,17 +25,17 @@ class VanDerCorputPolicy(Policy):
Parameters
----------
measure
The integration measure with finite domain.
batch_size
Size of batch of nodes when calling the policy once.
measure
The integration measure with finite domain.
References
--------
.. [1] https://en.wikipedia.org/wiki/Van_der_Corput_sequence
"""

def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None:
def __init__(self, batch_size: IntLike, measure: IntegrationMeasure) -> None:
super().__init__(batch_size=batch_size)

if int(measure.input_dim) > 1:
Expand All @@ -46,7 +49,13 @@ def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None:
self.domain_a = domain_a
self.domain_b = domain_b

def __call__(self, bq_state: BQState) -> np.ndarray:
@property
def requires_rng(self) -> bool:
return False

def __call__(
self, bq_state: BQState, rng: Optional[np.random.Generator]
) -> np.ndarray:
n_nodes = bq_state.nodes.shape[0]
vdc_seq = VanDerCorputPolicy.van_der_corput_sequence(
n_nodes + 1, n_nodes + 1 + self.batch_size
Expand Down
2 changes: 1 addition & 1 deletion src/probnum/quad/solvers/stopping_criteria/_max_nevals.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MaxNevals(BQStoppingCriterion):
"""

def __init__(self, max_nevals: IntLike):
self.max_nevals = max_nevals
self.max_nevals = int(max_nevals)

def __call__(self, bq_state: BQState, info: BQIterInfo) -> bool:
return info.nevals >= self.max_nevals
Loading

0 comments on commit 8dde541

Please sign in to comment.