-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add a Lux example #5
Conversation
test/nn_lux.jl
Outdated
losses = [] | ||
for epoch in 1:1_000 | ||
for (x, y) in loader | ||
loss, grads = Flux.withgradient(model) do m |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we're at it, let's do training do. We should make a new function which will get compiled which contains both the autodiff and update! [ideally with Enzyme now], want to try?
am adding similar in bg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to rewrite this part with LuxDL/Lux.jl#640
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You want to pull in Lux.Experimental.apply_gradients!
and compile it together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized this one is updated using the original model and not the compiled model
Also you only need to tracer through if a variable contains data like an array |
Let’s do it.
Also btw this will test both cpu and gpu
…On Tue, May 14, 2024 at 12:09 PM Avik Pal ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In test/nn_lux.jl
<#5 (comment)>:
> +comp = f(cmodel, cnoisy, cps, cst)
***@***.*** comp[3]
***@***.*** f(cmodel, cnoisy)
+
+# To train the model, we use batches of 64 samples, and one-hot encoding:
+target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix
+loader = Flux.DataLoader((noisy, target); batchsize=64, shuffle=true);
+# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix)
+
+optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc.
+
+# Training loop, using the whole data set 1000 times:
+losses = []
+for epoch in 1:1_000
+ for (x, y) in loader
+ loss, grads = Flux.withgradient(model) do m
You want to pull in Lux.Experimental.apply_gradients! and compile it
together?
—
Reply to this email directly, view it on GitHub
<#5 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXC432WZZKTX7G27BV3ZCJOOHAVCNFSM6AAAAABHWUKINSVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDANJWGIYDEMBVGQ>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Couple of Problems
julia> comp = f(cmodel, cnoisy, cps, cst)
ERROR: UndefVarError: `layer_2` not defined
Stacktrace:
[1] (::Reactant.var"#109#110")(arg_1::Chain{…}, arg_2::Reactant.ConcreteRArray{…}, arg_3::@NamedTuple{…}, arg_4::@NamedTuple{…})
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:704
[2] top-level scope
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:29
Some type information was truncated. Use `show(err)` to see complete types.
|
So it is possible to add getindex/etc here, but I'm intentionally preventing so now so we can ensure we trace fully vectorized code. And yeah, we should just add an overload for mean, want to give it a shot? |
This doesn't seem right: julia> @code_lowered f(cmodel, cnoisy, cps, cst)
CodeInfo(
1 ─ %1 = Reactant.XLA.synced_buffer
│ %2 = Base.getfield
│ %3 = Base.getfield
│ %4 = (%3)(arg_3, Reactant.layer_2)
│ %5 = (%2)(%4, Reactant.weight)
│ %6 = Base.getproperty(%5, :data)
│ sbuf_1 = (%1)(%6)
│ %8 = Reactant.XLA.synced_buffer
│ %9 = Base.getfield
│ %10 = Base.getfield
│ %11 = (%10)(arg_3, Reactant.layer_1)
│ %12 = (%9)(%11, Reactant.weight)
│ %13 = Base.getproperty(%12, :data)
│ sbuf_2 = (%8)(%13)
│ %15 = Reactant.XLA.synced_buffer
│ %16 = Base.getfield
│ %17 = Base.getfield
│ %18 = (%17)(arg_3, Reactant.layer_1)
│ %19 = (%16)(%18, Reactant.bias)
│ %20 = Base.getproperty(%19, :data)
│ sbuf_3 = (%15)(%20)
│ %22 = Reactant.XLA.synced_buffer
│ %23 = Base.getfield
│ %24 = Base.getfield
│ %25 = (%24)(arg_3, Reactant.layer_2)
│ %26 = (%23)(%25, Reactant.bias)
│ %27 = Base.getproperty(%26, :data)
│ sbuf_4 = (%22)(%27)
│ %29 = Reactant.XLA.synced_buffer
│ %30 = Base.getproperty(arg_2, :data)
│ sbuf_5 = (%29)(%30)
│ %32 = $(Expr(:gc_preserve_begin, :(sbuf_1), :(sbuf_2), :(sbuf_3), :(sbuf_4), :(sbuf_5)))
│ %33 = Reactant.XLA.ExecutableCall
│ %34 = Base.getproperty(sbuf_1, :buffer)
│ %35 = Base.getproperty(sbuf_2, :buffer)
│ %36 = Base.getproperty(sbuf_3, :buffer)
│ %37 = Base.getproperty(sbuf_4, :buffer)
│ %38 = Base.getproperty(sbuf_5, :buffer)
│ %39 = Core.tuple(%34, %35, %36, %37, %38)
│ %40 = Reactant.Val(1)
│ %41 = (%33)(Reactant.XLA.LoadedExecutable(Ptr{Nothing} @0x000000001114fd40), %39, (0x01, 0x01, 0x01, 0x01, 0x01), %40)
│ linearized_results = %41
│ $(Expr(:gc_preserve_end, :(%32)))
│ concrete_res_1 = Base.getindex(linearized_results, 1)
│ result = (Reactant.ConcreteRArray{Float32, (2, 1000), 2})(concrete_res_1)
└── return result
) it shouldn't be |
Makes sense, I added ArrayInterface to allow easy checking for that in downstream codes |
Ah we should probably add an escape in the macro |
wait that's odd though it should be a symbol there being looked up? |
It was directly getting interpolated, needed a |
function xlogy(x, y)
result = x * log(y)
return ifelse(iszero(x), zero(result), result)
end
function crossentropy(ŷ, y)
return .-sum(xlogy.(y, ŷ))
end
function loss_function(model, x, y, ps, st)
y_hat, _ = model(x, ps, st)
return crossentropy(y_hat, y)
end
compiled_loss_function = Reactant.compile(
loss_function, (cmodel, cnoisy, ctarget, cps, CST))
The eltypes don't match MethodError: no method matching elem_apply(::typeof(xlogy), ::Reactant.TracedRArray{Bool, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}) |
My failed attempt at defining the comparisons for (jlop, hloop, hlocomp, RT) in ((:(Base.:(==)), :compare, 0, :ElType),)
@eval begin
function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data;
comparison_direction=$hlocomp), 1))
end
function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=$hlocomp), 1))
end
function elem_apply(::typeof($jlop), lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=$hlocomp), 1))
end
end
end How do I pass the comparison direction enum? https://openxla.org/stablehlo/spec#compare |
Reactant.jl/src/mlir/libMLIR_h.jl Line 8252 in afc44a0
|
Seems like compiling the gradient is hitting julia> compiled_gradient = Reactant.compile(
gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst))
ERROR: MethodError: no method matching (Reactant.TracedRArray{Float32, Shape, 2} where Shape)(::Tuple{}, ::Reactant.MLIR.IR.Value)
Stacktrace:
[1] make_zero(::Type{…}, seen::IdDict{…}, prev::Reactant.TracedRArray{…}, ::Val{…})
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:149
[2] #42
@ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251 [inlined]
[3] ntuple
@ ./ntuple.jl:19 [inlined]
[4] make_zero
@ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249 [inlined]
[5] make_zero(::Type{…}, seen::IdDict{…}, prev::@NamedTuple{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256
[6] (::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}})(i::Int64)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251
[7] ntuple(f::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}}, n::Int64)
@ Base ./ntuple.jl:19
[8] make_zero(::Type{…}, seen::IdDict{…}, prev::Tuple{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249
[9] make_zero(::Type{…}, seen::IdDict{…}, prev::@NamedTuple{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256
[10] make_zero (repeats 2 times)
@ ~/.julia/packages/EnzymeCore/Z0CgU/src/EnzymeCore.jl:237 [inlined]
[11] overdub
@ /mnt/research/ongoing/lux/Reactant.jl/src/overloads.jl:358 [inlined]
[12] gradient_loss_function(::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::@NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}})
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined]
[13] gradient_loss_function
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined]
[14] overdub(::Cassette.Context{…}, ::typeof(gradient_loss_function), ::Chain{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
[15] (::Reactant.var"#5#13"{typeof(gradient_loss_function), Tuple{}, Reactant.MLIR.IR.Block, Vector{…}, Tuple{…}})()
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:53
[16] block!(f::Reactant.var"#5#13"{…}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Block.jl:198
[17] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool)
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:46
[18] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, typeof(gradient_loss_function), Tuple{…}, Int64})()
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:927
[19] mmodule!(f::Reactant.var"#100#105"{…}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Module.jl:89
[20] #99
@ /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:925 [inlined]
[21] context!(f::Reactant.var"#99#104"{typeof(gradient_loss_function), Tuple{…}, Int64}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Context.jl:68
[22] compile(f::typeof(gradient_loss_function), args::Tuple{…}; pipeline_options::String, client::Nothing)
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:923
[23] compile(f::typeof(gradient_loss_function), args::Tuple{…})
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:918
[24] top-level scope
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:70
Some type information was truncated. Use `show(err)` to see complete types. |
Oh we should just provide a hook into make_zero for tracedrarray
…On Fri, May 17, 2024 at 7:05 AM Avik Pal ***@***.***> wrote:
Seems like compiling the gradient is hitting
julia> compiled_gradient = Reactant.compile(
gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst))
ERROR: MethodError: no method matching (Reactant.TracedRArray{Float32, Shape, 2} where Shape)(::Tuple{}, ::Reactant.MLIR.IR.Value)
Stacktrace:
[1] make_zero(::Type{…}, seen::IdDict{…}, prev::Reactant.TracedRArray{…}, ::Val{…})
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:149
[2] #42
@ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251 [inlined]
[3] ntuple
@ ./ntuple.jl:19 [inlined]
[4] make_zero
@ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249 [inlined]
[5] make_zero(::Type{…}, seen::IdDict{…}, ***@***.***{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256
[6] (::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}})(i::Int64)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251
[7] ntuple(f::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}}, n::Int64)
@ Base ./ntuple.jl:19
[8] make_zero(::Type{…}, seen::IdDict{…}, prev::Tuple{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249
[9] make_zero(::Type{…}, seen::IdDict{…}, ***@***.***{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256
[10] make_zero (repeats 2 times)
@ ~/.julia/packages/EnzymeCore/Z0CgU/src/EnzymeCore.jl:237 [inlined]
[11] overdub
@ /mnt/research/ongoing/lux/Reactant.jl/src/overloads.jl:358 [inlined]
[12] ***@***.***{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ***@***.******@***.***{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, ***@***.***{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, ***@***.***{}}, ***@***.******@***.***{}, ***@***.***{}, ***@***.***{}})
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined]
[13] gradient_loss_function
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined]
[14] overdub(::Cassette.Context{…}, ::typeof(gradient_loss_function), ::Chain{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ***@***.***{…}, ***@***.***{…})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
[15] (::Reactant.var"#5#13"{typeof(gradient_loss_function), Tuple{}, Reactant.MLIR.IR.Block, Vector{…}, Tuple{…}})()
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:53
[16] block!(f::Reactant.var"#5#13"{…}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Block.jl:198
[17] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool)
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:46
[18] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, typeof(gradient_loss_function), Tuple{…}, Int64})()
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:927
[19] mmodule!(f::Reactant.var"#100#105"{…}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Module.jl:89
[20] #99
@ /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:925 [inlined]
[21] context!(f::Reactant.var"#99#104"{typeof(gradient_loss_function), Tuple{…}, Int64}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Context.jl:68
[22] compile(f::typeof(gradient_loss_function), args::Tuple{…}; pipeline_options::String, client::Nothing)
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:923
[23] compile(f::typeof(gradient_loss_function), args::Tuple{…})
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:918
[24] top-level scope
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:70
Some type information was truncated. Use `show(err)` to see complete types.
—
Reply to this email directly, view it on GitHub
<#5 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXDZRVUVFAKYLJUBZTDZCYFEPAVCNFSM6AAAAABHWUKINSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMJXGY4DMOBYGY>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
julia> cps
(layer_1 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[-0.98918384 0.190184; 0.046477042 -1.0701349; -0.36382833 0.8563723], Float32[0.0; 0.0; 0.0;;])), layer_2 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[0.8828562 -0.19665341 -0.70401317; -0.67718965 0.056223422 -0.2397092], Float32[0.0; 0.0;;])), layer_3 = NamedTuple()) Shouldn't the |
Depends on the requirements of the type. If something takes in a
vector(float64) as a member variable we default replace with the union over
sizes since that’s semantically equivalent (if say you have code that
changes the size).
But if it’s possible to leave consistent it may be nice to fully type the
size.
Check out our trace type function (i forget the exact name)
…On Fri, May 17, 2024 at 7:18 AM Avik Pal ***@***.***> wrote:
julia> cps
(layer_1 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[-0.98918384 0.190184; 0.046477042 -1.0701349; -0.36382833 0.8563723], Float32[0.0; 0.0; 0.0;;])), layer_2 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[0.8828562 -0.19665341 -0.70401317; -0.67718965 0.056223422 -0.2397092], Float32[0.0; 0.0;;])), layer_3 = NamedTuple())
Shouldn't the Shape here be fixed?
—
Reply to this email directly, view it on GitHub
<#5 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXB5X75GHMBFLMGWJODZCYGUHAVCNFSM6AAAAABHWUKINSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMJXG4YTKOJXGI>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
This still needs the reduce pipeline error to be fixed before it is ready to be merged |
That fix is here: EnzymeAD/Enzyme-JAX#93 and will have a jll later today hopefully |
Needs #4