-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto compile Lux models to reactant #665
Conversation
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 8ce9707 | Previous: 60c595e | Ratio |
---|---|---|---|
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) |
3669.375 ns |
3646.75 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) |
7116.5 ns |
7285 ns |
0.98 |
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) |
20819 ns |
21210 ns |
0.98 |
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) |
9519.8 ns |
9781.666666666666 ns |
0.97 |
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) |
8884.25 ns |
9087.2 ns |
0.98 |
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) |
4494.625 ns |
4453.888888888889 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) |
1168.6785714285713 ns |
1176.2706766917292 ns |
0.99 |
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) |
1112.012658227848 ns |
1112.28025477707 ns |
1.00 |
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) |
1181.0820895522388 ns |
1189.374074074074 ns |
0.99 |
Dense(2 => 2)/cpu/forward/Flux/(2, 128) |
1780.4406779661017 ns |
1814.3181818181818 ns |
0.98 |
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) |
179.55460992907803 ns |
179.93324061196105 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) |
17252 ns |
17212 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) |
17292 ns |
17463 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) |
36588 ns |
36689 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) |
27992 ns |
28147.5 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) |
19677 ns |
20058 ns |
0.98 |
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) |
16982 ns |
16921 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) |
4298 ns |
4310.5 ns |
1.00 |
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) |
3867.25 ns |
3867.25 ns |
1 |
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) |
3952.375 ns |
3951.125 ns |
1.00 |
Dense(20 => 20)/cpu/forward/Flux/(20, 128) |
4775.357142857143 ns |
4787.571428571428 ns |
1.00 |
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) |
1664.2 ns |
1659.1 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) |
39058487.5 ns |
38839150 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) |
58048876.5 ns |
57478179 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) |
69429268 ns |
68637336 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) |
88231886 ns |
80248739.5 ns |
1.10 |
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) |
72608460 ns |
66510498 ns |
1.09 |
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) |
12052157 ns |
11601127 ns |
1.04 |
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) |
8361274 ns |
8302158.5 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) |
7009826 ns |
6958814.5 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) |
6996013 ns |
6935871 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) |
9927628.5 ns |
9886349 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) |
6387327.5 ns |
6377484 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) |
714625608 ns |
711495815.5 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) |
2809656510 ns |
2802293498 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) |
143359748 ns |
158450926 ns |
0.90 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) |
742274565 ns |
745197995 ns |
1.00 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) |
2535181882 ns |
2536517155 ns |
1.00 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) |
194702054 ns |
186814591 ns |
1.04 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) |
646623448.5 ns |
698620045 ns |
0.93 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) |
2730166614 ns |
2703329300 ns |
1.01 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) |
119575299 ns |
122294200.5 ns |
0.98 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) |
194591286 ns |
172044480 ns |
1.13 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) |
648527666 ns |
643441503 ns |
1.01 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) |
45366635 ns |
45114156 ns |
1.01 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) |
164210512 ns |
163454975.5 ns |
1.00 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) |
636442834 ns |
628139701 ns |
1.01 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) |
29909710.5 ns |
29335904 ns |
1.02 |
vgg16/cpu/forward/Flux/(32, 32, 3, 16) |
186029419 ns |
207955667.5 ns |
0.89 |
vgg16/cpu/forward/Flux/(32, 32, 3, 64) |
711562847 ns |
722173872 ns |
0.99 |
vgg16/cpu/forward/Flux/(32, 32, 3, 2) |
35379449 ns |
37423155 ns |
0.95 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) |
1232480179.5 ns |
1242027523.5 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) |
1864117910.5 ns |
1847309072 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) |
2009631582.5 ns |
1988297584 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) |
2333653016.5 ns |
2337208631 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) |
1783708622.5 ns |
1825164998 ns |
0.98 |
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) |
344253318 ns |
347875405.5 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) |
316840031.5 ns |
318366365 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) |
318335263.5 ns |
319738018 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) |
344631084 ns |
452781616 ns |
0.76 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) |
11976289 ns |
11803413 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) |
18037586 ns |
17882962 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) |
19230996 ns |
19018033 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) |
23857907 ns |
23755630 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) |
18084016 ns |
17832966.5 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) |
1155306.5 ns |
1148767 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) |
2520956 ns |
2512938 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) |
2047653 ns |
2035570 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) |
2054475 ns |
2023578.5 ns |
1.02 |
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) |
2070591 ns |
2055760 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) |
201787 ns |
195727.5 ns |
1.03 |
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) |
289931.5 ns |
288322 ns |
1.01 |
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) |
264083 ns |
262603 ns |
1.01 |
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) |
365111 ns |
354936.5 ns |
1.03 |
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) |
406990 ns |
400938 ns |
1.02 |
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) |
272469 ns |
270257 ns |
1.01 |
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) |
405988 ns |
397421 ns |
1.02 |
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) |
83225.5 ns |
83306 ns |
1.00 |
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) |
80735.5 ns |
80271 ns |
1.01 |
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) |
81221 ns |
80581 ns |
1.01 |
Dense(200 => 200)/cpu/forward/Flux/(200, 128) |
86432 ns |
85480 ns |
1.01 |
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) |
104584 ns |
104617 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) |
193678547 ns |
187932820.5 ns |
1.03 |
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) |
322350064 ns |
321827872.5 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) |
390601991 ns |
393773632.5 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) |
455632147 ns |
454117809 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) |
342018490 ns |
366877761 ns |
0.93 |
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) |
324867768.5 ns |
309426428 ns |
1.05 |
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) |
51185900 ns |
51303991 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) |
43468243 ns |
43675671.5 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) |
43323005 ns |
43447693 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) |
59553933.5 ns |
49289683 ns |
1.21 |
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) |
27964827 ns |
28489085 ns |
0.98 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) |
18639172.5 ns |
18511523 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) |
19409965 ns |
19373919.5 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) |
23095282.5 ns |
22860858 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) |
24104485.5 ns |
23821494.5 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) |
19543279.5 ns |
19452776.5 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) |
6497967 ns |
6471809.5 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) |
6517959 ns |
6467840.5 ns |
1.01 |
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) |
6472398.5 ns |
6458192 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) |
6506948 ns |
6475071.5 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #665 +/- ##
==========================================
- Coverage 87.11% 80.26% -6.85%
==========================================
Files 50 55 +5
Lines 2515 2671 +156
==========================================
- Hits 2191 2144 -47
- Misses 324 527 +203 ☔ View full report in Codecov by Sentry. |
This comment was marked as outdated.
This comment was marked as outdated.
@avik-pal can this be done in a way that reactant compiles the whole update, not just the gradient as separate from the inference pass. Specifically, I expect there to be a substantial perf improvement from doing so -- including the model update actually fully occuring in place. E.g. the function reactant compiles being something like
|
Not with the layers API. Currently, if we can accelerate just the neural network part, I would consider it a good win. Also, having it like this makes it possible to use regular Julia ops for cases where we can't compile to Reactant, for example, the ODE solves happen in Julia and the neural network is via XLA. We can add |
This comment was marked as outdated.
This comment was marked as outdated.
Okay things are working mostly now, we just need a copyto! for TracedRArray |
using Reactant, Lux, Random, ComponentArrays
model = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
reactant_model = ToReactantAdaptor{true}(rand(Float32, 10, 3))(model) Gives me a |
2c67548
to
fb7ea0a
Compare
@avik-pal you should oopen an issue with the pipeline error on Reactant, once the prereqs are merged |
ext/LuxReactantExt.jl
Outdated
@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y) | ||
|
||
# Reactant doesn't handle mixed eltypes that well, so we will first try to compile it as | ||
# a usual julia function. However, if that fails, we will type cast and try to recompile. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should be able to fix mixed eltypes if you have an mwe by chance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using Reactant, Lux, Random
model = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
reactant_model = ToReactantAdaptor{true}(rand(10, 3))(model)
ext/LuxReactantExt.jl
Outdated
ps = __try_similar_structure(Lux.__named_tuple(ps), l.concrete_ps) | ||
ps = l.adaptor(ps) | ||
l.eltype_adaptor !== nothing && (ps = adapt(l.eltype_adaptor, ps)) | ||
ps = __make_concrete_array(ps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be set up where the ps are already reactant arrays so we don't need to call __make_concrete_array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[hope would be avoiding data shuffling, especially cpu<->gpu of the whole model]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be hit in most cases. I wasn't able to compile a ComponentArrays based version yet (needs a make_tracer
overload), so it is a temporary solution.
The correct use of this should hit L163
@wsmoses seems like an incorrect generation? Module:
module attributes {transform.with_named_sequence} {
func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<3x10xf32>, %arg3: tensor<10x5xf32>, %arg4: tensor<10x5xf32>) -> (tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<5x3xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<5xf32>
%0 = stablehlo.reshape %arg1 : (tensor<1x5xf32>) -> tensor<5xf32>
%1 = stablehlo.dot_general %arg4, %arg2, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<10x5xf32>, tensor<3x10xf32>) -> tensor<5x3xf32>
%2 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<5xf32>) -> tensor<5x3xf32>
%3 = stablehlo.add %1, %2 : tensor<5x3xf32>
%4 = stablehlo.tanh %3 : tensor<5x3xf32>
%5 = stablehlo.reduce(%4 init: %cst_0) applies stablehlo.add across dimensions = [0, 1] : (tensor<5x3xf32>, tensor<f32>) -> tensor<f32>
%6 = stablehlo.multiply %4, %4 : tensor<5x3xf32>
%7 = stablehlo.subtract %cst, %6 : tensor<5x3xf32>
%8 = stablehlo.reduce(%7 init: %cst_1) across dimensions = [1] : (tensor<5x3xf32>, tensor<5xf32>) -> tensor<5xf32>
reducer(%arg5: tensor<5xf32>, %arg6: tensor<5xf32>) {
%13 = stablehlo.add %arg5, %arg6 : tensor<5xf32>
stablehlo.return %13 : tensor<5xf32>
}
%9 = stablehlo.dot_general %arg2, %7, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x10xf32>, tensor<5x3xf32>) -> tensor<10x5xf32>
%10 = stablehlo.add %arg3, %9 : tensor<10x5xf32>
%11 = stablehlo.reshape %8 : (tensor<5xf32>) -> tensor<1x5xf32>
%12 = stablehlo.add %arg0, %11 : tensor<1x5xf32>
return %12, %5, %10 : tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>
}
}
terminate called after throwing an instance of 'xla::XlaRuntimeError'
what(): UNKNOWN: <unknown>:0: error: Reduction function must return a scalar or tuple of scalars but returns shape: f32[5]:
<unknown>:0: note: see current operation: "func.return"(%15, %8, %13) : (tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>) -> () |
@avik-pal the lux fixes (and named tuple) just landed and were released. I'll give the reduction error a go shortly, but at minimum we can see what works (and perhaps mark that as expected broken to start with) |
Currently the julia session crashes because of the broken reverse pass, so can't mark it broken |
Can we have a no copy transfer between Julia AbstractArrays and Reactant/XLA Arrays? This makes life simpler to support wrapper types like Also we can keep the parameters as regular Julia arrays which works more nicely with the current optimisers and such |
Not easily as we need to own the data.
Similarly the model and ideally the inputs are always kept rarrays here.
Also for better performance the optimizers themselves are compiled by
reactant
…On Sat, Jun 1, 2024 at 8:33 AM Avik Pal ***@***.***> wrote:
Can we have a no copy transfer between Julia AbstractArrays and
Reactant/XLA Arrays? This makes life simpler to support wrapper types like
ComponentArrays.
Also we can keep the parameters as regular Julia arrays which works more
nicely with the current optimisers and such
—
Reply to this email directly, view it on GitHub
<#665 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXFOCHIUSBR7OQWBYO3ZFFTNJAVCNFSM6AAAAABIKL3C3GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBTGMZDIMJQGM>
.
You are receiving this because you were mentioned.Message ID: <LuxDL/Lux.
***@***.***>
|
Right but I don't think we would be able to compile NeuralODE style models yet right? So having an eager version that can perform operations directly on RArrays seems like a good tradeoff to run part of the model is regular Julia. I might pull out the AutoReactant code (compiling the training iteration) into a separate PR because that would be easier to merge. |
There’s no reason why we couldn’t in theory, but I don’t think we do right
now.
Worth testing and opening an issue so we know what to work on though
…On Sat, Jun 1, 2024 at 5:46 PM Avik Pal ***@***.***> wrote:
Also for better performance the optimizers themselves are compiled by
reactant
Right but I don't think we would be able to compile NeuralODE style models
yet right? So having an eager version that can perform operations directly
on RArrays seems like a good tradeoff to run part of the model is regular
Julia.
I might pull out the AutoReactant code (compiling the training iteration)
into a separate PR because that would be easier to merge.
—
Reply to this email directly, view it on GitHub
<#665 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXBAIIEC7N2TEXAKBB3ZFHUGXAVCNFSM6AAAAABIKL3C3GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBTGQ4TGMJSGM>
.
You are receiving this because you were mentioned.Message ID: <LuxDL/Lux.
***@***.***>
|
25325ea
to
e931a5e
Compare
@avik-pal fix has landed, can we retry this? |
This one is too broadly scoped, so I will hold it off. First, I want to finish #673, which compiles the entire training loop and doesn't need to worry about users doing unwanted things to the parameters. |
reworking the partial model compilation logic in #999. |
Example Usage
This follows the same structure as SimpleChains. User demands a conversion and provides an input prototype.
Upstream Needs
create_result
forNamedTuple
EnzymeAD/Reactant.jl#8 for extending the support toLuxCore.apply
instead ofLuxCore.apply
TODOs
__make_reactant_array