From 89676522df8df1e03a4265acff148272e0540a52 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 5 Dec 2024 06:53:54 +0100 Subject: [PATCH] attempts at unbreaking enzyme tests --- test/ext_enzyme/enzyme.jl | 139 ++++++++++++++++++++++---------------- test/test_utils.jl | 18 ++++- 2 files changed, 98 insertions(+), 59 deletions(-) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 5d91770363..0208503ee3 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -1,6 +1,7 @@ using Test using Flux import Zygote +using Statistics, Random using Enzyme: Enzyme, make_zero, Active, Duplicated, Const, ReverseWithPrimal @@ -8,61 +9,87 @@ using Functors using FiniteDifferences -function gradient_fd(f, x...) - f = f |> f64 - x = [cpu(x) for x in x] - ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x] - ps = [f64(x[1]) for x in ps_and_res] - res = [x[2] for x in ps_and_res] - fdm = FiniteDifferences.central_fdm(5, 1) - gs = FiniteDifferences.grad(fdm, (ps...) -> f((re(p) for (p,re) in zip(ps, res))...), ps...) - return ((re(g) for (re, g) in zip(res, gs))...,) -end +function test_gradients_ez( + f, + xs...; + rtol=1e-4, atol=1e-4, + test_gpu = false, + test_grad_f = true, + test_grad_x = true, + compare_finite_diff = true, + loss = (f, xs...) -> mean(f(xs...)), + ) + + if !test_gpu && !compare_finite_diff + error("You should either compare finite diff vs CPU AD \ + or CPU AD vs GPU AD.") + end -function gradient_ez(f, x...) - args = [] - for x in x - if x isa Number - push!(args, Active(x)) - else - push!(args, Duplicated(x, make_zero(x))) - end + ## Let's make sure first that the forward pass works. + l = loss(f, xs...) + @test l isa Number + if test_gpu + gpu_dev = gpu_device(force=true) + cpu_dev = cpu_device() + xs_gpu = xs |> gpu_dev + f_gpu = f |> gpu_dev + l_gpu = loss(f_gpu, xs_gpu...) + @test l_gpu isa Number end - ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) - g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x)) - return g -end -function test_grad(g1, g2; broken=false) - fmap_with_path(g1, g2) do kp, x, y - :state ∈ kp && return # ignore RNN and LSTM state - if x isa AbstractArray{<:Number} - # @show kp - @test x ≈ y rtol=1e-2 atol=1e-6 broken=broken + if test_grad_x + # Zygote gradient with respect to input. + y, g = ez_withgradient((xs...) -> loss(f, xs...), xs...) + + if compare_finite_diff + # Cast to Float64 to avoid precision issues. + f64 = f |> Flux.f64 + xs64 = xs .|> Flux.f64 + y_fd, g_fd = finitediff_withgradient((xs...) -> loss(f64, xs...), xs64...) + @test y ≈ y_fd rtol=rtol atol=atol + check_equal_leaves(g, g_fd; rtol, atol) end - return x - end -end -function test_enzyme_grad(loss, model, x) - Flux.trainmode!(model) - l = loss(model, x) - @test loss(model, x) == l # Check loss doesn't change with multiple runs + if test_gpu + # Zygote gradient with respect to input on GPU. + y_gpu, g_gpu = ez_withgradient((xs...) -> loss(f_gpu, xs...), xs_gpu...) + @test get_device(g_gpu) == get_device(xs_gpu) + @test y_gpu ≈ y rtol=rtol atol=atol + check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) + end + end - grads_fd = gradient_fd(loss, model, x) |> cpu - grads_flux = Flux.gradient(loss, model, x) |> cpu - grads_enzyme = gradient_ez(loss, model, x) |> cpu + if test_grad_f + # Zygote gradient with respect to f. + y, g = ez_withgradient(f -> loss(f, xs...), f) + + if compare_finite_diff + # Cast to Float64 to avoid precision issues. + f64 = f |> Flux.f64 + ps, re = Flux.destructure(f64) + y_fd, g_fd = finitediff_withgradient(ps -> loss(re(ps), xs...), ps) + g_fd = (re(g_fd[1]),) + @test y ≈ y_fd rtol=rtol atol=atol + check_equal_leaves(g, g_fd; rtol, atol) + end - # test_grad(grads_flux, grads_enzyme) - test_grad(grads_fd, grads_enzyme) + if test_gpu + # Zygote gradient with respect to f on GPU. + y_gpu, g_gpu = ez_withgradient(f -> loss(f, xs_gpu...), f_gpu) + # @test get_device(g_gpu) == get_device(xs_gpu) + @test y_gpu ≈ y rtol=rtol atol=atol + check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) + end + end + return true end -@testset "gradient_ez" begin +@testset "test_utils.jl: ez_withgradient" begin @testset "number and arrays" begin f(x, y) = sum(x.^2) + y^3 x = Float32[1, 2, 3] y = 3f0 - g = gradient_ez(f, x, y) + out, g = ez_withgradient(f, x, y) @test g[1] isa Array{Float32} @test g[2] isa Float32 @test g[1] ≈ 2x @@ -82,7 +109,7 @@ end x = randn(Float32, 2) loss(model, x) = sum(model(x)) - g = gradient_ez(loss, model, x) + out, g = ez_withgradient(loss, model, x) @test g[1] isa SimpleDense @test g[2] isa Array{Float32} @test g[1].weight isa Array{Float32} @@ -93,10 +120,6 @@ end end @testset "Models" begin - function loss(model, x) - sum(model(x)) - end - models_xs = [ (Dense(2=>4), randn(Float32, 2), "Dense"), (Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"), @@ -117,7 +140,7 @@ end for (model, x, name) in models_xs @testset "Enzyme grad check $name" begin println("testing $name with Enzyme") - test_enzyme_grad(loss, model, x) + test_gradients_ez(model, x) end end end @@ -127,12 +150,12 @@ end for i in 1:3 x = model(x) end - return sum(x) + return mean(x) end - struct LSTMChain - rnn1 - rnn2 + struct LSTMChain{RNN1, RNN2} + rnn1::RNN1 + rnn2::RNN2 end function (m::LSTMChain)(x) st = m.rnn1(x) @@ -141,17 +164,17 @@ end end models_xs = [ - # (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), - # (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), - # (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - # (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), - # (LSTMChain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "LSTMChain(LSTM, LSTM)"), + (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), + (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), + (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), + (LSTMChain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "LSTMChain(LSTM, LSTM)"), ] for (model, x, name) in models_xs @testset "check grad $name" begin println("testing $name") - test_enzyme_grad(loss, model, x) + test_gradients_ez(model, x; loss) end end end diff --git a/test/test_utils.jl b/test/test_utils.jl index c736943f1c..788547cfc5 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -11,13 +11,27 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss] - +### Finite Differences #### function finitediff_withgradient(f, x...) y = f(x...) # We set a range to avoid domain errors fdm = FiniteDifferences.central_fdm(5, 1, max_range=1e-2) return y, FiniteDifferences.grad(fdm, f, x...) end +########################### + +### Enzyme ##### +duplicated(x) = Duplicated(x, make_zero(x)) +duplicated(x::Number) = Active(x) + +function ez_withgradient(f, xs...) + dups = [duplicated(x) for x in xs] + grads, y = Enzyme.autodiff(ReverseWithPrimal, Const(f), Active, dups...) + # numbers' gradient are returned in g, arrays' gradient are returned in dups + gs = [g isa Number ? g : d.dval for (g, d) in zip(grads, dups)] + return y, gs +end +################ function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4) fmapstructure_with_path(a, b) do kp, x, y @@ -103,3 +117,5 @@ function test_gradients( end return true end + +