-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Patch Flux._isleaf
for abstract arrays with bitstype elements
#2436
Conversation
Thanks for the PR! I'm a little late responding to the main issue so I'll do so here. It looks like there are still GPU-related tests failing and it's not immediately clear why, so more tweaking may be required. Stepping back a bit, you were right to point out in #2432 that the call stack/responsibilities of various libraries here is not super clear. Let me lay out what I think a more ideal design could be, and let's see how much of that can inform this PR. One possibility to bridge the disparate views of Flux, Functors and Adapt (which wasn't mentioned in #2432 but is a very important third player here) is to move some logic to the latter. My proposal would be to share the cache |
I think the only failure is Line 112 in c442f0c
which fails because previously
All of this sounds fantastic, and of course much more robust than this PR, which is really just meant as a stopgap to patch adjoint/transpose/etc. until a more complete fix is implemented. |
…CuArray}` instead of plain `CuArray`
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2436 +/- ##
===========================================
+ Coverage 46.37% 74.10% +27.72%
===========================================
Files 32 32
Lines 1876 1923 +47
===========================================
+ Hits 870 1425 +555
+ Misses 1006 498 -508 ☔ View full report in Codecov by Sentry. |
Fixes #2432:
On master we have
Flux._isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)
, but_isbitsarray(x)
returnstrue
for anyAbstractArray{T}
whereisbitstype(T) == true
, and so we get_isbitsarray([1.0]') == true
and thereforeFlux._isleaf([1.0]') == true
. In the referenced issue, this breaks parameter sharing between a set of weights and their transpose when a model is moved to the gpu.The fundamental issue is that AFAICT there is not a good way to extend
Functors.isleaf
outside of Functors.jl for abstract types which may contain children of the same abstract type. For example, hereTranspose <: AbstractArray
contains aparent::AbstractArray
field, and so Functors.jl must overloadFunctor.functor(::Transpose)
otherwise it would not be recursed into. But of course this can't be done similarly outside of Functors.jl without type piracy (henceFlux._isleaf
).So in order to:
Functors.functor(::AbstractArray)
methods defined here into Flux, andAbstractArray
s with bitstype elements as leaves,I've removed
_isbitsarray
in favour of defining_isleaf
methods directly, and special-casedFlux._isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false
to match Functors.jl. The other option is to do this in Functors.jl directly, but I'm not sure if treating all bitstype arrays as leaves is desirable in general (also it's probably breaking?).