-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add more Duplicated
methods for Enzyme.jl support
#2471
base: master
Are you sure you want to change the base?
Conversation
60acf27
to
6310548
Compare
The docs failure here looks real, but I'm not sure why. |
Fixed the docs. I think we need to own But while I fix that, any objections to the interface? We seem to be merging things in a hurry now... |
CUDA test failure is like this (and one more), why now?
|
(m::$EnzymeCore.Duplicated{<:$type})(xs...) = m.val(xs...) | ||
|
||
# Not sure but this does prevent printing of 2nd copy: | ||
$Optimisers.trainable(m::$EnzymeCore.Duplicated{<:$type}) = (; val = m.val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This leads to errors in show if used with Functors 0.4. Optimisers.jl gets confused if something which isn't a functor has trainable defined... but soon this will go away:
julia> struct Two{A,B}; a::A; b::B; end
julia> Flux.trainable(x::Two) = (; x.a)
julia> Flux.trainables(Two([1,2.], [3,4,5.]))
ERROR: MethodError: no method matching _trainable(::Tuple{}, ::@NamedTuple{a::Vector{Float64}})
The function `_trainable` exists, but no method is defined for this combination of argument types.
The interface looks reasonable to me and Flux owning |
Can you clarify what you mean by "the careful internal plumbing"? If you mean how
Edit: now requires FluxML/Optimisers.jl#192 |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2471 +/- ##
===========================================
+ Coverage 33.40% 59.24% +25.84%
===========================================
Files 31 32 +1
Lines 1907 2012 +105
===========================================
+ Hits 637 1192 +555
+ Misses 1270 820 -450 ☔ View full report in Codecov by Sentry. |
This reverts commit ca5a20f.
The gradient code looked pretty straightforward, but there was also The ideal path right now would be landing FluxML/Optimisers.jl#192 and then removing |
Re gradient, one quirk is that Enzyme has the rule that anything not julia> Enzyme.gradient(Reverse, dot, [1, 2.], [3, 4.])
([3.0, 4.0], [1.0, 2.0])
julia> Enzyme.gradient(Reverse, dot, [1, 2.], Const([3, 4.]))
([3.0, 4.0], nothing) while this PR's rule is that, once one thing is julia> Flux.gradient(dot, [1, 2.], [3, 4.]) # Zygote
([3.0, 4.0], [1.0, 2.0])
julia> Flux.gradient(dot, [1, 2.], Duplicated([3, 4.], [NaN, NaN])) # implicit Const
(nothing, [1.0, 2.0])
julia> Flux.gradient(dot, [1, 2.], Const([3, 4.]))
ERROR: ArgumentError: The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`. IDK if that's too weird. Edit, one more quirk... now fixed: julia> Flux.gradient((x,y) -> sum((x .* y).^2), [1,2,3.], 4.0) # Zygote
([32.0, 64.0, 96.0], 112.0)
julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), 4.0) # implicit Const
([32.0, 64.0, 96.0], nothing)
julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), Active(4.0))
ERROR: ArgumentError: The method `gradient(f, xs...)` using Enzyme.jl does not support `Active`, only `Duplicated` and ``Const`.
julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), Duplicated(4.0, 0.0)) # now an error, 2636454
ERROR: `Flux.gradient(f, Duplicatged(x), ...)` expects `x` to contain mutable parameter arrays. |
One question this PR opens is how how other AD might slot in. Mooncake.jl similarly prefers pre-allocated space for gradients, but uses different types. The obvious interface would be to have some container like so: struct MoonPair{X,DX}; x::X; dx::DX; end
MoonPair(x) = MoonPair(x, Mooncake.zero_codual(x)) Flux could own this. That would be easy, in terms of e.g. defining Mooncake.jl could also own it, but then |
|
||
forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...) | ||
tape, result, shadow_result = forward(Const(f), args...) | ||
reverse(Const(f), args..., _sensitivity(result), tape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just call autodiff here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was adapted from the ReverseSplitWithPrimal doc example. My understanding was that it needs that in order to construct _sensitivity(result)
after seeing what f
returns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the expected return type of f. If it's a float it should be fine to just use autodiff dirrectly.
If not a float (as looks like below), it might be more efficient to do something like the following (since split mode will introduce overhead).
@inline asfloat(x) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
or else a Tuple or NamedTuple whose first element is a real number.""")
@inline asfloat(x::Real) = x
@inline asfloat(x::Tuple) = asfloat(x[1])
@inline asfloat(x:: NamedTuple) = asfloat(x[1])
function return_asfloat(f, args...)
return asfloat(@inline f(args...))
end
autodiff(Reverse, Const(return_asfloat), Active, Const(f), ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point is to pass things other than the loss out of the function:
julia> withgradient([1,2,4]) do x
z = 1 ./ x
sum(z), string("aux output: z = ", z)
end
(val = (1.75, "aux output: z = [1.0, 0.5, 0.25]"), grad = ([-1.0, -0.25, -0.0625],))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible that trying to infer the return type of f
would pay... when it is shown to return a float you could call ReverseWithPrimal. How big is the overhead of this ReverseSplitWithPrimal?
Edit, with examples from the other thread I can't measure it, <1% maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, well in that case maybe something like the following (assuming you know the return type)
function byref(out, f, args)
res = f(args, ...)
out[] = as_float(res)
return res
end
dout = DuplicatedNoNeed(Ref(0.0), Ref(1.0))
autodiff(Reverse, byref, Const, Const(byref), dout, ....)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks possible. But where is it better? Is the overhead of ReverseSplitWithPrimal running it, or in generating code or something (which I didn't try to time)?
Longer term it might be worth making sure that the design is flexible enough for Reactant integration as well (x/ref https://lux.csail.mit.edu/stable/manual/compiling_lux_models). Lux has already shown relatively big speedups at least on CPU |
MLDataDevices has |
This adds a method like
gradient(f, ::Duplicated)
which liketrain!(loss, model::Duplicated, data, opt)
from #2446 uses the Duplicated type to signal that you want to use Enzyme not Zygote. It returns the gradient (for compatibility?) and mutates theDuplicated
object.To avoid piracy, this creates a new function
Flux.gradient
which by default callsZygote.gradient
. Unfortunately that's going to mean everyusing Flux, Zygote
now produces ambiguities...so probably it should not be exported? Which means 0.15.Such ambiguities give a clear message, maybe that's OK? Maybe clearer than stopping exporting.There's also
withgradient
but it doesn't allow you to return a tuple the way Zygote does, not yet.Now it does.There's also a method of
update!
whicheither needs to move to Optimisers.jl, or again we need to...we should let Flux own the function?has moved to: AddDuplicated
methods Optimisers.jl#192Finally,
@layer Chain
defines a 1-argumentDuplicated(c::Chain)
method, so that you don't need to construct the dual by hand.WIP, RFC?
Needs tests, and docs.PR Checklist