Skip to content

Commit

Permalink
add more Duplicated methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 25, 2024
1 parent 8c15898 commit ef9d7b4
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 4 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.14.17"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand Down Expand Up @@ -43,6 +44,7 @@ CUDA = "4, 5"
ChainRulesCore = "1.12"
Compat = "4.10.0"
Enzyme = "0.12"
EnzymeCore = "0.7.7"
Functors = "0.4"
MLUtils = "0.4"
MacroTools = "0.5"
Expand Down
93 changes: 91 additions & 2 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,80 @@ _make_zero_internal!(x::AbstractArray) = fill!(x, 0)
_make_zero_internal!(x) = x
_make_zero!(model) = fmap(_make_zero_internal!, model)

_applyloss(loss, model, d...) = loss(model, d...)

EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true


### gradient & withgradient

_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval)
_grad_or_nothing(::Const) = nothing
_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing

function Flux.withgradient(f, args::Union{Const, Duplicated}...)
for x in args
x isa Duplicated && _make_zero!(x.dval)
end
# TODO allow for f to return a tuple here, like in Zygote
_, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
(; val, grad = map(_grad_or_nothing, args))
end

"""
gradient(f, Duplicated(model), args...)
This should return the same answer as `gradient(f, model, args...)`,
but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.
Only available when Enzyme is loaded!
Besides returning the gradient, this is also stored within the `Duplicated` object.
Calling `Enzyme.Duplicated(model)` allocates space for the gradient,
which is zero'd befor use when calling `gradient`.
!!! warning "Experimental"
Enzyme support like this is new and somewhat experimental.
It has known problems if your model has shared parameters.
# Example
```
julia> using Flux, Enzyme
julia> model = Chain(Dense([3.0;;]));
julia> Flux.gradient(model, [1]) do m, x
sum(abs2, m(x))
end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0])
julia> Flux.gradient(Duplicated(model), Const([1])) do m, x
sum(abs2, m(x))
end
┌ Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing)
```
"""
function Flux.gradient(f, args::Union{Const, Duplicated}...)
for x in args
x isa Duplicated && _make_zero!(x.dval)
end
Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
map(_grad_or_nothing, args)
end

_const_unless_dup(x) = Const(x)
_const_unless_dup(dup::Duplicated) = x

# TODO allow for Duplicated as 2nd argument, assume others const? This produces ambiguities...
# Flux.withgradient(f, dup::Duplicated, rest...) = Flux.withgradient(f, dup, map(_const_unless_dup, rest)...)
# Flux.gradient(f, dup::Duplicated, rest...) = Flux.gradient(f, dup, map(_const_unless_dup, rest)...)


### Flux.Train, for train!

_applyloss(loss, model, d...) = loss(model, d...)

using Flux: _old_to_new # from src/deprecations.jl
train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
train!(loss, model, data, _old_to_new(opt); cb)
Expand Down Expand Up @@ -44,4 +114,23 @@ function train!(loss, model::Duplicated, data, opt; cb = nothing)
end
end


### Optimisers.update!, piracy, for now!

"""
Flux.update!(opt_state, model::Duplicated)
Method of `update!` for use with Enzyme, and in particular with `gradient(loss, Duplicated(model))`.
Since `Duplicated(model)` stores the gradient, `update!` can read it & update the model itself,
by calling `Flux.update!(opt_state, model.val, model.dval)`.
!!! warning "Experimental"
Enzyme support like this is new and somewhat experimental.
This method is piracy, and must move to Optimisers.jl in the end.
"""
function Flux.update!(opt_state, model::Duplicated)
Flux.update!(opt_state, model.val, model.dval)
model
end

end # FluxEnzymeExt
7 changes: 5 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owne
using Optimisers: freeze!, thaw!, adjust!, trainables
using Random: default_rng
using Zygote, ChainRulesCore
using Zygote: Params, @adjoint, gradient, pullback
using Zygote: Params, @adjoint, pullback
using Zygote.ForwardDiff: value
export gradient

# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)
Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`")
Expand Down Expand Up @@ -55,6 +54,10 @@ include("train.jl")
using .Train
using .Train: setup

include("gradient.jl")
export gradient
@compat(public, (withgradient,))

using Adapt, Functors, OneHotArrays
include("utils.jl")
include("functor.jl")
Expand Down
12 changes: 12 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ ChainRulesCore.@non_differentiable _greek_ascii_depwarn(::Any...)

Base.@deprecate_binding FluxAMDAdaptor FluxAMDGPUAdaptor

function gradient(f, p::Zygote.Params)
Base.depwarn("""Implicit gradients such as `gradient(f, ::Params)` are deprecated!
Please see the docs for new explicit form.""", :gradient)
Zygote.gradient(f, args...)
end

function withgradient(f, p::Zygote.Params)
Base.depwarn("""Implicit gradients such as `withgradient(f, ::Params)` are deprecated!
Please see the docs for new explicit form.""", :withgradient)
Zygote.withgradient(f, args...)
end

# v0.15 deprecations

# Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc:
Expand Down
66 changes: 66 additions & 0 deletions src/gradient.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

"""
gradient(f, args...)
Returns a tuple containing `∂f/∂x` for each argument `x`,
the derivative (for scalar `x`) or the gradient.
If no gradient is defined, `∂f/∂x` will be `nothing`.
`f(args...)` must be a real number, see [`Zygote.jacobian`](@ref) for array output.
By default, `Flux.gradient` calls Zygote. If you load Enzyme, then other methods become available.
See also [`withgradient`](@ref) to keep the value `f(args...)`.
```jldoctest; setup=:(using Zygote)
julia> gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)
julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
([14.0, 22.0, 26.0],)
julia> gradient([7, 11], 0, 1) do x, y, d
p = size(x, d)
sum(x.^p .+ y)
end
([14.0, 22.0], 2.0, nothing)
```
"""
gradient(f, args...) = Zygote.gradient(f, args...)



"""
withgradient(f, args...)
Returns both the value of the function and the [`gradient`](@ref), as a named tuple.
By default, `Flux.withgradient` calls Zygote. If you load Enzyme, then other methods become available.
```jldoctest; setup=:(using Zygote)
julia> y, ∇ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))
julia> ∇ == gradient(/, 1, 2)
true
```
Allows you to capture auxillary outputs, in addition to the scalar
used by `gradient`. To do this, `f` must return a Tuple or NamedTuple.
Then it calculates `grad = gradient(first∘f, args...)
but returns the whole `val = f(args...)`:
```jldoctest; setup=:(using Zygote)
julia> withgradient([1,2,4]) do x
z = 1 ./ x
sum(z), z # here z is an auxillary output
end
(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))
julia> withgradient(3.0, 4.0) do x, y
(div = x/y, mul = x*y)
end
(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))
```
"""
withgradient(f, args...) = Zygote.withgradient(f, args...)
4 changes: 4 additions & 0 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import EnzymeCore

"""
@layer Dense
Expand Down Expand Up @@ -65,6 +66,9 @@ macro layer(exs...)

# This function exists only for depwarns when you use @functor directly
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing))

# TODO this should probably make it zero first?
push!(out.args, :($EnzymeCore.Duplicated(m::$(esc(type))) = $EnzymeCore.Duplicated(m, $deepcopy(m))))

push!(out.args, _macro_functor(esc(type)))

Expand Down

0 comments on commit ef9d7b4

Please sign in to comment.