Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to Convert Keras Model‘s Inference Method (with JAX Backend) to Flax Training State for Using Flax to Predict #20255

Closed
KaiyueDuan opened this issue Sep 13, 2024 · 5 comments
Assignees

Comments

@KaiyueDuan
Copy link

KaiyueDuan commented Sep 13, 2024

I am using TensorFlow as the backend to train a DNN model. After training, I successfully converted the inference method of the Keras model to a JAX function and created a Flax training_state to perform inference using Flax. This workflow is working well. Here is my notebook.

However, when I switch to using JAX as the backend, I am unsure how to convert the inference method of the Keras model into a JAX function. Furthermore, I am also unclear about the steps needed to create a Flax training_state afterwards.

Could anyone provide guidance on how to achieve this? Any help would be greatly appreciated!

@KaiyueDuan KaiyueDuan changed the title How to Convert Keras Model with JAX B、ackend Inference Method to JAX Function with JAX Backend and Use it with Flax Training State? How to Convert Keras Model‘s Inference Method (with JAX Backend) to Flax Training State for Using Flax to Predict Sep 13, 2024
@fchollet
Copy link
Member

You can just use model.stateless_call(trainable_variables, non_trainable_variables, *args) (args are the model/layer's call arguments). This is a pure JAX function.

@fchollet
Copy link
Member

Check out this example. https://github.com/keras-team/keras/blob/master/examples/demo_custom_jax_workflow.py

@KaiyueDuan
Copy link
Author

KaiyueDuan commented Sep 17, 2024

Thank you, François! Your suggestion worked perfectly for my program. I successfully computed the Jacobian matrix using model.stateless_call(). Here is the code I used:

def func_to_diff(x):
    x = x[None, :]
    return model.stateless_call(trainable_variables, non_trainable_variables, x)[0]

def jac_fwd_lambda(single_input):
    return jax.jacfwd(func_to_diff)(single_input)

jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)

@KaiyueDuan
Copy link
Author

Hi,
I encountered a new issue. I noticed that in sample code https://github.com/keras-team/keras/blob/master/examples/demo_custom_jax_workflow.py, the model is a global variable.

However, in my code, I need to pass my model as a function parameter; here is my code:

from functools import partial
key = jax.random.PRNGKey(42)
tf2jax.update_config('strict_shape_check', False)

@jax.jit
def _jax_predict_tf_single(model_state, single_point):
    return model_state.apply_fn( model_state.params, single_point)[0]


@partial(jax.jit, static_argnums=0)
def predict_jax_single(my_model,single_point):
    return  my_model.stateless_call(my_model.trainable_variables,my_model.non_trainable_variables,single_point[None, :])[0].squeeze(axis=0)


def f_jacfwd(predict_single,my_model,input_data):
  # mode 0: flax; mode 1: tf; mode 2: jax; otherwise: error

    def jac_fwd_lambda(single_input):
        return jax.jacfwd(predict_single)(my_model,single_input)

    return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)

input_data = np.ones((5,3))
ret = f_jacfwd(predict_jax_single, model, input_data)

The code throws me an error

 ---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-6-0d48a9ca553d>](https://localhost:8080/#) in <cell line: 24>()
     22 
     23 input_data = np.ones((5,3))
---> 24 ret = f_jacfwd(predict_jax_single, model, input_data)

2 frames
[<ipython-input-6-0d48a9ca553d>](https://localhost:8080/#) in f_jacfwd(predict_single, my_model, input_data)
     19         return jax.jacfwd(predict_single)(my_model,single_input)
     20 
---> 21     return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)
     22 
     23 input_data = np.ones((5,3))

    [... skipping hidden 3 frame]

[<ipython-input-6-0d48a9ca553d>](https://localhost:8080/#) in jac_fwd_lambda(single_input)
     17 
     18     def jac_fwd_lambda(single_input):
---> 19         return jax.jacfwd(predict_single)(my_model,single_input)
     20 
     21     return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)

    [... skipping hidden 4 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in check_arg(arg)
    279 def check_arg(arg: Any):
    280   if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)):
--> 281     raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
    282                     "JAX type.")
    283 

TypeError: Argument '<Functional name=functional, built=True>' of type <class 'keras.src.models.functional.Functional'> is not a valid JAX type.

And I found that the type of model.stateless_call is

<class 'method'>
{'__wrapped__': <function Layer.stateless_call at 0x795fc6fdfa30>}

, which seems not to be a Jax function.

So I am wondering how to pass the model as a parameter. Here is my notebook https://colab.research.google.com/drive/1nV8oIn4TzgmtcAk1xFaaFg4EnxLN9n4c?usp=sharing

Many thanks!

@KaiyueDuan KaiyueDuan reopened this Sep 20, 2024
@KaiyueDuan
Copy link
Author

Update: I solve this issue by defining a inner function model_call, which uses the local variable my_model:

@partial(jax.jit, static_argnums=(0,1))
def f_jacfwd(predict_single,my_model,input_data):
    def jac_fwd_lambda(single_input):
        if "jax_single" in predict_single.__name__:
            def model_call(input_val):
                result = my_model.stateless_call(my_model.trainable_variables, my_model.non_trainable_variables, input_val[None, :])[0]
                return result.squeeze(axis=0)
            return jax.jacfwd(model_call)(single_input)
        return jax.jacfwd(predict_single)(my_model,single_input)

    return jax.vmap(predict_single, in_axes=(None,0))(my_model,input_data), jax.vmap(jac_fwd_lambda, in_axes=(0))(input_data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants