Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre-conditioning matrix to Barker proposal #731

Merged
merged 26 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
68ede50
Draft pre-conditioning matrix in Barker proposal.
ismael-mendoza Sep 1, 2024
36edad3
Fix typing of inverse_mass_matrix argument
ismael-mendoza Sep 1, 2024
e5e9a49
Fix docstrings.
ismael-mendoza Sep 1, 2024
3276892
Make test for Barker in test_sampling run again
ismael-mendoza Sep 1, 2024
7b31071
Add test to ensure correctness of precond matrix
ismael-mendoza Sep 1, 2024
d04b7a2
Fix dimensionality of identity matrix
ismael-mendoza Sep 1, 2024
a61af36
Add missing mass matrix in missing tests.
ismael-mendoza Sep 1, 2024
c934b6b
Merge branch 'main' into barker-inverse-mm2
ismael-mendoza Sep 18, 2024
26ff6f8
Merge branch 'main' into barker-inverse-mm2
ismael-mendoza Sep 24, 2024
776c83a
added option to transpose the matrix when scaling
ismael-mendoza Sep 24, 2024
be32caf
use the metric scaling function in barker
ismael-mendoza Sep 24, 2024
fd35a51
update test_sampling with barker api
ismael-mendoza Sep 24, 2024
8d580ec
update test_barker so it works with metric.scale
ismael-mendoza Sep 24, 2024
a7e1831
fix tests add trans to scale
ismael-mendoza Sep 24, 2024
f33b8a0
add trans argument to riemannian scaling
ismael-mendoza Sep 24, 2024
9122476
no default
ismael-mendoza Sep 24, 2024
ad59aba
Update barker.py
AdrienCorenflos Oct 2, 2024
6e50160
Update test_barker.py
AdrienCorenflos Oct 2, 2024
bd6ba3d
simplify logic to remove _barker_sample_nd
ismael-mendoza Oct 2, 2024
02e9cb9
fix bug so now everything is tree_mapped in barker
ismael-mendoza Oct 2, 2024
230ff26
fix test to not use _barker_sample_nd
ismael-mendoza Oct 2, 2024
d0a7066
Update blackjax/mcmc/metrics.py
ismael-mendoza Oct 4, 2024
eb9acbf
Update blackjax/mcmc/metrics.py
ismael-mendoza Oct 4, 2024
fa2e70b
propagate changes of inv, trans as required kwarg
ismael-mendoza Oct 4, 2024
ef4b434
fix test metrics
ismael-mendoza Oct 4, 2024
6d9c02d
Merge branch 'main' into barker-inverse-mm2
junpenglao Oct 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 79 additions & 67 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from jax.scipy import stats
from jax.tree_util import tree_leaves, tree_map

import blackjax.mcmc.metrics as metrics
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.metrics import Metric
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.types import ArrayLikeTree, ArrayTree, Numeric, PRNGKey
from blackjax.util import generate_gaussian_noise

__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"]

Expand Down Expand Up @@ -81,44 +83,70 @@ def build_kernel():
"""

def _compute_acceptance_probability(
state: BarkerState,
proposal: BarkerState,
) -> float:
state: BarkerState, proposal: BarkerState, metric: Metric
) -> Numeric:
"""Compute the acceptance probability of the Barker's proposal kernel."""

def ratio_proposal_nd(y, x, log_y, log_x):
num = -_log1pexp(-log_y * (x - y))
den = -_log1pexp(-log_x * (y - x))
x = state.position
y = proposal.position
log_x = state.logdensity_grad
log_y = proposal.logdensity_grad

return jnp.sum(num - den)
y_minus_x = jax.tree_util.tree_map(lambda a, b: a - b, y, x)
x_minus_y = jax.tree_util.tree_map(lambda a: -a, y_minus_x)
z_tilde_x_to_y = metric.scale(x, y_minus_x, inv=True, trans=True)
z_tilde_y_to_x = metric.scale(y, x_minus_y, inv=True, trans=True)

ratios_proposals = tree_map(
ratio_proposal_nd,
proposal.position,
state.position,
proposal.logdensity_grad,
state.logdensity_grad,
c_x_to_y = metric.scale(x, log_x, inv=False, trans=True)
c_y_to_x = metric.scale(y, log_y, inv=False, trans=True)

z_tilde_x_to_y_flat, _ = ravel_pytree(z_tilde_x_to_y)
z_tilde_y_to_x_flat, _ = ravel_pytree(z_tilde_y_to_x)

c_x_to_y_flat, _ = ravel_pytree(c_x_to_y)
c_y_to_x_flat, _ = ravel_pytree(c_y_to_x)

num = metric.kinetic_energy(x_minus_y, y) - _log1pexp(
-z_tilde_y_to_x_flat * c_y_to_x_flat
)
ratio_proposal = sum(tree_leaves(ratios_proposals))
denom = metric.kinetic_energy(y_minus_x, x) - _log1pexp(
-z_tilde_x_to_y_flat * c_x_to_y_flat
)

ratio_proposal = jnp.sum(num - denom)

return proposal.logdensity - state.logdensity + ratio_proposal

def kernel(
rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float
rng_key: PRNGKey,
state: BarkerState,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: metrics.MetricTypes | None = None,
) -> tuple[BarkerState, BarkerInfo]:
"""Generate a new sample with the MALA kernel."""
"""Generate a new sample with the Barker kernel."""
if inverse_mass_matrix is None:
p, _ = ravel_pytree(state.position)
(m,) = p.shape
inverse_mass_matrix = jnp.ones((m,))
metric = metrics.default_metric(inverse_mass_matrix)
grad_fn = jax.value_and_grad(logdensity_fn)

key_sample, key_rmh = jax.random.split(rng_key)

proposed_pos = _barker_sample(
key_sample, state.position, state.logdensity_grad, step_size
key_sample,
state.position,
state.logdensity_grad,
step_size,
metric,
)

proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos)
proposed_state = BarkerState(
proposed_pos, proposed_logdensity, proposed_logdensity_grad
)

log_p_accept = _compute_acceptance_probability(state, proposed_state)
log_p_accept = _compute_acceptance_probability(state, proposed_state, metric)
accepted_state, info = static_binomial_sampling(
key_rmh, log_p_accept, state, proposed_state
)
Expand All @@ -131,6 +159,7 @@ def kernel(
def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: metrics.MetricTypes | None = None,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a
Gaussian base kernel.
Expand Down Expand Up @@ -174,7 +203,9 @@ def as_top_level_api(
logdensity_fn
The log-density function we wish to draw samples from.
step_size
The value to use for the step size in the symplectic integrator.
The value of the step_size correspnoding to the global scale of the proposal distribution.
inverse_mass_matrix
The inverse mass matrix to use for pre-conditioning (see Appendix G of :cite:p:`Livingstone2022Barker`).

Returns
-------
Expand All @@ -189,74 +220,55 @@ def init_fn(position: ArrayLikeTree, rng_key=None):
return init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, step_size)
return kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix)

return SamplingAlgorithm(init_fn, step_fn)


def _barker_sample_nd(key, mean, a, scale):
"""
Sample from a multivariate Barker's proposal distribution. In 1D, this has the following probability density function:

.. math::
p(x; \\mu, a, \\sigma) = 2 \frac{N(x; \\mu, \\sigma^2)}{1 + \\exp(-a (x - \\mu)}
def _generate_bernoulli(
rng_key: PRNGKey, position: ArrayLikeTree, p: ArrayLikeTree
) -> ArrayTree:
pos, unravel_fn = ravel_pytree(position)
p_flat, _ = ravel_pytree(p)
sample = jax.random.bernoulli(rng_key, p=p_flat, shape=pos.shape)
return unravel_fn(sample)

where :math:`N(x; \\mu, \\sigma^2)` is the normal distribution with mean :math:`\\mu` and standard deviation :math:`\\sigma`.
The multivariate Barker's proposal distribution is the product of one-dimensional Barker's proposal distributions.

def _barker_sample(key, mean, a, scale, metric):
r"""
Sample from a multivariate Barker's proposal distribution for PyTrees.

Parameters
----------
key
A PRNG key.
mean
The mean of the normal distribution, an Array. This corresponds to :math:`\\mu` in the equation above.
The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above.
a
The parameter :math:`a` in the equation above, an Array. This is a skewness parameter.
The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter.
scale
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\\sigma` in the equation above.
The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above.
It encodes the step size of the proposal.

Returns
-------
A sample from the Barker's multidimensional proposal distribution.

metric
A `metrics.MetricTypes` object encoding the mass matrix information.
"""

key1, key2 = jax.random.split(key)
z = scale * jax.random.normal(key1, shape=mean.shape)

z = generate_gaussian_noise(key1, mean, sigma=scale)
c = metric.scale(mean, a, inv=False, trans=True)

# Sample b=1 with probability p and 0 with probability 1 - p where
# p = 1 / (1 + exp(-a * (z - mean)))
log_p = -_log1pexp(-a * z)
b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape)

# return mean + z if b == 1 else mean - z
return mean + b * z - (1 - b) * z

log_p = jax.tree_util.tree_map(lambda x, y: -_log1pexp(-x * y), c, z)
p = jax.tree_util.tree_map(lambda x: jnp.exp(x), log_p)
b = _generate_bernoulli(key2, mean, p=p)

def _barker_sample(key, mean, a, scale):
r"""
Sample from a multivariate Barker's proposal distribution for PyTrees.

Parameters
----------
key
A PRNG key.
mean
The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above.
a
The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter.
scale
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\sigma` in the equation above.
It encodes the step size of the proposal.

"""
bz = jax.tree_util.tree_map(lambda x, y: x * y - (1 - x) * y, b, z)

flat_mean, unravel_fn = ravel_pytree(mean)
flat_a, _ = ravel_pytree(a)
flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale)
return unravel_fn(flat_sample)
return jax.tree_util.tree_map(
lambda a, b: a + b, mean, metric.scale(mean, bz, inv=False, trans=False)
)


def _log1pexp(a):
Expand Down
55 changes: 39 additions & 16 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"""
from typing import Callable, NamedTuple, Optional, Protocol, Union

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -62,7 +61,12 @@ def __call__(

class Scale(Protocol):
def __call__(
self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
self,
position: ArrayLikeTree,
element: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
...

Expand Down Expand Up @@ -187,7 +191,11 @@ def is_turning(
return turning_at_left | turning_at_right

def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
position: ArrayLikeTree,
element: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.

Expand All @@ -197,10 +205,11 @@ def scale(
The current position. Not used in this metric.
elements
Elements to scale
invs
inv
Whether to scale the elements by the inverse mass matrix or the mass matrix.
If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem.
Same pytree structure as `elements`.
trans
whether to transpose mass matrix when scaling

Returns
-------
Expand All @@ -209,11 +218,16 @@ def scale(
"""

ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)

if inv:
left_hand_side_matrix = inv_mass_matrix_sqrt
else:
left_hand_side_matrix = mass_matrix_sqrt
if trans:
left_hand_side_matrix = left_hand_side_matrix.T

scaled = linear_map(left_hand_side_matrix, ravelled_element)

return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)
Expand Down Expand Up @@ -279,7 +293,11 @@ def is_turning(
# return turning_at_left | turning_at_right

def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
position: ArrayLikeTree,
element: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.

Expand All @@ -298,11 +316,16 @@ def scale(
mass_matrix, is_inv=False
)
ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)

if inv:
left_hand_side_matrix = inv_mass_matrix_sqrt
else:
left_hand_side_matrix = mass_matrix_sqrt
if trans:
left_hand_side_matrix = left_hand_side_matrix.T

scaled = linear_map(left_hand_side_matrix, ravelled_element)

return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)
Expand Down
Loading
Loading