Skip to content

Commit

Permalink
More comprehensive compilation options
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 1, 2024
1 parent 7ec09d7 commit 916e685
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 82 deletions.
171 changes: 107 additions & 64 deletions ext/LuxReactantExt/layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,52 +43,67 @@ function Lux.__to_reactant_adaptor(
to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer,
input_prototype, ps, st, eltype_adaptor) where {FST}
output = first(model(input_prototype, ps, st))
concrete_output = __make_concrete_array(output)
concrete_output = Lux.__make_reactant_array(output)

concrete_input = __make_concrete_array(input_prototype)
cmodel = __make_concrete_array(model)
cps = __make_concrete_array(ps)
cst = __make_concrete_array(st)
concrete_input = Lux.__make_reactant_array(input_prototype)
cps = Lux.__make_reactant_array(ps)
cst = Lux.__make_reactant_array(st)

csmodel = Lux.StatefulLuxLayer{FST}(cmodel, cps, cst)
smodel = Lux.StatefulLuxLayer{FST}(model, cps, cst)
fwd_fn = Reactant.compile((m, x) -> m(x), (smodel, concrete_input))

fwd_fn = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input))
cst_test = Lux.__make_reactant_array(Lux.testmode(st))
smodel_test = Lux.StatefulLuxLayer{FST}(model, cps, cst_test)
inference_fn = Reactant.compile((m, x) -> m(x), (smodel_test, concrete_input))

function enzyme_vjp_fn(m, x, y, dy)
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(m.ps)
st_m = ifelse(FST, m.st, m.st_any)

function wrapper_fn!(y, model, x, ps, st)
copyto!(y, first(LuxCore.apply(model, x, ps, st)))
return nothing
vjp_fn = if to.skip_compile_vjp
nothing
else
function enzyme_vjp_fn(m, x, y, dy)
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(m.ps)
st_m = ifelse(FST, m.st, m.st_any)

function wrapper_fn!(y, model, x, ps, st)
copyto!(y, first(LuxCore.apply(model, x, ps, st)))
return nothing
end

Enzyme.autodiff(
Enzyme.Reverse, wrapper_fn!, Enzyme.Const, Enzyme.Duplicated(y, dy),
Enzyme.Const(m.model), Enzyme.Duplicated(x, dx),
Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st_m))
return dx, dps
end

Enzyme.autodiff(Enzyme.Reverse, wrapper_fn!, Enzyme.Const, Enzyme.Duplicated(y, dy),
Enzyme.Const(m.model), Enzyme.Duplicated(x, dx),
Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st_m))
return dx, dps
try
concrete_output2 = Lux.__make_reactant_array(deepcopy(output))
Reactant.compile(
enzyme_vjp_fn, (smodel, concrete_input, concrete_output, concrete_output2))
catch err
to.force_compile_backward && rethrow(err)
@error """
Enzyme failed to compile the backward pass. Differentiation will be disabled \
for this model.
HINT: To force compilation of the backward pass, set \
`force_compile_backward=true` in the constructor of `ToReactantAdaptor`.\n
""" exception=err
nothing
end
end

vjp_fn = try
concrete_output2 = __make_concrete_array(deepcopy(output))
Reactant.compile(
enzyme_vjp_fn, (csmodel, concrete_input, concrete_output, concrete_output2))
catch err
to.force_compile_backward && rethrow(err)
@error """
Enzyme failed to compile the backward pass. Differentiation will be disabled for \
this model.
HINT: To force compilation of the backward pass, set `force_compile_backward=true` \
in the constructor of `ToReactantAdaptor`.\n
""" exception=err
jvp_fn = if to.skip_compile_jvp
nothing
else # TODO: Implement JVP with Enzyme.Forward
throw(ArgumentError("JVPs are not implemented yet."))
end

return Lux.ReactantLayer{FST, Lux.__recursive_eltype(input_prototype)}(
to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd_fn,
vjp_fn, eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
return Lux.ReactantLayer{
FST, Lux.__recursive_eltype(input_prototype), typeof(input_prototype),
typeof(concrete_input), typeof(cst), typeof(cst_test)}(
to, cps, model, fwd_fn, inference_fn, vjp_fn, jvp_fn,
eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
end

# TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal.
Expand All @@ -97,23 +112,20 @@ end
function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer)
ps = layer.adaptor(LuxCore.initialparameters(rng, layer.layer))
layer.eltype_adaptor !== nothing && (ps = adapt(layer.eltype_adaptor, ps))
return __make_concrete_array(ps)
return Lux.__make_reactant_array(ps)
end

function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer)
st = LuxCore.initialstates(rng, layer.layer)
layer.eltype_adaptor !== nothing && (st = adapt(layer.eltype_adaptor, st))
return __make_concrete_array(st)
return (; states=Lux.__make_reactant_array(st), training=Val(true))
end

function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T}
l.eltype_adaptor !== nothing && (x = adapt(l.eltype_adaptor, x))

# XLARuntimeError is not great, so check and terminate early if needed
input_structure = fmapstructure(Lux.__size, x)
if l.input_structure != input_structure
throw(DimensionMismatch(lazy"Input structure mismatch. Expected $(l.input_structure), got $(input_structure)."))
end
@argcheck fmapstructure(Lux.__size, x) == l.input_structure

# TODO: For non array inputs this we make the eltype uniform which might not be
# desirable. We should handle those cases with `fmap`
Expand All @@ -131,47 +143,69 @@ function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T}
return Lux.__apply_reactant(l, x, ps, st)
end

@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, x, ps, st)
y, st_ = Lux.__apply_reactant(l, x, ps, st.states, st.training)
return y, (; states=st_, training=st.training)
end

# This is the ideal case where all the types match correctly.
# Input Type mispatches should not happen here, they should be handled before this function
# is called.
# If `st` mismatch happens then user really messed something up. can't do anything about it.
@inline function Lux.__apply_reactant(
l::Lux.ReactantLayer{FST, T, inType}, x::inType, ps, st) where {FST, T, inType}
return Lux.__apply_reactant(l, __make_concrete_array(x), ps, st)
@inline function Lux.__apply_reactant(l::Lux.ReactantLayer{FST, T, inType}, x::inType,
ps, st, training) where {FST, T, inType}
return Lux.__apply_reactant(l, Lux.__make_reactant_array(x), ps, st, training)
end

@inline function Lux.__apply_reactant(
l::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, x::inCType,
ps::psType, st::stType) where {FST, T, inType, inCType, psType, stType}
csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st)
return Lux.__apply_reactant(l, csmodel, x), ifelse(FST, csmodel.st, csmodel.st_any)
l::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType},
x::inCType, ps::psType, st::stType,
training) where {FST, T, inType, inCType, psType, stType, stTestType}
smodel = Lux.StatefulLuxLayer{FST}(l.layer, ps, st)
return (
Lux.__apply_reactant(l, smodel, x, training), ifelse(FST, smodel.st, smodel.st_any))
end

# Parameter type mismatch. This might be too common so try to handle it gracefully.
@inline function Lux.__apply_reactant(
l::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, x::inCType,
ps::psType2, st::stType) where {FST, T, inType, inCType, psType, psType2, stType}
l::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType},
x::inCType, ps::psType2, st,
training) where {FST, T, inType, inCType, stType, stTestType, psType, psType2}
@warn "Parameter Type Mismatch with compiled Reactant function. This will lead to \
performance regressions" maxlog=1

ps = __try_similar_structure(Lux.__named_tuple(ps), l.concrete_ps)
ps = l.adaptor(ps)
l.eltype_adaptor !== nothing && (ps = adapt(l.eltype_adaptor, ps))
ps = __make_concrete_array(ps)
ps = Lux.__make_reactant_array(ps)

if typeof(ps) != psType
@warn "Automatic type conversion failed for `ps`." original_ps_type=psType2
__graceful_type_mismatch_error(l, x, ps, st)
__graceful_type_mismatch_error(l, x, ps, st, training)
end

return Lux.__apply_reactant(l, __make_concrete_array(x), ps, st)
return Lux.__apply_reactant(l, Lux.__make_reactant_array(x), ps, st, training)
end

function Lux.__apply_reactant(l, x, ps, st, training)
return __graceful_type_mismatch_error(l, x, ps, st, training)
end

Lux.__apply_reactant(l, x, ps, st) = __graceful_type_mismatch_error(l, x, ps, st)
@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, smodel, x, ::Val{true})
return l.fwd_fn(smodel, x)
end

@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd_fn(csmodel, x)
@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, smodel, x, ::Val{false})
return l.inference_fn(smodel, x)
end

# Don't inline, else types don't get displayed in the stack trace
function __graceful_type_mismatch_error(
::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType},
x, ps, st) where {FST, T, inType, inCType, psType, stType}
::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType},
x,
ps,
st,
::Val{training}) where {
FST, T, inType, inCType, psType, stType, stTestType, training}
#! format: off
input_type_mismatch_str = typeof(x) == inType || typeof(x) == inCType ? """
1. Input Types Matched.
Expand All @@ -189,12 +223,21 @@ function __graceful_type_mismatch_error(
Compiled Parameter Type: $(psType).
"""

st_type_mismatch_str = typeof(st) == stType ? """
3. State Types Matched.
""" : """
3. State Type: $(typeof(st)).
Compiled State Type: $(stType).
"""
st_type_mismatch_str = if training
typeof(st) == stType ? """
3. State Types Matched.
""" : """
3. State Type: $(typeof(st)).
Compiled State Type: $(stType).
"""
else
typeof(st) == stTestType ? """
3. State Types Matched.
""" : """
3. State Type: $(typeof(st)).
Compiled State Type: $(stTestType).
"""
end

throw(ArgumentError("""
Model compiled types and input types don't match. We tried our best to convert the \
Expand Down
10 changes: 5 additions & 5 deletions ext/LuxReactantExt/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ function Lux.Experimental.single_train_step(
# st = ts.states
# model = ts.model

data = __make_concrete_array(data)
model = __make_concrete_array(ts.model)
dps = __make_concrete_array(dps)
ps = __make_concrete_array(ts.parameters)
st = __make_concrete_array(ts.states)
data = Lux.__make_reactant_array(data)
model = Lux.__make_reactant_array(ts.model)
dps = Lux.__make_reactant_array(dps)
ps = Lux.__make_reactant_array(ts.parameters)
st = Lux.__make_reactant_array(ts.states)

# @show

Expand Down
10 changes: 9 additions & 1 deletion ext/LuxReactantExt/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
@inline function __make_concrete_array(x)
@inline Lux.__make_reactant_array(x::Reactant.RArray) = x
@inline function Lux.__make_reactant_array(x::AbstractArray)
hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(x)}) &&
return Reactant.ConcreteRArray(x)
return __make_tracer(x)
end
@inline Lux.__make_reactant_array(x) = __make_tracer(x)

@inline function __make_tracer(x)
return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing)
end

Expand Down
16 changes: 8 additions & 8 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,8 @@ end
# Workaround for SimpleChains not being able to handle some input types
function CRC.rrule(::typeof(__apply_simple_chain), layer, x, ps, ::LuxCPUDevice)
res, pb = CRC.rrule(layer, x, ps)
# Safety measure to prevent errors from weird Array types that SimpleChains doesn't support
__∇apply_simple_chain = @closure Δ -> begin
# Safety measure to prevent errors from weird Array types that SimpleChains doesn't
# support
∂layer, ∂x, ∂ps = pb(convert(Array, Δ))
return CRC.NoTangent(), ∂layer, ∂x, ∂ps, CRC.NoTangent()
end
Expand All @@ -251,18 +250,19 @@ end

# TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the
# gradient computation
# TODO: Inference won't work OOTB, we will have to compile that separately
@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType,
# TODO: Docstring
@concrete struct ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType,
L <: AbstractExplicitLayer, AD <: ToReactantAdaptor} <: AbstractExplicitLayer
adaptor::AD
input_prototype::inType
concrete_input_prototype::inCType
concrete_ps::psType
concrete_st::stType
layer::L
clayer

# Compiled Functions
fwd_fn
inference_fn
vjp_fn
jvp_fn

eltype_adaptor
input_structure
end
Expand Down
31 changes: 27 additions & 4 deletions src/transform/reactant.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
@concrete struct ToReactantAdaptor{FST, R <: AbstractRNG} <: AbstractFromLuxAdaptor
input_prototype

ps_transform
rng::R
force_compile_backward::Bool

force_allow_mixed_eltypes::Bool
skip_compile_vjp::Bool
force_compile_vjp::Bool
skip_compile_jvp::Bool
force_compile_jvp::Bool
end

function ToReactantAdaptor{FST}(input_prototype; rng=Xoshiro(123), ps_transform=identity,
force_compile_backward::Bool=false,
force_allow_mixed_eltypes::Bool=false) where {FST}
force_allow_mixed_eltypes::Bool=false, force_compile_vjp::Bool=false,
skip_compile_vjp::Bool=false, force_compile_jvp::Bool=false,
skip_compile_jvp::Bool=true) where {FST}
skip_compile_vjp && @argcheck !force_compile_vjp
skip_compile_jvp && @argcheck !force_compile_jvp

return ToReactantAdaptor{FST}(input_prototype, ps_transform, rng,
force_compile_backward, force_allow_mixed_eltypes)
force_allow_mixed_eltypes, skip_compile_vjp, force_compile_vjp,
skip_compile_jvp, force_compile_jvp)
end
function ToReactantAdaptor(args...; fixed_state_type::Val=Val(true), kwargs...)
return ToReactantAdaptor{__unwrap_val(fixed_state_type)}(args...; kwargs...)
Expand All @@ -36,3 +46,16 @@ only a limited subset of Lux models can be compiled via `Reactant.jl`. If you en
issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub repository.
"""
struct AutoReactant end

"""
__make_reactant_array(x)
Converts `x` to a `Reactant.ConcreteRArray` if it is not already one.
"""
function __make_reactant_array end

@inline function __make_reactant_array(nt::NamedTuple{names}) where {names}
return NamedTuple{names}(map(__make_reactant_array, values(nt)))
end
@inline __make_reactant_array(t::Tuple) = map(__make_reactant_array, t)
@inline __make_reactant_array(x::AbstractExplicitLayer) = x

0 comments on commit 916e685

Please sign in to comment.