Skip to content

Commit

Permalink
Implement a working VJP function
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 30, 2024
1 parent 2a390ab commit a5a0190
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
44 changes: 26 additions & 18 deletions ext/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,38 @@ end
function Lux.__to_reactant_adaptor(
to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer,
input_prototype, ps, st, eltype_adaptor) where {FST}
output = first(model(input_prototype, ps, st))
concrete_output = __make_concrete_array(output)

concrete_input = __make_concrete_array(input_prototype)
cmodel = __make_concrete_array(model)
cps = __make_concrete_array(ps)
cst = __make_concrete_array(st)

csmodel = Lux.StatefulLuxLayer{FST}(cmodel, cps, cst)

fwd = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input))

bwd = try
enzyme_grad_fn = (m, x) -> begin
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(m.ps)
st = ifelse(FST, m.st, m.st_any)
Enzyme.autodiff(
Enzyme.Reverse, (m, x, ps, st) -> first(LuxCore.apply(m, x, ps, st)),
Enzyme.Duplicated, Enzyme.Const(m.model), Enzyme.Duplicated(x, dx),
Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st))
return (; ps=dps), dx
fwd_fn = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input))

function enzyme_vjp_fn(m, x, y, dy)
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(m.ps)
st_m = ifelse(FST, m.st, m.st_any)

function wrapper_fn!(y, model, x, ps, st)
copyto!(y, first(LuxCore.apply(model, x, ps, st)))
return nothing
end

Reactant.compile(enzyme_grad_fn, (csmodel, concrete_input))
Enzyme.autodiff(Enzyme.Reverse, wrapper_fn!, Enzyme.Const, Enzyme.Duplicated(y, dy),
Enzyme.Const(m.model), Enzyme.Duplicated(x, dx),
Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st_m))
return dx, dps
end

vjp_fn = try
concrete_output2 = __make_concrete_array(deepcopy(output))
Reactant.compile(
enzyme_vjp_fn, (csmodel, concrete_input, concrete_output, concrete_output2))
catch err
to.force_compile_backward && rethrow(err)
@error """
Expand All @@ -101,11 +111,9 @@ function Lux.__to_reactant_adaptor(
nothing
end

# TODO: Add compiled types to the layer type information. That way we can check
# if the model is being executed with the correct types.
return Lux.ReactantLayer{FST, Lux.__recursive_eltype(input_prototype)}(
to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd,
bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd_fn,
vjp_fn, eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
end

# TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal.
Expand Down Expand Up @@ -183,7 +191,7 @@ end

Lux.__apply_reactant(l, x, ps, st) = __graceful_type_mismatch_error(l, x, ps, st)

@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd(csmodel, x)
@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd_fn(csmodel, x)

# Don't inline, else types don't get displayed in the stack trace
function __graceful_type_mismatch_error(
Expand Down
7 changes: 4 additions & 3 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ end

# TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the
# gradient computation
@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType, F, B,
# TODO: Inference won't work OOTB, we will have to compile that separately
@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType,
L <: AbstractExplicitLayer, AD <: ToReactantAdaptor} <: AbstractExplicitLayer
adaptor::AD
input_prototype::inType
Expand All @@ -260,8 +261,8 @@ end
concrete_st::stType
layer::L
clayer
fwd::F
bwd::B
fwd_fn
vjp_fn
eltype_adaptor
input_structure
end
Expand Down

0 comments on commit a5a0190

Please sign in to comment.