From 16a896f04a6fd84eb9643b63d195f4f8d6535acc Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sat, 11 May 2024 10:00:09 -0700 Subject: [PATCH 01/19] Enable remaining enzyme test --- Project.toml | 2 +- test/ext_enzyme/enzyme.jl | 34 +--------------------------------- 2 files changed, 2 insertions(+), 34 deletions(-) diff --git a/Project.toml b/Project.toml index bcc2ca8a05..2283464f5e 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ Functors = "0.4" MLUtils = "0.4" MacroTools = "0.5" Metal = "0.5, 1" -NNlib = "0.9.14" +NNlib = "0.9.15" OneHotArrays = "0.2.4" Optimisers = "0.3.3" Preferences = "1" diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index bef48c4da0..8241a3f8dd 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -6,10 +6,6 @@ using Functors using FiniteDifferences using CUDA -Enzyme.API.typeWarning!(false) # suppresses a warning with Bilinear https://github.com/EnzymeAD/Enzyme.jl/issues/1341 -Enzyme.API.runtimeActivity!(true) # for Enzyme debugging -# Enzyme.Compiler.bitcode_replacement!(false) - _make_zero(x::Union{Number,AbstractArray}) = zero(x) _make_zero(x) = x make_zero(model) = fmap(_make_zero, model) @@ -121,6 +117,7 @@ end (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), ] for (model, x, name) in models_xs @@ -155,32 +152,3 @@ end end end end - -@testset "Broken Models" begin - function loss(model, x) - Flux.reset!(model) - sum(model(x)) - end - - device = Flux.get_device() - - models_xs = [ - # Pending https://github.com/FluxML/NNlib.jl/issues/565 - (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), - ] - - for (model, x, name) in models_xs - @testset "check grad $name" begin - println("testing $name") - broken = false - try - test_enzyme_grad(loss, model, x) - catch e - println(e) - broken = true - end - @test broken - end - end -end - From 546ff165f620542ad014257012c6b2bc211f86e4 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 13 May 2024 18:39:27 -0700 Subject: [PATCH 02/19] Add Enzyme train function --- Project.toml | 1 + src/train.jl | 36 +++++++++++++++++++++++++++++++++++- test/train.jl | 28 +++++++++++++++++----------- 3 files changed, 53 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 2283464f5e..201d52560e 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.14.15" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/src/train.jl b/src/train.jl index e72eedebf3..76e708a536 100644 --- a/src/train.jl +++ b/src/train.jl @@ -5,8 +5,9 @@ using Optimisers: Optimisers using Functors: fmap, fmapstructure using ..Flux: Flux # used only in docstring import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions +import Enzyme -export setup, train! +export setup, train!, train_enzyme! using ProgressLogging: @progress, @withprogress, @logprogress using Zygote: Zygote, Params @@ -109,11 +110,44 @@ function train!(loss, model, data, opt; cb = nothing) end end +_make_zero!(x::AbstractArray) = fill!(x, 0) +_make_zero!(x) = x +make_zero!(model) = fmap(_make_zero!, model) + +applyloss(loss, model, d...) = loss(model, d...) + +""" + train_enzyme!(loss, model, data, opt::AbstractOptimiser; [cb]) + +Like [`train!](@ref), but gradient computed in place using [Enzyme](github.com/EnzymeAD/Enzyme.jl) +""" +function train_enzyme!(loss, model, data, opt; cb = nothing) + isnothing(cb) || error("""train_enzyme! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") + dmodel = Enzyme.make_zero(model) + @withprogress for (i,d) in enumerate(data) + d_splat = d isa Tuple ? d : (d,) + make_zero!(dmodel) + _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, applyloss, Enzyme.Active, Enzyme.Const(loss), Enzyme.Duplicated(model, dmodel), map(Enzyme.Const, d_splat)...) + + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end + opt, model = Optimisers.update!(opt, model, dmodel) + @logprogress Base.haslength(data) ? i/length(data) : nothing + end +end + # This method let you use Optimisers.Descent() without setup, when there is no state function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) end +# This method let you use Optimisers.Descent() without setup, when there is no state +function train_enzyme!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) + train_enzyme!(loss, model, data, _rule_to_state(model, rule); cb) +end + function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) @gensym warn_id diff --git a/test/train.jl b/test/train.jl index 1d938649d0..dec47cfdea 100644 --- a/test/train.jl +++ b/test/train.jl @@ -5,7 +5,8 @@ import Optimisers using Test using Random -@testset "Explicit Flux.train! with Zygote" begin +for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme")) +@testset "Explicit Flux.train! with $name" begin Random.seed!(84) w = randn(10, 10) w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. @@ -18,7 +19,7 @@ using Random @test loss(model, rand(10, 10)) > 1 opt = Flux.setup(rule, model) - Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end @@ -27,17 +28,19 @@ using Random loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) model = (weight=copy(w2), bias=zeros(10), ignore=nothing) @test loss(model, rand(10, 10)) > 1 - Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end end +end -@testset "Explicit Flux.train! features" begin +for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme")) +@testset "Explicit Flux.train! features with $name" begin @testset "Stop on NaN" begin m1 = Dense(1 => 1) m1.weight .= 0 CNT = 0 - @test_throws DomainError Flux.train!(m1, tuple.(1:100), Descent(0.1)) do m, i + @test_throws DomainError Flux.trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i CNT += 1 (i == 51 ? NaN32 : 1f0) * sum(m([1.0])) end @@ -51,16 +54,17 @@ end loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) model = (weight=copy(w2), bias=zeros(10)) opt = Flux.setup(AdamW(), model) - Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt) + trainfn!(loss, model, (rand(10) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end @testset "callbacks give helpful error" begin m1 = Dense(1 => 1) cb = () -> println("this should not be printed") - @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) + @test_throws ErrorException trainfn!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) end end +end @testset "Explicit Flux.update! features" begin m = Chain(Dense(2=>3, tanh), Dense(3=>1), only) @@ -98,7 +102,8 @@ end @test y5 < y4 end -@testset "L2 regularisation" begin +for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme")) +@testset "L2 regularisation with $name" begin # New docs claim an exact equivalent. It's a bit long to put the example in there, # but perhaps the tests should contain it. @@ -108,7 +113,7 @@ end # Take 1: explicitly add a penalty in the loss function opt = Flux.setup(Adam(0.1), model) - Flux.train!(model, data, opt) do m, x, y + trainfn!(model, data, opt) do m, x, y err = Flux.mse(m(x), y) l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 err + 0.33 * l2 @@ -120,7 +125,7 @@ end model.bias .= 0 pen2(x::AbstractArray) = sum(abs2, x)/2 opt = Flux.setup(Adam(0.1), model) - Flux.train!(model, data, opt) do m, x, y + trainfn!(model, data, opt) do m, x, y err = Flux.mse(m(x), y) l2 = sum(pen2, Flux.params(m)) err + 0.33 * l2 @@ -132,12 +137,13 @@ end model.weight .= init_weight model.bias .= 0 decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model); - Flux.train!(model, data, decay_opt) do m, x, y + trainfn!(model, data, decay_opt) do m, x, y Flux.mse(m(x), y) end diff3 = model.weight .- init_weight @test diff1 ≈ diff3 end +end @testset "Flux.setup bugs" begin # https://github.com/FluxML/Flux.jl/issues/2144 From 1d2052cf1afa92535226c633a161eef78800972d Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 13 May 2024 21:50:07 -0700 Subject: [PATCH 03/19] Mark check_sizes as inactive --- src/losses/utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/losses/utils.jl b/src/losses/utils.jl index e42bdfbe2e..312a6e348b 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -36,3 +36,5 @@ end _check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1 ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any) +import Enzyme +Enzyme.EnzymeRules.inactive(::typeof(_check_sizes), args...) = true From 77c35044a1709bfd4f5c77210632c82c424c4a58 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 16 May 2024 12:02:24 -0400 Subject: [PATCH 04/19] Adapt to suggestions --- src/train.jl | 27 ++++++++++----------------- test/train.jl | 11 ++++++++--- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/train.jl b/src/train.jl index 76e708a536..ffedc43a95 100644 --- a/src/train.jl +++ b/src/train.jl @@ -110,30 +110,28 @@ function train!(loss, model, data, opt; cb = nothing) end end -_make_zero!(x::AbstractArray) = fill!(x, 0) -_make_zero!(x) = x -make_zero!(model) = fmap(_make_zero!, model) +_make_zero_internal!(x::AbstractArray) = fill!(x, 0) +_make_zero_internal!(x) = x +_make_zero!(model) = fmap(_make_zero_internal!, model) -applyloss(loss, model, d...) = loss(model, d...) +_applyloss(loss, model, d...) = loss(model, d...) """ - train_enzyme!(loss, model, data, opt::AbstractOptimiser; [cb]) + train_enzyme!(loss, model_and_shadow, data, opt_state) Like [`train!](@ref), but gradient computed in place using [Enzyme](github.com/EnzymeAD/Enzyme.jl) """ -function train_enzyme!(loss, model, data, opt; cb = nothing) - isnothing(cb) || error("""train_enzyme! does not support callback functions. - For more control use a loop with `gradient` and `update!`.""") - dmodel = Enzyme.make_zero(model) +function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state) @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) - make_zero!(dmodel) - _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, applyloss, Enzyme.Active, Enzyme.Const(loss), Enzyme.Duplicated(model, dmodel), map(Enzyme.Const, d_splat)...) + _make_zero!(model_and_shadow.dval) + _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, applyloss, Enzyme.Active, Enzyme.Const(loss), model_and_shadow, map(Enzyme.Const, d_splat)...) if !isfinite(l) throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) end - opt, model = Optimisers.update!(opt, model, dmodel) + opt_state, model = Optimisers.update!(opt_state, model_and_shadow.val, model_and_shadow.dval) + model_and_shadow = Duplicated(model, model_and_shadow.dval) @logprogress Base.haslength(data) ? i/length(data) : nothing end end @@ -143,11 +141,6 @@ function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) end -# This method let you use Optimisers.Descent() without setup, when there is no state -function train_enzyme!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) - train_enzyme!(loss, model, data, _rule_to_state(model, rule); cb) -end - function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) @gensym warn_id diff --git a/test/train.jl b/test/train.jl index dec47cfdea..26c7381124 100644 --- a/test/train.jl +++ b/test/train.jl @@ -4,8 +4,13 @@ import Optimisers using Test using Random +using Enzyme -for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme")) +function train_enzyme!(fn, model, args...) + Flux.train!(fn, Duplicated(model, Enzyme.make_zero(model)), args...) +end + +for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) @testset "Explicit Flux.train! with $name" begin Random.seed!(84) w = randn(10, 10) @@ -34,7 +39,7 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme") end end -for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme")) +for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) @testset "Explicit Flux.train! features with $name" begin @testset "Stop on NaN" begin m1 = Dense(1 => 1) @@ -102,7 +107,7 @@ end @test y5 < y4 end -for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme")) +for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) @testset "L2 regularisation with $name" begin # New docs claim an exact equivalent. It's a bit long to put the example in there, # but perhaps the tests should contain it. From 97a490f3de6a4a6b164002b76d8ff0410d4350e9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 16 May 2024 10:21:07 -0700 Subject: [PATCH 05/19] Update train.jl --- src/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.jl b/src/train.jl index ffedc43a95..5b9998fb66 100644 --- a/src/train.jl +++ b/src/train.jl @@ -125,7 +125,7 @@ function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state) @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) _make_zero!(model_and_shadow.dval) - _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, applyloss, Enzyme.Active, Enzyme.Const(loss), model_and_shadow, map(Enzyme.Const, d_splat)...) + _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model_and_shadow, map(Enzyme.Const, d_splat)...) if !isfinite(l) throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) From d8f1ad8347a67a92a1e2a092f5dde6105e245ab2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 16 May 2024 14:02:23 -0700 Subject: [PATCH 06/19] Update train.jl --- src/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.jl b/src/train.jl index 5b9998fb66..c070cd6bf7 100644 --- a/src/train.jl +++ b/src/train.jl @@ -131,7 +131,7 @@ function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state) throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) end opt_state, model = Optimisers.update!(opt_state, model_and_shadow.val, model_and_shadow.dval) - model_and_shadow = Duplicated(model, model_and_shadow.dval) + model_and_shadow = Enzyme.Duplicated(model, model_and_shadow.dval) @logprogress Base.haslength(data) ? i/length(data) : nothing end end From e17bef83c8c090e56fde38abe013fe5e0396c9e8 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 16 May 2024 23:36:28 -0700 Subject: [PATCH 07/19] tmp --- test/train.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/train.jl b/test/train.jl index 26c7381124..38e05940aa 100644 --- a/test/train.jl +++ b/test/train.jl @@ -40,6 +40,7 @@ end end for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) +@eval begin @testset "Explicit Flux.train! features with $name" begin @testset "Stop on NaN" begin m1 = Dense(1 => 1) @@ -70,6 +71,7 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) end end end +end @testset "Explicit Flux.update! features" begin m = Chain(Dense(2=>3, tanh), Dense(3=>1), only) From eb62e15c1c1ce3db6b31397732db3aca0fac1935 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 28 May 2024 07:17:58 -0400 Subject: [PATCH 08/19] try fix --- src/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.jl b/src/train.jl index c070cd6bf7..ae4ba6fd1d 100644 --- a/src/train.jl +++ b/src/train.jl @@ -121,7 +121,7 @@ _applyloss(loss, model, d...) = loss(model, d...) Like [`train!](@ref), but gradient computed in place using [Enzyme](github.com/EnzymeAD/Enzyme.jl) """ -function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state) +function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state::T) where T @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) _make_zero!(model_and_shadow.dval) From 13c69890beb99ce420c9e0a0393aad058e594181 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 1 Jun 2024 00:40:51 +0200 Subject: [PATCH 09/19] Update train.jl --- src/train.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/train.jl b/src/train.jl index ae4ba6fd1d..e1fab2a6ce 100644 --- a/src/train.jl +++ b/src/train.jl @@ -136,6 +136,25 @@ function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state::T) w end end +# Required per method ambiguity with +# train!(loss, model, data, opt::Flux.Optimise.AbstractOptimiser; cb) +# @ Flux ~/work/Flux.jl/Flux.jl/src/deprecations.jl:110 +function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state::Flux.Optimise.AbstractOptimiser) + @withprogress for (i,d) in enumerate(data) + d_splat = d isa Tuple ? d : (d,) + _make_zero!(model_and_shadow.dval) + _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model_and_shadow, map(Enzyme.Const, d_splat)...) + + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end + opt_state, model = Optimisers.update!(opt_state, model_and_shadow.val, model_and_shadow.dval) + model_and_shadow = Enzyme.Duplicated(model, model_and_shadow.dval) + @logprogress Base.haslength(data) ? i/length(data) : nothing + end +end + + # This method let you use Optimisers.Descent() without setup, when there is no state function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) From c07bea977342f0a4de077c2c6d427f65ca0ac319 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 3 Jun 2024 03:19:32 +0200 Subject: [PATCH 10/19] Update train.jl --- src/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.jl b/src/train.jl index e1fab2a6ce..03dbe04826 100644 --- a/src/train.jl +++ b/src/train.jl @@ -121,7 +121,7 @@ _applyloss(loss, model, d...) = loss(model, d...) Like [`train!](@ref), but gradient computed in place using [Enzyme](github.com/EnzymeAD/Enzyme.jl) """ -function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state::T) where T +function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state::T) where T<:Optimisers.AbstractRule @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) _make_zero!(model_and_shadow.dval) From c01ab6fec07a1a201e3bc884b132ac04d0fabfbe Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 3 Jun 2024 03:20:58 +0200 Subject: [PATCH 11/19] Update train.jl --- test/train.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/train.jl b/test/train.jl index 38e05940aa..0222a7a1ac 100644 --- a/test/train.jl +++ b/test/train.jl @@ -40,13 +40,12 @@ end end for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) -@eval begin @testset "Explicit Flux.train! features with $name" begin @testset "Stop on NaN" begin m1 = Dense(1 => 1) m1.weight .= 0 CNT = 0 - @test_throws DomainError Flux.trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i + @test_throws DomainError trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i CNT += 1 (i == 51 ? NaN32 : 1f0) * sum(m([1.0])) end @@ -71,7 +70,6 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) end end end -end @testset "Explicit Flux.update! features" begin m = Chain(Dense(2=>3, tanh), Dense(3=>1), only) From 0cc61907c27b44c2f02d8ff934ddf313e974641d Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sat, 15 Jun 2024 15:33:13 -0400 Subject: [PATCH 12/19] Fixup --- src/losses/utils.jl | 3 +- src/train.jl | 71 ++++++++++++++++----------------------------- 2 files changed, 27 insertions(+), 47 deletions(-) diff --git a/src/losses/utils.jl b/src/losses/utils.jl index 312a6e348b..c380564908 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -1,3 +1,5 @@ +import Enzyme + """ xlogx(x) @@ -36,5 +38,4 @@ end _check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1 ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any) -import Enzyme Enzyme.EnzymeRules.inactive(::typeof(_check_sizes), args...) = true diff --git a/src/train.jl b/src/train.jl index 03dbe04826..a367bba554 100644 --- a/src/train.jl +++ b/src/train.jl @@ -7,7 +7,7 @@ using ..Flux: Flux # used only in docstring import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions import Enzyme -export setup, train!, train_enzyme! +export setup, train! using ProgressLogging: @progress, @withprogress, @logprogress using Zygote: Zygote, Params @@ -53,6 +53,12 @@ function setup(rule::Optimisers.AbstractRule, model) state end +_make_zero_internal!(x::AbstractArray) = fill!(x, 0) +_make_zero_internal!(x) = x +_make_zero!(model) = fmap(_make_zero_internal!, model) + +_applyloss(loss, model, d...) = loss(model, d...) + """ train!(loss, model, data, opt_state) @@ -61,6 +67,9 @@ according to a particular optimisation rule encoded in `opt_state`. Iterates through `data` once, evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`. +If `model` is an Enzyme.Duplicated, gradients will be computed with Enzyme, +otherwise they will be computed with Zygote. + For example, with these definitions... ``` data = [(x1, y1), (x2, y2), (x3, y3)] @@ -101,60 +110,30 @@ function train!(loss, model, data, opt; cb = nothing) For more control use a loop with `gradient` and `update!`.""") @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) - l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) - end - opt, model = Optimisers.update!(opt, model, gs[1]) - @logprogress Base.haslength(data) ? i/length(data) : nothing - end -end + + if model isa Enzyme.Duplicated + _make_zero!(model.dval) + _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...) -_make_zero_internal!(x::AbstractArray) = fill!(x, 0) -_make_zero_internal!(x) = x -_make_zero!(model) = fmap(_make_zero_internal!, model) + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end + opt, model2 = Optimisers.update!(opt, model.val, gs[1]) + model = Enzyme.Duplicated(model2, model.dval) + else + Zygote.withgradient(m -> loss(m, d_splat...), model) -_applyloss(loss, model, d...) = loss(model, d...) + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end -""" - train_enzyme!(loss, model_and_shadow, data, opt_state) - -Like [`train!](@ref), but gradient computed in place using [Enzyme](github.com/EnzymeAD/Enzyme.jl) -""" -function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state::T) where T<:Optimisers.AbstractRule - @withprogress for (i,d) in enumerate(data) - d_splat = d isa Tuple ? d : (d,) - _make_zero!(model_and_shadow.dval) - _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model_and_shadow, map(Enzyme.Const, d_splat)...) - - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) - end - opt_state, model = Optimisers.update!(opt_state, model_and_shadow.val, model_and_shadow.dval) - model_and_shadow = Enzyme.Duplicated(model, model_and_shadow.dval) - @logprogress Base.haslength(data) ? i/length(data) : nothing - end -end + opt, model = Optimisers.update!(opt, model, gs[1]) -# Required per method ambiguity with -# train!(loss, model, data, opt::Flux.Optimise.AbstractOptimiser; cb) -# @ Flux ~/work/Flux.jl/Flux.jl/src/deprecations.jl:110 -function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state::Flux.Optimise.AbstractOptimiser) - @withprogress for (i,d) in enumerate(data) - d_splat = d isa Tuple ? d : (d,) - _make_zero!(model_and_shadow.dval) - _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model_and_shadow, map(Enzyme.Const, d_splat)...) - - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) end - opt_state, model = Optimisers.update!(opt_state, model_and_shadow.val, model_and_shadow.dval) - model_and_shadow = Enzyme.Duplicated(model, model_and_shadow.dval) @logprogress Base.haslength(data) ? i/length(data) : nothing end end - # This method let you use Optimisers.Descent() without setup, when there is no state function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) From b7654bdfcb1e56ee11a6df5c1807bd65600e42aa Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sat, 15 Jun 2024 17:06:27 -0400 Subject: [PATCH 13/19] fix --- src/functor.jl | 40 ++++++++++++++++++++++++++++++++++++++++ src/train.jl | 4 ++-- test/train.jl | 3 +++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index eeaffab1c3..9c00cc9335 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -3,6 +3,7 @@ using LinearAlgebra: Cholesky using Zygote: IdSet import Functors: Functors, @functor, functor, fmap, isleaf using SparseArrays: AbstractSparseArray +using Enzyme """ testmode!(model, [mode]) -> model @@ -89,6 +90,45 @@ function params!(p::Params, x, seen = IdSet()) end end +function Enzyme.EnzymeRules.augmented_primal(config, func::Enzyme.Const{typeof(params!)}, ::Type{RT}, + p::Enzyme.Annotation, + x::Enzyme.Annotation, + seen::Enzyme.Annotation) where {RT} + + res = func.val(p.val, x.val, seen.val) + + primal = if EnzymeRules.needs_primal(config) + res + else + nothing + end + + sres = if EnzymeRules.width(config) == 1 + func.val(p.dval, x.dval, seen isa Const ? IdSet() : seen.dval) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + func.val(p.dval[i], x.dval[i], seen isa Const ? IdSet() : seen.dval[i]) + end + end + + shadow = if EnzymeRules.needs_shadow(config) + sres + else + nothing + end + + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function Enzyme.EnzymeRules.reverse(config, func::Enzyme.Const{typeof(params!)}, ::Type{RT}, cache, + p::Enzyme.Annotation, + x::Enzyme.Annotation, + seen::Enzyme.Annotation) where {RT} + + return (nothing, nothing, nothing) +end + """ params(model) params(layers...) diff --git a/src/train.jl b/src/train.jl index a367bba554..eac1f23f8b 100644 --- a/src/train.jl +++ b/src/train.jl @@ -118,10 +118,10 @@ function train!(loss, model, data, opt; cb = nothing) if !isfinite(l) throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) end - opt, model2 = Optimisers.update!(opt, model.val, gs[1]) + opt, model2 = Optimisers.update!(opt, model.val, model.dval) model = Enzyme.Duplicated(model2, model.dval) else - Zygote.withgradient(m -> loss(m, d_splat...), model) + l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) if !isfinite(l) throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) diff --git a/test/train.jl b/test/train.jl index 0222a7a1ac..4a0ece5ccc 100644 --- a/test/train.jl +++ b/test/train.jl @@ -29,6 +29,8 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) end # Test direct use of Optimisers.jl rule, only really OK for `Descent`: + # Enzyme doesn't work with un-initialized atm, presumably due to trainmode? + if name != "Enzyme" @testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()] loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) model = (weight=copy(w2), bias=zeros(10), ignore=nothing) @@ -36,6 +38,7 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end + end end end From f4e300c44172fcfbfc9d4a523ae8b568256b2af3 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sat, 15 Jun 2024 20:04:15 -0400 Subject: [PATCH 14/19] kwargs --- test/train.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test/train.jl b/test/train.jl index 4a0ece5ccc..6367de6eb0 100644 --- a/test/train.jl +++ b/test/train.jl @@ -6,8 +6,8 @@ using Test using Random using Enzyme -function train_enzyme!(fn, model, args...) - Flux.train!(fn, Duplicated(model, Enzyme.make_zero(model)), args...) +function train_enzyme!(fn, model, args...; kwargs...) + Flux.train!(fn, Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...) end for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) @@ -47,13 +47,17 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) @testset "Stop on NaN" begin m1 = Dense(1 => 1) m1.weight .= 0 - CNT = 0 + CNT = Ref(0) @test_throws DomainError trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i - CNT += 1 + CNT[] += 1 (i == 51 ? NaN32 : 1f0) * sum(m([1.0])) end - @test CNT == 51 # stopped early - @test m1.weight[1] ≈ -5 # did not corrupt weights + @test CNT[] == 51 # stopped early + if name != "Enzyme" + @test m1.weight[1] ≈ 0.0 # did not corrupt weights + else + @test m1.weight[1] ≈ -5 # did not corrupt weights + end end @testset "non-tuple data" begin From b329ef148e10305b4ecd344ad8deacf31e149425 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sat, 15 Jun 2024 23:11:23 -0400 Subject: [PATCH 15/19] fixup --- test/train.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/train.jl b/test/train.jl index 6367de6eb0..0ec40c2f22 100644 --- a/test/train.jl +++ b/test/train.jl @@ -54,9 +54,9 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) end @test CNT[] == 51 # stopped early if name != "Enzyme" - @test m1.weight[1] ≈ 0.0 # did not corrupt weights - else @test m1.weight[1] ≈ -5 # did not corrupt weights + else + @test m1.weight[1] ≈ 0.0 # did not corrupt weights end end From 7b8309de2dc645481ea25d97883a23a608524a89 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 21 Jun 2024 13:49:00 -0400 Subject: [PATCH 16/19] Rearrange dispatch for enzyme train --- src/functor.jl | 39 --------------------------------------- src/train.jl | 38 ++++++++++++++++++++++++-------------- 2 files changed, 24 insertions(+), 53 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 9c00cc9335..e48246ebde 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -90,45 +90,6 @@ function params!(p::Params, x, seen = IdSet()) end end -function Enzyme.EnzymeRules.augmented_primal(config, func::Enzyme.Const{typeof(params!)}, ::Type{RT}, - p::Enzyme.Annotation, - x::Enzyme.Annotation, - seen::Enzyme.Annotation) where {RT} - - res = func.val(p.val, x.val, seen.val) - - primal = if EnzymeRules.needs_primal(config) - res - else - nothing - end - - sres = if EnzymeRules.width(config) == 1 - func.val(p.dval, x.dval, seen isa Const ? IdSet() : seen.dval) - else - ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - func.val(p.dval[i], x.dval[i], seen isa Const ? IdSet() : seen.dval[i]) - end - end - - shadow = if EnzymeRules.needs_shadow(config) - sres - else - nothing - end - - return EnzymeRules.AugmentedReturn(primal, shadow, nothing) -end - -function Enzyme.EnzymeRules.reverse(config, func::Enzyme.Const{typeof(params!)}, ::Type{RT}, cache, - p::Enzyme.Annotation, - x::Enzyme.Annotation, - seen::Enzyme.Annotation) where {RT} - - return (nothing, nothing, nothing) -end - """ params(model) params(layers...) diff --git a/src/train.jl b/src/train.jl index eac1f23f8b..be38d2a092 100644 --- a/src/train.jl +++ b/src/train.jl @@ -111,25 +111,32 @@ function train!(loss, model, data, opt; cb = nothing) @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) - if model isa Enzyme.Duplicated - _make_zero!(model.dval) - _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...) + l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) - end - opt, model2 = Optimisers.update!(opt, model.val, model.dval) - model = Enzyme.Duplicated(model2, model.dval) - else - l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) - end + opt, model = Optimisers.update!(opt, model, gs[1]) - opt, model = Optimisers.update!(opt, model, gs[1]) + @logprogress Base.haslength(data) ? i/length(data) : nothing + end +end +function train!(loss, model::Enzyme.Duplicated, data, opt; cb = nothing) + isnothing(cb) || error("""train! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") + @withprogress for (i,d) in enumerate(data) + d_splat = d isa Tuple ? d : (d,) + + _make_zero!(model.dval) + _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...) + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) end + opt, model2 = Optimisers.update!(opt, model.val, model.dval) + model = Enzyme.Duplicated(model2, model.dval) + @logprogress Base.haslength(data) ? i/length(data) : nothing end end @@ -138,6 +145,9 @@ end function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) end +function train!(loss, model::Enzyme.Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing) + train!(loss, model, data, _rule_to_state(model, rule); cb) +end function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) From 419b99a92eeef8e9cbf8174fc03eeba3b4cd3b58 Mon Sep 17 00:00:00 2001 From: darsnack Date: Sun, 23 Jun 2024 17:10:29 -0400 Subject: [PATCH 17/19] Fix method ambiguity and skip some test for enzyme --- src/deprecations.jl | 5 ++++- src/train.jl | 4 ++-- test/train.jl | 27 +++++++++++++++------------ 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/deprecations.jl b/src/deprecations.jl index 5acdec5455..24372a570e 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -107,7 +107,10 @@ train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error But better to use the new explicit style, in which `m` itself is the 2nd argument. """) -train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = train!(loss, model, data, _old_to_new(opt); cb) +train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = + train!(loss, model, data, _old_to_new(opt); cb) +train!(loss, model::Enzyme.Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) = + train!(loss, model, data, _old_to_new(opt); cb) # Next, to use the new `setup` with the still-exported old-style `Adam` etc: import .Train: setup diff --git a/src/train.jl b/src/train.jl index be38d2a092..6094e13ac6 100644 --- a/src/train.jl +++ b/src/train.jl @@ -110,7 +110,7 @@ function train!(loss, model, data, opt; cb = nothing) For more control use a loop with `gradient` and `update!`.""") @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) - + l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) if !isfinite(l) @@ -127,7 +127,7 @@ function train!(loss, model::Enzyme.Duplicated, data, opt; cb = nothing) For more control use a loop with `gradient` and `update!`.""") @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) - + _make_zero!(model.dval) _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...) diff --git a/test/train.jl b/test/train.jl index 0ec40c2f22..3ed0e658ea 100644 --- a/test/train.jl +++ b/test/train.jl @@ -84,7 +84,7 @@ end y1 = m(x) # before # Implicit gradient - gold = gradient(() -> m(x), Flux.params(m)) + gold = Zygote.gradient(() -> m(x), Flux.params(m)) @test gold isa Flux.Zygote.Grads @test_throws ErrorException Flux.update!(Flux.Adam(), m, gold) # friendly Flux.update!(Flux.Adam(), Flux.params(m), gold) @@ -92,7 +92,7 @@ end @test y2 < y1 # Explicit gradient - gs = gradient(marg -> marg(x), m) + gs = Zygote.gradient(marg -> marg(x), m) @test gs isa Tuple @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly @@ -133,17 +133,20 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) diff1 = model.weight .- init_weight # Take 2: the same, but with Flux.params. Was broken for a bit, no tests! - model.weight .= init_weight - model.bias .= 0 - pen2(x::AbstractArray) = sum(abs2, x)/2 - opt = Flux.setup(Adam(0.1), model) - trainfn!(model, data, opt) do m, x, y - err = Flux.mse(m(x), y) - l2 = sum(pen2, Flux.params(m)) - err + 0.33 * l2 + # skipping this test for Enzyme cause implicit params is unsupported + if name == "Zygote" + model.weight .= init_weight + model.bias .= 0 + pen2(x::AbstractArray) = sum(abs2, x)/2 + opt = Flux.setup(Adam(0.1), model) + trainfn!(model, data, opt) do m, x, y + err = Flux.mse(m(x), y) + l2 = sum(pen2, Flux.params(m)) + err + 0.33 * l2 + end + diff2 = model.weight .- init_weight + @test diff1 ≈ diff2 end - diff2 = model.weight .- init_weight - @test diff1 ≈ diff2 # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. model.weight .= init_weight From 9a0f48679a04bd17269c2f16f95888f137d9cf8e Mon Sep 17 00:00:00 2001 From: darsnack Date: Sun, 23 Jun 2024 18:36:05 -0400 Subject: [PATCH 18/19] Updated news --- NEWS.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/NEWS.md b/NEWS.md index 87333f8717..aa135e13af 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,9 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.14.14 +* Add [support for Enzyme](https://github.com/FluxML/Flux.jl/pull/2446) with `Flux.train!`. + ## v0.14.13 * New macro `Flux.@layer` which should be used in place of `@functor`. This also adds `show` methods for pretty printing. From a5afde4f804a23fc86e11a72e56caf477b07b8c4 Mon Sep 17 00:00:00 2001 From: darsnack Date: Mon, 24 Jun 2024 06:55:27 -0400 Subject: [PATCH 19/19] Fix version in news and bump project.toml --- NEWS.md | 2 +- Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index aa135e13af..a4b0856327 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,7 +2,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. -## v0.14.14 +## v0.14.17 * Add [support for Enzyme](https://github.com/FluxML/Flux.jl/pull/2446) with `Flux.train!`. ## v0.14.13 diff --git a/Project.toml b/Project.toml index 201d52560e..6d02fc621f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.15" +version = "0.14.17" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"