From 07811ade08baa8ed44c1456ed9e4de6a44caca65 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 22 Dec 2024 00:03:34 +0530 Subject: [PATCH] fix: use enzyme correctly --- ext/LuxReactantExt/training.jl | 56 ++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index a37cd54a5..1a60bcb7e 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -1,3 +1,15 @@ +mutable struct StatsAndNewStateWrapper + stats::Any + st::Any +end + +function wrapped_objective_function(fn::F, model, ps, st, data, cache) where {F} + loss, stₙ, stats = fn(model, ps, st, data) + cache.stats = stats + cache.st = stₙ # XXX: Reactant bug here + return loss +end + function Lux.Training.compute_gradients_impl( backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} @@ -23,11 +35,14 @@ function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data, end function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F} - dps = Enzyme.make_zero(ps) - _, (loss, stₙ, stats) = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), - Duplicated(ps, dps), Const(st), Const(data)) - return dps, loss, stats, stₙ + stats_wrapper = StatsAndNewStateWrapper(nothing, nothing) + res = Enzyme.gradient( + Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), + Const(wrapped_objective_function), Const(objective_function), + Const(model), ps, Const(st), Const(data), stats_wrapper + ) + loss, dps = res.val, res.derivs[3] + return dps, loss, stats_wrapper.stats, stats_wrapper.st end for inplace in ("!", "") @@ -70,27 +85,16 @@ for inplace in ("!", "") return grads, loss, stats, ts end -end -function compute_gradients_internal_and_step(objective_function::F, model, data, ps, - st, opt_state) where {F} - dps = Enzyme.make_zero(ps) - _, (loss, stₙ, stats) = Enzyme.autodiff( - Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), - Const(objective_function), Active, Const(model), - Duplicated(ps, dps), Const(st), Const(data)) - opt_state, ps = Optimisers.update(opt_state, ps, dps) - return dps, ps, loss, stats, stₙ, opt_state -end + # XXX: Inplace version not actually inplace + internal_fn = Symbol(:compute_gradients_internal_and_step, inplace) + update_fn = Symbol(:update, inplace) -function compute_gradients_internal_and_step!(objective_function::F, model, data, ps, - st, opt_state) where {F} - dps = Enzyme.make_zero(ps) - _, (loss, stₙ, stats) = Enzyme.autodiff( - Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), - Const(objective_function), Active, Const(model), - Duplicated(ps, dps), Const(st), Const(data)) - # XXX: Inplace updates not actually inplace - opt_state, ps = Optimisers.update!(opt_state, ps, dps) - return dps, ps, loss, stats, stₙ, opt_state + @eval function $(internal_fn)( + objective_function::F, model, data, ps, st, opt_state) where {F} + dps, loss, stats, stₙ = compute_gradients_internal( + objective_function, model, data, ps, st) + opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps) + return dps, ps, loss, stats, stₙ, opt_state + end end