Skip to content

Commit

Permalink
Try compiling Enzyume gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 16, 2024
1 parent 1738019 commit 63911dd
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions test/nn_lux.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Reactant, Lux, Random, Statistics
using Enzyme
using Test

# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix:
Expand Down Expand Up @@ -44,7 +45,6 @@ losses = []

# Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the
# training loop manually:

function crossentropy(ŷ, y)
logŷ = log.(ŷ)
result = y .* logŷ
Expand All @@ -57,10 +57,18 @@ function loss_function(model, x, y, ps, st)
return crossentropy(y_hat, y)
end

compiled_loss_function = Reactant.compile(
loss_function, (cmodel, cnoisy, ctarget, cps, cst))
function gradient_loss_function(model, x, y, ps, st)
dps = Enzyme.make_zero(ps)
_, res = Enzyme.autodiff(
ReverseWithPrimal, loss_function, Active, Const(model), Const(x), Const(y),
Duplicated(ps, dps), Const(st))
return res, dps
end

gradient_loss_function(model, noisy, target, ps, st)

compiled_gradient = Reactant.compile(
gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst))

# # Training loop, using the whole data set 1000 times:
# losses = []
Expand Down

0 comments on commit 63911dd

Please sign in to comment.