diff --git a/Project.toml b/Project.toml index 01ac5cdeaa..7c65d587b3 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.14.17" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -43,6 +44,7 @@ CUDA = "4, 5" ChainRulesCore = "1.12" Compat = "4.10.0" Enzyme = "0.12" +EnzymeCore = "0.7.7" Functors = "0.4" MLUtils = "0.4" MacroTools = "0.5" diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl index e6ce51297f..4b1d00fac7 100644 --- a/ext/FluxEnzymeExt/FluxEnzymeExt.jl +++ b/ext/FluxEnzymeExt/FluxEnzymeExt.jl @@ -12,10 +12,80 @@ _make_zero_internal!(x::AbstractArray) = fill!(x, 0) _make_zero_internal!(x) = x _make_zero!(model) = fmap(_make_zero_internal!, model) -_applyloss(loss, model, d...) = loss(model, d...) - EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true + +### gradient & withgradient + +_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval) +_grad_or_nothing(::Const) = nothing +_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing + +function Flux.withgradient(f, args::Union{Const, Duplicated}...) + for x in args + x isa Duplicated && _make_zero!(x.dval) + end + # TODO allow for f to return a tuple here, like in Zygote + _, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) + (; val, grad = map(_grad_or_nothing, args)) +end + +""" + gradient(f, Duplicated(model), args...) + +This should return the same answer as `gradient(f, model, args...)`, +but it uses Enzyme.jl instead of Zygote.jl to compute the derivative. + +Only available when Enzyme is loaded! + +Besides returning the gradient, this is also stored within the `Duplicated` object. +Calling `Enzyme.Duplicated(model)` allocates space for the gradient, +which is zero'd befor use when calling `gradient`. + +!!! warning "Experimental" + Enzyme support like this is new and somewhat experimental. + It has known problems if your model has shared parameters. + +# Example +``` +julia> using Flux, Enzyme + +julia> model = Chain(Dense([3.0;;])); + +julia> Flux.gradient(model, [1]) do m, x + sum(abs2, m(x)) + end +((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0]) + +julia> Flux.gradient(Duplicated(model), Const([1])) do m, x + sum(abs2, m(x)) + end +┌ Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded +└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59 +((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing) + +``` +""" +function Flux.gradient(f, args::Union{Const, Duplicated}...) + for x in args + x isa Duplicated && _make_zero!(x.dval) + end + Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) + map(_grad_or_nothing, args) +end + +_const_unless_dup(x) = Const(x) +_const_unless_dup(dup::Duplicated) = x + +# TODO allow for Duplicated as 2nd argument, assume others const? This produces ambiguities... +# Flux.withgradient(f, dup::Duplicated, rest...) = Flux.withgradient(f, dup, map(_const_unless_dup, rest)...) +# Flux.gradient(f, dup::Duplicated, rest...) = Flux.gradient(f, dup, map(_const_unless_dup, rest)...) + + +### Flux.Train, for train! + +_applyloss(loss, model, d...) = loss(model, d...) + using Flux: _old_to_new # from src/deprecations.jl train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) = train!(loss, model, data, _old_to_new(opt); cb) @@ -44,4 +114,23 @@ function train!(loss, model::Duplicated, data, opt; cb = nothing) end end + +### Optimisers.update!, piracy, for now! + +""" + Flux.update!(opt_state, model::Duplicated) + +Method of `update!` for use with Enzyme, and in particular with `gradient(loss, Duplicated(model))`. +Since `Duplicated(model)` stores the gradient, `update!` can read it & update the model itself, +by calling `Flux.update!(opt_state, model.val, model.dval)`. + +!!! warning "Experimental" + Enzyme support like this is new and somewhat experimental. + This method is piracy, and must move to Optimisers.jl in the end. +""" +function Flux.update!(opt_state, model::Duplicated) + Flux.update!(opt_state, model.val, model.dval) + model +end + end # FluxEnzymeExt \ No newline at end of file diff --git a/src/Flux.jl b/src/Flux.jl index 2681ea6729..86852170d3 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -14,9 +14,8 @@ import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owne using Optimisers: freeze!, thaw!, adjust!, trainables using Random: default_rng using Zygote, ChainRulesCore -using Zygote: Params, @adjoint, gradient, pullback +using Zygote: Params, @adjoint, pullback using Zygote.ForwardDiff: value -export gradient # Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.) Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`") @@ -55,6 +54,10 @@ include("train.jl") using .Train using .Train: setup +include("gradient.jl") +export gradient +@compat(public, (withgradient,)) + using Adapt, Functors, OneHotArrays include("utils.jl") include("functor.jl") diff --git a/src/deprecations.jl b/src/deprecations.jl index 9306671494..2caf9198b3 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -212,6 +212,18 @@ ChainRulesCore.@non_differentiable _greek_ascii_depwarn(::Any...) Base.@deprecate_binding FluxAMDAdaptor FluxAMDGPUAdaptor +function gradient(f, p::Zygote.Params) + Base.depwarn("""Implicit gradients such as `gradient(f, ::Params)` are deprecated! + Please see the docs for new explicit form.""", :gradient) + Zygote.gradient(f, args...) +end + +function withgradient(f, p::Zygote.Params) + Base.depwarn("""Implicit gradients such as `withgradient(f, ::Params)` are deprecated! + Please see the docs for new explicit form.""", :withgradient) + Zygote.withgradient(f, args...) +end + # v0.15 deprecations # Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc: diff --git a/src/gradient.jl b/src/gradient.jl new file mode 100644 index 0000000000..d968336e42 --- /dev/null +++ b/src/gradient.jl @@ -0,0 +1,66 @@ + +""" + gradient(f, args...) + +Returns a tuple containing `∂f/∂x` for each argument `x`, +the derivative (for scalar `x`) or the gradient. +If no gradient is defined, `∂f/∂x` will be `nothing`. + +`f(args...)` must be a real number, see [`Zygote.jacobian`](@ref) for array output. + +By default, `Flux.gradient` calls Zygote. If you load Enzyme, then other methods become available. + +See also [`withgradient`](@ref) to keep the value `f(args...)`. + +```jldoctest; setup=:(using Zygote) +julia> gradient(*, 2.0, 3.0, 5.0) +(15.0, 10.0, 6.0) + +julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0]) +([14.0, 22.0, 26.0],) + +julia> gradient([7, 11], 0, 1) do x, y, d + p = size(x, d) + sum(x.^p .+ y) + end +([14.0, 22.0], 2.0, nothing) +``` +""" +gradient(f, args...) = Zygote.gradient(f, args...) + + + +""" + withgradient(f, args...) + +Returns both the value of the function and the [`gradient`](@ref), as a named tuple. + +By default, `Flux.withgradient` calls Zygote. If you load Enzyme, then other methods become available. + +```jldoctest; setup=:(using Zygote) +julia> y, ∇ = withgradient(/, 1, 2) +(val = 0.5, grad = (0.5, -0.25)) + +julia> ∇ == gradient(/, 1, 2) +true +``` + +Allows you to capture auxillary outputs, in addition to the scalar +used by `gradient`. To do this, `f` must return a Tuple or NamedTuple. +Then it calculates `grad = gradient(first∘f, args...) +but returns the whole `val = f(args...)`: + +```jldoctest; setup=:(using Zygote) +julia> withgradient([1,2,4]) do x + z = 1 ./ x + sum(z), z # here z is an auxillary output + end +(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],)) + +julia> withgradient(3.0, 4.0) do x, y + (div = x/y, mul = x*y) + end +(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875)) +``` +""" +withgradient(f, args...) = Zygote.withgradient(f, args...) diff --git a/src/layers/macro.jl b/src/layers/macro.jl index dcebe551e3..a615ab9ad2 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -1,3 +1,4 @@ +import EnzymeCore """ @layer Dense @@ -65,6 +66,9 @@ macro layer(exs...) # This function exists only for depwarns when you use @functor directly push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) + + # TODO this should probably make it zero first? + push!(out.args, :($EnzymeCore.Duplicated(m::$(esc(type))) = $EnzymeCore.Duplicated(m, $deepcopy(m)))) push!(out.args, _macro_functor(esc(type)))