Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote hangs when taking explicit gradients of NaiveGAFlux model #1243

Closed
ToucheSir opened this issue Jun 14, 2022 · 7 comments · Fixed by #1248
Closed

Zygote hangs when taking explicit gradients of NaiveGAFlux model #1243

ToucheSir opened this issue Jun 14, 2022 · 7 comments · Fixed by #1248
Labels
bug Something isn't working

Comments

@ToucheSir
Copy link
Member

Continuing from FluxML/Flux.jl#1986 (comment). @DrChainsaw are you able to capture a profile or at least a stacktrace mid-hang? I think that would be the easiest way to get started troubleshooting this and trying to put together a MWE.

@DrChainsaw
Copy link

Unfortunately I can't seem to terminate the program gracefully enough to get a stacktrace.

Here is a WE (notice the absence of M) where I manually created the model layer by layer until the gradient calculation stalled. I haven't run this one for 8 hours though so maybe it is not enough.

Packages:

]add NaiveNASflux#942bc90, [email protected], NaiveNASlib, ChainRulesCore, Flux, Functors

That NaiveNASflux commit is from this PR where I have removed the Zygote adjoint which accidentally also masked the problem due to usage of @nograd.

rrule definitions:

using NaiveNASflux, NaiveNASlib.Extend, Flux
import Functors

import ChainRulesCore
import ChainRulesCore: RuleConfig, HasReverseMode, rrule, rrule_via_ad, NoTangent

# Need to pirate NaiveNASlibs forward pass because it is not Zygote compatible (it uses get!)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(NaiveNASlib.output!), memo, v)
    rrule_via_ad(config, output_rrule!, memo, v)
end

# This is just for logging and so we can return NoTangent instead of the computed gradient 
function output_rrule!(args...) end
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(output_rrule!), memo, v)
    res, back = rrule_via_ad(config, _output_rrule!, memo, v)
    @info "Forward $(name(v))"
    return res, function (d)
        @info "Backward $(name(v))"
        back(d)
        # Uncomment the line below to prevent the stall
        #return NoTangent(), NoTangent(), NoTangent()
    end
end

# This is the actual Zygore compatible forward pass
function _output_rrule!(memo, v::AbstractVertex)
    v in keys(memo) && return memo[v]
    inpt = map(iv -> output_rrule!(memo, iv),  inputs(v))
    memo[v] = v(inpt...)
end

Logging of forwards and backwards pass is not needed to trigger the stall, just remove them if they bother you.

Model definition:

function makemodel(;layerfun=identity)
    iv = conv2dinputvertex("in", 3)
    v1 = convvertex("v1", (5,5), iv, 8; layerfun)
    v2 = bnvertex("v2", v1, selu; layerfun)
    v3 = fluxvertex("v3", MeanPool((2,2)), v2; layerfun)

    # Fork with 3 paths a, b and c
    v4a1 = bnvertex("v4a1", v3, relu; layerfun)
    v4a2 = convvertex("v4a2", (1, 7), v4a1, 8, selu; layerfun)

    v4b1 = convvertex("v4b1", (3,3), v3, 8, relu; layerfun)
    v4b2 = bnvertex("v4b2", v4b1; layerfun)
    v4b3 = convvertex("v4b3", (7,3), v4b2, 256; layerfun)

    v4c1 = convvertex("v4c1", (3,3), v3, 8, relu; layerfun)
    v4c2 = bnvertex("v4c2", v4c1; layerfun)
    v4c3 = convvertex("v4c3", (1,7), v4c2, 16, relu; layerfun)
    v4c4 = bnvertex("v4c4", v4c3, relu; layerfun)

    v5 = concat("v5", v4a2, v4b3, v4c4; layerfun)
    v6 = fluxvertex("v6", MaxPool((2,2)), v5; layerfun)
    v7 = convvertex("v7", (3,3), v6, 32; layerfun)
    v8 = bnvertex("v8", v7, selu; layerfun)
    v9 = convvertex("v9", (5,3), v8, 512, selu; layerfun)
    v10 = bnvertex("v10", v9; layerfun)
    v11 = fluxvertex("v11", Conv((2,2), nout(v10) => 512, relu; stride=2), v10; layerfun)

    CompGraph(iv, v11)
end

function convvertex(name, ks, in, outsize, act=identity; layerfun) 
    fluxvertex(name, Conv(ks, nout(in) => outsize, act; pad=SamePad()), in; layerfun)
end

bnvertex(name, in, act=identity; layerfun) = fluxvertex(name, BatchNorm(nout(in), act), in; layerfun)

I made some attempts at removing parts of the model to make it simpler, but nothing exhaustive. Here is one example of how to generate an overview of the model if needed:

[name.(vertices(model)) layer.(vertices(model)) map(v -> name.(inputs(v)), vertices(model))]

Some utilities for experiments:

# Triggers issue #1111 
mutable struct MutableWrapper{T}
    wrapped::T
end
(m::MutableWrapper)(x...) = m.wrapped(x...)
NaiveNASflux.layertype(g::MutableWrapper) = NaiveNASflux.layertype(g.wrapped)
NaiveNASlib.nout(g::MutableWrapper) = nout(g.wrapped)

# Removes output vertices, making the CompGraph structure non-cyclic
stripoutputs(g::CompGraph) = Functors.fmap(identity, g; walk=stripoutputs)
stripoutputs(f, x) = Functors._default_walk(f, x)
stripoutputs(f, v::InputVertex) = Functors._default_walk(f, v)
stripoutputs(f, v::AbstractVertex) = stripoutputs(f, base(v))
stripoutputs(f, v::CompVertex) = Functors._default_walk(f, v)

Phew, here is the experiment code. Logging output is omitted for brevity. Forward and Backwards are always printed for each vertex.

# This terminates despite BatchNorm having NoTangent. Gradients are given for all layers except BatchNorms
gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel());  

# This also terminates, despite the MutableWrapper erasing all gradients
gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel(layerfun=MutableWrapper) |> stripoutputs);

# But this stalls (unless you uncomment the return statement in the rrule definition)
gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel(layerfun=MutableWrapper));

It could be so that the outputs is a red herring, and the middle example only works because the structure is simpler overall. This works however:

gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel(layerfun=reduce(, Iterators.repeated(MutableWrapper,10))) |> stripoutputs);

Note the result is the same with implicit gradients for all examples above.

@mcabbott mcabbott added the bug Something isn't working label Jul 26, 2022
@mcabbott
Copy link
Member

mcabbott commented Jul 30, 2022

This mutable example FluxML/Flux.jl#1986 (comment) can I think be simplified to this, and has been broken at least since Zygote v0.6.0:

julia> gradient(x -> x[], Ref(1.0))
(Base.RefValue{Any}((x = 1.0,)),)  # v0.6.0
((x = 1.0,),)  # v0.6.41

julia> gradient(x -> x[1][], (Ref(1.0),))
(nothing,)

That nothing comes from lines like these:

accum_param(__context__, val, Δ) === nothing && return

accum_param(__context__, val, Δ) === nothing && return

The accum_param is necessary for implicit mode, but removing the return nothing branches (which are taken for mutable structs) seems to fix these examples.

I am not sure why it was ever there. I worry a little about introducing double-counting between the updated Ref version and the returned version. But in the examples I can invent, one seems to matter for implicit mode, the other for explicit.

@ToucheSir
Copy link
Member Author

ToucheSir commented Jul 31, 2022

Nice bisect. With #1248:

julia> gradient(x -> x[1][], (Ref(1.0),))
(((x = 1.0,),),)

Edit: #1243 (comment) appears to terminate locally as well with the aforementioned PR. @DrChainsaw do you mind checking on your side?

@mcabbott
Copy link
Member

mcabbott commented Jul 31, 2022

I see. That PR sounds like a safer way to get the same effect, as this line will (I think) mean it never returns nothing in implicit explicit mode:

https://github.com/FluxML/Zygote.jl/pull/1248/files#diff-cd0210083ce3136f79bee6ebca2bcca77f41a14f11b5a7a65ea1cc54803164c3R51

Perhaps tests from my attempt might be worth borrowing: mcabbott@927ee27

@DrChainsaw
Copy link

DrChainsaw commented Jul 31, 2022

@ToucheSir Will do, hopefully later tonight when I get some time off.

Bit of an unrelated shower thought: Zygote ranks pretty high on my "I give up" codebases and I'm sure many others feel the same (I think you guys are heroic for making the effort to maintain it). Anyways, the concrete proposal is to add a strong wish for more comments in the contributors guideline. Something along the lines of:

Zygote is a very complex project maintained by the Julia community as its original creator is no longer maintaining it. Please help us making it more maintainable by adding comments whenever you have figured out what the purpose of some (undocumented) part is. Even if you are not certain, a "I think the purpose of this code is..." type comment can often help immensely.

For example, I suspect that you two have some insights into the purpose of accum and accum_params which are far from obvious to an outsider. Formulating some of those insights in comments at the call site will probably be quite helpful for future maintainers as well as users.

@ToucheSir
Copy link
Member Author

ToucheSir commented Jul 31, 2022

Bit of an unrelated shower thought: Zygote ranks pretty high on my "I give up" codebases and I'm sure many others feel the same

I think the big problem is that the current group of maintainers also falls into this bucket. Speaking for myself, I don't want to even think about AD, let alone touch it. Unfortunately, I have to because it ends up causing issues further up the stack.

For Zygote specifically, it's basically on life support unless we can get someone who understands the compiler well enough and is also motivated to revive it. Presently, bugs like #1236 just pile up and we're kind of helpless to do much about them.

That's why, though I try to add some comments while creating new PRs, I am hesitant to go on a docs spree across the codebase. When there appears to be zero appetite from the broader community for helping with Zygote maintenance, every non-bugfix feels like another sunk cost. Perhaps others have a different perspective, but from my POV we've somehow arrived at a XKCD #2347 type scenario where the de-facto reverse mode AD is a sinking ship and there are no ready alternatives to pick up the slack.

Stepping off my soapbox, internal documentation is now tracked in #1274. If anyone wants to take a shot, I am more than happy to prioritize reviewing PRs for this. Otherwise, at least there's a list now.

@DrChainsaw
Copy link

I think the big problem is that the current group of maintainers also falls into this bucket.

This is my view as well. I was perhaps just a bit too careful when wording it to the point that the message became unclear.

I am hesitant to go on a docs spree across the codebase.

Fully understandable. I tried to propose a somewhat milder and more distributed approach: Whenever someone spends more than five minutes to figure out some undocumented part of Zygote, instead of just fixing the issue and moving on, they add some comments describing their understanding of it, then over time the codebase might be a bit more approachable.

About the actual issue:

Strangely enough, the WE does not seem to hang anymore, despite using the exact same project and manifest. There is however this which maybe is a hint to what the hang is about:

# "Hanging" example with all gradients being nothing
julia> gg = gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel(layerfun=MutableWrapper));

julia> @time @show gg[1];
# Takes about 373 seconds, then prints about 20 lines of output
373.533915 seconds (180.54 k allocations: 11.777 MiB, 99.99% compilation time)

# Non hanging example where we get gradients in a fresh session
julia> ggg = gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel() |> stripoutputs);

julia> @time @show ggg[1];
# Starts printing right away, but spends about 37 seconds printing numbers
 37.013652 seconds (14.55 M allocations: 578.515 MiB, 0.34% gc time, 0.35% compilation time)

Note that I started from one of the the innermost gradients (gg[1][2][1][1][1][2][1][1][1][2][1][1][1][2][1][1][1][2][1][1][1][2][1][1][1][2][1][1][1][2][1][1][1][2][1][1][1][2][1][1][1][2][1][1][1][2][1][1][1][2][1]) and removed one index at the time since a previous attempt to print gg seemed to stall indefinitely (but maybe I was just impatient). This was not done for the second gradient (ggg).

Perhaps it is just the horrible nested NamedTuple which for some reason sees some exponential compile time when it is full of nothings, but not when it is full of Arrays due to some @nospecialise somewhere?

I might have started working on the example when printing gradients and then added output supression without checking carefully that this didn't change the outcome.

Anyways, with #1248 I get the exact same behaviour between the two gradients, so from what I can tell it solves the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Development

Successfully merging a pull request may close this issue.

3 participants