-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add non-reversible parallel tempering #740
Comments
Thank you for the detailed write-up, much appreciated. And yes, a contribution will be very welcome! Regarding the design choice, I have not read in detail of your blog post, but just wondering if you have compared your implementation with TFP, which i am a bit more familiar with. |
BTW, I am a huge fan of parallel tempering - very excited about this! Looking forward to your PR! |
Same here, I had planned to do it at some point but I've not been able to commit the time to :D very happy someone is doing it! Design choice wise, I actually do not think it would be a good idea to vmap everything at the lower level, in particular in sight of being able to do proper sharding. Once that's done, the choice of parallelism for the state chains is very much user driven and it's hard to enforce a coherent interface supporting all JAX models. |
IIUC, swapping kernel is basically swapping parameter (eg. step size), which means you update the input parameter with some advance indexing. The base kernel would remain the same like |
Thank you for your kind feedback and all the suggestions!
Thanks, I have not been aware that there exists a TFP implementation! I like it, the major differences seem to be:
I think this is a very good point. I think it'd be convenient to have a utility function, allowing the end user to quickly build a reasonable (even if not optimally sharded) parallel tempering kernel out of an existing one, using import jax
import jax.random as jrandom
import jax.numpy as jnp
import blackjax
def init(
init_fn,
positions,
log_target,
log_reference,
inverse_temperatures,
):
def create_tempered_log_p(inverse_temperature):
def log_p(x):
return inverse_temperature * log_target(x) + (1.0 - inverse_temperature) * log_reference(x)
return log_p
def init_fn_temp(position, inverse_temperature):
return init_fn(position, create_tempered_log_p(inverse_temperature))
return jax.vmap(init_fn_temp)(positions, inverse_temperatures)
def build_kernel(
base_kernel_fn,
log_target,
log_reference,
inverse_temperatures,
parameters,
):
def create_tempered_log_p(inverse_temperature):
def log_p(x):
return inverse_temperature * log_target(x) + (1.0 - inverse_temperature) * log_reference(x)
return log_p
def kernel(rng_key, state, inverse_temperature, parameter):
return base_kernel_fn(rng_key, state, create_tempered_log_p(inverse_temperature), parameter)
n_chains = inverse_temperatures.shape[0]
def step_fn(
rng_key,
state,
):
keys = jrandom.split(rng_key, n_chains)
return jax.vmap(kernel)(
keys,
state,
inverse_temperatures,
parameters,
)
return step_fn
def log_p(x):
return -(jnp.sum(jnp.square(x)) + jnp.sum(jnp.power(x, 4)))
def log_ref(x):
return -jnp.sum(jnp.square(x))
n_chains = 10
inverse_temperatures = 0.9 ** jnp.linspace(0, 1, n_chains)
initial_positions = jnp.ones((n_chains, 2))
parameters = jnp.linspace(0.1, 1, n_chains)
init_state = init(
blackjax.mcmc.mala.init,
initial_positions,
log_p,
log_ref,
inverse_temperatures,
)
kernel = build_kernel(
blackjax.mcmc.mala.build_kernel(),
log_p,
log_ref,
inverse_temperatures,
parameters,
)
rng_key = jrandom.PRNGKey(42)
new_state, info = kernel(rng_key, init_state) Please, let me know what you think! Also: Question: how to optimally pass the parameters to the individual kernels? This solution works only if each kernel has a single parameter. This parameter could be dictionary-valued, though, allowing the users to write wrappers around kernel initialisers. E.g., one could create a wrapper around the the HMC kernel builder function, which passes the step size, mass matrix and numbers of steps as a dictionary, but I'm not sure how convenient it is for the end users.
I see! This potentially can lead to better parallelism, but I think it'd be easier for me to swap the states Question: I'm also not sure about the best design choice regarding the composed kernels. Currently each kernel More generally, |
I think for simplicity, let's start with building the functionality assuming we are using the same base kernel (e.g., HMC) with different parameter (e.g., step_size) |
Presentation of the new sampler
Parallel tempering, known also as replica exchange MCMC, maintains$K$ Markov chains at different temperatures, ranging from $\pi_0$ (the reference distribution, for example the prior) to the target distribution $\pi =: \pi_{K-1}$ .
$$(x_0, \dotsc, x_{K-1})\mapsto \prod_{i=0}^{K-1} \pi_i(x_i),$$ $\mathcal X^K = \mathcal X\times \cdots \times \mathcal X$ . By retaining the samples from only the last coordinate, it allows one to sample from $\pi = \pi_{K-1}$ .
Apart from local exploration kernels, targeting each distribution individually, it includes swap kernels, trying to switch states from different chains, hence targeting the distribution
defined on the spaces
Similarly to sequential Monte Carlo (SMC) samplers, this strategy can be highly efficient to sample from multimodal posteriors. Modern variant of parallel tempering, called non-reversible parallel tempering (NRPT), achieves the state-of-the-art performance in sampling from complex high-dimensional distributions. NRPT works with both discrete and continuous spaces, and allows one to leverage preliminary runs to tune the tempering schedule.
Resources
How does it compare to other algorithms in blackjax?
Compared with a single-chain MCMC sampler:
Compared with a tempered SMC sampler:
Where does it fit in blackjax
BlackJAX offers a large collection of MCMC kernels. They can be leveraged to build non-reversible parallel tempering MCMC samplers, exploring different temperatures simultaneously, which leads to faster mixing and allows one to sample from multimodal posteriors. NRPT ac
Are you willing to open a PR?
Yes. I have a prototype implemented in a blog post, which I would be willing to refactor and contribute.
I am however unsure about two design choices:
jax.vmap
, rather than afor
loop. I guess this problem is less apparent in tempered SMC samplers, where at temperatureThe text was updated successfully, but these errors were encountered: