Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DifferentiationInterface testing #769

Open
gdalle opened this issue Jul 15, 2024 · 10 comments
Open

DifferentiationInterface testing #769

gdalle opened this issue Jul 15, 2024 · 10 comments
Labels

Comments

@gdalle
Copy link
Contributor

gdalle commented Jul 15, 2024

Hi @avik-pal!
I'm heading towards multi-argument and non-array support in DI, and I'd like to start testing Lux layers. For this I would need two things:

  • Suggestions for a test suite of layers so that we hit some corner cases
  • Your definition of what it means to be "the right gradient" (in other words, a recursive comparison function between a given gradient and the reference output).

Do you think you could help me out?

@gdalle gdalle changed the title DI testing DifferentiationInterface testing Jul 15, 2024
@avik-pal
Copy link
Member

@gdalle
Copy link
Contributor Author

gdalle commented Jul 19, 2024

Thanks! In the file you linked, the source of truth seems to be Zygote? What would you use to validate the Zygote gradients themselves?

@avik-pal
Copy link
Member

Currently I compute with Zygote and then test against other backends based on the device

  1. For CPU
    1. Tracker
    2. ReverseDiff
    3. ForwardDiff (if the array sizes are < 100)
    4. FiniteDifferences
    5. Enzyme (currently only tested in that file, but testing is being increased more here)
  2. For GPU
    1. Tracker
    2. ForwardDiff in certain situations depending on the problem

Tracker and Zygote hit very different code paths in LuxLib (Zygote is the optimized one with often handwritten rules). In case of conflict/mismatch, the general assumption is that Tracker (on GPU) or FiniteDifferences (on CPU) is the source of truth.

@avik-pal
Copy link
Member

I have been meaning to try out FiniteDiff to validate the GPU gradients but haven't had the time to set it up.

@gdalle
Copy link
Contributor Author

gdalle commented Jul 30, 2024

Hi @avik-pal,

gdalle/DifferentiationInterface.jl#372 has a first version of Lux tests with ComponentArrays.jl encoding. A few points remain unclear to me:

  • When you say the ground truth is finite differences, do you mean FiniteDifferences.jl or FiniteDiff.jl? I only found the latter in LuxTestUtils.jl
  • How do you handle the flattening and unflattening of ps for finite differences? The only code I found is this one, should I copy it inside DITest?

https://github.com/LuxDL/LuxTestUtils.jl/blob/100b2de3214a82dfdfc9643895fd66fc05ec0ddf/src/autodiff.jl#L72-L86

  • Because of DI's single-argument limitation, at the moment I close over (model, x, st) for the loss. Should I deepcopy(st) between two calls, in order to avoid state evolution?

@avik-pal
Copy link
Member

When you say the ground truth is finite differences, do you mean FiniteDifferences.jl or FiniteDiff.jl? I only found the latter in LuxTestUtils.jl

I meant FiniteDifferences.jl originally. But you are looking at the new release where I migrated to FiniteDiff 😅

How do you handle the flattening and unflattening of ps for finite differences? The only code I found is this one, should I copy it inside DITest?

That + https://github.com/LuxDL/LuxTestUtils.jl/blob/100b2de3214a82dfdfc9643895fd66fc05ec0ddf/src/utils.jl#L53-L62

Because of DI's single-argument limitation, at the moment I close over (model, x, st) for the loss. Should I deepcopy(st) between two calls, in order to avoid state evolution?

No states cannot be mutated. (It is a bug in Lux if it happens for any of the layers). 1 particular case to be careful about is TaskLocalRNG. We do print a warning but it is easy to miss. Since we cannot copy a TaskLocalRNG it is impossible to guarantee same results across multiple calls there. The recommended solution is to use Xoshiro in regular code and StableRNG in test code.

@gdalle
Copy link
Contributor Author

gdalle commented Jul 30, 2024

How do you choose FiniteDiff parameters like the epsilon? Keep the package defaults? I'm looking at test failures in gdalle/DifferentiationInterface.jl#372 which I think are due to the numerical errors in finite differencing, but I don't know if I should take even higher atol and rtol (I'm using FiniteDifferences.jl at the time of writing).

Are there layers that contain an rng? I don't see you handling this specifically in the Enzyme tests you pointed me to.

@avik-pal
Copy link
Member

avik-pal commented Jul 30, 2024

How do you choose FiniteDiff parameters like the epsilon? Keep the package defaults?

Yes keep the default. Normalization layers are tricky to test with Finite Differencing especially because of the reason you cited. In those cases, I rely on comparing Zygote with any of the other AD backends. For example, using Tracker hits generic codepaths and using Zygote hits optimized codepaths with custom rrules, so the assumption is that Tracker without custom rules gets the gradient correct. Alternatively for smaller systems comparing against ForwardDiff is also an option.

Are there layers that contain an rng? I don't see you handling this specifically in the Enzyme tests you pointed me to.

The rng passed to Lux.setup(rng, model) is cached with some ugly tricks to make reproducibility works. Make sure this rng is not TaskLocalRNG and you are good.

@gdalle
Copy link
Contributor Author

gdalle commented Jul 31, 2024

Have you ever encountered this error with Tracker + ComponentArrays? The tests pass with Zygote so I'm trying to add more backends
https://github.com/gdalle/DifferentiationInterface.jl/actions/runs/10176602495/job/28146280540?pr=393

@avik-pal
Copy link
Member

Yes you should call Tracker.param on the ComponentArray directly instead of NamedTuple --> Tracker.param --> ComponentArray

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants