Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DifferentiationInterface testing #2469

Open
gdalle opened this issue Jul 15, 2024 · 6 comments
Open

DifferentiationInterface testing #2469

gdalle opened this issue Jul 15, 2024 · 6 comments

Comments

@gdalle
Copy link

gdalle commented Jul 15, 2024

Hi there!
I'm heading towards multi-argument and non-array support in DI, and I'd like to start testing Lux layers. For this I would need two things:

  • Suggestions for a test suite of layers so that we hit some corner cases
  • Your definition of what it means to be "the right gradient" (in other words, a recursive comparison function between a given gradient and the reference output).

Do you think you could help me out?

@CarloLucibello
Copy link
Member

You can take a look at the tests we added for Enzyme

https://github.com/FluxML/Flux.jl/blob/master/test/ext_enzyme/enzyme.jl

e.g. begin with

x = rand(Float32, 2, 1)
model = Chain(Dense(2=>3, relu), Dense(3=>2))
g = gradient(model -> sum(model(x)), model)[1]

We impose little limitations on gradients, they can be nested structs or named structs.
For instance, the ones returned by Enzyme and the ones returned by Zygote are compared by

function test_grad(g1, g2; broken=false)
    fmap_with_path(g1, g2) do kp, x, y
        :state ∈ kp && return # ignore RNN and LSTM state
        if x isa AbstractArray{<:Number}
            # @show kp
            @test x ≈ y rtol=1e-2 atol=1e-6 broken=broken
        end
        return x
    end
end

where fmap_with_path is defined in Functors.jl. So what we need is a gradient for each numerical array leaf in the original object. These leaves should be reachable through the same "path", e.g. g.layers[1].weight.

@gdalle
Copy link
Author

gdalle commented Jul 19, 2024

I'm having issues when comparing the true gradients with finite differences. Depending on the random seed I get unpredictable failures. Is that a problem in the Flux test suite as well @CarloLucibello? I didn't find a way to pass an rng to the network constructors, do I have to seed! the global rng?
For now I have increased atol and rtol but it's hard to know the right threshold.

@CarloLucibello
Copy link
Member

CarloLucibello commented Jul 19, 2024

We do Random.seed!(0) in runtests.jl and we don't see test failures, but I would have expected the tests to be robust. Can you identify the frail ones? Maybe the ones with RNNs?

@gdalle
Copy link
Author

gdalle commented Jul 19, 2024

I'll try! Which backends should I aim to test? Zygote, Enzyme and Tracker?

@CarloLucibello
Copy link
Member

We don't support Tracker anymore. Primarly Zygote, and experimentally Enzyme.

@gdalle
Copy link
Author

gdalle commented Jul 19, 2024

I added a random seed in gdalle/DifferentiationInterface.jl#371, tests seem to pass for Zygote with the same tolerances as you. I'll notify you if I see random failures further down the road.

Any idea why Enzyme fails on two scenarios only (see the PR for details)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants