From 8d926fd83056ebba6eef3dce860abfd23159b647 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 May 2024 15:54:26 -0700 Subject: [PATCH] Try converting common parameter types to the compiled type --- ext/LuxReactantExt.jl | 116 +++++++++++++++++++++++++++++++++----- src/Lux.jl | 16 +++--- src/layers/extension.jl | 23 +++++++- src/transform/reactant.jl | 11 ++-- 4 files changed, 138 insertions(+), 28 deletions(-) diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl index 6914c136bb..4cb4839a40 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt.jl @@ -3,18 +3,26 @@ module LuxReactantExt using Adapt: adapt using ArgCheck: @argcheck using Enzyme: Enzyme -using Functors: fmapstructure +using Functors: fmapstructure, fmap +using Markdown: @md_str using Random: AbstractRNG, Xoshiro using Reactant: Reactant using Lux: Lux, LuxEltypeAdaptor using LuxCore: LuxCore, AbstractExplicitLayer -@inline __make_concrete_array(x::Reactant.ConcreteRArray) = x -@inline __make_concrete_array(x::AbstractArray) = Reactant.ConcreteRArray(x) @inline function __make_concrete_array(x) return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) end +@inline function __try_similar_structure(x::AbstractArray, y::NamedTuple{()}) + length(x) == 0 && return y + throw(DimensionMismatch(lazy"Expected empty array, got $(size(x)).")) +end +@inline function __try_similar_structure(x::AbstractArray, y::AbstractArray) + return parent(x) !== x ? copy(x) : x # unview arrays and such +end +@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y) + # Reactant doesn't handle mixed eltypes that well, so we will first try to compile it as # a usual julia function. However, if that fails, we will type cast and try to recompile. # Note that this is only a one time operation so it doesn't matter if this step is too slow. @@ -22,7 +30,8 @@ function Lux.__to_reactant_adaptor( to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer) where {FST} input_prototype = to.input_prototype input_eltype = Lux.__recursive_eltype(input_prototype) - ps, st = Lux.setup(Xoshiro(123), model) # We generate fake parameters and states to compile the model + ps, st = Lux.setup(LuxCore.replicate(to.rng), model) + ps = to.ps_transform(ps) ps_eltype = Lux.__recursive_eltype(ps) st_eltype = Lux.__recursive_eltype(st) @@ -31,8 +40,7 @@ function Lux.__to_reactant_adaptor( if !to.force_allow_mixed_eltypes && any(x -> x != newT && x != Union{}, (input_eltype, ps_eltype, st_eltype)) - # Try compiling, but this might fail - try + try # Try compiling, but this might fail return Lux.__to_reactant_adaptor(to, model, input_prototype, ps, st, nothing) catch err @warn """ @@ -93,12 +101,18 @@ function Lux.__to_reactant_adaptor( nothing end + # TODO: Add compiled types to the layer type information. That way we can check + # if the model is being executed with the correct types. return Lux.ReactantLayer{FST, Lux.__recursive_eltype(input_prototype)}( - model, cmodel, fwd, bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype)) + to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd, + bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype)) end +# TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal. +# We can return the parameters and states from the layer itself, since we don't care +# about the values, but just the type. function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) - ps = LuxCore.initialparameters(rng, layer.layer) + ps = layer.adaptor(LuxCore.initialparameters(rng, layer.layer)) layer.eltype_adaptor !== nothing && (ps = adapt(layer.eltype_adaptor, ps)) return __make_concrete_array(ps) end @@ -110,7 +124,6 @@ function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer) end function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T} - csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st) l.eltype_adaptor !== nothing && (x = adapt(l.eltype_adaptor, x)) # XLARuntimeError is not great, so check and terminate early if needed @@ -120,7 +133,7 @@ function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T} end # TODO: For non array inputs this we make the eltype uniform which might not be - # desirable. We should handle those cases with `fmap` + # desirable. We should handle those cases with `fmap` if T != Lux.__recursive_eltype(x) @warn """ `Reactant.compile` was called with input eltype $(T) but the current input eltype \ @@ -132,11 +145,86 @@ function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T} x = adapt(LuxEltypeAdaptor{T}(), x) end - y = Lux.__apply_reactant(l, csmodel, x) - return y, ifelse(FST, csmodel.st, csmodel.st_any) + return Lux.__apply_reactant(l, x, ps, st) +end + +# This is the ideal case where all the types match correctly. +# Input Type mispatches should not happen here, they should be handled before this function +# is called. +# If `st` mismatch happens then user really messed something up. can't do anything about it. +@inline function Lux.__apply_reactant( + l::Lux.ReactantLayer{FST, T, inType}, x::inType, ps, st) where {FST, T, inType} + return Lux.__apply_reactant(l, __make_concrete_array(x), ps, st) +end + +@inline function Lux.__apply_reactant( + l::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, x::inCType, + ps::psType, st::stType) where {FST, T, inType, inCType, psType, stType} + csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st) + return Lux.__apply_reactant(l, csmodel, x), ifelse(FST, csmodel.st, csmodel.st_any) +end + +# Parameter type mismatch. This might be too common so try to handle it gracefully. +@inline function Lux.__apply_reactant( + l::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, x::inCType, + ps::psType2, st::stType) where {FST, T, inType, inCType, psType, psType2, stType} + ps = __try_similar_structure(Lux.__named_tuple(ps), l.concrete_ps) + ps = l.adaptor(ps) + l.eltype_adaptor !== nothing && (ps = adapt(l.eltype_adaptor, ps)) + ps = __make_concrete_array(ps) + + if typeof(ps) != psType + @warn "Automatic type conversion failed for `ps`." original_ps_type=psType2 + __graceful_type_mismatch_error(l, x, ps, st) + end + + return Lux.__apply_reactant(l, __make_concrete_array(x), ps, st) end -@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd( - csmodel, __make_concrete_array(x)) +Lux.__apply_reactant(l, x, ps, st) = __graceful_type_mismatch_error(l, x, ps, st) + +@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd(csmodel, x) + +# Don't inline, else types don't get displayed in the stack trace +function __graceful_type_mismatch_error( + ::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, + x, ps, st) where {FST, T, inType, inCType, psType, stType} + #! format: off + input_type_mismatch_str = typeof(x) == inType || typeof(x) == inCType ? """ + 1. Input Types Matched. + """ : """ + 1. Input Type: $(typeof(x)). + Compiled Input Type: $(inType). + Compiled Concrete Input Type: $(inCType). + """ + #! format: on + + ps_type_mismatch_str = typeof(ps) == psType ? """ + 2. Parameter Types Matched. + """ : """ + 2. Parameter Type: $(typeof(ps)). + Compiled Parameter Type: $(psType). + """ + + st_type_mismatch_str = typeof(st) == stType ? """ + 3. State Types Matched. + """ : """ + 3. State Type: $(typeof(st)). + Compiled State Type: $(stType). + """ + + throw(ArgumentError(""" + Model compiled types and input types don't match. We tried our best to convert the \ + types to the right ones, but we failed. Ideally the argument types should not be \ + modified after compilation. + + 1. Recompile the model with the correct input types. + 2. Open an issue on the Lux.jl repository, to check if we can ease out the automatic \ + type conversion. + + List of Type Mismatches: + + $(input_type_mismatch_str) $(ps_type_mismatch_str) $(st_type_mismatch_str)""")) +end end diff --git a/src/Lux.jl b/src/Lux.jl index 89cb5c21c8..22616d898b 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -16,7 +16,7 @@ using PrecompileTools: @recompile_invalidations using Markdown: @doc_str using OhMyThreads: tmapreduce using Preferences: @load_preference - using Random: Random, AbstractRNG + using Random: Random, AbstractRNG, Xoshiro using Reexport: @reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers @@ -48,6 +48,12 @@ const DISABLE_AUTOMATIC_NESTED_AD_SWITCH = @load_preference("DisableAutomaticNes # Utilities include("utils.jl") +# Transform to and from other frameworks +include("transform/types.jl") +include("transform/flux.jl") +include("transform/simplechains.jl") +include("transform/reactant.jl") + # Layer Implementations include("layers/basic.jl") include("layers/containers.jl") @@ -72,12 +78,6 @@ include("helpers/compact.jl") include("helpers/autodiff.jl") include("helpers/nested_ad.jl") -# Transform to and from other frameworks -include("transform/types.jl") -include("transform/flux.jl") -include("transform/simplechains.jl") -include("transform/reactant.jl") - # Distributed Training include("distributed/backend.jl") include("distributed/public_api.jl") @@ -111,7 +111,7 @@ export f16, f32, f64 export transform export FromFluxAdaptor, FluxLayer export ToSimpleChainsAdaptor, SimpleChainsLayer -export ToReactantAdaptor +export ToReactantAdaptor, ReactantLayer export DynamicExpressionsLayer export MPIBackend, NCCLBackend, DistributedUtils diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 3d4eeb79b6..1a4285c8d8 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -245,8 +245,13 @@ end # TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the # gradient computation -@concrete struct ReactantLayer{FST, T, F, B, L <: AbstractExplicitLayer} <: - AbstractExplicitLayer +@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType, F, B, + L <: AbstractExplicitLayer, AD <: ToReactantAdaptor} <: AbstractExplicitLayer + adaptor::AD + input_prototype::inType + concrete_input_prototype::inCType + concrete_ps::psType + concrete_st::stType layer::L clayer fwd::F @@ -254,3 +259,17 @@ end eltype_adaptor input_structure end + +function Base.show(io::IO, s::ReactantLayer{ST}) where {ST} + if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL + print(io, "ReactantLayer{$ST}(\n") + _big_show(io, s.layer, 4) + elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix + print(io, "ReactantLayer{$ST}(") + _layer_show(io, s.layer) + else + print(io, "ReactantLayer{$ST}(") + show(io, s.layer) + end + print(io, ")") +end diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl index 34e5bb3ddf..fce9ec63e4 100644 --- a/src/transform/reactant.jl +++ b/src/transform/reactant.jl @@ -1,13 +1,16 @@ -@concrete struct ToReactantAdaptor{FST} <: AbstractFromLuxAdaptor +@concrete struct ToReactantAdaptor{FST, R <: AbstractRNG} <: AbstractFromLuxAdaptor input_prototype + ps_transform + rng::R force_compile_backward::Bool force_allow_mixed_eltypes::Bool end -function ToReactantAdaptor{FST}(input_prototype; force_compile_backward::Bool=false, +function ToReactantAdaptor{FST}(input_prototype; rng=Xoshiro(123), ps_transform=identity, + force_compile_backward::Bool=false, force_allow_mixed_eltypes::Bool=false) where {FST} - return ToReactantAdaptor{FST}( - input_prototype, force_compile_backward, force_allow_mixed_eltypes) + return ToReactantAdaptor{FST}(input_prototype, ps_transform, rng, + force_compile_backward, force_allow_mixed_eltypes) end function ToReactantAdaptor(args...; fixed_state_type::Val=Val(true), kwargs...) return ToReactantAdaptor{__unwrap_val(fixed_state_type)}(args...; kwargs...)