-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to Convert Keras Model‘s Inference Method (with JAX Backend) to Flax Training State for Using Flax to Predict #20255
Comments
You can just use |
Check out this example. https://github.com/keras-team/keras/blob/master/examples/demo_custom_jax_workflow.py |
Thank you, François! Your suggestion worked perfectly for my program. I successfully computed the Jacobian matrix using def func_to_diff(x):
x = x[None, :]
return model.stateless_call(trainable_variables, non_trainable_variables, x)[0]
def jac_fwd_lambda(single_input):
return jax.jacfwd(func_to_diff)(single_input)
jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data) |
Hi, However, in my code, I need to pass my model as a function parameter; here is my code:
The code throws me an error
And I found that the type of
, which seems not to be a Jax function. So I am wondering how to pass the model as a parameter. Here is my notebook https://colab.research.google.com/drive/1nV8oIn4TzgmtcAk1xFaaFg4EnxLN9n4c?usp=sharing Many thanks! |
Update: I solve this issue by defining a inner function
|
I am using TensorFlow as the backend to train a DNN model. After training, I successfully converted the inference method of the Keras model to a JAX function and created a Flax training_state to perform inference using Flax. This workflow is working well. Here is my notebook.
However, when I switch to using JAX as the backend, I am unsure how to convert the inference method of the Keras model into a JAX function. Furthermore, I am also unclear about the steps needed to create a Flax training_state afterwards.
Could anyone provide guidance on how to achieve this? Any help would be greatly appreciated!
The text was updated successfully, but these errors were encountered: