Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to mark parameters as trainable or not? #866

Open
tomsch420 opened this issue Sep 28, 2024 · 9 comments
Open

How to mark parameters as trainable or not? #866

tomsch420 opened this issue Sep 28, 2024 · 9 comments
Labels
question User queries

Comments

@tomsch420
Copy link

Greetings!

I got custom Layers in equinox that look approximately like this.

class ProductLayer(InnerLayer):

    child_layers: List[Union[SumLayer, InputLayer]]
    edges: BCOO

class SumLayer(InnerLayer):

    log_weights: List[BCOO]
    child_layers: Union[List[[ProductLayer]], List[InputLayer]]

class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC):
    interval: jax.Array

I now want to exclude ProductLayer.edges from the parameters of a model since they cannot be adjusted by gradient descent.
Fruthermore, SumLayer.log_weights.indices can also not be adjusted. The ContinuousLayerWithFiniteSupport.interval can also not be adjusted using gradient descent. How can i best filter these out for the eqx.partition method?

@patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger added the question User queries label Sep 29, 2024
@danielward27
Copy link
Contributor

There is a risk to the suggested approach that should at least be highlighted in the docs: the parameters may still be punished by regularization.

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
from optax import adamw


class Model(eqx.Module):
    buffer: Array
    param: Array

    def __call__(self, x):
        return self.param * x + jax.lax.stop_gradient(self.buffer)

@eqx.filter_value_and_grad
def loss(model, x):
    return model(x)

model = Model(jnp.ones(()), jnp.ones(()))
loss, grad = loss(model, 2)
optimizer = adamw(1e-1)  # Optimizer with regularization
opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
updates, opt_state = optimizer.update(grad, opt_state, eqx.filter(model, eqx.is_array))
model = eqx.apply_updates(model, updates)
assert model.buffer == jnp.ones(())  # Fails!

Unless I am missing a downside, the approach I think should be recommended is to use a wrapper class (NonTrainable) to wrap non-trainable nodes, and partitioning parameters e.g. with:

params, static = eqx.partition(
        model,
        eqx.is_inexact_array,
        is_leaf=lambda leaf: isinstance(leaf, NonTrainable),
    )

@patrick-kidger
Copy link
Owner

Ah! That really isn't very good, you're right.

Hmm, I'm trying to figure out if there's a way to handle this ergonomically. The best I can come up with is to wrap the Optax calls (like we already do for eqx.apply_updates) with something that respects such a Nontrainable wrapper. This is just such an easy footgun!

@dlwh
Copy link
Contributor

dlwh commented Oct 22, 2024

FWIW I've landed on the optax wrapper approach. I have a trainable/non_trainable mask that I create early on and partition that way. I don't even bother with stop_grad most of the time and pray that XLA does the DCE for me (it seems to).

For things that are really constants (e.g. rotary embeddings) I just materialize those in the kernel with ensure_compile_time_eval

@patrick-kidger
Copy link
Owner

Ah, nice! Okay, I think I'm convinced.

I'd be happy to take a PR implementing this, then.

@smorad
Copy link

smorad commented Jan 21, 2025

Just for posterity, @danielward27 has written a small library that fixes this (along with enabling other parameterizations) in https://github.com/danielward27/paramax

@patrick-kidger
Copy link
Owner

Oh, this is excellent. Small and does exactly the right thing. @danielward27 would you be interested in having this be advertised in the various Equinox-ecosystem READMEs, e.g. https://github.com/patrick-kidger/equinox/?tab=readme-ov-file#see-also-other-libraries-in-the-jax-ecosystem ?

@danielward27
Copy link
Contributor

Thanks! That would be great to be added to the list! I can do a pull request to add it if you would like - whatever is easiest for you.

@patrick-kidger
Copy link
Owner

Awesome :) Yup send a PR! The ecosystem lists appear in README.md and in docs/index.md.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

5 participants