Skip to content

Commit

Permalink
Update run_inference_algorithm to split initial_position and `ini…
Browse files Browse the repository at this point in the history
…tial_state` (#672)

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* ADD INITIAL_POSITION

* FIX TEST

* RENAME O

* FIX DOCSTRING

* PUT EXPECTATION AFTER TRANSFORM
  • Loading branch information
reubenharry authored May 20, 2024
1 parent cd91e41 commit e0a7f9e
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 44 deletions.
104 changes: 80 additions & 24 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from functools import partial
from typing import Callable, Union

import jax
import jax.numpy as jnp
from jax import jit, lax
from jax.flatten_util import ravel_pytree
from jax.random import normal, split
from jax.tree_util import tree_leaves

from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm
from blackjax.base import SamplingAlgorithm, VIAlgorithm
from blackjax.progress_bar import progress_bar_scan
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

Expand Down Expand Up @@ -142,12 +143,15 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree:

def run_inference_algorithm(
rng_key: PRNGKey,
initial_state_or_position: ArrayLikeTree,
inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm],
num_steps: int,
initial_state: ArrayLikeTree = None,
initial_position: ArrayLikeTree = None,
progress_bar: bool = False,
transform: Callable = lambda x: x,
) -> tuple[State, State, Info]:
return_state_history=True,
expectation: Callable = lambda x: x,
) -> tuple:
"""Wrapper to run an inference algorithm.
Note that this utility function does not work for Stochastic Gradient MCMC samplers
Expand All @@ -158,9 +162,10 @@ def run_inference_algorithm(
----------
rng_key
The random state used by JAX's random numbers generator.
initial_state_or_position
The initial state OR the initial position of the inference algorithm. If an initial position
is passed in, the function will automatically convert it into an initial state.
initial_state
The initial state of the inference algorithm.
initial_position
The initial position of the inference algorithm. This is used when the initial state is not provided.
inference_algorithm
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps
Expand All @@ -171,34 +176,85 @@ def run_inference_algorithm(
A transformation of the trace of states to be returned. This is useful for
computing determinstic variables, or returning a subset of the states.
By default, the states are returned as is.
expectation
A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states.
return_state_history
if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck.
Returns
-------
Tuple[State, State, Info]
1. The final state of the inference algorithm.
2. The trace of states of the inference algorithm (contains the MCMC samples).
If return_state_history is True:
1. The final state.
2. The trace of the state.
3. The trace of the info of the inference algorithm for diagnostics.
If return_state_history is False:
1. This is the expectation of state over the chain. Otherwise the final state.
2. The final state of the inference algorithm.
"""
init_key, sample_key = split(rng_key, 2)
try:
initial_state = inference_algorithm.init(initial_state_or_position, init_key)
except (TypeError, ValueError, AttributeError):
# We assume initial_state is already in the right format.
initial_state = initial_state_or_position

keys = split(sample_key, num_steps)
if initial_state is None and initial_position is None:
raise ValueError("Either initial_state or initial_position must be provided.")
if initial_state is not None and initial_position is not None:
raise ValueError(
"Only one of initial_state or initial_position must be provided."
)

@jit
def _one_step(state, xs):
rng_key, init_key = split(rng_key, 2)
if initial_position is not None:
initial_state = inference_algorithm.init(initial_position, init_key)

keys = split(rng_key, num_steps)

def one_step(average_and_state, xs, return_state):
_, rng_key = xs
average, state = average_and_state
state, info = inference_algorithm.step(rng_key, state)
return state, (transform(state), info)
average = streaming_average(expectation(transform(state)), average)
if return_state:
return (average, state), (transform(state), info)
else:
return (average, state), None

one_step = jax.jit(partial(one_step, return_state=return_state_history))

if progress_bar:
one_step = progress_bar_scan(num_steps)(_one_step)
else:
one_step = _one_step
one_step = progress_bar_scan(num_steps)(one_step)

xs = (jnp.arange(num_steps), keys)
final_state, (state_history, info_history) = lax.scan(one_step, initial_state, xs)
return final_state, state_history, info_history
((_, average), final_state), history = lax.scan(
one_step, ((0, expectation(transform(initial_state))), initial_state), xs
)

if not return_state_history:
return average, transform(final_state)
else:
state_history, info_history = history
return transform(final_state), state_history, info_history


def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0.0):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
expectation
the value of the expectation at the current timestep
streaming_avg
tuple of (total, average) where total is the sum of weights and average is the current average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new streaming average
"""

flat_expectation, unravel_fn = ravel_pytree(expectation)
total, average = streaming_avg
flat_average, _ = ravel_pytree(average)
average = (total * flat_average + weight * flat_expectation) / (
total + weight + zero_prevention
)
total += weight
streaming_avg = (total, unravel_fn(average))
return streaming_avg
7 changes: 6 additions & 1 deletion tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ def test_chees_adaptation(adaptation_filters):

chain_keys = jax.random.split(inference_key, num_chains)
_, _, infos = jax.vmap(
lambda key, state: run_inference_algorithm(key, state, algorithm, num_results)
lambda key, state: run_inference_algorithm(
rng_key=key,
initial_state=state,
inference_algorithm=algorithm,
num_steps=num_results,
)
)(chain_keys, last_states)

harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)
Expand Down
42 changes: 32 additions & 10 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key):

_, samples, _ = run_inference_algorithm(
rng_key=run_key,
initial_state_or_position=blackjax_state_after_tuning,
initial_state=blackjax_state_after_tuning,
inference_algorithm=sampling_alg,
num_steps=num_steps,
transform=lambda x: x.position,
Expand Down Expand Up @@ -187,7 +187,10 @@ def check_attrs(attribute, keyset):
check_attrs(attribute, keysets[i])

_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, case["num_sampling_steps"]
rng_key=inference_key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=case["num_sampling_steps"],
)

coefs_samples = states.position["coefs"]
Expand All @@ -209,7 +212,12 @@ def test_mala(self):

mala = blackjax.mala(logposterior_fn, 1e-5)
state = mala.init({"coefs": 1.0, "log_scale": 1.0})
_, states, _ = run_inference_algorithm(inference_key, state, mala, 10_000)
_, states, _ = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=mala,
num_steps=10_000,
)

coefs_samples = states.position["coefs"][3000:]
scale_samples = np.exp(states.position["log_scale"][3000:])
Expand Down Expand Up @@ -275,7 +283,10 @@ def test_pathfinder_adaptation(
inference_algorithm = algorithm(logposterior_fn, **parameters)

_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, num_sampling_steps
rng_key=inference_key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=num_sampling_steps,
)

coefs_samples = states.position["coefs"]
Expand Down Expand Up @@ -316,7 +327,10 @@ def test_meads(self):
chain_keys = jax.random.split(inference_key, num_chains)
_, states, _ = jax.vmap(
lambda key, state: run_inference_algorithm(
key, state, inference_algorithm, 100
rng_key=key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=100,
)
)(chain_keys, last_states)

Expand Down Expand Up @@ -360,7 +374,10 @@ def test_chees(self, jitter_generator):
chain_keys = jax.random.split(inference_key, num_chains)
_, states, _ = jax.vmap(
lambda key, state: run_inference_algorithm(
key, state, inference_algorithm, 100
rng_key=key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=100,
)
)(chain_keys, last_states)

Expand All @@ -384,7 +401,12 @@ def test_barker(self):
barker = blackjax.barker_proposal(logposterior_fn, 1e-1)
state = barker.init({"coefs": 1.0, "log_scale": 1.0})

_, states, _ = run_inference_algorithm(inference_key, state, barker, 10_000)
_, states, _ = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=barker,
num_steps=10_000,
)

coefs_samples = states.position["coefs"][3000:]
scale_samples = np.exp(states.position["log_scale"][3000:])
Expand Down Expand Up @@ -570,7 +592,7 @@ def test_latent_gaussian(self):
inference_algorithm=inference_algorithm,
num_steps=self.sampling_steps,
),
)(self.key, initial_state)
)(rng_key=self.key, initial_state=initial_state)

np.testing.assert_allclose(
np.var(states.position[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2
Expand Down Expand Up @@ -614,7 +636,7 @@ def univariate_normal_test_case(
inference_algorithm=inference_algorithm,
num_steps=num_sampling_steps,
)
)(inference_key, initial_state)
)(rng_key=inference_key, initial_state=initial_state)

# else:
if postprocess_samples:
Expand Down Expand Up @@ -885,7 +907,7 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal):
)
)
_, states, _ = inference_loop_multiple_chains(
multi_chain_sample_key, initial_states
rng_key=multi_chain_sample_key, initial_state=initial_states
)

posterior_samples = states.position[:, -1000:]
Expand Down
5 changes: 4 additions & 1 deletion tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def run_regression(algorithm, **parameters):
inference_algorithm = algorithm(logdensity_fn, **parameters)

_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, 10_000
rng_key=inference_key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=10_000,
)

return states
Expand Down
63 changes: 55 additions & 8 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,70 @@ def setUp(self):
)
self.num_steps = 10

def check_compatible(self, initial_state_or_position, progress_bar):
def check_compatible(self, initial_state, progress_bar):
"""
Runs 10 steps with `run_inference_algorithm` starting with
`initial_state_or_position` and potentially a progress bar.
`initial_state` and potentially a progress bar.
"""
_ = run_inference_algorithm(
self.key,
initial_state_or_position,
self.algorithm,
self.num_steps,
progress_bar,
rng_key=self.key,
initial_state=initial_state,
inference_algorithm=self.algorithm,
num_steps=self.num_steps,
progress_bar=progress_bar,
transform=lambda x: x.position,
)

def test_streaming(self):
def logdensity_fn(x):
return -0.5 * jnp.sum(jnp.square(x))

initial_position = jnp.ones(
10,
)

init_key, run_key = jax.random.split(self.key, 2)

initial_state = blackjax.mcmc.mclmc.init(
position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key
)

alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1)

_, states, info = run_inference_algorithm(
rng_key=run_key,
initial_state=initial_state,
inference_algorithm=alg,
num_steps=50,
progress_bar=False,
expectation=lambda x: x,
transform=lambda x: x.position,
return_state_history=True,
)

average, _ = run_inference_algorithm(
rng_key=run_key,
initial_state=initial_state,
inference_algorithm=alg,
num_steps=50,
progress_bar=False,
expectation=lambda x: x,
transform=lambda x: x.position,
return_state_history=False,
)

assert jnp.allclose(states.mean(axis=0), average)

@parameterized.parameters([True, False])
def test_compatible_with_initial_pos(self, progress_bar):
self.check_compatible(jnp.array([1.0, 1.0]), progress_bar)
_ = run_inference_algorithm(
rng_key=self.key,
initial_position=jnp.array([1.0, 1.0]),
inference_algorithm=self.algorithm,
num_steps=self.num_steps,
progress_bar=progress_bar,
transform=lambda x: x.position,
)

@parameterized.parameters([True, False])
def test_compatible_with_initial_state(self, progress_bar):
Expand Down

0 comments on commit e0a7f9e

Please sign in to comment.