diff --git a/NEWS.md b/NEWS.md index 87333f8717..a4b0856327 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,9 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.14.17 +* Add [support for Enzyme](https://github.com/FluxML/Flux.jl/pull/2446) with `Flux.train!`. + ## v0.14.13 * New macro `Flux.@layer` which should be used in place of `@functor`. This also adds `show` methods for pretty printing. diff --git a/Project.toml b/Project.toml index 3abbf2988e..48158f4175 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.16" +version = "0.14.17" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/src/deprecations.jl b/src/deprecations.jl index 5acdec5455..24372a570e 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -107,7 +107,10 @@ train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error But better to use the new explicit style, in which `m` itself is the 2nd argument. """) -train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = train!(loss, model, data, _old_to_new(opt); cb) +train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = + train!(loss, model, data, _old_to_new(opt); cb) +train!(loss, model::Enzyme.Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) = + train!(loss, model, data, _old_to_new(opt); cb) # Next, to use the new `setup` with the still-exported old-style `Adam` etc: import .Train: setup diff --git a/src/functor.jl b/src/functor.jl index eeaffab1c3..e48246ebde 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -3,6 +3,7 @@ using LinearAlgebra: Cholesky using Zygote: IdSet import Functors: Functors, @functor, functor, fmap, isleaf using SparseArrays: AbstractSparseArray +using Enzyme """ testmode!(model, [mode]) -> model diff --git a/src/losses/utils.jl b/src/losses/utils.jl index e42bdfbe2e..c380564908 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -1,3 +1,5 @@ +import Enzyme + """ xlogx(x) @@ -36,3 +38,4 @@ end _check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1 ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any) +Enzyme.EnzymeRules.inactive(::typeof(_check_sizes), args...) = true diff --git a/src/train.jl b/src/train.jl index e72eedebf3..6094e13ac6 100644 --- a/src/train.jl +++ b/src/train.jl @@ -5,6 +5,7 @@ using Optimisers: Optimisers using Functors: fmap, fmapstructure using ..Flux: Flux # used only in docstring import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions +import Enzyme export setup, train! @@ -52,6 +53,12 @@ function setup(rule::Optimisers.AbstractRule, model) state end +_make_zero_internal!(x::AbstractArray) = fill!(x, 0) +_make_zero_internal!(x) = x +_make_zero!(model) = fmap(_make_zero_internal!, model) + +_applyloss(loss, model, d...) = loss(model, d...) + """ train!(loss, model, data, opt_state) @@ -60,6 +67,9 @@ according to a particular optimisation rule encoded in `opt_state`. Iterates through `data` once, evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`. +If `model` is an Enzyme.Duplicated, gradients will be computed with Enzyme, +otherwise they will be computed with Zygote. + For example, with these definitions... ``` data = [(x1, y1), (x2, y2), (x3, y3)] @@ -100,11 +110,33 @@ function train!(loss, model, data, opt; cb = nothing) For more control use a loop with `gradient` and `update!`.""") @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) + l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) + if !isfinite(l) throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) end + opt, model = Optimisers.update!(opt, model, gs[1]) + + @logprogress Base.haslength(data) ? i/length(data) : nothing + end +end +function train!(loss, model::Enzyme.Duplicated, data, opt; cb = nothing) + isnothing(cb) || error("""train! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") + @withprogress for (i,d) in enumerate(data) + d_splat = d isa Tuple ? d : (d,) + + _make_zero!(model.dval) + _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...) + + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end + opt, model2 = Optimisers.update!(opt, model.val, model.dval) + model = Enzyme.Duplicated(model2, model.dval) + @logprogress Base.haslength(data) ? i/length(data) : nothing end end @@ -113,6 +145,9 @@ end function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) end +function train!(loss, model::Enzyme.Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing) + train!(loss, model, data, _rule_to_state(model, rule); cb) +end function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) diff --git a/test/train.jl b/test/train.jl index 1d938649d0..3ed0e658ea 100644 --- a/test/train.jl +++ b/test/train.jl @@ -4,8 +4,14 @@ import Optimisers using Test using Random +using Enzyme -@testset "Explicit Flux.train! with Zygote" begin +function train_enzyme!(fn, model, args...; kwargs...) + Flux.train!(fn, Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...) +end + +for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) +@testset "Explicit Flux.train! with $name" begin Random.seed!(84) w = randn(10, 10) w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. @@ -18,31 +24,40 @@ using Random @test loss(model, rand(10, 10)) > 1 opt = Flux.setup(rule, model) - Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end # Test direct use of Optimisers.jl rule, only really OK for `Descent`: + # Enzyme doesn't work with un-initialized atm, presumably due to trainmode? + if name != "Enzyme" @testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()] loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) model = (weight=copy(w2), bias=zeros(10), ignore=nothing) @test loss(model, rand(10, 10)) > 1 - Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end + end +end end -@testset "Explicit Flux.train! features" begin +for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) +@testset "Explicit Flux.train! features with $name" begin @testset "Stop on NaN" begin m1 = Dense(1 => 1) m1.weight .= 0 - CNT = 0 - @test_throws DomainError Flux.train!(m1, tuple.(1:100), Descent(0.1)) do m, i - CNT += 1 + CNT = Ref(0) + @test_throws DomainError trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i + CNT[] += 1 (i == 51 ? NaN32 : 1f0) * sum(m([1.0])) end - @test CNT == 51 # stopped early - @test m1.weight[1] ≈ -5 # did not corrupt weights + @test CNT[] == 51 # stopped early + if name != "Enzyme" + @test m1.weight[1] ≈ -5 # did not corrupt weights + else + @test m1.weight[1] ≈ 0.0 # did not corrupt weights + end end @testset "non-tuple data" begin @@ -51,16 +66,17 @@ end loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) model = (weight=copy(w2), bias=zeros(10)) opt = Flux.setup(AdamW(), model) - Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt) + trainfn!(loss, model, (rand(10) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end @testset "callbacks give helpful error" begin m1 = Dense(1 => 1) cb = () -> println("this should not be printed") - @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) + @test_throws ErrorException trainfn!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) end end +end @testset "Explicit Flux.update! features" begin m = Chain(Dense(2=>3, tanh), Dense(3=>1), only) @@ -68,7 +84,7 @@ end y1 = m(x) # before # Implicit gradient - gold = gradient(() -> m(x), Flux.params(m)) + gold = Zygote.gradient(() -> m(x), Flux.params(m)) @test gold isa Flux.Zygote.Grads @test_throws ErrorException Flux.update!(Flux.Adam(), m, gold) # friendly Flux.update!(Flux.Adam(), Flux.params(m), gold) @@ -76,7 +92,7 @@ end @test y2 < y1 # Explicit gradient - gs = gradient(marg -> marg(x), m) + gs = Zygote.gradient(marg -> marg(x), m) @test gs isa Tuple @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly @@ -98,7 +114,8 @@ end @test y5 < y4 end -@testset "L2 regularisation" begin +for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) +@testset "L2 regularisation with $name" begin # New docs claim an exact equivalent. It's a bit long to put the example in there, # but perhaps the tests should contain it. @@ -108,7 +125,7 @@ end # Take 1: explicitly add a penalty in the loss function opt = Flux.setup(Adam(0.1), model) - Flux.train!(model, data, opt) do m, x, y + trainfn!(model, data, opt) do m, x, y err = Flux.mse(m(x), y) l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 err + 0.33 * l2 @@ -116,28 +133,32 @@ end diff1 = model.weight .- init_weight # Take 2: the same, but with Flux.params. Was broken for a bit, no tests! - model.weight .= init_weight - model.bias .= 0 - pen2(x::AbstractArray) = sum(abs2, x)/2 - opt = Flux.setup(Adam(0.1), model) - Flux.train!(model, data, opt) do m, x, y - err = Flux.mse(m(x), y) - l2 = sum(pen2, Flux.params(m)) - err + 0.33 * l2 + # skipping this test for Enzyme cause implicit params is unsupported + if name == "Zygote" + model.weight .= init_weight + model.bias .= 0 + pen2(x::AbstractArray) = sum(abs2, x)/2 + opt = Flux.setup(Adam(0.1), model) + trainfn!(model, data, opt) do m, x, y + err = Flux.mse(m(x), y) + l2 = sum(pen2, Flux.params(m)) + err + 0.33 * l2 + end + diff2 = model.weight .- init_weight + @test diff1 ≈ diff2 end - diff2 = model.weight .- init_weight - @test diff1 ≈ diff2 # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. model.weight .= init_weight model.bias .= 0 decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model); - Flux.train!(model, data, decay_opt) do m, x, y + trainfn!(model, data, decay_opt) do m, x, y Flux.mse(m(x), y) end diff3 = model.weight .- init_weight @test diff1 ≈ diff3 end +end @testset "Flux.setup bugs" begin # https://github.com/FluxML/Flux.jl/issues/2144