-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to mark parameters as trainable or not? #866
Comments
There is a risk to the suggested approach that should at least be highlighted in the docs: the parameters may still be punished by regularization. import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
from optax import adamw
class Model(eqx.Module):
buffer: Array
param: Array
def __call__(self, x):
return self.param * x + jax.lax.stop_gradient(self.buffer)
@eqx.filter_value_and_grad
def loss(model, x):
return model(x)
model = Model(jnp.ones(()), jnp.ones(()))
loss, grad = loss(model, 2)
optimizer = adamw(1e-1) # Optimizer with regularization
opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
updates, opt_state = optimizer.update(grad, opt_state, eqx.filter(model, eqx.is_array))
model = eqx.apply_updates(model, updates)
assert model.buffer == jnp.ones(()) # Fails! Unless I am missing a downside, the approach I think should be recommended is to use a wrapper class ( params, static = eqx.partition(
model,
eqx.is_inexact_array,
is_leaf=lambda leaf: isinstance(leaf, NonTrainable),
) |
Ah! That really isn't very good, you're right. Hmm, I'm trying to figure out if there's a way to handle this ergonomically. The best I can come up with is to wrap the Optax calls (like we already do for |
FWIW I've landed on the optax wrapper approach. I have a trainable/non_trainable mask that I create early on and partition that way. I don't even bother with stop_grad most of the time and pray that XLA does the DCE for me (it seems to). For things that are really constants (e.g. rotary embeddings) I just materialize those in the kernel with |
Ah, nice! Okay, I think I'm convinced. I'd be happy to take a PR implementing this, then. |
Just for posterity, @danielward27 has written a small library that fixes this (along with enabling other parameterizations) in https://github.com/danielward27/paramax |
Oh, this is excellent. Small and does exactly the right thing. @danielward27 would you be interested in having this be advertised in the various Equinox-ecosystem READMEs, e.g. https://github.com/patrick-kidger/equinox/?tab=readme-ov-file#see-also-other-libraries-in-the-jax-ecosystem ? |
Thanks! That would be great to be added to the list! I can do a pull request to add it if you would like - whatever is easiest for you. |
Awesome :) Yup send a PR! The ecosystem lists appear in |
Greetings!
I got custom Layers in equinox that look approximately like this.
I now want to exclude ProductLayer.edges from the parameters of a model since they cannot be adjusted by gradient descent.
Fruthermore, SumLayer.log_weights.indices can also not be adjusted. The ContinuousLayerWithFiniteSupport.interval can also not be adjusted using gradient descent. How can i best filter these out for the eqx.partition method?
The text was updated successfully, but these errors were encountered: