Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more Duplicated methods for Enzyme.jl support #2471

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.14.17"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand Down Expand Up @@ -43,6 +44,7 @@ CUDA = "4, 5"
ChainRulesCore = "1.12"
Compat = "4.10.0"
Enzyme = "0.12"
EnzymeCore = "0.7.7"
Functors = "0.4"
MLUtils = "0.4"
MacroTools = "0.5"
Expand Down
93 changes: 91 additions & 2 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,80 @@ _make_zero_internal!(x::AbstractArray) = fill!(x, 0)
_make_zero_internal!(x) = x
_make_zero!(model) = fmap(_make_zero_internal!, model)

_applyloss(loss, model, d...) = loss(model, d...)

EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true


### gradient & withgradient

_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval)
_grad_or_nothing(::Const) = nothing
_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing

function Flux.withgradient(f, args::Union{Const, Duplicated}...)
for x in args
x isa Duplicated && _make_zero!(x.dval)
end
# TODO allow for f to return a tuple here, like in Zygote
_, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
(; val, grad = map(_grad_or_nothing, args))
end

"""
gradient(f, Duplicated(model), args...)

This should return the same answer as `gradient(f, model, args...)`,
but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.

Only available when Enzyme is loaded!

Besides returning the gradient, this is also stored within the `Duplicated` object.
Calling `Enzyme.Duplicated(model)` allocates space for the gradient,
which is zero'd befor use when calling `gradient`.

!!! warning "Experimental"
Enzyme support like this is new and somewhat experimental.
It has known problems if your model has shared parameters.

# Example
```
julia> using Flux, Enzyme

julia> model = Chain(Dense([3.0;;]));

julia> Flux.gradient(model, [1]) do m, x
sum(abs2, m(x))
end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0])

julia> Flux.gradient(Duplicated(model), Const([1])) do m, x
sum(abs2, m(x))
end
┌ Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing)

```
"""
function Flux.gradient(f, args::Union{Const, Duplicated}...)
for x in args
x isa Duplicated && _make_zero!(x.dval)
end
Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
map(_grad_or_nothing, args)
end

_const_unless_dup(x) = Const(x)
_const_unless_dup(dup::Duplicated) = x

# TODO allow for Duplicated as 2nd argument, assume others const? This produces ambiguities...
# Flux.withgradient(f, dup::Duplicated, rest...) = Flux.withgradient(f, dup, map(_const_unless_dup, rest)...)
# Flux.gradient(f, dup::Duplicated, rest...) = Flux.gradient(f, dup, map(_const_unless_dup, rest)...)


### Flux.Train, for train!

_applyloss(loss, model, d...) = loss(model, d...)

using Flux: _old_to_new # from src/deprecations.jl
train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
train!(loss, model, data, _old_to_new(opt); cb)
Expand Down Expand Up @@ -44,4 +114,23 @@ function train!(loss, model::Duplicated, data, opt; cb = nothing)
end
end


### Optimisers.update!, piracy, for now!

"""
Flux.update!(opt_state, model::Duplicated)

Method of `update!` for use with Enzyme, and in particular with `gradient(loss, Duplicated(model))`.
Since `Duplicated(model)` stores the gradient, `update!` can read it & update the model itself,
by calling `Flux.update!(opt_state, model.val, model.dval)`.

!!! warning "Experimental"
Enzyme support like this is new and somewhat experimental.
This method is piracy, and must move to Optimisers.jl in the end.
"""
function Flux.update!(opt_state, model::Duplicated)
Flux.update!(opt_state, model.val, model.dval)
model
end

end # FluxEnzymeExt
7 changes: 5 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owne
using Optimisers: freeze!, thaw!, adjust!, trainables
using Random: default_rng
using Zygote, ChainRulesCore
using Zygote: Params, @adjoint, gradient, pullback
using Zygote: Params, @adjoint, pullback
using Zygote.ForwardDiff: value
export gradient

# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)
Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`")
Expand Down Expand Up @@ -55,6 +54,10 @@ include("train.jl")
using .Train
using .Train: setup

include("gradient.jl")
export gradient
@compat(public, (withgradient,))

using Adapt, Functors, OneHotArrays
include("utils.jl")
include("functor.jl")
Expand Down
12 changes: 12 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ ChainRulesCore.@non_differentiable _greek_ascii_depwarn(::Any...)

Base.@deprecate_binding FluxAMDAdaptor FluxAMDGPUAdaptor

function gradient(f, p::Zygote.Params)
Base.depwarn("""Implicit gradients such as `gradient(f, ::Params)` are deprecated!
Please see the docs for new explicit form.""", :gradient)
Zygote.gradient(f, args...)
end

function withgradient(f, p::Zygote.Params)
Base.depwarn("""Implicit gradients such as `withgradient(f, ::Params)` are deprecated!
Please see the docs for new explicit form.""", :withgradient)
Zygote.withgradient(f, args...)
end

# v0.15 deprecations

# Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc:
Expand Down
66 changes: 66 additions & 0 deletions src/gradient.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

"""
gradient(f, args...)

Returns a tuple containing `∂f/∂x` for each argument `x`,
the derivative (for scalar `x`) or the gradient.
If no gradient is defined, `∂f/∂x` will be `nothing`.

`f(args...)` must be a real number, see [`Zygote.jacobian`](@ref) for array output.

By default, `Flux.gradient` calls Zygote. If you load Enzyme, then other methods become available.

See also [`withgradient`](@ref) to keep the value `f(args...)`.

```jldoctest; setup=:(using Zygote)
julia> gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)

julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
([14.0, 22.0, 26.0],)

julia> gradient([7, 11], 0, 1) do x, y, d
p = size(x, d)
sum(x.^p .+ y)
end
([14.0, 22.0], 2.0, nothing)
```
"""
gradient(f, args...) = Zygote.gradient(f, args...)



"""
withgradient(f, args...)

Returns both the value of the function and the [`gradient`](@ref), as a named tuple.

By default, `Flux.withgradient` calls Zygote. If you load Enzyme, then other methods become available.

```jldoctest; setup=:(using Zygote)
julia> y, ∇ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))

julia> ∇ == gradient(/, 1, 2)
true
```

Allows you to capture auxillary outputs, in addition to the scalar
used by `gradient`. To do this, `f` must return a Tuple or NamedTuple.
Then it calculates `grad = gradient(first∘f, args...)
but returns the whole `val = f(args...)`:

```jldoctest; setup=:(using Zygote)
julia> withgradient([1,2,4]) do x
z = 1 ./ x
sum(z), z # here z is an auxillary output
end
(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))

julia> withgradient(3.0, 4.0) do x, y
(div = x/y, mul = x*y)
end
(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))
```
"""
withgradient(f, args...) = Zygote.withgradient(f, args...)
4 changes: 4 additions & 0 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import EnzymeCore

"""
@layer Dense
Expand Down Expand Up @@ -65,6 +66,9 @@ macro layer(exs...)

# This function exists only for depwarns when you use @functor directly
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing))

# TODO this should probably make it zero first?
push!(out.args, :($EnzymeCore.Duplicated(m::$(esc(type))) = $EnzymeCore.Duplicated(m, $deepcopy(m))))

push!(out.args, _macro_functor(esc(type)))

Expand Down
Loading