diff --git a/_images/f7fd34fb893d79efc3e37fb858a4589d8551017b468c16da4ac5f51a3ad224d1.png b/_images/007875a6641d5f45dcd208adcadda104db098d0a2c3a6adc8ab9a549a3ea8194.png similarity index 98% rename from _images/f7fd34fb893d79efc3e37fb858a4589d8551017b468c16da4ac5f51a3ad224d1.png rename to _images/007875a6641d5f45dcd208adcadda104db098d0a2c3a6adc8ab9a549a3ea8194.png index 0cf1967..05e029a 100644 Binary files a/_images/f7fd34fb893d79efc3e37fb858a4589d8551017b468c16da4ac5f51a3ad224d1.png and b/_images/007875a6641d5f45dcd208adcadda104db098d0a2c3a6adc8ab9a549a3ea8194.png differ diff --git a/_images/0a70760fa6e9e436f76686fc34a0a16b4091cbcd9d13b793090929596eff7b8e.png b/_images/0a70760fa6e9e436f76686fc34a0a16b4091cbcd9d13b793090929596eff7b8e.png deleted file mode 100644 index 9f3f2ad..0000000 Binary files a/_images/0a70760fa6e9e436f76686fc34a0a16b4091cbcd9d13b793090929596eff7b8e.png and /dev/null differ diff --git a/_images/0ec1b4434d7309896a76cee60b2f67a770c5ddaeb10603670ec5fa8f8e14d4c1.png b/_images/0ec1b4434d7309896a76cee60b2f67a770c5ddaeb10603670ec5fa8f8e14d4c1.png deleted file mode 100644 index fc97b0c..0000000 Binary files a/_images/0ec1b4434d7309896a76cee60b2f67a770c5ddaeb10603670ec5fa8f8e14d4c1.png and /dev/null differ diff --git a/_images/38d0c9943edeadd2477f421372f67a01eb1b0f4f12fc3a057b3f5a05794aac96.png b/_images/38d0c9943edeadd2477f421372f67a01eb1b0f4f12fc3a057b3f5a05794aac96.png deleted file mode 100644 index 60453d9..0000000 Binary files a/_images/38d0c9943edeadd2477f421372f67a01eb1b0f4f12fc3a057b3f5a05794aac96.png and /dev/null differ diff --git a/_images/50b50a28029252f98e42ed1b367172bcd81c5ebe172126bb0c87553f87976245.png b/_images/50b50a28029252f98e42ed1b367172bcd81c5ebe172126bb0c87553f87976245.png deleted file mode 100644 index 323e87e..0000000 Binary files a/_images/50b50a28029252f98e42ed1b367172bcd81c5ebe172126bb0c87553f87976245.png and /dev/null differ diff --git a/_images/71464f41a600b82a21fb8145d6da6372716a1d658eb09b075628598c1ca0c2c6.png b/_images/71464f41a600b82a21fb8145d6da6372716a1d658eb09b075628598c1ca0c2c6.png new file mode 100644 index 0000000..c3da19e Binary files /dev/null and b/_images/71464f41a600b82a21fb8145d6da6372716a1d658eb09b075628598c1ca0c2c6.png differ diff --git a/_images/7a204e67bf4cd962187cf778534ae56bbaeb4dbc1bbf19681183439da815d942.png b/_images/7a204e67bf4cd962187cf778534ae56bbaeb4dbc1bbf19681183439da815d942.png deleted file mode 100644 index c832584..0000000 Binary files a/_images/7a204e67bf4cd962187cf778534ae56bbaeb4dbc1bbf19681183439da815d942.png and /dev/null differ diff --git a/_images/94d18d3d4664c09e29ac2a9f661585be3ea67e4043777c3f339077f031456486.png b/_images/94d18d3d4664c09e29ac2a9f661585be3ea67e4043777c3f339077f031456486.png new file mode 100644 index 0000000..691599e Binary files /dev/null and b/_images/94d18d3d4664c09e29ac2a9f661585be3ea67e4043777c3f339077f031456486.png differ diff --git a/_images/9ac388bb73a2903fcc26a9e56d67334a38b7217f1d9fb9fe422768ea6419656d.png b/_images/9ac388bb73a2903fcc26a9e56d67334a38b7217f1d9fb9fe422768ea6419656d.png new file mode 100644 index 0000000..b47a6aa Binary files /dev/null and b/_images/9ac388bb73a2903fcc26a9e56d67334a38b7217f1d9fb9fe422768ea6419656d.png differ diff --git a/_images/adaeffd1cf285dd63535c2c781fb1286bb3bd072812b80b4aadc4cc675736d3a.png b/_images/adaeffd1cf285dd63535c2c781fb1286bb3bd072812b80b4aadc4cc675736d3a.png new file mode 100644 index 0000000..64b6f6a Binary files /dev/null and b/_images/adaeffd1cf285dd63535c2c781fb1286bb3bd072812b80b4aadc4cc675736d3a.png differ diff --git a/_images/c80fb9135c61b29129499d74efabdb728d01e90f1b18055728d11bf698744b5c.png b/_images/c80fb9135c61b29129499d74efabdb728d01e90f1b18055728d11bf698744b5c.png new file mode 100644 index 0000000..437c013 Binary files /dev/null and b/_images/c80fb9135c61b29129499d74efabdb728d01e90f1b18055728d11bf698744b5c.png differ diff --git a/_images/ce341e7b5e15603176fe4c01cc9731b1255965a6276c12e0db0cf60fa40c380e.png b/_images/ce341e7b5e15603176fe4c01cc9731b1255965a6276c12e0db0cf60fa40c380e.png new file mode 100644 index 0000000..8670027 Binary files /dev/null and b/_images/ce341e7b5e15603176fe4c01cc9731b1255965a6276c12e0db0cf60fa40c380e.png differ diff --git a/_images/d1c7d7b84b3a6ac445ac29b9c2056d5ff847bed88f6d4718a6a5d831289668c4.png b/_images/d1c7d7b84b3a6ac445ac29b9c2056d5ff847bed88f6d4718a6a5d831289668c4.png deleted file mode 100644 index 6950fdc..0000000 Binary files a/_images/d1c7d7b84b3a6ac445ac29b9c2056d5ff847bed88f6d4718a6a5d831289668c4.png and /dev/null differ diff --git a/_images/f58a28d8bafef836b4f2d1ea0216cb04e354440f8ce63a6016047a6e39295bf5.png b/_images/f58a28d8bafef836b4f2d1ea0216cb04e354440f8ce63a6016047a6e39295bf5.png new file mode 100644 index 0000000..d8715de Binary files /dev/null and b/_images/f58a28d8bafef836b4f2d1ea0216cb04e354440f8ce63a6016047a6e39295bf5.png differ diff --git a/_images/f872d80acedc9f464f69f4829103b9c5bb0b73d3fe8a5c463eee4cc45be0f044.png b/_images/f872d80acedc9f464f69f4829103b9c5bb0b73d3fe8a5c463eee4cc45be0f044.png new file mode 100644 index 0000000..b141566 Binary files /dev/null and b/_images/f872d80acedc9f464f69f4829103b9c5bb0b73d3fe8a5c463eee4cc45be0f044.png differ diff --git a/_sources/algorithms/mclmc.md b/_sources/algorithms/mclmc.md index 09d7e9b..77889ea 100644 --- a/_sources/algorithms/mclmc.md +++ b/_sources/algorithms/mclmc.md @@ -44,7 +44,7 @@ MCLMC in Blackjax comes with a tuning algorithm which attempts to find optimal v An example is given below, of tuning and running a chain for a 1000 dimensional Gaussian target (of which a 2 dimensional marginal is plotted): -```{code-cell} ipython3 +```{code-cell} :tags: [hide-cell] import matplotlib.pyplot as plt @@ -66,7 +66,7 @@ from numpyro.infer.util import initialize_model rng_key = jax.random.key(int(date.today().strftime("%Y%m%d"))) ``` -```{code-cell} ipython3 +```{code-cell} def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desired_energy_variance= 5e-4): init_key, tune_key, run_key = jax.random.split(key, 3) @@ -115,7 +115,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desire return samples, blackjax_state_after_tuning, blackjax_mclmc_sampler_params, run_key ``` -```{code-cell} ipython3 +```{code-cell} # run the algorithm on a high dimensional gaussian, and show two of the dimensions logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x)) @@ -134,13 +134,13 @@ samples, initial_state, params, chain_key = run_mclmc( samples.mean() ``` -```{code-cell} ipython3 +```{code-cell} plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1) plt.axis("equal") plt.title("Scatter Plot of Samples") ``` -```{code-cell} ipython3 +```{code-cell} def visualize_results_gauss(samples, label, color): x1 = samples[:, 0] plt.hist(x1, bins= 30, density= True, histtype= 'step', lw= 4, color= color, label= label) @@ -165,12 +165,12 @@ ground_truth_gauss() A natural sanity check is to see if reducing $\epsilon$ changes the inferred distribution to an extent you care about. For example, we can inspect the 1D marginal with a stepsize $\epsilon$ as above, and compare it to a stepsize $\epsilon/2$ (and double the number of steps). We show this comparison below: -```{code-cell} ipython3 +```{code-cell} new_params = params._replace(step_size= params.step_size / 2) new_num_steps = num_steps * 2 ``` -```{code-cell} ipython3 +```{code-cell} sampling_alg = blackjax.mclmc( logdensity_fn, L=new_params.L, @@ -211,7 +211,7 @@ Our task is to find the posterior of the parameters $\{R_n\}_{n =1}^N$, $\sigma$ First, we get the data, define a model using NumPyro, and draw samples: -```{code-cell} ipython3 +```{code-cell} import matplotlib.dates as mdates from numpyro.examples.datasets import SP500, load_dataset from numpyro.distributions import StudentT @@ -243,7 +243,7 @@ def setup(): setup() ``` -```{code-cell} ipython3 +```{code-cell} def from_numpyro(model, rng_key, model_args): init_params, potential_fn_gen, *_ = initialize_model( rng_key, @@ -272,13 +272,13 @@ rng_key = jax.random.key(42) logp_sv, x_init = from_numpyro(stochastic_volatility, rng_key, model_args) ``` -```{code-cell} ipython3 +```{code-cell} num_steps = 20000 samples, initial_state, params, chain_key = run_mclmc(logdensity_fn= logp_sv, num_steps= num_steps, initial_position= x_init, key= sample_key, transform=lambda state, info: state.position) ``` -```{code-cell} ipython3 +```{code-cell} def visualize_results_sv(samples, color, label): R = np.exp(np.array(samples['s'])) # take an exponent to get R @@ -297,7 +297,7 @@ plt.legend() plt.show() ``` -```{code-cell} ipython3 +```{code-cell} new_params = params._replace(step_size = params.step_size/2) new_num_steps = num_steps * 2 @@ -318,10 +318,9 @@ _, new_samples = blackjax.util.run_inference_algorithm( transform=lambda state, info : state.position, progress_bar=True, ) - ``` -```{code-cell} ipython3 +```{code-cell} setup() visualize_results_sv(new_samples,'red', 'MCLMC', ) visualize_results_sv(samples,'teal', 'MCLMC (stepsize/2)', ) @@ -332,7 +331,7 @@ plt.show() Here, we have again inspected the effect of halving $\epsilon$. This looks OK, but suppose we are interested in the hierarchial parameters in particular, which tend to be harder to infer. We now inspect the marginal of a hierarchical parameter: -```{code-cell} ipython3 +```{code-cell} def visualize_results_sv_marginal(samples, color, label): # plt.subplot(1, 2, 1) # plt.hist(samples['nu'], bins = 20, histtype= 'step', lw= 4, density= True, color= color, label= label) @@ -354,9 +353,131 @@ If we care about this parameter in particular, we should reduce step size furthe +++ +## Adjusted MCLMC + +Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, `step_size` and `L`. `L` is related to the `L` parameter of the unadjusted version, but not identical. The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice. + +```{code-cell} +from blackjax.mcmc.adjusted_mclmc import rescale +from blackjax.util import run_inference_algorithm + +def run_adjusted_mclmc( + logdensity_fn, + num_steps, + initial_position, + key, + transform=lambda state, _ : state.position, + diagonal_preconditioning=False, + random_trajectory_length=True, + L_proposal_factor=jnp.inf +): + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc.init( + position=initial_position, + logdensity_fn=logdensity_fn, + random_generator_arg=init_key, + ) + + if random_trajectory_length: + integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps)) + else: + integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil(avg_num_integration_steps) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + integration_steps_fn=integration_steps_fn(avg_num_integration_steps), + sqrt_diag_cov=sqrt_diag_cov, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn, + L_proposal_factor=L_proposal_factor, + ) + + target_acc_rate = 0.9 # our recommendation + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.0, # our recommendation + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=lambda key: jnp.ceil( + jax.random.uniform(key) * rescale(L / step_size) + ), + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + L_proposal_factor=L_proposal_factor, + ) + + _, out = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=transform, + progress_bar=False, + ) + + return out +``` + +```{code-cell} +# run the algorithm on a high dimensional gaussian, and show two of the dimensions + +sample_key, rng_key = jax.random.split(rng_key) +samples = run_adjusted_mclmc( + logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)), + num_steps=1000, + initial_position=jnp.ones((1000,)), + key=sample_key, +) +plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1) +plt.axis("equal") +plt.title("Scatter Plot of Samples") +``` + +```{code-cell} +# run the algorithm on a high dimensional gaussian, and show two of the dimensions + +sample_key, rng_key = jax.random.split(rng_key) +samples = run_adjusted_mclmc( + logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)), + num_steps=1000, + initial_position=jnp.ones((1000,)), + key=sample_key, + random_trajectory_length=False, + L_proposal_factor=1.25, +) +plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1) +plt.axis("equal") +plt.title("Scatter Plot of Samples") +``` + ```{bibliography} :filter: docname in docnames ``` +``` + +```{code-cell} + ``` diff --git a/algorithms/mclmc.html b/algorithms/mclmc.html index b053ee6..2916e34 100644 --- a/algorithms/mclmc.html +++ b/algorithms/mclmc.html @@ -386,6 +386,7 @@
/opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
+/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
@@ -557,7 +558,7 @@ How to run MCLMC in BlackJaxArray(0.00264662, dtype=float32)
+
Array(0.00752303, dtype=float32)
Text(0.5, 1.0, 'Scatter Plot of Samples')
So here the change has little effect in this case.
@@ -671,8 +672,8 @@We now consider a more complex model, of stock volatility.
The returns \(r_n\) are modeled by a Student’s-t distribution whose scale (volatility) \(R_n\) is time varying and unknown. The prior for \(\log R_n\) is a Gaussian random walk, with an exponential distribution of the random walk step-size \(\sigma\). An exponential prior is also taken for the Student’s-t degrees of freedom \(\nu\). The generative process of the data is:
-Downloading - https://d2hg8soec8ck9v.cloudfront.net/datasets/SP500.csv.
+Download complete.
Download complete.
-
Here, we have again inspected the effect of halving \(\epsilon\). This looks OK, but suppose we are interested in the hierarchial parameters in particular, which tend to be harder to infer. We now inspect the marginal of a hierarchical parameter:
@@ -895,10 +894,149 @@If we care about this parameter in particular, we should reduce step size further, until the difference disappears.
+ +Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, step_size
and L
. L
is related to the L
parameter of the unadjusted version, but not identical. The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice.
from blackjax.mcmc.adjusted_mclmc import rescale
+from blackjax.util import run_inference_algorithm
+
+def run_adjusted_mclmc(
+ logdensity_fn,
+ num_steps,
+ initial_position,
+ key,
+ transform=lambda state, _ : state.position,
+ diagonal_preconditioning=False,
+ random_trajectory_length=True,
+ L_proposal_factor=jnp.inf
+):
+
+ init_key, tune_key, run_key = jax.random.split(key, 3)
+
+ initial_state = blackjax.mcmc.adjusted_mclmc.init(
+ position=initial_position,
+ logdensity_fn=logdensity_fn,
+ random_generator_arg=init_key,
+ )
+
+ if random_trajectory_length:
+ integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.ceil(
+ jax.random.uniform(k) * rescale(avg_num_integration_steps))
+ else:
+ integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil(avg_num_integration_steps)
+
+ kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel(
+ integration_steps_fn=integration_steps_fn(avg_num_integration_steps),
+ sqrt_diag_cov=sqrt_diag_cov,
+ )(
+ rng_key=rng_key,
+ state=state,
+ step_size=step_size,
+ logdensity_fn=logdensity_fn,
+ L_proposal_factor=L_proposal_factor,
+ )
+
+ target_acc_rate = 0.9 # our recommendation
+
+ (
+ blackjax_state_after_tuning,
+ blackjax_mclmc_sampler_params,
+ ) = blackjax.adjusted_mclmc_find_L_and_step_size(
+ mclmc_kernel=kernel,
+ num_steps=num_steps,
+ state=initial_state,
+ rng_key=tune_key,
+ target=target_acc_rate,
+ frac_tune1=0.1,
+ frac_tune2=0.1,
+ frac_tune3=0.0, # our recommendation
+ diagonal_preconditioning=diagonal_preconditioning,
+ )
+
+ step_size = blackjax_mclmc_sampler_params.step_size
+ L = blackjax_mclmc_sampler_params.L
+
+ alg = blackjax.adjusted_mclmc(
+ logdensity_fn=logdensity_fn,
+ step_size=step_size,
+ integration_steps_fn=lambda key: jnp.ceil(
+ jax.random.uniform(key) * rescale(L / step_size)
+ ),
+ sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov,
+ L_proposal_factor=L_proposal_factor,
+ )
+
+ _, out = run_inference_algorithm(
+ rng_key=run_key,
+ initial_state=blackjax_state_after_tuning,
+ inference_algorithm=alg,
+ num_steps=num_steps,
+ transform=transform,
+ progress_bar=False,
+ )
+
+ return out
+
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
+
+sample_key, rng_key = jax.random.split(rng_key)
+samples = run_adjusted_mclmc(
+ logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
+ num_steps=1000,
+ initial_position=jnp.ones((1000,)),
+ key=sample_key,
+)
+plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
+plt.axis("equal")
+plt.title("Scatter Plot of Samples")
+
Text(0.5, 1.0, 'Scatter Plot of Samples')
+
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
+
+sample_key, rng_key = jax.random.split(rng_key)
+samples = run_adjusted_mclmc(
+ logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
+ num_steps=1000,
+ initial_position=jnp.ones((1000,)),
+ key=sample_key,
+ random_trajectory_length=False,
+ L_proposal_factor=1.25,
+)
+plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
+plt.axis("equal")
+plt.title("Scatter Plot of Samples")
+
Text(0.5, 1.0, 'Scatter Plot of Samples')
+
+```{code-cell}
+