Skip to content

Commit

Permalink
Fix the benchmark (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Mar 12, 2024
1 parent df6c5e7 commit 9c3af23
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def regression_logprob(scale, coefs, preds, x):


def inference_loop(kernel, num_samples, rng_key, initial_state):
@jax.jit
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
Expand All @@ -49,11 +50,10 @@ def run_regression(algorithm, **parameters):
warmup = blackjax.window_adaptation(
algorithm,
logposterior_fn,
1000,
False,
is_mass_matrix_diagonal=False,
**parameters,
)
state, kernel, _ = warmup.run(warmup_key, {"scale": 1.0, "coefs": 2.0})
state, kernel, _ = warmup.run(warmup_key, {"scale": 1.0, "coefs": 2.0}, 1000)

states = inference_loop(kernel, 10_000, inference_key, state)

Expand Down

0 comments on commit 9c3af23

Please sign in to comment.