Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to wrap a JAX function for use in PyMC (the automatic way) #755

Open
jdehning opened this issue Dec 18, 2024 · 0 comments
Open

How to wrap a JAX function for use in PyMC (the automatic way) #755

jdehning opened this issue Dec 18, 2024 · 0 comments
Labels
proposal New notebook proposal still up for discussion

Comments

@jdehning
Copy link

Notebook proposal

Title: How to wrap a JAX function for use in PyMC (the automatic way)

Why should this notebook be added to pymc-examples?

The new wrapper @as_jax_op (still in a draft PR, but the functionality is there) requires some examples to showcase its functionality.

I would propose to have two parts, first an example of solving an ODE, similar to what I wrote here, but with only diffrax as external dependency.

Second, rewrite the existing notebook on how to wrap a function a Jax function, but using @as_jax_op instead of defining the operators manually.

Suggested categories:

  • Level: Intermediate
  • Diataxis type: How-to guide

Related notebooks

Relates to
https://www.pymc.io/projects/examples/en/latest/howto/wrapping_jax_function.html but simplifies the building of the Op. I would keep the existing one, as it explains in more depth what is happening behind the scenes.

@jdehning jdehning added the proposal New notebook proposal still up for discussion label Dec 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
proposal New notebook proposal still up for discussion
Projects
None yet
Development

No branches or pull requests

1 participant