Replies: 1 comment
-
Hey @JINKEHE
state = jax.tree.map(lambda x: x, state) # clone
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I have two questions about how to clone things in
flax.nnx
.How to clone a state?
cloned_model = nnx.clone(model)
. But is there anything likegraphdef, state = nnx.split(model). new_state = nnx.clone(state)
?I also have a question about how
clone
is implemented in flax.nnx.in
graph.py
, we haveBut why does this implement a deepcopy? Is it in
split
ormerge
? Does it mean that everytime we use the nnx functional API, we perform a deepcopy of neural network weights? This might cause some concerns on performance ...Thanks in advance for your help!
Beta Was this translation helpful? Give feedback.
All reactions