Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add a Lux example #5

Closed
wants to merge 5 commits into from
Closed

[WIP] Add a Lux example #5

wants to merge 5 commits into from

Conversation

avik-pal
Copy link
Collaborator

Needs #4

test/nn_lux.jl Outdated
losses = []
for epoch in 1:1_000
for (x, y) in loader
loss, grads = Flux.withgradient(model) do m
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're at it, let's do training do. We should make a new function which will get compiled which contains both the autodiff and update! [ideally with Enzyme now], want to try?

am adding similar in bg

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to rewrite this part with LuxDL/Lux.jl#640

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You want to pull in Lux.Experimental.apply_gradients! and compile it together?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized this one is updated using the original model and not the compiled model

@wsmoses
Copy link
Member

wsmoses commented May 14, 2024

Also you only need to tracer through if a variable contains data like an array

@wsmoses
Copy link
Member

wsmoses commented May 14, 2024 via email

@avik-pal
Copy link
Collaborator Author

Couple of Problems

  1. batchnorm doesn't seem to work because it tries to trace through mean and the arrays don't have getindex defined
  2. Without batchnorm. After the function is compiled and I try to run it, I get:
julia> comp = f(cmodel, cnoisy, cps, cst)
ERROR: UndefVarError: `layer_2` not defined
Stacktrace:
 [1] (::Reactant.var"#109#110")(arg_1::Chain{…}, arg_2::Reactant.ConcreteRArray{…}, arg_3::@NamedTuple{}, arg_4::@NamedTuple{})
   @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:704
 [2] top-level scope
   @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:29
Some type information was truncated. Use `show(err)` to see complete types.
  1. In create_result should we add a case for NamedTuple similar to the other cases? The returned state is a namedtuple always.

@wsmoses
Copy link
Member

wsmoses commented May 14, 2024

So it is possible to add getindex/etc here, but I'm intentionally preventing so now so we can ensure we trace fully vectorized code.

And yeah, we should just add an overload for mean, want to give it a shot?

@avik-pal
Copy link
Collaborator Author

This doesn't seem right:

julia> @code_lowered  f(cmodel, cnoisy, cps, cst)
CodeInfo(
1%1  = Reactant.XLA.synced_buffer
│   %2  = Base.getfield
│   %3  = Base.getfield
│   %4  = (%3)(arg_3, Reactant.layer_2)
│   %5  = (%2)(%4, Reactant.weight)
│   %6  = Base.getproperty(%5, :data)
│         sbuf_1 = (%1)(%6)
│   %8  = Reactant.XLA.synced_buffer
│   %9  = Base.getfield
│   %10 = Base.getfield
│   %11 = (%10)(arg_3, Reactant.layer_1)
│   %12 = (%9)(%11, Reactant.weight)
│   %13 = Base.getproperty(%12, :data)
│         sbuf_2 = (%8)(%13)
│   %15 = Reactant.XLA.synced_buffer
│   %16 = Base.getfield
│   %17 = Base.getfield
│   %18 = (%17)(arg_3, Reactant.layer_1)
│   %19 = (%16)(%18, Reactant.bias)
│   %20 = Base.getproperty(%19, :data)
│         sbuf_3 = (%15)(%20)
│   %22 = Reactant.XLA.synced_buffer
│   %23 = Base.getfield
│   %24 = Base.getfield
│   %25 = (%24)(arg_3, Reactant.layer_2)
│   %26 = (%23)(%25, Reactant.bias)
│   %27 = Base.getproperty(%26, :data)
│         sbuf_4 = (%22)(%27)
│   %29 = Reactant.XLA.synced_buffer
│   %30 = Base.getproperty(arg_2, :data)
│         sbuf_5 = (%29)(%30)
│   %32 = $(Expr(:gc_preserve_begin, :(sbuf_1), :(sbuf_2), :(sbuf_3), :(sbuf_4), :(sbuf_5)))
│   %33 = Reactant.XLA.ExecutableCall
│   %34 = Base.getproperty(sbuf_1, :buffer)
│   %35 = Base.getproperty(sbuf_2, :buffer)
│   %36 = Base.getproperty(sbuf_3, :buffer)
│   %37 = Base.getproperty(sbuf_4, :buffer)
│   %38 = Base.getproperty(sbuf_5, :buffer)
│   %39 = Core.tuple(%34, %35, %36, %37, %38)
│   %40 = Reactant.Val(1)
│   %41 = (%33)(Reactant.XLA.LoadedExecutable(Ptr{Nothing} @0x000000001114fd40), %39, (0x01, 0x01, 0x01, 0x01, 0x01), %40)
│         linearized_results = %41$(Expr(:gc_preserve_end, :(%32)))
│         concrete_res_1 = Base.getindex(linearized_results, 1)
│         result = (Reactant.ConcreteRArray{Float32, (2, 1000), 2})(concrete_res_1)
└──       return result
)

it shouldn't be Reactant.layer_2

@avik-pal
Copy link
Collaborator Author

So it is possible to add getindex/etc here, but I'm intentionally preventing so now so we can ensure we trace fully vectorized code.

Makes sense, I added ArrayInterface to allow easy checking for that in downstream codes

@wsmoses
Copy link
Member

wsmoses commented May 14, 2024

Ah we should probably add an escape in the macro

@wsmoses
Copy link
Member

wsmoses commented May 14, 2024

wait that's odd though it should be a symbol there being looked up?

@avik-pal
Copy link
Collaborator Author

wait that's odd though it should be a symbol there being looked up?

It was directly getting interpolated, needed a Meta.quot

@avik-pal
Copy link
Collaborator Author

avik-pal commented May 14, 2024

function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end
function crossentropy(ŷ, y)
    return .-sum(xlogy.(y, ŷ))
end

function loss_function(model, x, y, ps, st)
    y_hat, _ = model(x, ps, st)
    return crossentropy(y_hat, y)
end

compiled_loss_function = Reactant.compile(
    loss_function, (cmodel, cnoisy, ctarget, cps, CST))

elem_apply is not defined for xlogy, how do I trace into the body of xlogy, there isn't a direct mapping for that in https://openxla.org/stablehlo/spec#log

The eltypes don't match

MethodError: no method matching elem_apply(::typeof(xlogy), ::Reactant.TracedRArray{Bool, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2})

@avik-pal
Copy link
Collaborator Author

My failed attempt at defining the comparisons

for (jlop, hloop, hlocomp, RT) in ((:(Base.:(==)), :compare, 0, :ElType),)
    @eval begin
        function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
            return TracedRArray{$RT,Shape,N}((),  MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data;
                comparison_direction=$hlocomp), 1))
        end

        function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N}
            rhs = promote_to(lhs, rhs)
            return TracedRArray{$RT,Shape,N}((),  MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=$hlocomp), 1))
        end 

        function elem_apply(::typeof($jlop), lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
            lhs = promote_to(rhs, lhs)
            return TracedRArray{$RT,Shape,N}((),  MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=$hlocomp), 1))
        end
    end
end

How do I pass the comparison direction enum? https://openxla.org/stablehlo/spec#compare

@wsmoses
Copy link
Member

wsmoses commented May 15, 2024

function stablehloComparisonDirectionAttrGet(ctx, value)

@wsmoses
Copy link
Member

wsmoses commented May 16, 2024

@avik-pal comparisons added here: 0f7a912

@avik-pal
Copy link
Collaborator Author

Seems like compiling the gradient is hitting

julia> compiled_gradient = Reactant.compile(
           gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst))
ERROR: MethodError: no method matching (Reactant.TracedRArray{Float32, Shape, 2} where Shape)(::Tuple{}, ::Reactant.MLIR.IR.Value)
Stacktrace:
  [1] make_zero(::Type{…}, seen::IdDict{…}, prev::Reactant.TracedRArray{…}, ::Val{…})
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:149
  [2] #42
    @ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251 [inlined]
  [3] ntuple
    @ ./ntuple.jl:19 [inlined]
  [4] make_zero
    @ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249 [inlined]
  [5] make_zero(::Type{…}, seen::IdDict{…}, prev::@NamedTuple{}, ::Val{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256
  [6] (::Enzyme.Compiler.var"#42#43"{Tuple{}, false, IdDict{}, Tuple{}})(i::Int64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251
  [7] ntuple(f::Enzyme.Compiler.var"#42#43"{Tuple{}, false, IdDict{}, Tuple{}}, n::Int64)
    @ Base ./ntuple.jl:19
  [8] make_zero(::Type{…}, seen::IdDict{…}, prev::Tuple{…}, ::Val{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249
  [9] make_zero(::Type{…}, seen::IdDict{…}, prev::@NamedTuple{}, ::Val{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256
 [10] make_zero (repeats 2 times)
    @ ~/.julia/packages/EnzymeCore/Z0CgU/src/EnzymeCore.jl:237 [inlined]
 [11] overdub
    @ /mnt/research/ongoing/lux/Reactant.jl/src/overloads.jl:358 [inlined]
 [12] gradient_loss_function(::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::@NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}})
    @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined]
 [13] gradient_loss_function
    @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined]
 [14] overdub(::Cassette.Context{…}, ::typeof(gradient_loss_function), ::Chain{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{}, ::@NamedTuple{})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
 [15] (::Reactant.var"#5#13"{typeof(gradient_loss_function), Tuple{}, Reactant.MLIR.IR.Block, Vector{}, Tuple{}})()
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:53
 [16] block!(f::Reactant.var"#5#13"{}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Block.jl:198
 [17] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool)
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:46
 [18] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, typeof(gradient_loss_function), Tuple{}, Int64})()
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:927
 [19] mmodule!(f::Reactant.var"#100#105"{}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Module.jl:89
 [20] #99
    @ /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:925 [inlined]
 [21] context!(f::Reactant.var"#99#104"{typeof(gradient_loss_function), Tuple{}, Int64}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Context.jl:68
 [22] compile(f::typeof(gradient_loss_function), args::Tuple{…}; pipeline_options::String, client::Nothing)
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:923
 [23] compile(f::typeof(gradient_loss_function), args::Tuple{…})
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:918
 [24] top-level scope
    @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:70
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link
Member

wsmoses commented May 17, 2024 via email

@avik-pal
Copy link
Collaborator Author

julia> cps
(layer_1 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[-0.98918384 0.190184; 0.046477042 -1.0701349; -0.36382833 0.8563723], Float32[0.0; 0.0; 0.0;;])), layer_2 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[0.8828562 -0.19665341 -0.70401317; -0.67718965 0.056223422 -0.2397092], Float32[0.0; 0.0;;])), layer_3 = NamedTuple())

Shouldn't the Shape here be fixed?

@wsmoses
Copy link
Member

wsmoses commented May 17, 2024 via email

@avik-pal
Copy link
Collaborator Author

This still needs the reduce pipeline error to be fixed before it is ready to be merged

@wsmoses
Copy link
Member

wsmoses commented Jun 18, 2024

That fix is here: EnzymeAD/Enzyme-JAX#93 and will have a jll later today hopefully

@wsmoses wsmoses closed this Jun 18, 2024
@avik-pal avik-pal deleted the ap/lux branch July 10, 2024 20:02
@wsmoses wsmoses mentioned this pull request Oct 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants