Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make quad policies stateless #744

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/probnum/quad/_bayesquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,13 @@ def bayesquad(
var_tol=var_tol,
rel_tol=rel_tol,
batch_size=batch_size,
rng=rng,
jitter=jitter,
)

# Integrate
integral_belief, _, info = bq_method.integrate(fun=fun, nodes=None, fun_evals=None)
integral_belief, _, info = bq_method.integrate(
fun=fun, nodes=None, fun_evals=None, rng=rng
)

return integral_belief, info

Expand Down Expand Up @@ -261,7 +262,7 @@ def bayesquad_from_data(

# Integrate
integral_belief, _, info = bq_method.integrate(
fun=None, nodes=nodes, fun_evals=fun_evals
fun=None, nodes=nodes, fun_evals=fun_evals, rng=None
)

return integral_belief, info
Expand Down
32 changes: 15 additions & 17 deletions src/probnum/quad/solvers/_bayesian_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Callable, Optional, Tuple
from typing import Callable, Optional, Tuple, Union, get_args
import warnings

import numpy as np
Expand Down Expand Up @@ -83,7 +83,6 @@ def from_problem(
var_tol: Optional[FloatLike] = None,
rel_tol: Optional[FloatLike] = None,
batch_size: IntLike = 1,
rng: np.random.Generator = None,
jitter: FloatLike = 1.0e-8,
) -> "BayesianQuadrature":

Expand Down Expand Up @@ -112,8 +111,6 @@ def from_problem(
Relative tolerance as stopping criterion.
batch_size
Batch size used in node acquisition. Defaults to 1.
rng
The random number generator.
jitter
Non-negative jitter to numerically stabilise kernel matrix inversion.
Defaults to 1e-8.
Expand All @@ -127,9 +124,6 @@ def from_problem(
------
ValueError
If neither a ``domain`` nor a ``measure`` are given.
ValueError
If Bayesian Monte Carlo ('bmc') is selected as ``policy`` and no random
number generator (``rng``) is given.
NotImplementedError
If an unknown ``policy`` is given.
"""
Expand All @@ -153,15 +147,9 @@ def from_problem(
# require an acquisition loop. The error handling is done in ``integrate``.
pass
elif policy == "bmc":
if rng is None:
errormsg = (
"Policy 'bmc' relies on random sampling, "
"thus requires a random number generator ('rng')."
)
raise ValueError(errormsg)
policy = RandomPolicy(measure.sample, batch_size=batch_size, rng=rng)
policy = RandomPolicy(batch_size, measure.sample)
elif policy == "vdc":
policy = VanDerCorputPolicy(measure=measure, batch_size=batch_size)
policy = VanDerCorputPolicy(batch_size, measure)
else:
raise NotImplementedError(f"The given policy ({policy}) is unknown.")

Expand Down Expand Up @@ -215,6 +203,7 @@ def bq_iterator(
bq_state: BQState,
info: Optional[BQIterInfo],
fun: Optional[Callable],
rng: np.random.Generator,
) -> Tuple[Normal, BQState, BQIterInfo]:
"""Generator that implements the iteration of the BQ method.

Expand All @@ -231,6 +220,8 @@ def bq_iterator(
fun
Function to be integrated. It needs to accept a shape=(n_eval, input_dim)
``np.ndarray`` and return a shape=(n_eval,) ``np.ndarray``.
rng
The random number generator used for random methods.

Yields
------
Expand Down Expand Up @@ -258,7 +249,7 @@ def bq_iterator(
break

# Select new nodes via policy
new_nodes = self.policy(bq_state=bq_state)
new_nodes = self.policy(bq_state, rng)

# Evaluate the integrand at new nodes
new_fun_evals = fun(new_nodes)
Expand All @@ -278,6 +269,7 @@ def integrate(
fun: Optional[Callable],
nodes: Optional[np.ndarray],
fun_evals: Optional[np.ndarray],
rng: Union[IntLike, np.random.Generator] = np.random.default_rng(),
mmahsereci marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[Normal, BQState, BQIterInfo]:
"""Integrates the function ``fun``.

Expand All @@ -297,6 +289,8 @@ def integrate(
fun_evals
*shape=(n_eval,)* -- Optional function evaluations at ``nodes`` available
from the start.
rng
The random number generator used for random methods, or a seed.

Returns
-------
Expand All @@ -316,6 +310,10 @@ def integrate(
If dimension of ``nodes`` or ``fun_evals`` is incorrect, or if their
shapes do not match.
"""
# Get the rng
if isinstance(rng, get_args(IntLike)):
rng = np.random.default_rng(int(rng))

mmahsereci marked this conversation as resolved.
Show resolved Hide resolved
# no policy given: Integrate on fixed dataset.
if self.policy is None:
# nodes must be provided if no policy is given.
Expand Down Expand Up @@ -375,7 +373,7 @@ def integrate(
)

info = None
for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun):
for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun, rng):
pass

return bq_state.integral_belief, bq_state, info
15 changes: 12 additions & 3 deletions src/probnum/quad/solvers/policies/_policy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Abstract base class for BQ policies."""

from __future__ import annotations

import abc
from typing import Optional

import numpy as np

from probnum.quad.solvers._bq_state import BQState
from probnum.typing import IntLike

# pylint: disable=too-few-public-methods, fixme

Expand All @@ -18,17 +22,22 @@ class Policy(abc.ABC):
Size of batch of nodes when calling the policy once.
"""

def __init__(self, batch_size: int) -> None:
self.batch_size = batch_size
def __init__(self, batch_size: IntLike) -> None:
self.batch_size = int(batch_size)

@abc.abstractmethod
def __call__(self, bq_state: BQState) -> np.ndarray:
def __call__(
self, bq_state: BQState, rng: Optional[np.random.Generator]
) -> np.ndarray:
"""Find nodes according to the policy.

Parameters
----------
bq_state
State of the BQ belief.
rng
A random number generator.

Returns
-------
nodes :
Expand Down
23 changes: 12 additions & 11 deletions src/probnum/quad/solvers/policies/_random_policy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Random policy for Bayesian Monte Carlo."""

from typing import Callable
from __future__ import annotations

from typing import Callable, Optional

import numpy as np

from probnum.quad.solvers._bq_state import BQState
from probnum.typing import IntLike

from ._policy import Policy

Expand All @@ -16,25 +19,23 @@ class RandomPolicy(Policy):

Parameters
----------
batch_size
Size of batch of nodes when calling the policy once.
sample_func
The sample function. Needs to have the following interface:
`sample_func(batch_size: int, rng: np.random.Generator)` and return an array of
shape (batch_size, n_dim).
batch_size
Size of batch of nodes when calling the policy once.
rng
A random number generator.
shape (batch_size, input_dim).
"""

def __init__(
self,
batch_size: IntLike,
sample_func: Callable,
batch_size: int,
rng: np.random.Generator = np.random.default_rng(),
) -> None:
super().__init__(batch_size=batch_size)
self.sample_func = sample_func
self.rng = rng

def __call__(self, bq_state: BQState) -> np.ndarray:
return self.sample_func(self.batch_size, rng=self.rng)
def __call__(
self, bq_state: BQState, rng: Optional[np.random.Generator]
) -> np.ndarray:
return self.sample_func(self.batch_size, rng=rng)
13 changes: 9 additions & 4 deletions src/probnum/quad/solvers/policies/_van_der_corput_policy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Van der Corput points for integration on 1D intervals."""

from __future__ import annotations

from typing import Optional

import numpy as np

from probnum.quad.integration_measures import IntegrationMeasure
from probnum.quad.solvers._bq_state import BQState
from probnum.typing import IntLike

from ._policy import Policy

Expand All @@ -22,17 +25,17 @@ class VanDerCorputPolicy(Policy):

Parameters
----------
measure
The integration measure with finite domain.
batch_size
Size of batch of nodes when calling the policy once.
measure
The integration measure with finite domain.

References
--------
.. [1] https://en.wikipedia.org/wiki/Van_der_Corput_sequence
"""

def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None:
def __init__(self, batch_size: IntLike, measure: IntegrationMeasure) -> None:
super().__init__(batch_size=batch_size)

if int(measure.input_dim) > 1:
Expand All @@ -46,7 +49,9 @@ def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None:
self.domain_a = domain_a
self.domain_b = domain_b

def __call__(self, bq_state: BQState) -> np.ndarray:
def __call__(
self, bq_state: BQState, rng: Optional[np.random.Generator]
) -> np.ndarray:
n_nodes = bq_state.nodes.shape[0]
vdc_seq = VanDerCorputPolicy.van_der_corput_sequence(
n_nodes + 1, n_nodes + 1 + self.batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MaxNevals(BQStoppingCriterion):
"""

def __init__(self, max_nevals: IntLike):
self.max_nevals = max_nevals
self.max_nevals = int(max_nevals)

def __call__(self, bq_state: BQState, info: BQIterInfo) -> bool:
return info.nevals >= self.max_nevals
43 changes: 38 additions & 5 deletions tests/test_quad/test_bayesian_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from probnum.quad.integration_measures import LebesgueMeasure
from probnum.quad.solvers import BayesianQuadrature
from probnum.quad.solvers.policies import RandomPolicy, VanDerCorputPolicy
from probnum.quad.solvers.stopping_criteria import ImmediateStop
from probnum.quad.solvers.stopping_criteria import (
ImmediateStop,
IntegralVarianceTolerance,
MaxNevals,
RelativeMeanChange,
)
from probnum.randprocs.kernels import ExpQuad


Expand All @@ -31,7 +36,6 @@ def bq(input_dim):
return BayesianQuadrature.from_problem(
input_dim=input_dim,
domain=(np.zeros(input_dim), np.ones(input_dim)),
rng=np.random.default_rng(),
)


Expand All @@ -56,9 +60,7 @@ def test_bq_from_problem_wrong_inputs(input_dim):
)
def test_bq_from_problem_policy_assignment(policy, policy_type):
"""Test if correct policy is assigned from string identifier."""
bq = BayesianQuadrature.from_problem(
input_dim=1, domain=(0, 1), policy=policy, rng=np.random.default_rng()
)
bq = BayesianQuadrature.from_problem(input_dim=1, domain=(0, 1), policy=policy)
assert isinstance(bq.policy, policy_type)


Expand All @@ -81,6 +83,30 @@ def test_bq_from_problem_defaults(bq_no_policy, bq):
assert isinstance(bq.kernel, ExpQuad)


@pytest.mark.parametrize(
"max_evals, var_tol, rel_tol, t",
[
(None, None, None, LambdaStoppingCriterion),
(1000, None, None, MaxNevals),
(None, 1e-5, None, IntegralVarianceTolerance),
(None, None, 1e-5, RelativeMeanChange),
(None, 1e-5, 1e-5, LambdaStoppingCriterion),
(1000, None, 1e-5, LambdaStoppingCriterion),
(1000, 1e-5, None, LambdaStoppingCriterion),
(1000, 1e-5, 1e-5, LambdaStoppingCriterion),
],
)
def test_bq_from_problem_stopping_condition_assignment(max_evals, var_tol, rel_tol, t):
mmahsereci marked this conversation as resolved.
Show resolved Hide resolved
bq = BayesianQuadrature.from_problem(
input_dim=2,
domain=(0, 1),
max_evals=max_evals,
var_tol=var_tol,
rel_tol=rel_tol,
)
assert isinstance(bq.stopping_criterion, t)


def test_integrate_no_policy_wrong_input(bq_no_policy, data):
nodes, fun_evals, fun = data

Expand Down Expand Up @@ -120,3 +146,10 @@ def test_integrate_wrong_input(bq, bq_no_policy, data):
bq.integrate(fun=fun, nodes=wrong_nodes, fun_evals=fun_evals)
with pytest.raises(ValueError):
bq_no_policy.integrate(fun=None, nodes=wrong_nodes, fun_evals=fun_evals)


@pytest.mark.parametrize("rng", [np.random.default_rng(42), 42])
def test_integrate_runs_with_integer_rng(bq, data, rng):
# make sure integrate runs with both a rn generator and a seed.
mmahsereci marked this conversation as resolved.
Show resolved Hide resolved
nodes, fun_evals, fun = data
bq.integrate(fun=fun, nodes=None, fun_evals=None, rng=rng)
12 changes: 9 additions & 3 deletions tests/test_quad/test_bayesquad/test_bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ def rng():


@pytest.mark.parametrize("input_dim", [1], ids=["dim1"])
def test_type_1d(f1d, kernel, measure, input_dim):
def test_type_1d(f1d, kernel, measure, input_dim, rng):
"""Test that BQ outputs normal random variables for 1D integrands."""
integral, _ = bayesquad(
fun=f1d, input_dim=input_dim, kernel=kernel, measure=measure, max_evals=10
fun=f1d,
input_dim=input_dim,
kernel=kernel,
measure=measure,
max_evals=10,
rng=rng,
)
assert isinstance(integral, Normal)

Expand All @@ -43,7 +48,7 @@ def test_type_1d(f1d, kernel, measure, input_dim):
@pytest.mark.parametrize("scale_estimation", [None, "mle"])
@pytest.mark.parametrize("jitter", [1e-6, 1e-7])
def test_integral_values_1d(
f1d, kernel, domain, input_dim, scale_estimation, var_tol, rel_tol, jitter
f1d, kernel, domain, input_dim, scale_estimation, var_tol, rel_tol, jitter, rng
):
"""Test numerically that BQ computes 1D integrals correctly for a number of
different parameters.
Expand All @@ -70,6 +75,7 @@ def integrand(x):
var_tol=var_tol,
rel_tol=rel_tol,
jitter=jitter,
rng=rng,
)
domain = measure.domain
num_integral, _ = scipyquad(integrand, domain[0], domain[1])
Expand Down
Loading