-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert a BrainPy model to process batched input by jax.vmap
#608
Comments
Thanks for opening this great question. Actually, the object-oriented style in BrainPy does not support a general mapping transformation with |
The key of BrainPy's In your case, you want to import jax
import brainpy.math as bm
from functools import wraps
def vmap_grad_fun(f, *inputs):
# Step 1: finding out all variables #
# --------------------------------- #
# evaluation without spending any actual FLOP computation
vars, _ = bm.eval_shape(f, *inputs)
# separate variables into two groups: weights and states
weights, states = vars.separate_by_instance(bm.TrainVar)
# Step 2: transform the object as the function that compatible with jax.vmap #
# -------------------------------------------------------------------------- #
@wraps(f)
def new_fun(ws, vars, inputs):
# A. assign weights and states in each batch to the model
for key in ws: weights[key] = ws[key]
for key in vars: states[key] = vars[key]
# B. run the function
outputs = f(*inputs)
# C. return outputs of each batch
return outputs
ori_weights, ori_states = weights.dict_data(), vars.dict_data()
# replicate the states for batching
batch_size = inputs[0].shape[0]
batched_states = jax.tree_map(lambda x: bm.repeat(bm.expand_dims(x, 0), batch_size, axis=0), ori_states)
# batching the states and inputs
batched_outs = jax.vmap(new_fun, in_axes=(None, 0, 0), out_axes=0)(ori_weights, batched_states, inputs)
del batched_states
# recovery the origin weights and states
for key in ori_weights: weights[key] = ori_weights[key]
for key in ori_states: vars[key] = ori_states[key]
# Step 3: return the batched outputs
return batched_outs I hope this example can help you achieve the desired transformation. |
For a model written to process single input data, is it possible to convert the model to process batched input data simply by using
jax.vmap
? Or do we have to re-write the model to process batched data?The code section looks like this:
It currently raises the following error:
I found a previous issue (#206) mentioning this. Is it still not possible to use
jax.vmap
with brainpy models?The text was updated successfully, but these errors were encountered: