Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attempts at unbreaking enzyme tests #2540

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 81 additions & 58 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,68 +1,95 @@
using Test
using Flux
import Zygote
using Statistics, Random

using Enzyme: Enzyme, make_zero, Active, Duplicated, Const, ReverseWithPrimal

using Functors
using FiniteDifferences


function gradient_fd(f, x...)
f = f |> f64
x = [cpu(x) for x in x]
ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x]
ps = [f64(x[1]) for x in ps_and_res]
res = [x[2] for x in ps_and_res]
fdm = FiniteDifferences.central_fdm(5, 1)
gs = FiniteDifferences.grad(fdm, (ps...) -> f((re(p) for (p,re) in zip(ps, res))...), ps...)
return ((re(g) for (re, g) in zip(res, gs))...,)
end
function test_gradients_ez(
f,
xs...;
rtol=1e-4, atol=1e-4,
test_gpu = false,
test_grad_f = true,
test_grad_x = true,
compare_finite_diff = true,
loss = (f, xs...) -> mean(f(xs...)),
)

if !test_gpu && !compare_finite_diff
error("You should either compare finite diff vs CPU AD \
or CPU AD vs GPU AD.")
end

function gradient_ez(f, x...)
args = []
for x in x
if x isa Number
push!(args, Active(x))
else
push!(args, Duplicated(x, make_zero(x)))
end
## Let's make sure first that the forward pass works.
l = loss(f, xs...)
@test l isa Number
if test_gpu
gpu_dev = gpu_device(force=true)
cpu_dev = cpu_device()
xs_gpu = xs |> gpu_dev
f_gpu = f |> gpu_dev
l_gpu = loss(f_gpu, xs_gpu...)
@test l_gpu isa Number
end
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return g
end

function test_grad(g1, g2; broken=false)
fmap_with_path(g1, g2) do kp, x, y
:state ∈ kp && return # ignore RNN and LSTM state
if x isa AbstractArray{<:Number}
# @show kp
@test x ≈ y rtol=1e-2 atol=1e-6 broken=broken
if test_grad_x
# Zygote gradient with respect to input.
y, g = ez_withgradient((xs...) -> loss(f, xs...), xs...)

if compare_finite_diff
# Cast to Float64 to avoid precision issues.
f64 = f |> Flux.f64
xs64 = xs .|> Flux.f64
y_fd, g_fd = finitediff_withgradient((xs...) -> loss(f64, xs...), xs64...)
@test y ≈ y_fd rtol=rtol atol=atol
check_equal_leaves(g, g_fd; rtol, atol)
end
return x
end
end

function test_enzyme_grad(loss, model, x)
Flux.trainmode!(model)
l = loss(model, x)
@test loss(model, x) == l # Check loss doesn't change with multiple runs
if test_gpu
# Zygote gradient with respect to input on GPU.
y_gpu, g_gpu = ez_withgradient((xs...) -> loss(f_gpu, xs...), xs_gpu...)
@test get_device(g_gpu) == get_device(xs_gpu)
@test y_gpu ≈ y rtol=rtol atol=atol
check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol)
end
end

grads_fd = gradient_fd(loss, model, x) |> cpu
grads_flux = Flux.gradient(loss, model, x) |> cpu
grads_enzyme = gradient_ez(loss, model, x) |> cpu
if test_grad_f
# Zygote gradient with respect to f.
y, g = ez_withgradient(f -> loss(f, xs...), f)

if compare_finite_diff
# Cast to Float64 to avoid precision issues.
f64 = f |> Flux.f64
ps, re = Flux.destructure(f64)
y_fd, g_fd = finitediff_withgradient(ps -> loss(re(ps), xs...), ps)
g_fd = (re(g_fd[1]),)
@test y ≈ y_fd rtol=rtol atol=atol
check_equal_leaves(g, g_fd; rtol, atol)
end

# test_grad(grads_flux, grads_enzyme)
test_grad(grads_fd, grads_enzyme)
if test_gpu
# Zygote gradient with respect to f on GPU.
y_gpu, g_gpu = ez_withgradient(f -> loss(f, xs_gpu...), f_gpu)
# @test get_device(g_gpu) == get_device(xs_gpu)
@test y_gpu ≈ y rtol=rtol atol=atol
check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol)
end
end
return true
end

@testset "gradient_ez" begin
@testset "test_utils.jl: ez_withgradient" begin
@testset "number and arrays" begin
f(x, y) = sum(x.^2) + y^3
x = Float32[1, 2, 3]
y = 3f0
g = gradient_ez(f, x, y)
out, g = ez_withgradient(f, x, y)
@test g[1] isa Array{Float32}
@test g[2] isa Float32
@test g[1] ≈ 2x
Expand All @@ -82,7 +109,7 @@ end
x = randn(Float32, 2)
loss(model, x) = sum(model(x))

g = gradient_ez(loss, model, x)
out, g = ez_withgradient(loss, model, x)
@test g[1] isa SimpleDense
@test g[2] isa Array{Float32}
@test g[1].weight isa Array{Float32}
Expand All @@ -93,10 +120,6 @@ end
end

@testset "Models" begin
function loss(model, x)
sum(model(x))
end

models_xs = [
(Dense(2=>4), randn(Float32, 2), "Dense"),
(Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"),
Expand All @@ -117,7 +140,7 @@ end
for (model, x, name) in models_xs
@testset "Enzyme grad check $name" begin
println("testing $name with Enzyme")
test_enzyme_grad(loss, model, x)
test_gradients_ez(model, x)
end
end
end
Expand All @@ -127,12 +150,12 @@ end
for i in 1:3
x = model(x)
end
return sum(x)
return mean(x)
end

struct LSTMChain
rnn1
rnn2
struct LSTMChain{RNN1, RNN2}
rnn1::RNN1
rnn2::RNN2
end
function (m::LSTMChain)(x)
st = m.rnn1(x)
Expand All @@ -141,17 +164,17 @@ end
end

models_xs = [
# (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
# (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
# (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
# (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
# (LSTMChain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "LSTMChain(LSTM, LSTM)"),
(RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
(LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
(Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
(LSTMChain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "LSTMChain(LSTM, LSTM)"),
]

for (model, x, name) in models_xs
@testset "check grad $name" begin
println("testing $name")
test_enzyme_grad(loss, model, x)
test_gradients_ez(model, x; loss)
end
end
end
Expand Down
18 changes: 17 additions & 1 deletion test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,27 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss,
Flux.Losses.siamese_contrastive_loss]


### Finite Differences ####
function finitediff_withgradient(f, x...)
y = f(x...)
# We set a range to avoid domain errors
fdm = FiniteDifferences.central_fdm(5, 1, max_range=1e-2)
return y, FiniteDifferences.grad(fdm, f, x...)
end
###########################

### Enzyme #####
duplicated(x) = Duplicated(x, make_zero(x))
duplicated(x::Number) = Active(x)

function ez_withgradient(f, xs...)
dups = [duplicated(x) for x in xs]
grads, y = Enzyme.autodiff(ReverseWithPrimal, Const(f), Active, dups...)
# numbers' gradient are returned in g, arrays' gradient are returned in dups
gs = [g isa Number ? g : d.dval for (g, d) in zip(grads, dups)]
return y, gs
end
################

function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
fmapstructure_with_path(a, b) do kp, x, y
Expand Down Expand Up @@ -103,3 +117,5 @@ function test_gradients(
end
return true
end


Loading