Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme train function #2446

Merged
merged 21 commits into from
Jun 24, 2024
Merged

Add Enzyme train function #2446

merged 21 commits into from
Jun 24, 2024

Conversation

wsmoses
Copy link
Contributor

@wsmoses wsmoses commented May 14, 2024

A quick test on the readme input seems positive. I have no opinions on the design/API and I will give this PR to you all to make it however you feel (and I will go back to staring at CUDA).

I will note that perf atm is unclear and is worth investigating. However, before we do that, having a good way to run/test things is critical, hence this PR.

using Enzyme, Flux

data = [([x], 2x-x^3) for x in -2:0.1f0:2]
model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)

optim = Flux.setup(Adam(), model)


julia> fn(m, x, y) = (m(x) - y)^2
fn (generic function with 1 method)

julia> @time for epoch in 1:1000
         Flux.train!(fn, model, data, optim)
       end
  1.414480 seconds (13.66 M allocations: 800.513 MiB, 3.57% gc time, 5.66% compilation time)

julia> @time for epoch in 1:1000
         Flux.train!(fn, model, data, optim)
       end
  1.364847 seconds (13.41 M allocations: 785.187 MiB, 3.36% gc time)

julia> @time for epoch in 1:1000
         Flux.train!(fn, model, data, optim)
       end
  1.348747 seconds (13.41 M allocations: 785.187 MiB, 3.73% gc time)

julia> @time for epoch in 1:1000
         Flux.train_enzyme!(fn, model, data, optim)
       end
┌ Warning: Using fallback BLAS replacements for (["ssymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59
  2.484299 seconds (15.29 M allocations: 708.970 MiB, 3.77% gc time, 79.16% compilation time)

julia> @time for epoch in 1:1000
         Flux.train_enzyme!(fn, model, data, optim)
       end
  0.521387 seconds (7.05 M allocations: 296.158 MiB, 5.02% gc time)

julia> @time for epoch in 1:1000
         Flux.train_enzyme!(fn, model, data, optim)
       end
  0.524848 seconds (7.05 M allocations: 296.158 MiB, 4.13% gc time)

@wsmoses wsmoses force-pushed the master branch 3 times, most recently from 966be4a to 3ce9e41 Compare May 14, 2024 04:59
@wsmoses
Copy link
Contributor Author

wsmoses commented May 14, 2024

bumping @CarloLucibello or @ToucheSir for review

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine, and I guess the minimal version is to define this only in the tests.

Here's one idea for a public-facing API. We could have two methods like this:

train!(loss, model, data, opt) --> Zygote
train!(loss, model_and_shadow::Duplicated, data, opt) --> Enzyme

ideally with Enzyme.Duplicated(x) = Duplicated(x, Enzyme.make_zero(x)) so that you can call it train!(loss, Duplicated(model), data, opt).

That suggests also defining methods of withgradient like this:

val, grads = Flux.withgradient(loss, model, data) --> Zygote
val, grads = Flux.withgradient(loss, Duplicated(model), data) --> Enzyme

That's a minimal change to select Enzyme. But unlike just passing some token like AutoEnzyme(), this Duplicated struct does other things... you can make in advance if you wish. And it's fairly obvious that you cannot do this without using Enzyme.

We could go one further and define a method

update!(opt_state, model_and_grad::Duplicated)

That's a bigger change away from calling val, grads = Flux.withgradient(... as you would discard what that returns, and hold onto the Duplicated. But perhaps quite neat.

test/train.jl Outdated Show resolved Hide resolved
src/train.jl Outdated Show resolved Hide resolved
src/train.jl Outdated Show resolved Hide resolved
@wsmoses
Copy link
Contributor Author

wsmoses commented May 16, 2024

@mcabbott did the api change suggested.

@mcabbott
Copy link
Member

mcabbott commented May 16, 2024

Thanks!

Thoughts on defining one-arg Duplicated(x) = Duplicated(x, Enzyme.make_zero(x))? That seems like what you'd almost always want, but perhaps I'm overlooking something.

@wsmoses
Copy link
Contributor Author

wsmoses commented May 16, 2024

I'm quite hesitant to doing so for a couple of reasons (including that duplicated isn't necessarily all you want and making an explicit second argument makes the user aware it is updating something else in place). Analagously, a one arg duplicated, if passed directly into autodiff like as follows:

autodiff(Reverse, sum, Duplicated(x));

Of course duplicated will cause the shadow to be updated in place. But since the user didn't store the dval, they don't have the derivative available anywhere and will end up confused. Of course they could do duplicated(x) before the autodiff and store x.dval somewhere, but I'd like to avoid confusion by design if possible.

@wsmoses
Copy link
Contributor Author

wsmoses commented May 17, 2024

@mcabbott the present API now hits the following. Thoughts?

  MethodError: train!(::var"#2104#loss#143"{Matrix{Float64}}, ::Duplicated{@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}, ignore::Nothing}}, ::Base.Generator{UnitRange{Int64}, var"#140#144"}, ::Descent) is ambiguous.
  
  Candidates:
    train!(loss, model, data, opt::Flux.Optimise.AbstractOptimiser; cb)
      @ Flux ~/work/Flux.jl/Flux.jl/src/deprecations.jl:110
    train!(loss, model_and_shadow::Duplicated, data, opt_state)
      @ Flux.Train ~/work/Flux.jl/Flux.jl/src/train.jl:124
  
  Possible fix, define
    train!(::Any, ::Duplicated, ::Any, ::Flux.Optimise.AbstractOptimiser)

@wsmoses wsmoses closed this May 29, 2024
@wsmoses wsmoses reopened this May 29, 2024
@wsmoses
Copy link
Contributor Author

wsmoses commented May 31, 2024

@darsnack @ToucheSir @mcabbott bumping for review

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good minus figuring out the dispatch issue and moving the docstring. Sorry the dispatch is a mess trying to keep support for the old optimizer interface.

We could consider merging this change in with dropping support for implicit optimizers completely.

src/losses/utils.jl Outdated Show resolved Hide resolved
src/train.jl Outdated Show resolved Hide resolved
src/train.jl Outdated Show resolved Hide resolved
src/train.jl Outdated Show resolved Hide resolved
src/train.jl Outdated Show resolved Hide resolved
@wsmoses
Copy link
Contributor Author

wsmoses commented Jun 16, 2024

@darsnack made the appropriate changes, mind giving it a final once over and merging?

@wsmoses
Copy link
Contributor Author

wsmoses commented Jun 21, 2024

gentle bump

@darsnack
Copy link
Member

For some reason the "allow edits by maintainers" is not letting me push to your fork. Can you manually add write permissions on your fork for me?

@wsmoses
Copy link
Contributor Author

wsmoses commented Jun 21, 2024

@darsnack weird, but in any case added!

@wsmoses
Copy link
Contributor Author

wsmoses commented Jun 22, 2024

@darsnack at least one of these failures is due to removing the custom rule for params

@darsnack
Copy link
Member

Okay it looks ready to me. Can you give it a once over then I can merge it?

@wsmoses
Copy link
Contributor Author

wsmoses commented Jun 23, 2024

@darsnack lgtm!

After this we should open a tracking PR for seeing the status of Flux+CUDA/AMDGPU+Enzyme training

@darsnack darsnack merged commit 0be4401 into FluxML:master Jun 24, 2024
5 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants