You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Disclaimer: I have not used oryx yet. Further, not an issue but rather just a question/discussion.
Suppose i want to define some recurrent network but its initial hidden state is not a parameter, i.e. it should be exposed to jax.jit but not to jax.grad. How can this be done?
E.g.
# syntax might be slightly wrong, think of it as pseudo-codedefnetwork_def(x):
s=state.variable(..., name="hidden-state")
p=state.variable(..., name="parameters")
s, y=f(s, p, x)
state.assign(s, name="hidden-state")
returnynetwork=state.init(network_def)(x)
@jax.jit# <- this should "see" hidden-state@jax.grad# <- this should not "see" hidden-statedefloss_fn(network, x, y):
...
Is there an elegant way of doing that?
Thank you!
Also, are all jax-transformations supported? Readme mentions jit, grad, vmap. What about pmap,scan (and all the others) ?
The text was updated successfully, but these errors were encountered:
Disclaimer: I have not used
oryx
yet. Further, not an issue but rather just a question/discussion.Suppose i want to define some recurrent network but its initial hidden state is not a parameter, i.e. it should be exposed to
jax.jit
but not tojax.grad
. How can this be done?E.g.
Is there an elegant way of doing that?
Thank you!
Also, are all jax-transformations supported? Readme mentions
jit
,grad
,vmap
. What aboutpmap
,scan
(and all the others) ?The text was updated successfully, but these errors were encountered: