Skip to content

Commit

Permalink
fix: use enzyme correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 21, 2024
1 parent 9b7f121 commit 879a599
Showing 1 changed file with 30 additions and 26 deletions.
56 changes: 30 additions & 26 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
mutable struct StatsAndNewStateWrapper
stats::Any
st::Any
end

function wrapped_objective_function(fn::F, model, ps, st, data, cache) where {F}
loss, stₙ, stats = fn(model, ps, st, data)
cache.stats = stats
cache.st = stₙ
return loss
end

function Lux.Training.compute_gradients_impl(
backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
Expand All @@ -23,11 +35,14 @@ function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data,
end

function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
return dps, loss, stats, stₙ
stats_wrapper = StatsAndNewStateWrapper(nothing, nothing)
res = Enzyme.gradient(
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
Const(wrapped_objective_function), Const(objective_function),
Const(model), ps, Const(st), Const(data), stats_wrapper
)
loss, dps = res.val, res.derivs[3]
return dps, loss, stats_wrapper.stats, stats_wrapper.st
end

for inplace in ("!", "")
Expand Down Expand Up @@ -70,27 +85,16 @@ for inplace in ("!", "")

return grads, loss, stats, ts
end
end

function compute_gradients_internal_and_step(objective_function::F, model, data, ps,
st, opt_state) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
opt_state, ps = Optimisers.update(opt_state, ps, dps)
return dps, ps, loss, stats, stₙ, opt_state
end
# XXX: Inplace version not actually inplace
internal_fn = Symbol(:compute_gradients_internal_and_step, inplace)
update_fn = Symbol(:update, inplace)

function compute_gradients_internal_and_step!(objective_function::F, model, data, ps,
st, opt_state) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
# XXX: Inplace updates not actually inplace
opt_state, ps = Optimisers.update!(opt_state, ps, dps)
return dps, ps, loss, stats, stₙ, opt_state
@eval function $(internal_fn)(
objective_function::F, model, data, ps, st, opt_state) where {F}
dps, loss, stats, stₙ = compute_gradients_internal(
objective_function, model, data, ps, st)
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
return dps, ps, loss, stats, stₙ, opt_state
end
end

0 comments on commit 879a599

Please sign in to comment.