diff --git a/docs/src/api.md b/docs/src/api.md index 1017af94..ad00a2aa 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -35,6 +35,8 @@ Optimisers.setup Optimisers.update Optimisers.update! Optimisers.adjust(::Any, ::Real) +Optimisers.freeze! +Optimisers.thaw! ``` Calling `Functors.@functor` on your model's layer types by default causes diff --git a/docs/src/index.md b/docs/src/index.md index 65b441bb..659ae837 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -138,6 +138,33 @@ Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chid st = Optimisers.setup(DecayDescent(0.1), Layer(3)) ``` +## Frozen Parameters + +To temporarily prevent training from affecting some parameters, +use [freeze!](@ref Optimisers.freeze!) and `thaw!`. +They work by mutating all `Leaf`s of the state tree, or part of it. + +```julia +using Flux, Optimisers + +x = randn(Float32, 28, 28, 1, 1); +net = @autosize (size(x)...,) Chain( + Conv((3, 3), 1 => 3, stride=2, bias=false), Flux.flatten, Dense(_ => 2, relu), +) +opt = Optimisers.setup(Optimisers.Momentum(), net); + +net.layers[3] isa Dense # now freeze this layer's parameters: +Optimisers.freeze!(opt.layers[3]) +opt.layers[3].bias # confirm: Leaf(Momentum(...), [0.0, 0.0], frozen = true) + +Optimisers.update!(opt, net, gradient(m -> sum(m(x)), net)...); + +net.layers[3].bias # stil zero, and its momentum is too: + +Optimisers.thaw!(opt) +opt.layers[3].bias # Leaf(Momentum(...), [0.0, 0.0]) +``` + ## Tied Parameters If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this. @@ -159,7 +186,7 @@ st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s. It will not at present work for `reshape`d arrays, nor for immutable arrays such as those from StaticArrays.jl. - + ## Obtaining a flat parameter vector diff --git a/src/adjust.jl b/src/adjust.jl index 78b3d452..a8123676 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -1,3 +1,59 @@ +### +### freezing +### + +""" + Optimisers.freeze!(tree) + +Temporarily alters the state `tree = setup(rule, model)` so that parameters +will not be updated. Un-done by [`thaw!`](@ref Optimisers.thaw!). + +Can be applied to the state corresponding to only part of a model, +for instance with `model::Chain`, to freeze `model.layers[1]` you +should call `freeze!(tree.layers[1])`. + +# Example +```jldoctest +julia> m = (x = ([1.0], 2.0), y = [3.0]); + +julia> s = Optimisers.setup(Momentum(), m); + +julia> Optimisers.freeze!(s.x) + +julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient + +julia> m +(x = ([1.0], 2.0), y = [-0.14159258336972558]) + +julia> s +(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159])) + +julia> Optimisers.thaw!(s) + +julia> s.x +(Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), ()) +``` +""" +freeze!(tree) = foreach(freeze!, tree) +freeze!(ℓ::Leaf) = (ℓ.frozen = true; nothing) + +""" + Optimisers.thaw!(tree) + +The reverse of [`freeze!`](@ref Optimisers.freeze!). Applies to all parameters, +mutating every `Leaf(rule, state, frozen = true)` to `Leaf(rule, state, frozen = false)`. +""" +thaw!(tree) = foreach(thaw!, tree) +thaw!(ℓ::Leaf) = (ℓ.frozen = false; nothing) + +freeze!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError( + "`freeze!` must not be applied to a model, only to the state tree from `setup`")) +thaw!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError( + "`thaw!` must not be applied to a model, only to the state tree from `setup`")) + +### +### adjust +### """ Optimisers.adjust(tree, η) -> tree @@ -47,8 +103,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree) adjust(::Nothing, ::Real) = nothing adjust(::Nothing; kw...) = nothing -adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state) -adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state) +adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen) +adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen) """ diff --git a/src/interface.jl b/src/interface.jl index 79d03396..401c9b1c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -10,10 +10,12 @@ abstract type AbstractRule end ### setup ### -mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing +mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing... rule::R state::S + frozen::Bool # ... and to allow freeze! to act on this. end +Leaf(rule, state; frozen::Bool = false) = Leaf(rule, state, frozen) @functor Leaf @@ -42,11 +44,12 @@ function _setup(rule, x; cache) end end -function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! +function Base.show(io::IO, ℓ::Leaf; colour = ℓ.frozen ? :cyan : :green) ioc = IOContext(io, :compact => true) - print(ioc, "Leaf(", ℓ.rule, ", ") + str = sprint(show, ℓ.rule; context = ioc) # produces Adam{Float32}(0.001, ... not 0.001f0 + printstyled(io, "Leaf(", str, ", "; color = colour) show(ioc, ℓ.state) - print(ioc, ")") + printstyled(io, ℓ.frozen ? ", frozen = true)" : ")"; color = colour) end ### @@ -83,6 +86,7 @@ function _update!(tree, x; grads, params) end function _update!(ℓ::Leaf, x; grads, params) haskey(params, (ℓ,x)) && return params[(ℓ,x)] + ℓ.frozen && return x params[(ℓ,x)] = if haskey(grads, ℓ) ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...) subtract!(x, x̄′) diff --git a/test/runtests.jl b/test/runtests.jl index 51e76053..a8cef6f6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -221,6 +221,24 @@ end @test sc2.γ.state[2][1] ≈ [0.1, 0.2, 0.2] end + @testset "freeze/thaw" begin + m = (x=[1.0, 2.0], y=([3.0, 4.0], sin)); + st = Optimisers.setup(Descent(0.1), m); + Optimisers.freeze!(st.y) + st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing))); + @test m.x ≈ [0.9, 1.0] + @test m.y[1] == [3, 4] + + st = Optimisers.adjust(st, 0.2) + Optimisers.thaw!(st) + st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing))); + @test m.y[1] ≈ [-17.0, -196.0] + @test m.x ≈ [0.7, -1.0] + + @test_throws ArgumentError Optimisers.freeze!(m) + @test_throws ArgumentError Optimisers.thaw!(m) + end + @testset "forgotten gradient" begin x = [1.0, 2.0] sx = Optimisers.setup(Descent(), x)