From 63911dd01169abcdbf15cfc69d2f88d162ecacf0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 May 2024 18:33:02 -0400 Subject: [PATCH] Try compiling Enzyume gradient --- test/nn_lux.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/nn_lux.jl b/test/nn_lux.jl index af0605758..0b52f0994 100644 --- a/test/nn_lux.jl +++ b/test/nn_lux.jl @@ -1,4 +1,5 @@ using Reactant, Lux, Random, Statistics +using Enzyme using Test # Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: @@ -44,7 +45,6 @@ losses = [] # Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the # training loop manually: - function crossentropy(ŷ, y) logŷ = log.(ŷ) result = y .* logŷ @@ -57,10 +57,18 @@ function loss_function(model, x, y, ps, st) return crossentropy(y_hat, y) end -compiled_loss_function = Reactant.compile( - loss_function, (cmodel, cnoisy, ctarget, cps, cst)) +function gradient_loss_function(model, x, y, ps, st) + dps = Enzyme.make_zero(ps) + _, res = Enzyme.autodiff( + ReverseWithPrimal, loss_function, Active, Const(model), Const(x), Const(y), + Duplicated(ps, dps), Const(st)) + return res, dps +end +gradient_loss_function(model, noisy, target, ps, st) +compiled_gradient = Reactant.compile( + gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst)) # # Training loop, using the whole data set 1000 times: # losses = []