-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compile training loop with Reactant #673
Conversation
0c35c52
to
e52e708
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: ceaf4d3 | Previous: 2a55829 | Ratio |
---|---|---|---|
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) |
3634.25 ns |
3661.875 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) |
7235.5 ns |
7419.8 ns |
0.98 |
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) |
20949 ns |
20999 ns |
1.00 |
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) |
9772.4 ns |
9738 ns |
1.00 |
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) |
8837 ns |
9050.8 ns |
0.98 |
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) |
4447.125 ns |
4475.875 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) |
4658.75 ns |
4693.875 ns |
0.99 |
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) |
1110.8733333333332 ns |
1112.656050955414 ns |
1.00 |
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) |
1188.320610687023 ns |
1169.8661971830986 ns |
1.02 |
Dense(2 => 2)/cpu/forward/Flux/(2, 128) |
1789.6296296296296 ns |
1777.3272727272727 ns |
1.01 |
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) |
179.2549019607843 ns |
179.5702364394993 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) |
17392 ns |
17302 ns |
1.01 |
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) |
16862 ns |
17032 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) |
37139 ns |
37040 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) |
29205 ns |
29184 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) |
21390 ns |
21551 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) |
17222 ns |
17293 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) |
25437 ns |
25498 ns |
1.00 |
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) |
3844.625 ns |
3844.75 ns |
1.00 |
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) |
3952.375 ns |
3933.5 ns |
1.00 |
Dense(20 => 20)/cpu/forward/Flux/(20, 128) |
4834.857142857143 ns |
4952.142857142857 ns |
0.98 |
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) |
1656 ns |
1654.1 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) |
40070662 ns |
50721668.5 ns |
0.79 |
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) |
58285994 ns |
58546553 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) |
82358150 ns |
101678849 ns |
0.81 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) |
81107120 ns |
101190618 ns |
0.80 |
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) |
78168071 ns |
78796997 ns |
0.99 |
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) |
11628216.5 ns |
12343077 ns |
0.94 |
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) |
91647252 ns |
92604699 ns |
0.99 |
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) |
7672283.5 ns |
7675430 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) |
7611777 ns |
7608555 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) |
11686366 ns |
12512726 ns |
0.93 |
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) |
6385580 ns |
6420844 ns |
0.99 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) |
693526419 ns |
698316993 ns |
0.99 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) |
2584030394 ns |
2576787594 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) |
144300903.5 ns |
141986062.5 ns |
1.02 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) |
791406151 ns |
893749885.5 ns |
0.89 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) |
3067094223 ns |
3362911269 ns |
0.91 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) |
221829171 ns |
207659968 ns |
1.07 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) |
787350317.5 ns |
857934198 ns |
0.92 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) |
2609668860 ns |
2864984030 ns |
0.91 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) |
127890003 ns |
149742910 ns |
0.85 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) |
176812287 ns |
177136415.5 ns |
1.00 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) |
662011151 ns |
662230879 ns |
1.00 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) |
34877100 ns |
36459877 ns |
0.96 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) |
167010500 ns |
167990874.5 ns |
0.99 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) |
652562919 ns |
655057499 ns |
1.00 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) |
30473235.5 ns |
35749851 ns |
0.85 |
vgg16/cpu/forward/Flux/(32, 32, 3, 16) |
213466842 ns |
228962990.5 ns |
0.93 |
vgg16/cpu/forward/Flux/(32, 32, 3, 64) |
762734758 ns |
847928462 ns |
0.90 |
vgg16/cpu/forward/Flux/(32, 32, 3, 2) |
40460078.5 ns |
37808714 ns |
1.07 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) |
1261696884.5 ns |
1328888452 ns |
0.95 |
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) |
1875037129.5 ns |
1883995892 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) |
2276188576 ns |
2418351457 ns |
0.94 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) |
2418092871 ns |
2440238931 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) |
1896459269.5 ns |
1948384572 ns |
0.97 |
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) |
2060237643 ns |
2084376262 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) |
339686792 ns |
336503625 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) |
331852824 ns |
336582406 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) |
405068534 ns |
456801176 ns |
0.89 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) |
11937972 ns |
11858627 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) |
18066755.5 ns |
18114833 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) |
19171977 ns |
19305583 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) |
23806904 ns |
24143369.5 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) |
17821618 ns |
17990331 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) |
1158131 ns |
1181481.5 ns |
0.98 |
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) |
23016860 ns |
23281651.5 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) |
2282789 ns |
2329001 ns |
0.98 |
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) |
2206837 ns |
2259670 ns |
0.98 |
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) |
2070232 ns |
2106879 ns |
0.98 |
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) |
199262 ns |
216514 ns |
0.92 |
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) |
291294 ns |
295862 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) |
265065 ns |
269132.5 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) |
367124 ns |
370853 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) |
409343.5 ns |
416127 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) |
273731 ns |
278089 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) |
405005 ns |
409124 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) |
393324 ns |
399266 ns |
0.99 |
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) |
80921 ns |
84137 ns |
0.96 |
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) |
82213.5 ns |
86591 ns |
0.95 |
Dense(200 => 200)/cpu/forward/Flux/(200, 128) |
86782 ns |
87734 ns |
0.99 |
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) |
104335 ns |
104285 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) |
199086726 ns |
196679056 ns |
1.01 |
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) |
329556427 ns |
331090526 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) |
417942847.5 ns |
446330080.5 ns |
0.94 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) |
438080100.5 ns |
501425011 ns |
0.87 |
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) |
387166012 ns |
419437580.5 ns |
0.92 |
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) |
320184207 ns |
346091233 ns |
0.93 |
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) |
472105455.5 ns |
483112933.5 ns |
0.98 |
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) |
47204961 ns |
47755090 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) |
46703341 ns |
47275593 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) |
52498749.5 ns |
57880765 ns |
0.91 |
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) |
27950628.5 ns |
28154162.5 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) |
18942323.5 ns |
19633406 ns |
0.96 |
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) |
19627356 ns |
19863886 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) |
23670610 ns |
23978995 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) |
24290902.5 ns |
24483866 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) |
19726361.5 ns |
19853090.5 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) |
21131087 ns |
21211617 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) |
6616113.5 ns |
6587852 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) |
6562468 ns |
6555306 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) |
6579424 ns |
6543112 ns |
1.01 |
This comment was automatically generated by workflow using github-action-benchmark.
This comment was marked as outdated.
This comment was marked as outdated.
7d90e08
to
db8fc3d
Compare
2e2b6b6
to
be504a3
Compare
d661097
to
25325ea
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #673 +/- ##
==========================================
- Coverage 96.13% 93.61% -2.52%
==========================================
Files 54 58 +4
Lines 2818 2868 +50
==========================================
- Hits 2709 2685 -24
- Misses 109 183 +74 ☔ View full report in Codecov by Sentry. |
be504a3
to
96f6cb1
Compare
This comment was marked as outdated.
This comment was marked as outdated.
1311b31
to
88a3742
Compare
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
88a3742
to
982c67d
Compare
@avik-pal that should now be fixed |
Can confirm this works, now I need to finish the wrapper. Once it is merged, it should be easy to get a list of all the models that don't work. |
Ooof Optimisers is going to be a bit of a pain, it seems to do a lot of operations (not sure why):
|
982c67d
to
9ecca42
Compare
I mean we can quickly try to add overloads if you have relevant backtraces? |
Here is one which was simple using ChainRulesCore
Base.:+(a::Reactant.TracedRArray, ::AbstractZero) = a The other one: julia> Lux.Experimental.single_train_step(AutoReactant(), loss_fn, data, ts)
ERROR: MethodError: no method matching elem_apply(::Type{Float32}, ::Reactant.TracedRArray{Float32, (5, 10), 2})
Closest candidates are:
elem_apply(::typeof(*), ::Reactant.TracedRArray{ElType, Shape, N}, ::Reactant.TracedRArray{ElType, Shape, N}) where {ElType, Shape, N}
@ Reactant ~/.julia/packages/Reactant/LF3m2/src/overloads.jl:470
elem_apply(::typeof(*), ::Any, ::Reactant.TracedRArray{ElType, Shape, N}) where {ElType, Shape, N}
@ Reactant ~/.julia/packages/Reactant/LF3m2/src/overloads.jl:495
elem_apply(::typeof(*), ::Reactant.TracedRArray{ElType, Shape, N}, ::Any) where {ElType, Shape, N}
@ Reactant ~/.julia/packages/Reactant/LF3m2/src/overloads.jl:483
...
Stacktrace:
[1] _copyto!(dest::Reactant.TracedRArray{Float32, (5, 10), 2}, bc::Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Type{Float32}, Tuple{Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{…}, Nothing, typeof(-), Tuple{…}}}})
@ Reactant ~/.julia/packages/Reactant/LF3m2/src/overloads.jl:832
[2] copyto!
@ ~/.julia/packages/Reactant/LF3m2/src/overloads.jl:750 [inlined]
[3] copyto!
@ ./broadcast.jl:956 [inlined]
[4] copy
@ ~/.julia/packages/Reactant/LF3m2/src/overloads.jl:740 [inlined]
[5] overdub(context::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, f::typeof(copy), args::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{…}, Tuple{…}, Type{…}, Tuple{…}})
@ Reactant ~/.julia/packages/Reactant/LF3m2/src/overloads.jl:627
[6] materialize(::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{2}, Nothing, Type{Float32}, Tuple{Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{2}, Nothing, typeof(-), Tuple{Reactant.TracedRArray{Float32, (5, 10), 2}, Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Reactant.TracedRArray{Float32, (5, 10), 2}, Float32}}}}}})
@ ./broadcast.jl:903 [inlined]
[7] materialize
@ ./broadcast.jl:903 [inlined]
[8] subtract!(::Reactant.TracedRArray{Float32, (5, 10), 2}, ::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Reactant.TracedRArray{Float32, (5, 10), 2}, Float32}})
@ ~/.julia/packages/Optimisers/yDIWk/src/interface.jl:103 [inlined]
[9] subtract!
@ ~/.julia/packages/Optimisers/yDIWk/src/interface.jl:103 [inlined] That should also just be a noop probably |
Yeah -- or more specifically we have it call convert [and that will get compiled to a noop] |
ext/LuxReactantExt/train.jl
Outdated
end | ||
|
||
function __update_fn_wrapper(obj_fn, model, ps, dps, st, st_opt, data) | ||
_, (loss, st_, stats) = Enzyme.autodiff( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor, but this is going to be more efficient if dps is defined within the function and doesn't escape the scope.
Similarly it would be more efficient to have this do ps .= ps_ and not have to return the second allocation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other words all temporaries/allocations that cross the compiled function boundaries need to be materialized, but all variables defined only within can be optimized out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you why I am able to return a loss, st_, stats
here without enzyme complaining that I specified the return type as Active?
In other words all temporaries/allocations that cross the compiled function boundaries need to be materialized, but all variables defined only within can be optimized out
Makes sense I will do that.
@avik-pal here's your fix for the above: EnzymeAD/Reactant.jl#25 |
@avik-pal we just released a version with better broadcast, so I'm now curious if the earlier things you did to work around that can be removed? |
In the example, I just used a hacky loss function without the operations that reactant did not support. Let me try with a proper loss now |
Also I’d actually probably try benchmarking on cuda
…On Sun, Jul 14, 2024 at 3:18 PM Avik Pal ***@***.***> wrote:
@avik-pal <https://github.com/avik-pal> we just released a version with
better broadcast, so I'm now curious if the earlier things you did to work
around that can be removed?
In the example, I just used a hacky loss function without the operations
that reactant did not support. Let me try with a proper loss now
—
Reply to this email directly, view it on GitHub
<#673 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXGWUUYFBKFGMXDN6TLZMLFJ7AVCNFSM6AAAAABIUKS4NCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMRXGQ2TOMZSGY>
.
You are receiving this because you were mentioned.Message ID: <LuxDL/Lux.
***@***.***>
|
48bb523
to
434a744
Compare
bd24029
to
ceaf4d3
Compare
@avik-pal with things released [and cuda functional], anything blocking here? |
No this is mostly functional, I will have to do a rebase once the other PR lands. |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
ceaf4d3
to
c47a9d0
Compare
superceded by #969 |
Pulling out training logic from #665, because this is simpler to implement and merge.
Example Usage
We can simply replace
AutoEnzyme
or any of those backends withAutoReactant
and usesingle_train_step
orsingle_train_step!
and everyTODO: Code example
Upstream Needs
Meta.quot
EnzymeAD/Reactant.jl#34TODOs
ReactantBackend
single_train_step!