From 44da1c85a0f57828d4e52a671ea826203569468e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 26 May 2024 20:03:46 -0700 Subject: [PATCH] Auto compile Lux models to reactant --- Project.toml | 5 +++- ext/LuxReactantExt.jl | 48 +++++++++++++++++++++++++++++++++++++++ src/Lux.jl | 2 ++ src/layers/extension.jl | 9 ++++++++ src/transform/reactant.jl | 13 +++++++++++ 5 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 ext/LuxReactantExt.jl create mode 100644 src/transform/reactant.jl diff --git a/Project.toml b/Project.toml index 2ba3bded8e..c7b54b8936 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.51" +version = "0.5.52" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -40,6 +40,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -57,6 +58,7 @@ LuxMLUtilsExt = "MLUtils" LuxMPIExt = "MPI" LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] LuxOptimisersExt = "Optimisers" +LuxReactantExt = "Reactant" LuxReverseDiffExt = "ReverseDiff" LuxSimpleChainsExt = "SimpleChains" LuxTrackerExt = "Tracker" @@ -102,6 +104,7 @@ Pkg = "1.10" PrecompileTools = "1.2" Preferences = "1.4.3" Random = "1.10" +Reactant = "0.1.1" ReTestItems = "1.23.1" Reexport = "1.2.2" ReverseDiff = "1.15" diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl new file mode 100644 index 0000000000..1c39ce1d12 --- /dev/null +++ b/ext/LuxReactantExt.jl @@ -0,0 +1,48 @@ +module LuxReactantExt + +using ArgCheck: @argcheck +using Random: AbstractRNG, Xoshiro +using Reactant: Reactant +using Lux: Lux +using LuxCore: LuxCore, AbstractExplicitLayer + +@inline __make_concrete_array(x::Reactant.ConcreteRArray) = x +@inline __make_concrete_array(x::AbstractArray) = Reactant.ConcreteRArray(x) +@inline function __make_concrete_array(x) + return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) +end + +# FIXME: currently only `stateless_apply` is supported: https://github.com/EnzymeAD/Reactant.jl/issues/8 +function Lux.__to_reactant_adaptor(model::AbstractExplicitLayer, input_prototype) + concrete_input = __make_concrete_array(input_prototype) + cmodel = __make_concrete_array(model) + # We generate fake parameters and states to compile the model + ps = LuxCore.initialparameters(Xoshiro(123), model) + cps = __make_concrete_array(ps) + + st = LuxCore.initialstates(Xoshiro(123), model) + @argcheck st==LuxCore._getemptystate(model) "Currently only stateless models are supported." + + fwd = Reactant.compile( + (m, x, ps) -> LuxCore.stateless_apply(m, x, ps), (cmodel, concrete_input, cps)) + + # TODO: conditionally compile the backward pass + + return Lux.ReactantLayer(model, cmodel, fwd, nothing) +end + +function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) + return __make_concrete_array(LuxCore.initialparameters(rng, layer.layer)) +end + +# FIXME: Change once https://github.com/EnzymeAD/Reactant.jl/pull/8 is fixed +function LuxCore.initialstates(::AbstractRNG, layer::Lux.ReactantLayer) + return NamedTuple() # __make_concrete_array(LuxCore.initialstates(rng, layer.layer)) +end + +# TODO: Add a type assert here to make it type stable +function (l::Lux.ReactantLayer)(x, ps, ::NamedTuple{()}) + return LuxCore.stateless_apply(l.clayer, __make_concrete_array(x), ps), NamedTuple() +end + +end diff --git a/src/Lux.jl b/src/Lux.jl index 92695f60e2..89cb5c21c8 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -76,6 +76,7 @@ include("helpers/nested_ad.jl") include("transform/types.jl") include("transform/flux.jl") include("transform/simplechains.jl") +include("transform/reactant.jl") # Distributed Training include("distributed/backend.jl") @@ -110,6 +111,7 @@ export f16, f32, f64 export transform export FromFluxAdaptor, FluxLayer export ToSimpleChainsAdaptor, SimpleChainsLayer +export ToReactantAdaptor export DynamicExpressionsLayer export MPIBackend, NCCLBackend, DistributedUtils diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 8df169ae92..f4ea4cb1d9 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -242,3 +242,12 @@ function CRC.rrule(::typeof(__apply_simple_chain), layer, x, ps, ::LuxCPUDevice) end return res, __∇apply_simple_chain end + +# TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the +# gradient computation +@concrete struct ReactantLayer{F, B, L <: AbstractExplicitLayer} <: AbstractExplicitLayer + layer::L + clayer + fwd::F + bwd::B +end diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl new file mode 100644 index 0000000000..fb21336fde --- /dev/null +++ b/src/transform/reactant.jl @@ -0,0 +1,13 @@ +# TODO: Add options to compile the gradients directly using Enzyme.jl +@concrete struct ToReactantAdaptor <: AbstractFromLuxAdaptor + input_prototype +end + +function Adapt.adapt(to::ToReactantAdaptor, model::AbstractExplicitLayer) + if Base.get_extension(@__MODULE__, :LuxReactantExt) === nothing + error("`ToReactantAdaptor` requires `LuxReactantExt.jl` to be loaded.") + end + return __to_reactant_adaptor(model, to.input_prototype) +end + +function __to_reactant_adaptor end