Skip to content

Commit

Permalink
Adjusted mclmc (#61)
Browse files Browse the repository at this point in the history
* WIP

* Updated run_inference_algorithm

* UPDATE EXAMPLE

* UPDATE EXAMPLE

* UPDATE EXAMPLE

* UPDATE EXAMPLE

* mams

* mams

* adjusted

* adjusted
  • Loading branch information
reubenharry authored Jan 3, 2025
1 parent d053e23 commit 960dea0
Showing 1 changed file with 136 additions and 15 deletions.
151 changes: 136 additions & 15 deletions book/algorithms/mclmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ MCLMC in Blackjax comes with a tuning algorithm which attempts to find optimal v

An example is given below, of tuning and running a chain for a 1000 dimensional Gaussian target (of which a 2 dimensional marginal is plotted):

```{code-cell} ipython3
```{code-cell}
:tags: [hide-cell]
import matplotlib.pyplot as plt
Expand All @@ -66,7 +66,7 @@ from numpyro.infer.util import initialize_model
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
```

```{code-cell} ipython3
```{code-cell}
def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desired_energy_variance= 5e-4):
init_key, tune_key, run_key = jax.random.split(key, 3)
Expand Down Expand Up @@ -115,7 +115,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desire
return samples, blackjax_state_after_tuning, blackjax_mclmc_sampler_params, run_key
```

```{code-cell} ipython3
```{code-cell}
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x))
Expand All @@ -134,13 +134,13 @@ samples, initial_state, params, chain_key = run_mclmc(
samples.mean()
```

```{code-cell} ipython3
```{code-cell}
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
plt.axis("equal")
plt.title("Scatter Plot of Samples")
```

```{code-cell} ipython3
```{code-cell}
def visualize_results_gauss(samples, label, color):
x1 = samples[:, 0]
plt.hist(x1, bins= 30, density= True, histtype= 'step', lw= 4, color= color, label= label)
Expand All @@ -165,12 +165,12 @@ ground_truth_gauss()

A natural sanity check is to see if reducing $\epsilon$ changes the inferred distribution to an extent you care about. For example, we can inspect the 1D marginal with a stepsize $\epsilon$ as above, and compare it to a stepsize $\epsilon/2$ (and double the number of steps). We show this comparison below:

```{code-cell} ipython3
```{code-cell}
new_params = params._replace(step_size= params.step_size / 2)
new_num_steps = num_steps * 2
```

```{code-cell} ipython3
```{code-cell}
sampling_alg = blackjax.mclmc(
logdensity_fn,
L=new_params.L,
Expand Down Expand Up @@ -211,7 +211,7 @@ Our task is to find the posterior of the parameters $\{R_n\}_{n =1}^N$, $\sigma$

First, we get the data, define a model using NumPyro, and draw samples:

```{code-cell} ipython3
```{code-cell}
import matplotlib.dates as mdates
from numpyro.examples.datasets import SP500, load_dataset
from numpyro.distributions import StudentT
Expand Down Expand Up @@ -243,7 +243,7 @@ def setup():
setup()
```

```{code-cell} ipython3
```{code-cell}
def from_numpyro(model, rng_key, model_args):
init_params, potential_fn_gen, *_ = initialize_model(
rng_key,
Expand Down Expand Up @@ -272,13 +272,13 @@ rng_key = jax.random.key(42)
logp_sv, x_init = from_numpyro(stochastic_volatility, rng_key, model_args)
```

```{code-cell} ipython3
```{code-cell}
num_steps = 20000
samples, initial_state, params, chain_key = run_mclmc(logdensity_fn= logp_sv, num_steps= num_steps, initial_position= x_init, key= sample_key, transform=lambda state, info: state.position)
```

```{code-cell} ipython3
```{code-cell}
def visualize_results_sv(samples, color, label):
R = np.exp(np.array(samples['s'])) # take an exponent to get R
Expand All @@ -297,7 +297,7 @@ plt.legend()
plt.show()
```

```{code-cell} ipython3
```{code-cell}
new_params = params._replace(step_size = params.step_size/2)
new_num_steps = num_steps * 2
Expand All @@ -318,10 +318,9 @@ _, new_samples = blackjax.util.run_inference_algorithm(
transform=lambda state, info : state.position,
progress_bar=True,
)
```

```{code-cell} ipython3
```{code-cell}
setup()
visualize_results_sv(new_samples,'red', 'MCLMC', )
visualize_results_sv(samples,'teal', 'MCLMC (stepsize/2)', )
Expand All @@ -332,7 +331,7 @@ plt.show()

Here, we have again inspected the effect of halving $\epsilon$. This looks OK, but suppose we are interested in the hierarchial parameters in particular, which tend to be harder to infer. We now inspect the marginal of a hierarchical parameter:

```{code-cell} ipython3
```{code-cell}
def visualize_results_sv_marginal(samples, color, label):
# plt.subplot(1, 2, 1)
# plt.hist(samples['nu'], bins = 20, histtype= 'step', lw= 4, density= True, color= color, label= label)
Expand All @@ -354,9 +353,131 @@ If we care about this parameter in particular, we should reduce step size furthe

+++

## Adjusted MCLMC

Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, `step_size` and `L`. `L` is related to the `L` parameter of the unadjusted version, but not identical. The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice.

```{code-cell}
from blackjax.mcmc.adjusted_mclmc import rescale
from blackjax.util import run_inference_algorithm
def run_adjusted_mclmc(
logdensity_fn,
num_steps,
initial_position,
key,
transform=lambda state, _ : state.position,
diagonal_preconditioning=False,
random_trajectory_length=True,
L_proposal_factor=jnp.inf
):
init_key, tune_key, run_key = jax.random.split(key, 3)
initial_state = blackjax.mcmc.adjusted_mclmc.init(
position=initial_position,
logdensity_fn=logdensity_fn,
random_generator_arg=init_key,
)
if random_trajectory_length:
integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.ceil(
jax.random.uniform(k) * rescale(avg_num_integration_steps))
else:
integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil(avg_num_integration_steps)
kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel(
integration_steps_fn=integration_steps_fn(avg_num_integration_steps),
sqrt_diag_cov=sqrt_diag_cov,
)(
rng_key=rng_key,
state=state,
step_size=step_size,
logdensity_fn=logdensity_fn,
L_proposal_factor=L_proposal_factor,
)
target_acc_rate = 0.9 # our recommendation
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
target=target_acc_rate,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.0, # our recommendation
diagonal_preconditioning=diagonal_preconditioning,
)
step_size = blackjax_mclmc_sampler_params.step_size
L = blackjax_mclmc_sampler_params.L
alg = blackjax.adjusted_mclmc(
logdensity_fn=logdensity_fn,
step_size=step_size,
integration_steps_fn=lambda key: jnp.ceil(
jax.random.uniform(key) * rescale(L / step_size)
),
sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov,
L_proposal_factor=L_proposal_factor,
)
_, out = run_inference_algorithm(
rng_key=run_key,
initial_state=blackjax_state_after_tuning,
inference_algorithm=alg,
num_steps=num_steps,
transform=transform,
progress_bar=False,
)
return out
```

```{code-cell}
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
sample_key, rng_key = jax.random.split(rng_key)
samples = run_adjusted_mclmc(
logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
num_steps=1000,
initial_position=jnp.ones((1000,)),
key=sample_key,
)
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
plt.axis("equal")
plt.title("Scatter Plot of Samples")
```

```{code-cell}
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
sample_key, rng_key = jax.random.split(rng_key)
samples = run_adjusted_mclmc(
logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
num_steps=1000,
initial_position=jnp.ones((1000,)),
key=sample_key,
random_trajectory_length=False,
L_proposal_factor=1.25,
)
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
plt.axis("equal")
plt.title("Scatter Plot of Samples")
```

```{bibliography}
:filter: docname in docnames
```


```
```{code-cell}
```

0 comments on commit 960dea0

Please sign in to comment.