Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto compile Lux models to reactant #665

Closed
wants to merge 15 commits into from
Closed
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -57,6 +58,7 @@ LuxMLUtilsExt = "MLUtils"
LuxMPIExt = "MPI"
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
LuxOptimisersExt = "Optimisers"
LuxReactantExt = ["Enzyme", "Reactant"]
LuxReverseDiffExt = "ReverseDiff"
LuxSimpleChainsExt = "SimpleChains"
LuxTrackerExt = "Tracker"
Expand Down Expand Up @@ -102,6 +104,7 @@ Pkg = "1.10"
PrecompileTools = "1.2"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.1.1"
ReTestItems = "1.23.1"
Reexport = "1.2.2"
ReverseDiff = "1.15"
Expand Down Expand Up @@ -134,6 +137,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Expand All @@ -144,4 +148,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "Flux", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
test = ["Aqua", "ComponentArrays", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "Flux", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "Reactant", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
23 changes: 23 additions & 0 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module LuxReactantExt

using Adapt: adapt
using ArgCheck: @argcheck
using ConcreteStructs: @concrete
using Enzyme: Enzyme, Active, Const, Duplicated
using Functors: fmapstructure, fmap
using Random: AbstractRNG, Xoshiro
using Reactant: Reactant
using Lux: Lux, LuxEltypeAdaptor, AutoReactant
using LuxCore: LuxCore, AbstractExplicitLayer

include("utils.jl")

# compile just the model. This allows us to run part of the model in vanilla LLVM. Needed
# for cases where we can't currently compile via Reactant or where XLA is not great
# for the model.
include("layer.jl")

# compile the entire training loop
include("train.jl")

end
258 changes: 258 additions & 0 deletions ext/LuxReactantExt/layer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# Reactant doesn't handle mixed eltypes that well, so we will first try to compile it as
# a usual julia function. However, if that fails, we will type cast and try to recompile.
# Note that this is only a one time operation so it doesn't matter if this step is too slow.
function Lux.__to_reactant_adaptor(

Check warning on line 4 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L4

Added line #L4 was not covered by tests
to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer) where {FST}
input_prototype = to.input_prototype
input_eltype = Lux.__recursive_eltype(input_prototype)
ps, st = Lux.setup(LuxCore.replicate(to.rng), model)
ps = to.ps_transform(ps)
ps_eltype = Lux.__recursive_eltype(ps)
st_eltype = Lux.__recursive_eltype(st)

Check warning on line 11 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L6-L11

Added lines #L6 - L11 were not covered by tests

newT = promote_type(input_eltype, ps_eltype, st_eltype)
eltype_adaptor = nothing

Check warning on line 14 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L13-L14

Added lines #L13 - L14 were not covered by tests

if !to.force_allow_mixed_eltypes &&
any(x -> x != newT && x != Union{}, (input_eltype, ps_eltype, st_eltype))
try # Try compiling, but this might fail
return Lux.__to_reactant_adaptor(to, model, input_prototype, ps, st, nothing)

Check warning on line 19 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L16-L19

Added lines #L16 - L19 were not covered by tests
catch err
@warn """

Check warning on line 21 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L21

Added line #L21 was not covered by tests
Mixed Eltypes detected. Failure is NOT unexpected. Trying to recompile with a \
common eltype.

HINT: To force compiling the mixed eltypes, set \
`force_allow_mixed_eltypes=true` in the constructor of `ToReactantAdaptor`.

If compilation succeeds, all inputs to the compiled model will be \
automatically type casted to the common eltype.\n
""" exception=err input_eltype ps_eltype st_eltype common_eltype=newT
end

eltype_adaptor = LuxEltypeAdaptor{newT}()
input_prototype = adapt(eltype_adaptor, to.input_prototype)
ps = adapt(eltype_adaptor, ps)
st = adapt(eltype_adaptor, st)

Check warning on line 36 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L33-L36

Added lines #L33 - L36 were not covered by tests
end

return Lux.__to_reactant_adaptor(to, model, input_prototype, ps, st, eltype_adaptor)

Check warning on line 39 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L39

Added line #L39 was not covered by tests
end

function Lux.__to_reactant_adaptor(

Check warning on line 42 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L42

Added line #L42 was not covered by tests
to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer,
input_prototype, ps, st, eltype_adaptor) where {FST}
output = first(model(input_prototype, ps, st))
concrete_output = Lux.__make_reactant_array(output)

Check warning on line 46 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L45-L46

Added lines #L45 - L46 were not covered by tests

concrete_input = Lux.__make_reactant_array(input_prototype)
cps = Lux.__make_reactant_array(ps)
cst = Lux.__make_reactant_array(st)

Check warning on line 50 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L48-L50

Added lines #L48 - L50 were not covered by tests

smodel = Lux.StatefulLuxLayer{FST}(model, cps, cst)
fwd_fn = to.skip_compile_forward ? nothing :
Reactant.compile((m, x) -> m(x), (smodel, concrete_input))

Check warning on line 54 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L52-L54

Added lines #L52 - L54 were not covered by tests

cst_test = Lux.__make_reactant_array(Lux.testmode(st))
smodel_test = Lux.StatefulLuxLayer{FST}(model, cps, cst_test)
inference_fn = to.skip_compile_inference ? nothing :
Reactant.compile((m, x) -> m(x), (smodel_test, concrete_input))

Check warning on line 59 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L56-L59

Added lines #L56 - L59 were not covered by tests

vjp_fn = if to.skip_compile_vjp
nothing

Check warning on line 62 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L61-L62

Added lines #L61 - L62 were not covered by tests
else
function enzyme_vjp_fn(m, x, y, dy)
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(m.ps)
st_m = ifelse(FST, m.st, m.st_any)

Check warning on line 67 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L64-L67

Added lines #L64 - L67 were not covered by tests

function wrapper_fn!(y, model, x, ps, st)
copyto!(y, first(LuxCore.apply(model, x, ps, st)))
return nothing

Check warning on line 71 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L69-L71

Added lines #L69 - L71 were not covered by tests
end

Enzyme.autodiff(

Check warning on line 74 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L74

Added line #L74 was not covered by tests
Enzyme.Reverse, wrapper_fn!, Enzyme.Const, Enzyme.Duplicated(y, dy),
Enzyme.Const(m.model), Enzyme.Duplicated(x, dx),
Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st_m))
return dx, dps

Check warning on line 78 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L78

Added line #L78 was not covered by tests
end

try
concrete_output2 = Lux.__make_reactant_array(deepcopy(output))
Reactant.compile(

Check warning on line 83 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L81-L83

Added lines #L81 - L83 were not covered by tests
enzyme_vjp_fn, (smodel, concrete_input, concrete_output, concrete_output2))
catch err
to.force_compile_backward && rethrow(err)
@error """

Check warning on line 87 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L86-L87

Added lines #L86 - L87 were not covered by tests
Enzyme failed to compile the backward pass. Differentiation will be disabled \
for this model.

HINT: To force compilation of the backward pass, set \
`force_compile_backward=true` in the constructor of `ToReactantAdaptor`.\n
""" exception=err
nothing

Check warning on line 94 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L94

Added line #L94 was not covered by tests
end
end

jvp_fn = if to.skip_compile_jvp
nothing

Check warning on line 99 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L98-L99

Added lines #L98 - L99 were not covered by tests
else # TODO: Implement JVP with Enzyme.Forward
throw(ArgumentError("JVPs are not implemented yet."))

Check warning on line 101 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L101

Added line #L101 was not covered by tests
end

return Lux.ReactantLayer{

Check warning on line 104 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L104

Added line #L104 was not covered by tests
FST, Lux.__recursive_eltype(input_prototype), typeof(input_prototype),
typeof(concrete_input), typeof(cst), typeof(cst_test)}(
to, cps, model, fwd_fn, inference_fn, vjp_fn, jvp_fn,
eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
end

# TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal.
# We can return the parameters and states from the layer itself, since we don't care
# about the values, but just the type.
function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer)
ps = layer.adaptor(LuxCore.initialparameters(rng, layer.layer))
layer.eltype_adaptor !== nothing && (ps = adapt(layer.eltype_adaptor, ps))
return Lux.__make_reactant_array(ps)

Check warning on line 117 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L114-L117

Added lines #L114 - L117 were not covered by tests
end

function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer)
st = LuxCore.initialstates(rng, layer.layer)
layer.eltype_adaptor !== nothing && (st = adapt(layer.eltype_adaptor, st))
return (; states=Lux.__make_reactant_array(st), training=Val(true))

Check warning on line 123 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L120-L123

Added lines #L120 - L123 were not covered by tests
end

function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T}
l.eltype_adaptor !== nothing && (x = adapt(l.eltype_adaptor, x))

Check warning on line 127 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L126-L127

Added lines #L126 - L127 were not covered by tests

# XLARuntimeError is not great, so check and terminate early if needed
@argcheck fmapstructure(Lux.__size, x) == l.input_structure

Check warning on line 130 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L130

Added line #L130 was not covered by tests

# TODO: For non array inputs this we make the eltype uniform which might not be
# desirable. We should handle those cases with `fmap`
if T != Lux.__recursive_eltype(x)
@warn """

Check warning on line 135 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L134-L135

Added lines #L134 - L135 were not covered by tests
`Reactant.compile` was called with input eltype $(T) but the current input eltype \
is $(Lux.__recursive_eltype(x)). This might lead to unexpected behavior.

We will convert the input to $(T) and continue. If you want to avoid this, please \
recompile the model with the correct input eltype.
""" maxlog=1
x = adapt(LuxEltypeAdaptor{T}(), x)

Check warning on line 142 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L142

Added line #L142 was not covered by tests
end

return Lux.__apply_reactant(l, x, ps, st)

Check warning on line 145 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L145

Added line #L145 was not covered by tests
end

@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, x, ps, st)
y, st_ = Lux.__apply_reactant(l, x, ps, st.states, st.training)
return y, (; states=st_, training=st.training)

Check warning on line 150 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L148-L150

Added lines #L148 - L150 were not covered by tests
end

# This is the ideal case where all the types match correctly.
# Input Type mispatches should not happen here, they should be handled before this function
# is called.
@inline function Lux.__apply_reactant(l::Lux.ReactantLayer{FST, T, inType}, x::inType,

Check warning on line 156 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L156

Added line #L156 was not covered by tests
ps, st, training) where {FST, T, inType}
return Lux.__apply_reactant(l, Lux.__make_reactant_array(x), ps, st, training)

Check warning on line 158 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L158

Added line #L158 was not covered by tests
end

@inline function Lux.__apply_reactant(

Check warning on line 161 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L161

Added line #L161 was not covered by tests
l::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType},
x::inCType, ps::psType, st::stType,
training) where {FST, T, inType, inCType, psType, stType, stTestType}
smodel = Lux.StatefulLuxLayer{FST}(l.layer, ps, st)
return (

Check warning on line 166 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L165-L166

Added lines #L165 - L166 were not covered by tests
Lux.__apply_reactant(l, smodel, x, training), ifelse(FST, smodel.st, smodel.st_any))
end

# Parameter type mismatch. This might be too common so try to handle it gracefully.
@inline function Lux.__apply_reactant(

Check warning on line 171 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L171

Added line #L171 was not covered by tests
l::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType},
x::inCType, ps::psType2, st,
training) where {FST, T, inType, inCType, stType, stTestType, psType, psType2}
@warn "Parameter Type Mismatch with compiled Reactant function. This will lead to \

Check warning on line 175 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L175

Added line #L175 was not covered by tests
performance regressions" maxlog=1

ps = __try_similar_structure(Lux.__named_tuple(ps), l.concrete_ps)
ps = l.adaptor(ps)
l.eltype_adaptor !== nothing && (ps = adapt(l.eltype_adaptor, ps))
ps = Lux.__make_reactant_array(ps)

Check warning on line 181 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L178-L181

Added lines #L178 - L181 were not covered by tests

if typeof(ps) != psType
@warn "Automatic type conversion failed for `ps`." original_ps_type=psType2
__graceful_type_mismatch_error(l, x, ps, st, training)

Check warning on line 185 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L183-L185

Added lines #L183 - L185 were not covered by tests
end

return Lux.__apply_reactant(l, Lux.__make_reactant_array(x), ps, st, training)

Check warning on line 188 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L188

Added line #L188 was not covered by tests
end

function Lux.__apply_reactant(l, x, ps, st, training)
return __graceful_type_mismatch_error(l, x, ps, st, training)

Check warning on line 192 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L191-L192

Added lines #L191 - L192 were not covered by tests
end

@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, smodel, x, ::Val{true})
@argcheck l.fwd_fn !== nothing
return l.fwd_fn(smodel, x)

Check warning on line 197 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L195-L197

Added lines #L195 - L197 were not covered by tests
end

@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, smodel, x, ::Val{false})
@argcheck l.inference_fn !== nothing
return l.inference_fn(smodel, x)

Check warning on line 202 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L200-L202

Added lines #L200 - L202 were not covered by tests
end

# Don't inline, else types don't get displayed in the stack trace
function __graceful_type_mismatch_error(

Check warning on line 206 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L206

Added line #L206 was not covered by tests
::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType},
x,
ps,
st,
::Val{training}) where {
FST, T, inType, inCType, psType, stType, stTestType, training}
#! format: off
input_type_mismatch_str = typeof(x) == inType || typeof(x) == inCType ? """

Check warning on line 214 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L214

Added line #L214 was not covered by tests
1. Input Types Matched.
""" : """
1. Input Type: $(typeof(x)).
Compiled Input Type: $(inType).
Compiled Concrete Input Type: $(inCType).
"""
#! format: on

ps_type_mismatch_str = typeof(ps) == psType ? """

Check warning on line 223 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L223

Added line #L223 was not covered by tests
2. Parameter Types Matched.
""" : """
2. Parameter Type: $(typeof(ps)).
Compiled Parameter Type: $(psType).
"""

st_type_mismatch_str = if training
typeof(st) == stType ? """

Check warning on line 231 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L230-L231

Added lines #L230 - L231 were not covered by tests
3. State Types Matched.
""" : """
3. State Type: $(typeof(st)).
Compiled State Type: $(stType).
"""
else
typeof(st) == stTestType ? """

Check warning on line 238 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L238

Added line #L238 was not covered by tests
3. State Types Matched.
""" : """
3. State Type: $(typeof(st)).
Compiled State Type: $(stTestType).
"""
end

throw(ArgumentError("""

Check warning on line 246 in ext/LuxReactantExt/layer.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/layer.jl#L246

Added line #L246 was not covered by tests
Model compiled types and input types don't match. We tried our best to convert the \
types to the right ones, but we failed. Ideally the argument types should not be \
modified after compilation.

1. Recompile the model with the correct input types.
2. Open an issue on the Lux.jl repository, to check if we can ease out the automatic \
type conversion.

List of Type Mismatches:

$(input_type_mismatch_str) $(ps_type_mismatch_str) $(st_type_mismatch_str)"""))
end
55 changes: 55 additions & 0 deletions ext/LuxReactantExt/train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# TODO: For the iip versions as well. Metaprogram that part

# Case III: Nothing is cached. First call to `single_train_step`
function Lux.Experimental.single_train_step(

Check warning on line 4 in ext/LuxReactantExt/train.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/train.jl#L4

Added line #L4 was not covered by tests
ad::AutoReactant, obj_fn::F, data, ts::Lux.Experimental.TrainState) where {F}
# ps = ts.parameters
dps = Lux.__recursive_make_zero(ts.parameters)

Check warning on line 7 in ext/LuxReactantExt/train.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/train.jl#L7

Added line #L7 was not covered by tests
# st = ts.states
# model = ts.model

data = Lux.__make_reactant_array(data)
model = Lux.__make_reactant_array(ts.model)
dps = Lux.__make_reactant_array(dps)
ps = Lux.__make_reactant_array(ts.parameters)
st = Lux.__make_reactant_array(ts.states)

Check warning on line 15 in ext/LuxReactantExt/train.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/train.jl#L11-L15

Added lines #L11 - L15 were not covered by tests

# @show

# function reverse_fn_wrapper(obj_fn, model, ps, dps, st, data)
obj_fn_wrapper, st_updated, stats = Lux.Experimental.__wrap_objective_function(

Check warning on line 20 in ext/LuxReactantExt/train.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/train.jl#L20

Added line #L20 was not covered by tests
obj_fn, st)
# st_, stats = nothing, (;)
# @show __update_fn_wrapper(obj_fn_wrapper, model, ps, dps, st, data)

# function obj_fn_wrapper(obj_fn, model, ps, st, data) # Intentionally boxing
# y, st_, stats = obj_fn(model, ps, st, data)
# return y
# end

# @show obj_fn_wrapper # (obj_fn, model, ps, st, data)

# _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn_wrapper, Active,
# Const(model), Duplicated(ps, dps), Const(st), Const(data))
# loss = obj_fn_wrapper(obj_fn, model, ps, st, data)

# return loss, st_new, stats # FIXME: Return the correct things
# return loss
# end

# @show reverse_fn_wrapper # (obj_fn, model, ps, dps, st, data)

# @show reverse_fn_wrapper(obj_fn, model, ts.parameters,
# Lux.__recursive_make_zero(ts.parameters), ts.states, data)

compiled_fn = Reactant.compile(__update_fn_wrapper, (obj_fn, model, ps, dps, st, data))

Check warning on line 45 in ext/LuxReactantExt/train.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/train.jl#L45

Added line #L45 was not covered by tests

# return compiled_fn, (obj_fn, model, ps, dps, st, data)
end

function __update_fn_wrapper(obj_fn, model, ps, dps, st, data)
_, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(model),

Check warning on line 51 in ext/LuxReactantExt/train.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/train.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
Duplicated(ps, dps), Const(st), Const(data))
# Lux.Experimental.apply_gradients()
return loss

Check warning on line 54 in ext/LuxReactantExt/train.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/train.jl#L54

Added line #L54 was not covered by tests
end
20 changes: 20 additions & 0 deletions ext/LuxReactantExt/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@inline Lux.__make_reactant_array(x::Reactant.RArray) = x
@inline function Lux.__make_reactant_array(x::AbstractArray)
hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(x)}) &&

Check warning on line 3 in ext/LuxReactantExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/utils.jl#L1-L3

Added lines #L1 - L3 were not covered by tests
return Reactant.ConcreteRArray(x)
return __make_tracer(x)

Check warning on line 5 in ext/LuxReactantExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/utils.jl#L5

Added line #L5 was not covered by tests
end
@inline Lux.__make_reactant_array(x) = __make_tracer(x)

Check warning on line 7 in ext/LuxReactantExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/utils.jl#L7

Added line #L7 was not covered by tests

@inline function __make_tracer(x)
return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing)

Check warning on line 10 in ext/LuxReactantExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/utils.jl#L9-L10

Added lines #L9 - L10 were not covered by tests
end

@inline function __try_similar_structure(x::AbstractArray, y::NamedTuple{()})
length(x) == 0 && return y
throw(DimensionMismatch(lazy"Expected empty array, got $(size(x))."))

Check warning on line 15 in ext/LuxReactantExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/utils.jl#L13-L15

Added lines #L13 - L15 were not covered by tests
end
@inline function __try_similar_structure(x::AbstractArray, y::AbstractArray)
return parent(x) !== x ? copy(x) : x # unview arrays and such

Check warning on line 18 in ext/LuxReactantExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/utils.jl#L17-L18

Added lines #L17 - L18 were not covered by tests
end
@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y)

Check warning on line 20 in ext/LuxReactantExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/utils.jl#L20

Added line #L20 was not covered by tests
Loading
Loading