Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix errors in notebook due to version update #630

Merged
merged 3 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions examples/howto/wrapping_jax_function.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions examples/howto/wrapping_jax_function.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ x_grad_wrt_emission_signal.eval()
We are now ready to make inferences about our HMM model with PyMC. We will define priors for each model parameter and use {class}`~pymc.Potential` to add the joint log-likelihood term to our model.

```{code-cell} ipython3
with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
with pm.Model() as model:
emission_signal = pm.Normal("emission_signal", 0, 1)
emission_noise = pm.HalfNormal("emission_noise", 1)

Expand Down Expand Up @@ -515,7 +515,7 @@ pm.model_to_graphviz(model)
Before we start sampling, we check the logp of each variable at the model initial point. Bugs tend to manifest themselves in the form of `nan` or `-inf` for the initial probabilities.

```{code-cell} ipython3
initial_point = model.compute_initial_point()
initial_point = model.initial_point()
initial_point
```

Expand Down Expand Up @@ -604,7 +604,7 @@ jax_fn()
We can also compile a JAX function that computes the log probability of each variable in our PyMC model, similar to {meth}`~pymc.Model.point_logps`. We will use the helper method {meth}`~pymc.Model.compile_fn`.

```{code-cell} ipython3
model_logp_jax_fn = model.compile_fn(model.logpt(sum=False), mode="JAX")
model_logp_jax_fn = model.compile_fn(model.logp(sum=False), mode="JAX")
model_logp_jax_fn(initial_point)
```

Expand All @@ -622,7 +622,7 @@ Now that we know our model logp can be entirely compiled to JAX, we can use the

```{code-cell} ipython3
with model:
idata_numpyro = pm.sampling_jax.sample_numpyro_nuts(chains=2, progress_bar=False)
idata_numpyro = pm.sampling_jax.sample_numpyro_nuts(chains=2, progressbar=False)
```

```{code-cell} ipython3
Expand Down
Loading