Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved type stability with explicit params #1248

Merged
merged 3 commits into from
Aug 1, 2022
Merged

Conversation

ToucheSir
Copy link
Member

@ToucheSir ToucheSir commented Jun 23, 2022

We can disable accumulating (implicit) parameters to the gradient cache in explicit mode. This can dramatically improve type stability because accum_param will return a Union{Nothing, [grad type]} otherwise.

One impact of this PR is that taking gradients of functions with both implicit and explicit parameters (i.e. calling pullback twice) may involve some additional compilation. However, given that we're trying to move users off of using implicit params anyhow, I see it as a small price to pay for being friendlier to the compiler.

Benchmarking TTFG on the MWE in #1126, modified to use explicit params:

julia> @time loss_grad(model, lr_images)
 85.027100 seconds (61.36 M allocations: 3.216 GiB, 1.17% gc time, 99.98% compilation time) # 0.6.40
 59.024238 seconds (60.72 M allocations: 3.174 GiB, 1.69% gc time, 99.98% compilation time) # this PR

Closes #1243.

@ToucheSir ToucheSir force-pushed the bc/no-cache-context branch 3 times, most recently from 6ccf008 to 7540bd6 Compare June 24, 2022 02:53
@ToucheSir
Copy link
Member Author

ToucheSir commented Jun 24, 2022

Ok, now we're back to the known failures on Nightly and downstream NeuralPDE. Molly one appears to be intermittent (minor numerical error) and showed up recently in PRs from a few days ago, but that should be investigated separately.

One unexpected find while working on this PR is that Zygote and downstream packages were calling pullback with the wrong argument order. This worked by sheer coincidence, but I've added a warning so that it can be rectified before we disallow it in the next breaking release.

@mcabbott
Copy link
Member

This doesn't sound crazy, but FWIW I do not see the same speedup:

julia>@time gradient(loss, lr_images);
 51.949997 seconds (59.92 M allocations: 3.781 GiB, 4.14% gc time, 99.92% compilation time)  # tagged
 54.520161 seconds (60.06 M allocations: 3.801 GiB, 2.68% gc time, 99.94% compilation time)  # this PR

(Julia master, M1 mac.)

@ToucheSir
Copy link
Member Author

That may be because loss is closing over a global model. My version passes both as args.

@mcabbott
Copy link
Member

Avoiding globals makes a surprisingly large difference here:

julia> @time gradient((m, x) -> sum(m(x)), model, lr_images);
 35.009302 seconds (58.93 M allocations: 3.719 GiB, 8.44% gc time, 99.92% compilation time)  # tagged
 34.581913 seconds (59.08 M allocations: 3.738 GiB, 2.97% gc time, 99.93% compilation time)  # this PR

@ToucheSir
Copy link
Member Author

Perhaps a newer Julia version helps close the gap, but I'm consistently seeing this ~15s difference on 1.7.3.

This is the full script I'm using:

using Flux

channels = 4

function resblock(channels)
    return SkipConnection(Chain(
        Conv((3, 3), channels => channels, pad=1),
        Conv((3, 3), channels => channels, pad=1),
    ), +)
end

model = Chain(
    SkipConnection(
        Chain(
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),

            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),

            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
        ),
    +),
    AdaptiveMeanPool((1, 1))
)

@show typeof(model)

loss(m, x) = sum(m(x))

lr_images = randn(Float32, 2, 2, channels, 1)
@time loss(model, lr_images)
@time loss(model, lr_images)

loss_grad(m, x) = gradient(m -> loss(m, x), m)
# This gives the same numbers:
# loss_grad(m, x) = gradient((m, x) -> loss(m, x), m, x)

@time loss_grad(model, lr_images)
@time loss_grad(model, lr_images)

We can disable accumulating (implicit) parameters to the gradient cache
in explicit mode. This can dramatically improve type stability because
`accum_param` will return a `Union{Nothing, [grad type]}` otherwise.
Co-authored-by: Michael Abbott <[email protected]>
@ToucheSir ToucheSir requested a review from mcabbott July 31, 2022 17:22
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it.

Should be marked as closing #1243?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Zygote hangs when taking explicit gradients of NaiveGAFlux model
4 participants