Skip to content

Commit

Permalink
Deploying to gh-pages from @ 960dea0 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Jan 3, 2025
1 parent 49bdb75 commit 74bc35d
Show file tree
Hide file tree
Showing 18 changed files with 292 additions and 30 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
151 changes: 136 additions & 15 deletions _sources/algorithms/mclmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ MCLMC in Blackjax comes with a tuning algorithm which attempts to find optimal v

An example is given below, of tuning and running a chain for a 1000 dimensional Gaussian target (of which a 2 dimensional marginal is plotted):

```{code-cell} ipython3
```{code-cell}
:tags: [hide-cell]
import matplotlib.pyplot as plt
Expand All @@ -66,7 +66,7 @@ from numpyro.infer.util import initialize_model
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
```

```{code-cell} ipython3
```{code-cell}
def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desired_energy_variance= 5e-4):
init_key, tune_key, run_key = jax.random.split(key, 3)
Expand Down Expand Up @@ -115,7 +115,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desire
return samples, blackjax_state_after_tuning, blackjax_mclmc_sampler_params, run_key
```

```{code-cell} ipython3
```{code-cell}
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x))
Expand All @@ -134,13 +134,13 @@ samples, initial_state, params, chain_key = run_mclmc(
samples.mean()
```

```{code-cell} ipython3
```{code-cell}
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
plt.axis("equal")
plt.title("Scatter Plot of Samples")
```

```{code-cell} ipython3
```{code-cell}
def visualize_results_gauss(samples, label, color):
x1 = samples[:, 0]
plt.hist(x1, bins= 30, density= True, histtype= 'step', lw= 4, color= color, label= label)
Expand All @@ -165,12 +165,12 @@ ground_truth_gauss()

A natural sanity check is to see if reducing $\epsilon$ changes the inferred distribution to an extent you care about. For example, we can inspect the 1D marginal with a stepsize $\epsilon$ as above, and compare it to a stepsize $\epsilon/2$ (and double the number of steps). We show this comparison below:

```{code-cell} ipython3
```{code-cell}
new_params = params._replace(step_size= params.step_size / 2)
new_num_steps = num_steps * 2
```

```{code-cell} ipython3
```{code-cell}
sampling_alg = blackjax.mclmc(
logdensity_fn,
L=new_params.L,
Expand Down Expand Up @@ -211,7 +211,7 @@ Our task is to find the posterior of the parameters $\{R_n\}_{n =1}^N$, $\sigma$

First, we get the data, define a model using NumPyro, and draw samples:

```{code-cell} ipython3
```{code-cell}
import matplotlib.dates as mdates
from numpyro.examples.datasets import SP500, load_dataset
from numpyro.distributions import StudentT
Expand Down Expand Up @@ -243,7 +243,7 @@ def setup():
setup()
```

```{code-cell} ipython3
```{code-cell}
def from_numpyro(model, rng_key, model_args):
init_params, potential_fn_gen, *_ = initialize_model(
rng_key,
Expand Down Expand Up @@ -272,13 +272,13 @@ rng_key = jax.random.key(42)
logp_sv, x_init = from_numpyro(stochastic_volatility, rng_key, model_args)
```

```{code-cell} ipython3
```{code-cell}
num_steps = 20000
samples, initial_state, params, chain_key = run_mclmc(logdensity_fn= logp_sv, num_steps= num_steps, initial_position= x_init, key= sample_key, transform=lambda state, info: state.position)
```

```{code-cell} ipython3
```{code-cell}
def visualize_results_sv(samples, color, label):
R = np.exp(np.array(samples['s'])) # take an exponent to get R
Expand All @@ -297,7 +297,7 @@ plt.legend()
plt.show()
```

```{code-cell} ipython3
```{code-cell}
new_params = params._replace(step_size = params.step_size/2)
new_num_steps = num_steps * 2
Expand All @@ -318,10 +318,9 @@ _, new_samples = blackjax.util.run_inference_algorithm(
transform=lambda state, info : state.position,
progress_bar=True,
)
```

```{code-cell} ipython3
```{code-cell}
setup()
visualize_results_sv(new_samples,'red', 'MCLMC', )
visualize_results_sv(samples,'teal', 'MCLMC (stepsize/2)', )
Expand All @@ -332,7 +331,7 @@ plt.show()

Here, we have again inspected the effect of halving $\epsilon$. This looks OK, but suppose we are interested in the hierarchial parameters in particular, which tend to be harder to infer. We now inspect the marginal of a hierarchical parameter:

```{code-cell} ipython3
```{code-cell}
def visualize_results_sv_marginal(samples, color, label):
# plt.subplot(1, 2, 1)
# plt.hist(samples['nu'], bins = 20, histtype= 'step', lw= 4, density= True, color= color, label= label)
Expand All @@ -354,9 +353,131 @@ If we care about this parameter in particular, we should reduce step size furthe

+++

## Adjusted MCLMC

Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, `step_size` and `L`. `L` is related to the `L` parameter of the unadjusted version, but not identical. The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice.

```{code-cell}
from blackjax.mcmc.adjusted_mclmc import rescale
from blackjax.util import run_inference_algorithm
def run_adjusted_mclmc(
logdensity_fn,
num_steps,
initial_position,
key,
transform=lambda state, _ : state.position,
diagonal_preconditioning=False,
random_trajectory_length=True,
L_proposal_factor=jnp.inf
):
init_key, tune_key, run_key = jax.random.split(key, 3)
initial_state = blackjax.mcmc.adjusted_mclmc.init(
position=initial_position,
logdensity_fn=logdensity_fn,
random_generator_arg=init_key,
)
if random_trajectory_length:
integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.ceil(
jax.random.uniform(k) * rescale(avg_num_integration_steps))
else:
integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil(avg_num_integration_steps)
kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel(
integration_steps_fn=integration_steps_fn(avg_num_integration_steps),
sqrt_diag_cov=sqrt_diag_cov,
)(
rng_key=rng_key,
state=state,
step_size=step_size,
logdensity_fn=logdensity_fn,
L_proposal_factor=L_proposal_factor,
)
target_acc_rate = 0.9 # our recommendation
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
target=target_acc_rate,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.0, # our recommendation
diagonal_preconditioning=diagonal_preconditioning,
)
step_size = blackjax_mclmc_sampler_params.step_size
L = blackjax_mclmc_sampler_params.L
alg = blackjax.adjusted_mclmc(
logdensity_fn=logdensity_fn,
step_size=step_size,
integration_steps_fn=lambda key: jnp.ceil(
jax.random.uniform(key) * rescale(L / step_size)
),
sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov,
L_proposal_factor=L_proposal_factor,
)
_, out = run_inference_algorithm(
rng_key=run_key,
initial_state=blackjax_state_after_tuning,
inference_algorithm=alg,
num_steps=num_steps,
transform=transform,
progress_bar=False,
)
return out
```

```{code-cell}
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
sample_key, rng_key = jax.random.split(rng_key)
samples = run_adjusted_mclmc(
logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
num_steps=1000,
initial_position=jnp.ones((1000,)),
key=sample_key,
)
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
plt.axis("equal")
plt.title("Scatter Plot of Samples")
```

```{code-cell}
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
sample_key, rng_key = jax.random.split(rng_key)
samples = run_adjusted_mclmc(
logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
num_steps=1000,
initial_position=jnp.ones((1000,)),
key=sample_key,
random_trajectory_length=False,
L_proposal_factor=1.25,
)
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
plt.axis("equal")
plt.title("Scatter Plot of Samples")
```

```{bibliography}
:filter: docname in docnames
```


```
```{code-cell}
```
Loading

0 comments on commit 74bc35d

Please sign in to comment.