Skip to content

Commit

Permalink
Auto compile Lux models to reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 27, 2024
1 parent ca2c635 commit 44da1c8
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 1 deletion.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.5.51"
version = "0.5.52"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -40,6 +40,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -57,6 +58,7 @@ LuxMLUtilsExt = "MLUtils"
LuxMPIExt = "MPI"
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
LuxOptimisersExt = "Optimisers"
LuxReactantExt = "Reactant"
LuxReverseDiffExt = "ReverseDiff"
LuxSimpleChainsExt = "SimpleChains"
LuxTrackerExt = "Tracker"
Expand Down Expand Up @@ -102,6 +104,7 @@ Pkg = "1.10"
PrecompileTools = "1.2"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.1.1"
ReTestItems = "1.23.1"
Reexport = "1.2.2"
ReverseDiff = "1.15"
Expand Down
48 changes: 48 additions & 0 deletions ext/LuxReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module LuxReactantExt

using ArgCheck: @argcheck
using Random: AbstractRNG, Xoshiro
using Reactant: Reactant
using Lux: Lux
using LuxCore: LuxCore, AbstractExplicitLayer

@inline __make_concrete_array(x::Reactant.ConcreteRArray) = x
@inline __make_concrete_array(x::AbstractArray) = Reactant.ConcreteRArray(x)
@inline function __make_concrete_array(x)
return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing)
end

# FIXME: currently only `stateless_apply` is supported: https://github.com/EnzymeAD/Reactant.jl/issues/8
function Lux.__to_reactant_adaptor(model::AbstractExplicitLayer, input_prototype)
concrete_input = __make_concrete_array(input_prototype)
cmodel = __make_concrete_array(model)
# We generate fake parameters and states to compile the model
ps = LuxCore.initialparameters(Xoshiro(123), model)
cps = __make_concrete_array(ps)

st = LuxCore.initialstates(Xoshiro(123), model)
@argcheck st==LuxCore._getemptystate(model) "Currently only stateless models are supported."

fwd = Reactant.compile(
(m, x, ps) -> LuxCore.stateless_apply(m, x, ps), (cmodel, concrete_input, cps))

# TODO: conditionally compile the backward pass

return Lux.ReactantLayer(model, cmodel, fwd, nothing)
end

function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer)
return __make_concrete_array(LuxCore.initialparameters(rng, layer.layer))
end

# FIXME: Change once https://github.com/EnzymeAD/Reactant.jl/pull/8 is fixed
function LuxCore.initialstates(::AbstractRNG, layer::Lux.ReactantLayer)
return NamedTuple() # __make_concrete_array(LuxCore.initialstates(rng, layer.layer))
end

# TODO: Add a type assert here to make it type stable
function (l::Lux.ReactantLayer)(x, ps, ::NamedTuple{()})
return LuxCore.stateless_apply(l.clayer, __make_concrete_array(x), ps), NamedTuple()
end

end
2 changes: 2 additions & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ include("helpers/nested_ad.jl")
include("transform/types.jl")
include("transform/flux.jl")
include("transform/simplechains.jl")
include("transform/reactant.jl")

# Distributed Training
include("distributed/backend.jl")
Expand Down Expand Up @@ -110,6 +111,7 @@ export f16, f32, f64
export transform
export FromFluxAdaptor, FluxLayer
export ToSimpleChainsAdaptor, SimpleChainsLayer
export ToReactantAdaptor
export DynamicExpressionsLayer

export MPIBackend, NCCLBackend, DistributedUtils
Expand Down
9 changes: 9 additions & 0 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,12 @@ function CRC.rrule(::typeof(__apply_simple_chain), layer, x, ps, ::LuxCPUDevice)
end
return res, __∇apply_simple_chain
end

# TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the
# gradient computation
@concrete struct ReactantLayer{F, B, L <: AbstractExplicitLayer} <: AbstractExplicitLayer
layer::L
clayer
fwd::F
bwd::B
end
13 changes: 13 additions & 0 deletions src/transform/reactant.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# TODO: Add options to compile the gradients directly using Enzyme.jl
@concrete struct ToReactantAdaptor <: AbstractFromLuxAdaptor
input_prototype
end

function Adapt.adapt(to::ToReactantAdaptor, model::AbstractExplicitLayer)
if Base.get_extension(@__MODULE__, :LuxReactantExt) === nothing
error("`ToReactantAdaptor` requires `LuxReactantExt.jl` to be loaded.")
end
return __to_reactant_adaptor(model, to.input_prototype)
end

function __to_reactant_adaptor end

0 comments on commit 44da1c8

Please sign in to comment.