Skip to content

Commit

Permalink
test: add BatchNorm to the lux test
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 27, 2024
1 parent 6516bce commit 477bf21
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
1 change: 1 addition & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Reactant
using Test
using Enzyme
using Statistics

# Reactant.set_default_backend("gpu")

Expand Down
11 changes: 7 additions & 4 deletions test/nn_lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-ele
# Define our model, a multi-layer perceptron with one hidden layer of size 3:
model = Lux.Chain(
Lux.Dense(2 => 3, tanh), # activation function inside layer
Lux.BatchNorm(3, gelu),
Lux.Dense(3 => 2),
softmax,
)
Expand All @@ -17,8 +18,7 @@ ps, st = Lux.setup(Xoshiro(123), model)
using BenchmarkTools

origout, _ = model(noisy, ps, st)
@show origout[3]
@btime model($noisy, $ps, $st) # 52.731 μs (10 allocations: 32.03 KiB)
@btime model($noisy, $ps, $st) # 68.444 μs (46 allocations: 45.88 KiB)

cmodel = Reactant.make_tracer(IdDict(), model, (), Reactant.ArrayToConcrete)
cps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete)
Expand All @@ -31,8 +31,9 @@ f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cs
# # @show @code_typed f(cmodel,cnoisy)
# # @show @code_llvm f(cmodel,cnoisy)
comp = f(cmodel, cnoisy, cps, cst)
@show comp[3]
@btime f($cmodel, $cnoisy, $cps, $cst) # 4.430 μs (5 allocations: 160 bytes)
@btime f($cmodel, $cnoisy, $cps, $cst) # 21.790 μs (6 allocations: 224 bytes)

@test comp origout atol = 1e-5 rtol = 1e-2

# To train the model, we use batches of 64 samples, and one-hot encoding:

Expand Down Expand Up @@ -81,6 +82,8 @@ compiled_gradient = Reactant.compile(
gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst)
)

@test length(compiled_gradient(cmodel, cnoisy, ctarget, cps, cst)) == 2

# # Training loop, using the whole data set 1000 times:
# losses = []
# for epoch in 1:1_000
Expand Down

0 comments on commit 477bf21

Please sign in to comment.