Skip to content

Commit

Permalink
Try using Enzyme for the backward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 27, 2024
1 parent 78bbc1e commit 008953d
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion ext/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,26 @@ function Lux.__to_reactant_adaptor(

fwd = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input))

return Lux.ReactantLayer{FST}(model, cmodel, fwd, nothing)
bwd = try
enzyme_grad_fn = (m, x) -> begin
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(m.ps)
st = ifelse(FST, m.st, m.st_any)
Enzyme.autodiff(
Enzyme.Reverse, (m, x, ps, st) -> first(LuxCore.apply(m, x, ps, st)),
Enzyme.Duplicated, Enzyme.Const(m), Enzyme.Duplicated(x, dx),
Enzyme.Duplicated(ps, dps), Enzyme.Const(st))
return (; ps=dps), dx
end

Reactant.compile(enzyme_grad_fn, (csmodel, concrete_input))
catch err
@error "Enzyme failed to compile the backward pass. Differentiation will be \
disabled for this model." exception=err
return nothing
end

return Lux.ReactantLayer{FST}(model, cmodel, fwd, bwd)
end

function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer)
Expand Down

0 comments on commit 008953d

Please sign in to comment.