Releases: blackjax-devs/blackjax
BlackJAX v0.6.0
What's Changed
- Specify custom gradients. by @FedericoV in #205
- Add Elliptical slice sampler by @albcab in #183
New Contributors
- @FedericoV made their first contribution in #205
Full Changelog: 0.5.0...0.6.0
BlackJAX v0.5.0
What's Changed
- Fix tests and pre-commits on local devices by @albcab in #188
- Add the MALA sampling algorithm by @rlouf in #189
- Fix install instructions in README by @FredericWantiez in #196
- Adding progress bars to window adaptation by @zaxtax in #190
- Add the Orbital HMC sampler by @albcab
New Contributors
- @FredericWantiez made their first contribution in #196
BlackJAX 0.4.0
Breaking changes
This release simplifies the high-level API for samplers. For instance, to initialize and use a HMC kernel:
import blackjax
hmc = blackjax.hmc(logprob_fn step_size, inverse_mass_matrix, num_integration_steps)
state = hmc.init(position)
new_state, info = hmc.step(rng_key, state)
hmc
is now a namedtuple with a init
and a step
function; you only need to pass logprob_fn
at initialization unlike the previous version. The internals were simplified a lot, and the hierarchy is now more flat. For instance, to use the base HMC kernel directly:
import blackjax.mcmc.integrators as integrators
import blackjax.mcmc.hmc as hmc
kernel = hmc.kernel(integrators.mclachlan)
state = hmc.init(position, logprob_fn)
state, info = kernel(rng_key, state, logprob_fn, step_size, inverse_mass_matrix, num_integration_steps)
The API of the base kernels has also been changed to be more flexible.
Performance improvements
Thanks to the work of @zaxtax @junpenglao and @rlouf the performance of the NUTS sampler (especially the warmup) has been greatly improved and is now at least on par with numpyro.
What's Changed
No new algorithm in this release, but important work was done on the API, the internals and the examples.
- Fix SMC notebook as per issue #148 by @AdrienCorenflos in #149
- Add a logistic regression example by @gerdm in #103
- Update to fold in all trajectory building into a single while_loop by @junpenglao in #164
- Update use_with_numpyro.ipynb by @junpenglao in #165
- Updated trajectory.py by @Gautam-Hegde in #167
- Simplify the user API & create a more general kernel by @rlouf in #159
- Make the PyMC example run on version 4 by @rlouf in #178
- Moved dual_averaging.py to /adaptation and renamed it to optimizers.py by @Bnux256 in #179
- Fix example notebook to use pymc 4 by @zaxtax in #180
New Contributors
- @gerdm made their first contribution in #103
- @Gautam-Hegde made their first contribution in #167
- @Bnux256 made their first contribution in #179
- @zaxtax made their first contribution in #180
Full Changelog: 0.3.0...0.4.0
BlackJAX 0.3.0
What changed
Breaking changes
To build a HMC or NUTS kernel in 0.2.1 and previous versions one needed to provide a potential_fn
function:
kernel = nuts.kernel(potential_fn, step_size, inverse_mass_matrix)
Instead we now ask the users to provide the more commonly used log-probability function:
kernel = nuts.kernel(logprob_fn, step_size, inverse_mass_matrix)
where logprob_fn = lambda x: -potential_fn(*x)
New features
- Tempered Sequential Monte Carlo (@AdrienCorenflos #40 )
- Rosenbluth Metropolis Hastings algorithm (@AdrienCorenflos #74 )
- Effective Sample Size, RHat (@junpenglao #66 )
- Higher-order integrators for HMC (@rlouf #59 )
Bugs
BlackJAX 0.2.1
What changed
momentum
andposition
were passed to the kinetic energy in the wrong order, leading to biased sampling as noticed in #46. We corrected this behavior and added a new test.
BlackJAX 0.2
What changed
- The Stan adaptation scheme, including dual averaging, computing covariance with Welford's algorithm and the schedule (@rlouf)
- Recursive implementation of NUTS (@junpenglao)
- Many BUG fixes on NUTS (@junpenglao)
BlackJAX 0.1
New features
hmc
kernelnuts
kernel- Notebook with examples of how to sample one or multiple chains with HMC, NUTS