Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accumulate NamedTuple + Tangent #88

Merged
merged 3 commits into from
Sep 4, 2022
Merged

Conversation

mcabbott
Copy link
Member

This is a hack to fix problems like this:

g_kw(;x=1.0) = sin(x)
f_kw(x) = g_kw(;x)
@test bwd(f_kw)(1.0) == bwd(sin)(1.0)  broken=true
#=
MethodError: no method matching +(::Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}, ::Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}})
...
  [2] elementwise_add(a::NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}, b::NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/tangent.jl:287
  [3] +(a::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}}, b::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_arithmetic.jl:130
=#

Diffractor sometimes produces ChainRulesCore's Tangent types, and sometimes produces plain NamedTuples instead. If these happen for structural gradients of the same object, then you will get an error.

Ideally it should probably make up its mind. Maybe Tangents are a pain to deal with at higher order?

@codecov-commenter
Copy link

codecov-commenter commented Sep 1, 2022

Codecov Report

Merging #88 (c9f2a90) into main (9a8a788) will decrease coverage by 2.19%.
The diff coverage is 72.72%.

@@            Coverage Diff             @@
##             main      #88      +/-   ##
==========================================
- Coverage   53.89%   51.69%   -2.20%     
==========================================
  Files          21       21              
  Lines        2171     2124      -47     
==========================================
- Hits         1170     1098      -72     
- Misses       1001     1026      +25     
Impacted Files Coverage Δ
src/runtime.jl 78.94% <70.00%> (+20.61%) ⬆️
src/extra_rules.jl 35.45% <100.00%> (-0.33%) ⬇️
src/stage1/recurse.jl 91.48% <0.00%> (-5.29%) ⬇️
src/stage1/forward.jl 69.38% <0.00%> (-4.77%) ⬇️
src/stage1/hacks.jl 2.12% <0.00%> (-4.26%) ⬇️
src/tangent.jl 32.97% <0.00%> (-1.07%) ⬇️
src/stage1/generated.jl 72.72% <0.00%> (-0.58%) ⬇️
src/jet.jl 39.16% <0.00%> (-0.51%) ⬇️
src/stage1/recurse_fwd.jl 94.11% <0.00%> (-0.17%) ⬇️
... and 3 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@@ -5,11 +5,25 @@ struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}}
@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get some errors from this adding tuples of mis-matched lengths, which I haven't tracked down. Not touched in this PR.

What I don't know here is whether accum is only ever called on a single tangent, or also on a tuple of a function's arguments. IIRC there is no distinction in Zygote, but in CR the former can be Array or Tangent or NoTangent, but the latter can only ever be a Tuple (of the right length).

Comment on lines +270 to +271
Base.real(z::ZeroTangent) = z # TODO should be in CRC
Base.real(z::NoTangent) = z
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oscardssmith
Copy link
Member

I don't think this is the right long term solution, but it seems right for now.

@mcabbott
Copy link
Member Author

mcabbott commented Sep 4, 2022

Yes for sure it's a hack. But it doesn't seem like it will hide serious bugs, and having tests that work for now is good.

Last commit disables tests on 1.8, since judging by "┌ Warning: ir verification broken. Either use 1.9 or 1.7" there isn't an intention to make everything work there. (The actual failures are inference tests.)

Edit: and now nightly has some unexpected passes, inference tests... Maybe we need to split CI into two files, to allow failures on 1.8 + nightly while returning a useful overall pass/fail?

@ToucheSir
Copy link

On the topic of longer-term solutions, why is Box appearing here at all? kwsorters are usually type stable and Diffractor doesn't appear to change their branching in ways which would disrupt that. Diffractor.KwFunc also doesn't make use of any inner closures.

@mcabbott
Copy link
Member Author

mcabbott commented Sep 4, 2022

I don't know. Where should we put things like this, where a failing test has uncovered bad behaviour, but possibly no longer bad enough to make the test fail?

(Fun Zygote example where re-using a variable name caused a 5x slowdown for ages, later became a bug: Z#1290)

@mcabbott mcabbott merged commit fb1d4ec into JuliaDiff:main Sep 4, 2022
@mcabbott mcabbott deleted the mixed_accum branch September 4, 2022 20:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants