Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dictionary indexing failure inside closure and structs #717

Closed
willtebbutt opened this issue Jun 27, 2020 · 15 comments
Closed

Dictionary indexing failure inside closure and structs #717

willtebbutt opened this issue Jun 27, 2020 · 15 comments
Labels
bug Something isn't working dictionary

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Jun 27, 2020

julia> Zygote.gradient(x -> (() -> x[:y])(), Dict(:y => 0.4))
(nothing,)

the gradient w.r.t. the y element of x should be 1.

This bug doesn't occur with the equivalent closure-free function

julia> Zygote.gradient(x -> x[:y], Dict(:y => 0.4))
(Dict{Any,Any}(:y => 1.0),)

and appears to be Dict-specific since

julia> Zygote.gradient(x -> (() -> x.y)(), (y = 0.4,))
((y = 1.0,),)

This bug was introduced in 0.4.21 -- the correct result is obtained on 0.4.20. The bug persists on 0.4.22 and 0.5.

This is breaking for Stheno.jl.

@MikeInnes @CarloLucibello any thoughts on what might be causing this?

@willtebbutt willtebbutt added the bug Something isn't working label Jun 27, 2020
@willtebbutt willtebbutt changed the title Dictionary indexing failure Dictionary indexing failure inside closure Jun 27, 2020
@willtebbutt
Copy link
Member Author

@DhairyaLGandhi do you have any thoughts on what might be causing this?

@femtomc
Copy link

femtomc commented Aug 6, 2020

Here's another MWE. This one is a little more complex, because it matches a use case that I have.

Julia version = 1.5, Zygote version = 0.5.4

module GradsMVP

using Zygote

mutable struct Foo
    store::Dict{Symbol, Float64}
    score::Float64
end

function (f::Foo)(acc::Symbol, fn::Function, args...)
    val = getindex(f.store, acc)
    ret = fn(val)
    f.score += ret
    fn(args...)
end

function get_grads(store, acc, ret_grad, call, args...)
    fn = (args, store) -> begin
        f = Foo(store, 0.0)
        ret = f(acc, call, args...)
        (f.score, ret)
    end
    _, back = Zygote.pullback(fn, args, store)
    arg_grads, store_grads = back((1.0, ret_grad))
    return arg_grads, store_grads
end

function foo(a::Float64)
    return a
end

ags, gs = get_grads(Dict(:x => 1.0), :x, 1.0, foo, 1.0)
println(ags)
println(gs) # = nothing

end # module

whereas this code works fine

module GradsMVP

using Zygote

mutable struct Foo
    store::Float64
    score::Float64
end

function (f::Foo)(acc::Symbol, fn::Function, args...)
    val = f.store
    ret = fn(val)
    f.score += ret
    fn(args...)
end

function get_grads(store, ret_grad, call, args...)
    fn = (args, store) -> begin
        f = Foo(store, 0.0)
        ret = f(call, args...)
        (f.score, ret)
    end
    _, back = Zygote.pullback(fn, args, store)
    arg_grads, store_grads = back((1.0, ret_grad))
    return arg_grads, store_grads
end

function foo(a::Float64)
    return a
end

ags, gs = get_grads(1.0, 1.0, foo, 1.0)
println(ags)
println(gs) # = 1.0

end # module

@femtomc
Copy link

femtomc commented Aug 6, 2020

To fix this MWE, it suffices to define the adjoint for getindex:

Zygote.@adjoint getindex(d::Dict, acc) = getindex(d, acc), retgrad -> (retgrad, nothing)

I'm unsure if this will break something fundamental.

Edit: sorry, this is supposed to be retgrad

@femtomc
Copy link

femtomc commented Aug 6, 2020

@DhairyaLGandhi it's not Zygote's version of getindex - print outs of grad show the correct gradients. This makes sense - that's obviously something which has been tested numerous times.

Something else is happening in the pipeline.

@femtomc
Copy link

femtomc commented Aug 6, 2020

PS This is fixed on 0.4.20 as @willtebbutt says. I just checked with my own codebase.

@DhairyaLGandhi
Copy link
Member

Are you suggesting that the gradient is correctly calculated but isn't actually returned to the user properly?

@willtebbutt
Copy link
Member Author

What's happening is entirely unclear to me. Since it's Dict-specific, and I could only produce the bug in conjunction with a closure 🤷

@femtomc
Copy link

femtomc commented Aug 11, 2020

@DhairyaLGandhi when I print out accum in the adjoint for getindex - I see the correct gradients. But in the MWE above, the returned grad is nothing.

@femtomc
Copy link

femtomc commented Aug 15, 2020

@DhairyaLGandhi @willtebbutt any update on this?

This is highly frustrating to me. I can't update to the latest version of Zygote, so I can't use the latest version of IRTools, so I can't use the latest version of Flux, which means I can't use neural networks in my PPs.

I have no idea where this bug is occurring, but I'm motivated to find it and fix it - especially since it was fixed before in 0.4.20, so it can't be hard to find again can it? Any ideas where to start looking?

@femtomc
Copy link

femtomc commented Aug 15, 2020

Setup a PR. I don't know what I'm doing, so I don't know if this fix breaks many other things - please inform.

femtomc added a commit to femtomc/Jaynes.jl that referenced this issue Aug 25, 2020
femtomc added a commit to femtomc/Jaynes.jl that referenced this issue Aug 25, 2020
femtomc added a commit to femtomc/Jaynes.jl that referenced this issue Aug 25, 2020
@femtomc
Copy link

femtomc commented Sep 21, 2020

@willtebbutt @DhairyaLGandhi did this happened to get squashed in recent tags/PRs?

@willtebbutt
Copy link
Member Author

Hmmm I'm not sure. @DhairyaLGandhi is more likely to know.

@CarloLucibello CarloLucibello changed the title Dictionary indexing failure inside closure Dictionary indexing failure inside closure and structs Jul 21, 2021
@CarloLucibello
Copy link
Member

This problem is still present

julia> d = Dict("x"=>rand(2))
Dict{String, Vector{Float64}} with 1 entry:
  "x" => [0.626974, 0.519716]

julia> gradient(x -> sum(x["x"]), d)  #OK
(Dict{Any, Any}("x" => 2-element Fill{Float64}: entries equal to 1.0),)

julia> nt = (; data=rand(2))
(data = [0.7536687262661153, 0.34819635465370324],)

julia> gradient(x -> sum(x.data), nt)  #OK
((data = 2-element Fill{Float64}: entries equal to 1.0,),)

julia> ntd = (; data = Dict("x" => rand(2)))
(data = Dict("x" => [0.6917549230112572, 0.16463696222948876]),)

julia> gradient(x -> sum(x.data["x"]), ntd) #WRONG
(nothing,)

@ToucheSir
Copy link
Member

Came across this issue and I see all MWEs passing with #1248. If anyone still has a larger example to test, could you confirm it passes as well? Otherwise I'll consider this issue fixed if nothing pops up after a few days.

@CarloLucibello
Copy link
Member

closing as all examples are fixed. Will add tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working dictionary
Projects
None yet
Development

No branches or pull requests

5 participants