Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more Duplicated methods for Enzyme.jl support #2471

Open
wants to merge 18 commits into
base: master
Choose a base branch
from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jul 25, 2024

This adds a method like gradient(f, ::Duplicated) which like train!(loss, model::Duplicated, data, opt) from #2446 uses the Duplicated type to signal that you want to use Enzyme not Zygote. It returns the gradient (for compatibility?) and mutates the Duplicated object.

  • To avoid piracy, this creates a new function Flux.gradient which by default calls Zygote.gradient. Unfortunately that's going to mean every using Flux, Zygote now produces ambiguities... so probably it should not be exported? Which means 0.15. Such ambiguities give a clear message, maybe that's OK? Maybe clearer than stopping exporting.

  • There's also withgradient but it doesn't allow you to return a tuple the way Zygote does, not yet. Now it does.

  • There's also a method of update! which either needs to move to Optimisers.jl, or again we need to... we should let Flux own the function? has moved to: Add Duplicated methods Optimisers.jl#192

  • Finally, @layer Chain defines a 1-argument Duplicated(c::Chain) method, so that you don't need to construct the dual by hand.

WIP, RFC?

Needs tests, and docs.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@ToucheSir
Copy link
Member

The docs failure here looks real, but I'm not sure why.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 6, 2024

Fixed the docs. I think we need to own update! too, for this not to be piracy.

But while I fix that, any objections to the interface? We seem to be merging things in a hurry now...

@mcabbott
Copy link
Member Author

mcabbott commented Nov 6, 2024

CUDA test failure is like this (and one more), why now?

Dropout Layer GPU grad test: Test Failed at /var/lib/buildkite-agent/builds/gpuci-15/julialang/flux-dot-jl/test/test_utils.jl:77
--
  | Expression: ≈(y_gpu, y, rtol = rtol, atol = atol)
  | Evaluated: 0.50719213f0 ≈ 0.5073142f0 (rtol=0.0001, atol=0.0001)

(m::$EnzymeCore.Duplicated{<:$type})(xs...) = m.val(xs...)

# Not sure but this does prevent printing of 2nd copy:
$Optimisers.trainable(m::$EnzymeCore.Duplicated{<:$type}) = (; val = m.val)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leads to errors in show if used with Functors 0.4. Optimisers.jl gets confused if something which isn't a functor has trainable defined... but soon this will go away:

julia> struct Two{A,B}; a::A; b::B; end

julia> Flux.trainable(x::Two) = (; x.a)

julia> Flux.trainables(Two([1,2.], [3,4,5.]))
ERROR: MethodError: no method matching _trainable(::Tuple{}, ::@NamedTuple{a::Vector{Float64}})
The function `_trainable` exists, but no method is defined for this combination of argument types.

@ToucheSir
Copy link
Member

The interface looks reasonable to me and Flux owning gradient could make sense for moving off Zygote, but I don't love the careful internal plumbing required to make this work. I wonder if we can work around concerns about type piracy by moving the update! and trainable overloads to an Optimisers.jl extension.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 8, 2024

Can you clarify what you mean by "the careful internal plumbing"? If you mean how gradient works, it's trying to give friendly errors if you have loaded only EnzymeCore somehow. And also to not require Const; we could simplify by requiring either all Cons/Duplicated or none. But this doesn't seem so horrendous.

update! has a lot of methods because we keep old Flux.Optimise around (and aim for friendly errors). But I don't mind moving the new ones to Optimisers.jl. Flux has its own setup but using the wrong version matters less there.

Edit: now requires FluxML/Optimisers.jl#192

Copy link

codecov bot commented Nov 8, 2024

Codecov Report

Attention: Patch coverage is 26.74419% with 63 lines in your changes missing coverage. Please review.

Project coverage is 59.24%. Comparing base (c51e6fb) to head (a9424ab).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
ext/FluxEnzymeExt/FluxEnzymeExt.jl 0.00% 29 Missing ⚠️
src/gradient.jl 58.62% 12 Missing ⚠️
src/deprecations.jl 0.00% 8 Missing ⚠️
src/layers/macro.jl 46.15% 7 Missing ⚠️
src/train.jl 0.00% 7 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2471       +/-   ##
===========================================
+ Coverage   33.40%   59.24%   +25.84%     
===========================================
  Files          31       32        +1     
  Lines        1907     2012      +105     
===========================================
+ Hits          637     1192      +555     
+ Misses       1270      820      -450     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ToucheSir
Copy link
Member

The gradient code looked pretty straightforward, but there was also update!, train! and all of _macro_enzyme. I think the gradient changes are fine (though I'm surprised the code is type stable!), but would also be ok with limiting to Const + Duplicated if that makes things easier. It's easier to relax parts of the API than to tighten them up later.

The ideal path right now would be landing FluxML/Optimisers.jl#192 and then removing _macro_enzyme from here. I think that should be straightforward?

@mcabbott
Copy link
Member Author

mcabbott commented Nov 8, 2024

Re gradient, one quirk is that Enzyme has the rule that anything not Const is active:

julia> Enzyme.gradient(Reverse, dot, [1, 2.], [3, 4.])
([3.0, 4.0], [1.0, 2.0])

julia> Enzyme.gradient(Reverse, dot, [1, 2.], Const([3, 4.]))
([3.0, 4.0], nothing)

while this PR's rule is that, once one thing is Duplicated, everything else is constant:

julia> Flux.gradient(dot, [1, 2.], [3, 4.])  # Zygote
([3.0, 4.0], [1.0, 2.0])

julia> Flux.gradient(dot, [1, 2.], Duplicated([3, 4.], [NaN, NaN]))  # implicit Const
(nothing, [1.0, 2.0])

julia> Flux.gradient(dot, [1, 2.], Const([3, 4.]))
ERROR: ArgumentError: The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`.

IDK if that's too weird.

Edit, one more quirk... now fixed:

julia> Flux.gradient((x,y) -> sum((x .* y).^2), [1,2,3.], 4.0)  # Zygote
([32.0, 64.0, 96.0], 112.0)

julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), 4.0)  # implicit Const
([32.0, 64.0, 96.0], nothing)

julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), Active(4.0))
ERROR: ArgumentError: The method `gradient(f, xs...)` using Enzyme.jl does not support `Active`, only `Duplicated` and ``Const`.

julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), Duplicated(4.0, 0.0))  # now an error, 2636454
ERROR: `Flux.gradient(f, Duplicatged(x), ...)` expects `x` to contain mutable parameter arrays.

@mcabbott
Copy link
Member Author

One question this PR opens is how how other AD might slot in. Mooncake.jl similarly prefers pre-allocated space for gradients, but uses different types. The obvious interface would be to have some container like so:

struct MoonPair{X,DX}; x::X; dx::DX; end
MoonPair(x) = MoonPair(x, Mooncake.zero_codual(x))

Flux could own this. That would be easy, in terms of e.g. defining show nicely.

Mooncake.jl could also own it, but then @layer cannot rely on it being defined. An extension would have to dispatch on something like Flux.gradient(f, args::Union{MoonPair, Const}...) and maybe that gets messy if the model is the 3rd argument.


forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
tape, result, shadow_result = forward(Const(f), args...)
reverse(Const(f), args..., _sensitivity(result), tape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just call autodiff here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was adapted from the ReverseSplitWithPrimal doc example. My understanding was that it needs that in order to construct _sensitivity(result) after seeing what f returns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected return type of f. If it's a float it should be fine to just use autodiff dirrectly.

If not a float (as looks like below), it might be more efficient to do something like the following (since split mode will introduce overhead).

@inline asfloat(x) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
    or else a Tuple or NamedTuple whose first element is a real number.""")
    
@inline asfloat(x::Real) = x

@inline asfloat(x::Tuple) = asfloat(x[1])

@inline asfloat(x:: NamedTuple) = asfloat(x[1])

function return_asfloat(f, args...)
    return asfloat(@inline f(args...))
end

autodiff(Reverse, Const(return_asfloat), Active, Const(f), ...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is to pass things other than the loss out of the function:

julia> withgradient([1,2,4]) do x
            z = 1 ./ x
            sum(z), string("aux output: z = ", z)
         end
(val = (1.75, "aux output: z = [1.0, 0.5, 0.25]"), grad = ([-1.0, -0.25, -0.0625],))

Copy link
Member Author

@mcabbott mcabbott Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that trying to infer the return type of f would pay... when it is shown to return a float you could call ReverseWithPrimal. How big is the overhead of this ReverseSplitWithPrimal?

Edit, with examples from the other thread I can't measure it, <1% maybe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, well in that case maybe something like the following (assuming you know the return type)

function byref(out, f, args)
  res = f(args, ...)
  out[] = as_float(res)
  return res
end

dout = DuplicatedNoNeed(Ref(0.0), Ref(1.0))
autodiff(Reverse, byref, Const, Const(byref), dout, ....)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks possible. But where is it better? Is the overhead of ReverseSplitWithPrimal running it, or in generating code or something (which I didn't try to time)?

@wsmoses
Copy link
Contributor

wsmoses commented Nov 11, 2024

Longer term it might be worth making sure that the design is flexible enough for Reactant integration as well (x/ref https://lux.csail.mit.edu/stable/manual/compiling_lux_models). Lux has already shown relatively big speedups at least on CPU

@avik-pal
Copy link
Member

Longer term it might be worth making sure that the design is flexible enough for Reactant integration as well (x/ref lux.csail.mit.edu/stable/manual/compiling_lux_models). Lux has already shown relatively big speedups at least on CPU

MLDataDevices has get_device_type which lets you get the device. See https://github.com/LuxDL/Lux.jl/blob/0be75045c37a51fc6369a28c9e8e893c1044089d/src/helpers/training.jl#L200-L206. If we get a AutoEnzyme + ReactantDevice, it switches to using Reactant. Also you need to ensure that compilation is done only once (https://github.com/LuxDL/Lux.jl/blob/0be75045c37a51fc6369a28c9e8e893c1044089d/ext/LuxReactantExt/training.jl#L37-L57) and then reused (https://github.com/LuxDL/Lux.jl/blob/0be75045c37a51fc6369a28c9e8e893c1044089d/ext/LuxReactantExt/training.jl#L59-L70)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants