Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frozen parameters #107

Open
mcabbott opened this issue Aug 28, 2022 · 5 comments
Open

Frozen parameters #107

mcabbott opened this issue Aug 28, 2022 · 5 comments
Labels
enhancement New feature or request

Comments

@mcabbott
Copy link
Member

mcabbott commented Aug 28, 2022

It would be nice to be able to temporarily exclude some parameters from training.
(Edit: I forgot that there is FluxML/Flux.jl#1931, now folded in here.)

  1. One mechanism is to alter Leaf to record whether is is frozen. This is what Per-leaf freezing #49 does, and what Allow shared parameters, take III #106 suggests as an aside. The former is immutable, changed by walking & re-building. The latter makes Leaf mutable (for other reasons) so this can be changed in place. (Edit: implemented in Add freeze!/thaw! #112, merged.)

  2. Another mechanism would be to insert some Frozen struct into the state tree which stops further exploration. This may make it easier to freeze a whole branch. But will result in a tree with different labels to the model, some pattern like model.layers[1].layers[1].weight will no longer translate directly to one for the state tree.

  3. A similar struct could equally be inserted into the model not the state. Or into both. Since gradient calculation never sees the state, changing the model may allow for faster gradients. Does Optimisers.jl own the struct Frozen, if it is to recognise it?

Maybe independently, it needs a friendly way to set & remove these labels.

  1. PR Per-leaf freezing #49 proposes that you give an address like freeze(state, (:layers, 1, :enc, 3)). It seems a bit awkward to require you to know all the field names from the root.

  2. It would also be possible to work based on just one field name: freeze(state, :enc) acts on anything within any field called enc (which in practice is some Chain(enc = ..., dec = ...)). Likewise freeze(state, :bias) could affect every layer.

  3. Another possibility is to allow control based on the type in the model. Then it has to walk both, state = freeze(state, model, cond) or perhaps state = freeze(f, state, model) where f is a do block which tests x isa Dense or whatever. Doesn't lend itself so obviously to freezing only some fields, enc or bias... unless f returns not a Bool but a list of fields, like x isa Chain && return :enc.

  4. If the modification is to the model, then 6. becomes model = freeze(f, model).

  5. If Leaf is mutable, then instead of an address you can just pass a part of the tree: freeze!(tree.layers[1].enc[3]), after confirming that model.layers[1].enc[3] is the part you want. (Edit: implemented as Add freeze!/thaw! #112, merged.)

There's a related API question for shared weights. At present Flux (and Functors) rely on objectid. This won't work for immutable arrays.

  1. One idea is to wrap them in a struct like TiedWeight(array, Ref()) to get an objectid (and possibly remove this later).

  2. The idea of Transparent handling of tied weights #100 is that instead the state tree can have the same (mutable) Leaf struct at the location of tied arrays. How do you construct this? With 4. this might be tie(state, (:layers, 1, :enc, 3) => (:layers, 1, :dec, 3, :parent)) where the :parent is because of a Transpose. Is there a less ugly way?

@ToucheSir
Copy link
Member

Optax's design may be of interest here. They of course can get away with making everything immutable. However, if we think of a masked state tree as a temporary view of the original, perhaps we can do something similar.

One concern with wholesale replacement vs mutating a flag is tied weights. We'd either have to do a two-pass solution to also catch those, or document that they won't be picked up.

But will result in a tree with different labels to the model, some pattern like model.layers[1].layers[1].weight will no longer translate directly to one for the state tree.

Inserting this struct in place of leaf nodes instead of entire subtrees would mean more time spent traversing, but it would preserve the structural equivalence.

With 4. this might be tie(state, (:layers, 1, :enc, 3), (:layers, 1, :dec, 3, :parent)). Is there a less ugly way?

Accessors.@set state.layers[1].enc[3] = state.layers.[1].dec[3].parent

Would be the direct equivalent. Not prettier, but at least more familiar syntax. At this stage, I think figuring out how to tie immutable params in bulk can be left to users.

@mcabbott
Copy link
Member Author

Inserting this struct in place of leaf nodes instead

To be clear I've only considered reversible modifications. So the closest thing is 1.,#49 which replaces the Leaf with a different one.

Irreversibly truncating the state tree is another option. But then perhaps we need a way to merge back the old one. And the user needs to keep two objects.

tied weights. We'd either have to do a two-pass solution to also catch those, or document that they won't be picked up.

Good question. First, what's the desired behaviour here at all? If you freeze enc, should this freeze dec too? (As the mutable Leaf would do.) Or un-tie them from dec? Or just update using the gradients from dec alone (but update the shared momenta)?

@ToucheSir
Copy link
Member

How common are state tree merges? If you make Frozen a wrapper type, then the modification is reversible.

Good question. First, what's the desired behaviour here at all? If you freeze enc, should this freeze dec too? (As the mutable Leaf would do.) Or un-tie them from dec? Or just update using the gradients from dec alone (but update the shared momenta)?

There doesn't seem to be a clear answer, yeah. For safety then, we should try to match current Flux semantics and freeze ties as well I think.

@mcabbott
Copy link
Member Author

mcabbott commented Aug 28, 2022

The optax.masked is interesting. If I understand right, this is closer to a modification of our setup, to apply different rules to different parameters. (Or to append more rules to an OptimiserChain.) I think we discussed allowing setup(f, mode) at some point, and f has similar API questions to 6. above --- does it see the layer, or the field name, or the parameter array? The optax one seems to see the array. (I guess it could see the layer and then the array.)

For shared parameters, maybe the word "freeze" implies they never change (hence tied ones must be frozen too) while "mask" could be read as caring about some gradients. As you say Flux freezes both.

@ToucheSir
Copy link
Member

ToucheSir commented Aug 28, 2022

The optax one sees the entire tree. In their example they map over all leaves in the callback function, but as long as you return something with the same shape as the original state they don't care how you do it. Optax in general though tends to write their rules "vectorized", however (instead of taking in a leaf, each rule takes in a state tree and is responsible for mapping some function over each leaf), so a direct comparison to Optimisers.jl needs to account for that.

@mcabbott mcabbott mentioned this issue Oct 13, 2022
2 tasks
@mcabbott mcabbott added the enhancement New feature or request label Nov 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants