Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specify Module's PyTree-Representation for jit/grad seperately. I.e. How to freeze state.variables #22

Open
simon-bachhuber opened this issue Oct 15, 2022 · 0 comments

Comments

@simon-bachhuber
Copy link

Disclaimer: I have not used oryx yet. Further, not an issue but rather just a question/discussion.

Suppose i want to define some recurrent network but its initial hidden state is not a parameter, i.e. it should be exposed to jax.jit but not to jax.grad. How can this be done?

E.g.

# syntax might be slightly wrong, think of it as pseudo-code
def network_def(x):
  s = state.variable(..., name="hidden-state")
  p = state.variable(..., name="parameters")
  s, y = f(s, p, x)
  state.assign(s, name="hidden-state")
  return y 

network = state.init(network_def)(x)

@jax.jit # <- this should "see" hidden-state
@jax.grad # <- this should not "see" hidden-state
def loss_fn(network, x, y):
  ...

Is there an elegant way of doing that?
Thank you!

Also, are all jax-transformations supported? Readme mentions jit, grad, vmap. What about pmap,scan (and all the others) ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant