From e0a7f9e1adc672d1b23890b1d8481b6ec77acedd Mon Sep 17 00:00:00 2001 From: Reuben Date: Mon, 20 May 2024 09:36:00 -0400 Subject: [PATCH] Update `run_inference_algorithm` to split `initial_position` and `initial_state` (#672) * UPDATE DOCSTRING * ADD STREAMING VERSION * UPDATE TESTS * ADD DOCSTRING * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * ADD INITIAL_POSITION * FIX TEST * RENAME O * FIX DOCSTRING * PUT EXPECTATION AFTER TRANSFORM --- blackjax/util.py | 104 +++++++++++++++++++++------- tests/adaptation/test_adaptation.py | 7 +- tests/mcmc/test_sampling.py | 42 ++++++++--- tests/test_benchmarks.py | 5 +- tests/test_util.py | 63 ++++++++++++++--- 5 files changed, 177 insertions(+), 44 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index df527ed01..02c27e51c 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -2,13 +2,14 @@ from functools import partial from typing import Callable, Union +import jax import jax.numpy as jnp from jax import jit, lax from jax.flatten_util import ravel_pytree from jax.random import normal, split from jax.tree_util import tree_leaves -from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm +from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import progress_bar_scan from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -142,12 +143,15 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: def run_inference_algorithm( rng_key: PRNGKey, - initial_state_or_position: ArrayLikeTree, inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm], num_steps: int, + initial_state: ArrayLikeTree = None, + initial_position: ArrayLikeTree = None, progress_bar: bool = False, transform: Callable = lambda x: x, -) -> tuple[State, State, Info]: + return_state_history=True, + expectation: Callable = lambda x: x, +) -> tuple: """Wrapper to run an inference algorithm. Note that this utility function does not work for Stochastic Gradient MCMC samplers @@ -158,9 +162,10 @@ def run_inference_algorithm( ---------- rng_key The random state used by JAX's random numbers generator. - initial_state_or_position - The initial state OR the initial position of the inference algorithm. If an initial position - is passed in, the function will automatically convert it into an initial state. + initial_state + The initial state of the inference algorithm. + initial_position + The initial position of the inference algorithm. This is used when the initial state is not provided. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps @@ -171,34 +176,85 @@ def run_inference_algorithm( A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. + expectation + A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states. + return_state_history + if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. Returns ------- - Tuple[State, State, Info] - 1. The final state of the inference algorithm. - 2. The trace of states of the inference algorithm (contains the MCMC samples). + If return_state_history is True: + 1. The final state. + 2. The trace of the state. 3. The trace of the info of the inference algorithm for diagnostics. + If return_state_history is False: + 1. This is the expectation of state over the chain. Otherwise the final state. + 2. The final state of the inference algorithm. """ - init_key, sample_key = split(rng_key, 2) - try: - initial_state = inference_algorithm.init(initial_state_or_position, init_key) - except (TypeError, ValueError, AttributeError): - # We assume initial_state is already in the right format. - initial_state = initial_state_or_position - keys = split(sample_key, num_steps) + if initial_state is None and initial_position is None: + raise ValueError("Either initial_state or initial_position must be provided.") + if initial_state is not None and initial_position is not None: + raise ValueError( + "Only one of initial_state or initial_position must be provided." + ) - @jit - def _one_step(state, xs): + rng_key, init_key = split(rng_key, 2) + if initial_position is not None: + initial_state = inference_algorithm.init(initial_position, init_key) + + keys = split(rng_key, num_steps) + + def one_step(average_and_state, xs, return_state): _, rng_key = xs + average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - return state, (transform(state), info) + average = streaming_average(expectation(transform(state)), average) + if return_state: + return (average, state), (transform(state), info) + else: + return (average, state), None + + one_step = jax.jit(partial(one_step, return_state=return_state_history)) if progress_bar: - one_step = progress_bar_scan(num_steps)(_one_step) - else: - one_step = _one_step + one_step = progress_bar_scan(num_steps)(one_step) xs = (jnp.arange(num_steps), keys) - final_state, (state_history, info_history) = lax.scan(one_step, initial_state, xs) - return final_state, state_history, info_history + ((_, average), final_state), history = lax.scan( + one_step, ((0, expectation(transform(initial_state))), initial_state), xs + ) + + if not return_state_history: + return average, transform(final_state) + else: + state_history, info_history = history + return transform(final_state), state_history, info_history + + +def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0.0): + """Compute the streaming average of a function O(x) using a weight. + Parameters: + ---------- + expectation + the value of the expectation at the current timestep + streaming_avg + tuple of (total, average) where total is the sum of weights and average is the current average + weight + weight of the current state + zero_prevention + small value to prevent division by zero + Returns: + ---------- + new streaming average + """ + + flat_expectation, unravel_fn = ravel_pytree(expectation) + total, average = streaming_avg + flat_average, _ = ravel_pytree(average) + average = (total * flat_average + weight * flat_expectation) / ( + total + weight + zero_prevention + ) + total += weight + streaming_avg = (total, unravel_fn(average)) + return streaming_avg diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 8b0f55a7f..68751bee8 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -91,7 +91,12 @@ def test_chees_adaptation(adaptation_filters): chain_keys = jax.random.split(inference_key, num_chains) _, _, infos = jax.vmap( - lambda key, state: run_inference_algorithm(key, state, algorithm, num_results) + lambda key, state: run_inference_algorithm( + rng_key=key, + initial_state=state, + inference_algorithm=algorithm, + num_steps=num_results, + ) )(chain_keys, last_states) harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index e4ac5978d..cccd34c98 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -126,7 +126,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): _, samples, _ = run_inference_algorithm( rng_key=run_key, - initial_state_or_position=blackjax_state_after_tuning, + initial_state=blackjax_state_after_tuning, inference_algorithm=sampling_alg, num_steps=num_steps, transform=lambda x: x.position, @@ -187,7 +187,10 @@ def check_attrs(attribute, keyset): check_attrs(attribute, keysets[i]) _, states, _ = run_inference_algorithm( - inference_key, state, inference_algorithm, case["num_sampling_steps"] + rng_key=inference_key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=case["num_sampling_steps"], ) coefs_samples = states.position["coefs"] @@ -209,7 +212,12 @@ def test_mala(self): mala = blackjax.mala(logposterior_fn, 1e-5) state = mala.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm(inference_key, state, mala, 10_000) + _, states, _ = run_inference_algorithm( + rng_key=inference_key, + initial_state=state, + inference_algorithm=mala, + num_steps=10_000, + ) coefs_samples = states.position["coefs"][3000:] scale_samples = np.exp(states.position["log_scale"][3000:]) @@ -275,7 +283,10 @@ def test_pathfinder_adaptation( inference_algorithm = algorithm(logposterior_fn, **parameters) _, states, _ = run_inference_algorithm( - inference_key, state, inference_algorithm, num_sampling_steps + rng_key=inference_key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=num_sampling_steps, ) coefs_samples = states.position["coefs"] @@ -316,7 +327,10 @@ def test_meads(self): chain_keys = jax.random.split(inference_key, num_chains) _, states, _ = jax.vmap( lambda key, state: run_inference_algorithm( - key, state, inference_algorithm, 100 + rng_key=key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=100, ) )(chain_keys, last_states) @@ -360,7 +374,10 @@ def test_chees(self, jitter_generator): chain_keys = jax.random.split(inference_key, num_chains) _, states, _ = jax.vmap( lambda key, state: run_inference_algorithm( - key, state, inference_algorithm, 100 + rng_key=key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=100, ) )(chain_keys, last_states) @@ -384,7 +401,12 @@ def test_barker(self): barker = blackjax.barker_proposal(logposterior_fn, 1e-1) state = barker.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm(inference_key, state, barker, 10_000) + _, states, _ = run_inference_algorithm( + rng_key=inference_key, + initial_state=state, + inference_algorithm=barker, + num_steps=10_000, + ) coefs_samples = states.position["coefs"][3000:] scale_samples = np.exp(states.position["log_scale"][3000:]) @@ -570,7 +592,7 @@ def test_latent_gaussian(self): inference_algorithm=inference_algorithm, num_steps=self.sampling_steps, ), - )(self.key, initial_state) + )(rng_key=self.key, initial_state=initial_state) np.testing.assert_allclose( np.var(states.position[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2 @@ -614,7 +636,7 @@ def univariate_normal_test_case( inference_algorithm=inference_algorithm, num_steps=num_sampling_steps, ) - )(inference_key, initial_state) + )(rng_key=inference_key, initial_state=initial_state) # else: if postprocess_samples: @@ -885,7 +907,7 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): ) ) _, states, _ = inference_loop_multiple_chains( - multi_chain_sample_key, initial_states + rng_key=multi_chain_sample_key, initial_state=initial_states ) posterior_samples = states.position[:, -1000:] diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index d8f09cea0..c2295e7e2 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -49,7 +49,10 @@ def run_regression(algorithm, **parameters): inference_algorithm = algorithm(logdensity_fn, **parameters) _, states, _ = run_inference_algorithm( - inference_key, state, inference_algorithm, 10_000 + rng_key=inference_key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=10_000, ) return states diff --git a/tests/test_util.py b/tests/test_util.py index a6e023074..1f03498dd 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -19,23 +19,70 @@ def setUp(self): ) self.num_steps = 10 - def check_compatible(self, initial_state_or_position, progress_bar): + def check_compatible(self, initial_state, progress_bar): """ Runs 10 steps with `run_inference_algorithm` starting with - `initial_state_or_position` and potentially a progress bar. + `initial_state` and potentially a progress bar. """ _ = run_inference_algorithm( - self.key, - initial_state_or_position, - self.algorithm, - self.num_steps, - progress_bar, + rng_key=self.key, + initial_state=initial_state, + inference_algorithm=self.algorithm, + num_steps=self.num_steps, + progress_bar=progress_bar, transform=lambda x: x.position, ) + def test_streaming(self): + def logdensity_fn(x): + return -0.5 * jnp.sum(jnp.square(x)) + + initial_position = jnp.ones( + 10, + ) + + init_key, run_key = jax.random.split(self.key, 2) + + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + ) + + alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) + + _, states, info = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=50, + progress_bar=False, + expectation=lambda x: x, + transform=lambda x: x.position, + return_state_history=True, + ) + + average, _ = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=50, + progress_bar=False, + expectation=lambda x: x, + transform=lambda x: x.position, + return_state_history=False, + ) + + assert jnp.allclose(states.mean(axis=0), average) + @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): - self.check_compatible(jnp.array([1.0, 1.0]), progress_bar) + _ = run_inference_algorithm( + rng_key=self.key, + initial_position=jnp.array([1.0, 1.0]), + inference_algorithm=self.algorithm, + num_steps=self.num_steps, + progress_bar=progress_bar, + transform=lambda x: x.position, + ) @parameterized.parameters([True, False]) def test_compatible_with_initial_state(self, progress_bar):