Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add freeze!/thaw! #112

Merged
merged 8 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Optimisers.setup
Optimisers.update
Optimisers.update!
Optimisers.adjust(::Any, ::Real)
Optimisers.freeze!
Optimisers.thaw!
```

Calling `Functors.@functor` on your model's layer types by default causes
Expand Down
29 changes: 28 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,33 @@ Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chid
st = Optimisers.setup(DecayDescent(0.1), Layer(3))
```

## Frozen Parameters

To temporarily prevent training from affecting some parameters,
use [freeze!](@ref Optimisers.freeze!) and `thaw!`.
They work by mutating all `Leaf`s of the state tree, or part of it.

```julia
using Flux, Optimisers

x = randn(Float32, 28, 28, 1, 1);
net = @autosize (size(x)...,) Chain(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the point of splatting here instead of just using size(x)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The macro at present wants to see a tuple expression, not a call. It could be made to accept anything.

Conv((3, 3), 1 => 3, stride=2, bias=false), Flux.flatten, Dense(_ => 2, relu),
)
opt = Optimisers.setup(Optimisers.Momentum(), net);

net.layers[3] isa Dense # now freeze this layer's parameters:
Optimisers.freeze!(opt.layers[3])
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
opt.layers[3].bias # confirm: Leaf(Momentum(...), [0.0, 0.0], frozen = true)

Optimisers.update!(opt, net, gradient(m -> sum(m(x)), net)...);

net.layers[3].bias # stil zero, and its momentum is too:

Optimisers.thaw!(opt)
opt.layers[3].bias # Leaf(Momentum(...), [0.0, 0.0])
```

## Tied Parameters

If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this.
Expand All @@ -159,7 +186,7 @@ st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true
This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s.
It will not at present work for `reshape`d arrays, nor for immutable arrays such as those
from StaticArrays.jl.


## Obtaining a flat parameter vector

Expand Down
60 changes: 58 additions & 2 deletions src/adjust.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,59 @@
###
### freezing
###

"""
Optimisers.freeze!(tree)

Temporarily alters the state `tree = setup(rule, model)` so that parameters
will not be updated. Un-done by [`thaw!`](@ref Optimisers.thaw!).

Can be applied to the state corresponding to only part of a model,
for instance with `model::Chain`, to freeze `model.layers[1]` you
should call `freeze!(tree.layers[1])`.

# Example
```jldoctest
julia> m = (x = ([1.0], 2.0), y = [3.0]);

julia> s = Optimisers.setup(Momentum(), m);

julia> Optimisers.freeze!(s.x)

julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient

julia> m
(x = ([1.0], 2.0), y = [-0.14159258336972558])

julia> s
(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159]))

julia> Optimisers.thaw!(s)

julia> s.x
(Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), ())
```
"""
freeze!(tree) = foreach(freeze!, tree)
freeze!(ℓ::Leaf) = (ℓ.frozen = true; nothing)

"""
Optimisers.thaw!(tree)

The reverse of [`freeze!`](@ref Optimisers.freeze!). Applies to all parameters,
mutating every `Leaf(rule, state, frozen = true)` to `Leaf(rule, state, frozen = false)`.
"""
thaw!(tree) = foreach(thaw!, tree)
thaw!(ℓ::Leaf) = (ℓ.frozen = false; nothing)

freeze!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError(
"`freeze!` must not be applied to a model, only to the state tree from `setup`"))
thaw!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError(
"`thaw!` must not be applied to a model, only to the state tree from `setup`"))

###
### adjust
###

"""
Optimisers.adjust(tree, η) -> tree
Expand Down Expand Up @@ -47,8 +103,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree)
adjust(::Nothing, ::Real) = nothing
adjust(::Nothing; kw...) = nothing

adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state)
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state)
adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen)
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen)


"""
Expand Down
12 changes: 8 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ abstract type AbstractRule end
### setup
###

mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing...
rule::R
state::S
frozen::Bool # ... and to allow freeze! to act on this.
end
Leaf(rule, state; frozen::Bool = false) = Leaf(rule, state, frozen)

@functor Leaf

Expand Down Expand Up @@ -42,11 +44,12 @@ function _setup(rule, x; cache)
end
end

function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
function Base.show(io::IO, ℓ::Leaf; colour = ℓ.frozen ? :cyan : :green)
ioc = IOContext(io, :compact => true)
print(ioc, "Leaf(", ℓ.rule, ", ")
str = sprint(show, ℓ.rule; context = ioc) # produces Adam{Float32}(0.001, ... not 0.001f0
printstyled(io, "Leaf(", str, ", "; color = colour)
show(ioc, ℓ.state)
print(ioc, ")")
printstyled(io, ℓ.frozen ? ", frozen = true)" : ")"; color = colour)
end

###
Expand Down Expand Up @@ -83,6 +86,7 @@ function _update!(tree, x; grads, params)
end
function _update!(ℓ::Leaf, x; grads, params)
haskey(params, (ℓ,x)) && return params[(ℓ,x)]
ℓ.frozen && return x
params[(ℓ,x)] = if haskey(grads, ℓ)
ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...)
subtract!(x, x̄′)
Expand Down
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,24 @@ end
@test sc2.γ.state[2][1] ≈ [0.1, 0.2, 0.2]
end

@testset "freeze/thaw" begin
m = (x=[1.0, 2.0], y=([3.0, 4.0], sin));
st = Optimisers.setup(Descent(0.1), m);
Optimisers.freeze!(st.y)
st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing)));
@test m.x ≈ [0.9, 1.0]
@test m.y[1] == [3, 4]

st = Optimisers.adjust(st, 0.2)
Optimisers.thaw!(st)
st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing)));
@test m.y[1] ≈ [-17.0, -196.0]
@test m.x ≈ [0.7, -1.0]

@test_throws ArgumentError Optimisers.freeze!(m)
@test_throws ArgumentError Optimisers.thaw!(m)
end

@testset "forgotten gradient" begin
x = [1.0, 2.0]
sx = Optimisers.setup(Descent(), x)
Expand Down