-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
76 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "Lux" | ||
uuid = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "0.5.51" | ||
version = "0.5.52" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
|
@@ -40,6 +40,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | |
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" | ||
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" | ||
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | ||
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" | ||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" | ||
|
@@ -57,6 +58,7 @@ LuxMLUtilsExt = "MLUtils" | |
LuxMPIExt = "MPI" | ||
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] | ||
LuxOptimisersExt = "Optimisers" | ||
LuxReactantExt = "Reactant" | ||
LuxReverseDiffExt = "ReverseDiff" | ||
LuxSimpleChainsExt = "SimpleChains" | ||
LuxTrackerExt = "Tracker" | ||
|
@@ -102,6 +104,7 @@ Pkg = "1.10" | |
PrecompileTools = "1.2" | ||
Preferences = "1.4.3" | ||
Random = "1.10" | ||
Reactant = "0.1.1" | ||
ReTestItems = "1.23.1" | ||
Reexport = "1.2.2" | ||
ReverseDiff = "1.15" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
module LuxReactantExt | ||
|
||
using ArgCheck: @argcheck | ||
using Random: AbstractRNG, Xoshiro | ||
using Reactant: Reactant | ||
using Lux: Lux | ||
using LuxCore: LuxCore, AbstractExplicitLayer | ||
|
||
@inline __make_concrete_array(x::Reactant.ConcreteRArray) = x | ||
@inline __make_concrete_array(x::AbstractArray) = Reactant.ConcreteRArray(x) | ||
@inline function __make_concrete_array(x) | ||
return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) | ||
end | ||
|
||
# FIXME: currently only `stateless_apply` is supported: https://github.com/EnzymeAD/Reactant.jl/issues/8 | ||
function Lux.__to_reactant_adaptor(model::AbstractExplicitLayer, input_prototype) | ||
concrete_input = __make_concrete_array(input_prototype) | ||
cmodel = __make_concrete_array(model) | ||
# We generate fake parameters and states to compile the model | ||
ps = LuxCore.initialparameters(Xoshiro(123), model) | ||
cps = __make_concrete_array(ps) | ||
|
||
st = LuxCore.initialstates(Xoshiro(123), model) | ||
@argcheck st==LuxCore._getemptystate(model) "Currently only stateless models are supported." | ||
|
||
fwd = Reactant.compile( | ||
(m, x, ps) -> LuxCore.stateless_apply(m, x, ps), (cmodel, concrete_input, cps)) | ||
|
||
# TODO: conditionally compile the backward pass | ||
|
||
return Lux.ReactantLayer(model, cmodel, fwd, nothing) | ||
end | ||
|
||
function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) | ||
return __make_concrete_array(LuxCore.initialparameters(rng, layer.layer)) | ||
end | ||
|
||
# FIXME: Change once https://github.com/EnzymeAD/Reactant.jl/pull/8 is fixed | ||
function LuxCore.initialstates(::AbstractRNG, layer::Lux.ReactantLayer) | ||
return NamedTuple() # __make_concrete_array(LuxCore.initialstates(rng, layer.layer)) | ||
end | ||
|
||
# TODO: Add a type assert here to make it type stable | ||
function (l::Lux.ReactantLayer)(x, ps, ::NamedTuple{()}) | ||
return LuxCore.stateless_apply(l.clayer, __make_concrete_array(x), ps), NamedTuple() | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# TODO: Add options to compile the gradients directly using Enzyme.jl | ||
@concrete struct ToReactantAdaptor <: AbstractFromLuxAdaptor | ||
input_prototype | ||
end | ||
|
||
function Adapt.adapt(to::ToReactantAdaptor, model::AbstractExplicitLayer) | ||
if Base.get_extension(@__MODULE__, :LuxReactantExt) === nothing | ||
error("`ToReactantAdaptor` requires `LuxReactantExt.jl` to be loaded.") | ||
end | ||
return __to_reactant_adaptor(model, to.input_prototype) | ||
end | ||
|
||
function __to_reactant_adaptor end |