Skip to content

Commit

Permalink
Try converting common parameter types to the compiled type
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 27, 2024
1 parent b645921 commit 8d926fd
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 28 deletions.
116 changes: 102 additions & 14 deletions ext/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,35 @@ module LuxReactantExt
using Adapt: adapt
using ArgCheck: @argcheck
using Enzyme: Enzyme
using Functors: fmapstructure
using Functors: fmapstructure, fmap
using Markdown: @md_str
using Random: AbstractRNG, Xoshiro
using Reactant: Reactant
using Lux: Lux, LuxEltypeAdaptor
using LuxCore: LuxCore, AbstractExplicitLayer

@inline __make_concrete_array(x::Reactant.ConcreteRArray) = x
@inline __make_concrete_array(x::AbstractArray) = Reactant.ConcreteRArray(x)
@inline function __make_concrete_array(x)
return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing)
end

@inline function __try_similar_structure(x::AbstractArray, y::NamedTuple{()})
length(x) == 0 && return y
throw(DimensionMismatch(lazy"Expected empty array, got $(size(x))."))
end
@inline function __try_similar_structure(x::AbstractArray, y::AbstractArray)
return parent(x) !== x ? copy(x) : x # unview arrays and such
end
@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y)

# Reactant doesn't handle mixed eltypes that well, so we will first try to compile it as
# a usual julia function. However, if that fails, we will type cast and try to recompile.
# Note that this is only a one time operation so it doesn't matter if this step is too slow.
function Lux.__to_reactant_adaptor(
to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer) where {FST}
input_prototype = to.input_prototype
input_eltype = Lux.__recursive_eltype(input_prototype)
ps, st = Lux.setup(Xoshiro(123), model) # We generate fake parameters and states to compile the model
ps, st = Lux.setup(LuxCore.replicate(to.rng), model)
ps = to.ps_transform(ps)
ps_eltype = Lux.__recursive_eltype(ps)
st_eltype = Lux.__recursive_eltype(st)

Expand All @@ -31,8 +40,7 @@ function Lux.__to_reactant_adaptor(

if !to.force_allow_mixed_eltypes &&
any(x -> x != newT && x != Union{}, (input_eltype, ps_eltype, st_eltype))
# Try compiling, but this might fail
try
try # Try compiling, but this might fail
return Lux.__to_reactant_adaptor(to, model, input_prototype, ps, st, nothing)
catch err
@warn """
Expand Down Expand Up @@ -93,12 +101,18 @@ function Lux.__to_reactant_adaptor(
nothing
end

# TODO: Add compiled types to the layer type information. That way we can check
# if the model is being executed with the correct types.
return Lux.ReactantLayer{FST, Lux.__recursive_eltype(input_prototype)}(
model, cmodel, fwd, bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd,
bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
end

# TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal.
# We can return the parameters and states from the layer itself, since we don't care
# about the values, but just the type.
function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer)
ps = LuxCore.initialparameters(rng, layer.layer)
ps = layer.adaptor(LuxCore.initialparameters(rng, layer.layer))
layer.eltype_adaptor !== nothing && (ps = adapt(layer.eltype_adaptor, ps))
return __make_concrete_array(ps)
end
Expand All @@ -110,7 +124,6 @@ function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer)
end

function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T}
csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st)
l.eltype_adaptor !== nothing && (x = adapt(l.eltype_adaptor, x))

# XLARuntimeError is not great, so check and terminate early if needed
Expand All @@ -120,7 +133,7 @@ function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T}
end

# TODO: For non array inputs this we make the eltype uniform which might not be
# desirable. We should handle those cases with `fmap`
# desirable. We should handle those cases with `fmap`
if T != Lux.__recursive_eltype(x)
@warn """
`Reactant.compile` was called with input eltype $(T) but the current input eltype \
Expand All @@ -132,11 +145,86 @@ function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T}
x = adapt(LuxEltypeAdaptor{T}(), x)
end

y = Lux.__apply_reactant(l, csmodel, x)
return y, ifelse(FST, csmodel.st, csmodel.st_any)
return Lux.__apply_reactant(l, x, ps, st)
end

# This is the ideal case where all the types match correctly.
# Input Type mispatches should not happen here, they should be handled before this function
# is called.
# If `st` mismatch happens then user really messed something up. can't do anything about it.
@inline function Lux.__apply_reactant(
l::Lux.ReactantLayer{FST, T, inType}, x::inType, ps, st) where {FST, T, inType}
return Lux.__apply_reactant(l, __make_concrete_array(x), ps, st)
end

@inline function Lux.__apply_reactant(
l::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, x::inCType,
ps::psType, st::stType) where {FST, T, inType, inCType, psType, stType}
csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st)
return Lux.__apply_reactant(l, csmodel, x), ifelse(FST, csmodel.st, csmodel.st_any)
end

# Parameter type mismatch. This might be too common so try to handle it gracefully.
@inline function Lux.__apply_reactant(
l::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, x::inCType,
ps::psType2, st::stType) where {FST, T, inType, inCType, psType, psType2, stType}
ps = __try_similar_structure(Lux.__named_tuple(ps), l.concrete_ps)
ps = l.adaptor(ps)
l.eltype_adaptor !== nothing && (ps = adapt(l.eltype_adaptor, ps))
ps = __make_concrete_array(ps)

if typeof(ps) != psType
@warn "Automatic type conversion failed for `ps`." original_ps_type=psType2
__graceful_type_mismatch_error(l, x, ps, st)
end

return Lux.__apply_reactant(l, __make_concrete_array(x), ps, st)
end

@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd(
csmodel, __make_concrete_array(x))
Lux.__apply_reactant(l, x, ps, st) = __graceful_type_mismatch_error(l, x, ps, st)

@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd(csmodel, x)

# Don't inline, else types don't get displayed in the stack trace
function __graceful_type_mismatch_error(
::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType},
x, ps, st) where {FST, T, inType, inCType, psType, stType}
#! format: off
input_type_mismatch_str = typeof(x) == inType || typeof(x) == inCType ? """
1. Input Types Matched.
""" : """
1. Input Type: $(typeof(x)).
Compiled Input Type: $(inType).
Compiled Concrete Input Type: $(inCType).
"""
#! format: on

ps_type_mismatch_str = typeof(ps) == psType ? """
2. Parameter Types Matched.
""" : """
2. Parameter Type: $(typeof(ps)).
Compiled Parameter Type: $(psType).
"""

st_type_mismatch_str = typeof(st) == stType ? """
3. State Types Matched.
""" : """
3. State Type: $(typeof(st)).
Compiled State Type: $(stType).
"""

throw(ArgumentError("""
Model compiled types and input types don't match. We tried our best to convert the \
types to the right ones, but we failed. Ideally the argument types should not be \
modified after compilation.
1. Recompile the model with the correct input types.
2. Open an issue on the Lux.jl repository, to check if we can ease out the automatic \
type conversion.
List of Type Mismatches:
$(input_type_mismatch_str) $(ps_type_mismatch_str) $(st_type_mismatch_str)"""))
end

end
16 changes: 8 additions & 8 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using PrecompileTools: @recompile_invalidations
using Markdown: @doc_str
using OhMyThreads: tmapreduce
using Preferences: @load_preference
using Random: Random, AbstractRNG
using Random: Random, AbstractRNG, Xoshiro
using Reexport: @reexport

using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers
Expand Down Expand Up @@ -48,6 +48,12 @@ const DISABLE_AUTOMATIC_NESTED_AD_SWITCH = @load_preference("DisableAutomaticNes
# Utilities
include("utils.jl")

# Transform to and from other frameworks
include("transform/types.jl")
include("transform/flux.jl")
include("transform/simplechains.jl")
include("transform/reactant.jl")

# Layer Implementations
include("layers/basic.jl")
include("layers/containers.jl")
Expand All @@ -72,12 +78,6 @@ include("helpers/compact.jl")
include("helpers/autodiff.jl")
include("helpers/nested_ad.jl")

# Transform to and from other frameworks
include("transform/types.jl")
include("transform/flux.jl")
include("transform/simplechains.jl")
include("transform/reactant.jl")

# Distributed Training
include("distributed/backend.jl")
include("distributed/public_api.jl")
Expand Down Expand Up @@ -111,7 +111,7 @@ export f16, f32, f64
export transform
export FromFluxAdaptor, FluxLayer
export ToSimpleChainsAdaptor, SimpleChainsLayer
export ToReactantAdaptor
export ToReactantAdaptor, ReactantLayer
export DynamicExpressionsLayer

export MPIBackend, NCCLBackend, DistributedUtils
Expand Down
23 changes: 21 additions & 2 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,31 @@ end

# TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the
# gradient computation
@concrete struct ReactantLayer{FST, T, F, B, L <: AbstractExplicitLayer} <:
AbstractExplicitLayer
@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType, F, B,
L <: AbstractExplicitLayer, AD <: ToReactantAdaptor} <: AbstractExplicitLayer
adaptor::AD
input_prototype::inType
concrete_input_prototype::inCType
concrete_ps::psType
concrete_st::stType
layer::L
clayer
fwd::F
bwd::B
eltype_adaptor
input_structure
end

function Base.show(io::IO, s::ReactantLayer{ST}) where {ST}
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
print(io, "ReactantLayer{$ST}(\n")
_big_show(io, s.layer, 4)
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
print(io, "ReactantLayer{$ST}(")
_layer_show(io, s.layer)
else
print(io, "ReactantLayer{$ST}(")
show(io, s.layer)
end
print(io, ")")
end
11 changes: 7 additions & 4 deletions src/transform/reactant.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
@concrete struct ToReactantAdaptor{FST} <: AbstractFromLuxAdaptor
@concrete struct ToReactantAdaptor{FST, R <: AbstractRNG} <: AbstractFromLuxAdaptor
input_prototype
ps_transform
rng::R
force_compile_backward::Bool
force_allow_mixed_eltypes::Bool
end

function ToReactantAdaptor{FST}(input_prototype; force_compile_backward::Bool=false,
function ToReactantAdaptor{FST}(input_prototype; rng=Xoshiro(123), ps_transform=identity,
force_compile_backward::Bool=false,
force_allow_mixed_eltypes::Bool=false) where {FST}
return ToReactantAdaptor{FST}(
input_prototype, force_compile_backward, force_allow_mixed_eltypes)
return ToReactantAdaptor{FST}(input_prototype, ps_transform, rng,
force_compile_backward, force_allow_mixed_eltypes)
end
function ToReactantAdaptor(args...; fixed_state_type::Val=Val(true), kwargs...)
return ToReactantAdaptor{__unwrap_val(fixed_state_type)}(args...; kwargs...)
Expand Down

0 comments on commit 8d926fd

Please sign in to comment.