Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcasts which are type unstable with Dual numbers #1441

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 30 additions & 40 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,72 +281,62 @@ end
@inline function broadcast_forward(f, args::Vararg{Any,N}) where N
out = dual_function(f).(args...)
T = eltype(out)
T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)
if any(eltype(a) <: Complex for a in args)
_broadcast_forward_complex(T, out, args...)
if !isconcretetype(T) || T <: Union{Dual, Complex{<:Dual}}
if any(eltype(a) <: Complex for a in args)
return _broadcast_forward_complex(out, args...)
else
return _broadcast_forward(out, args...)
end
else
_broadcast_forward(T, out, args...)
return (out, _ -> nothing)
end
end

# Real input and real output pullback
@inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
# Real input
@inline _extract_value(x) = value(x)
@inline _extract_value(x::Complex) = Complex(value(real(x)), value(imag(x)))
@inline _broadcast_scalar_pullback(ȳ, out, i) = ȳ * partials(out, i)
@inline function _broadcast_scalar_pullback(ȳ, out::Complex, i)
return real(ȳ) * partials(real(out), i) + imag(ȳ) * partials(imag(out), i)
end
@inline function _broadcast_forward(out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
y = broadcast(x -> _extract_value(x), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out))
unbroadcast(args[i],
broadcast((y1, o1) -> _broadcast_scalar_pullback(y1, o1, i), ȳ, out)
)
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# This handles the complex output and real input pullback
@inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# This handles complex input and real output. We use the gradient definition from ChainRules here
# since it agrees with what Zygote did for real(x).
@inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(partials(o1, i), partials(o1, i+N)), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
@inline function _broadcast_scalar_pullback_complex(N, Δz, df, i)
return Δz * Complex(partials(df, i), partials(df, i + N))
end

# # # This is for complex input and complex output
# If we assume that
# f(x + iy) = u(x,y) + iv(x,y)
# then we do the following for the adjoint
# Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y )
# this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html
function _adjoint_complex(N, Δz, df, i)
Δu, Δv = reim(Δz)
du, dv = reim(df)
return Complex(Δu*partials(du, i) + Δv*partials(dv, i), Δu*partials(du, i+N) + Δv*partials(dv, i+N))
@inline function _broadcast_scalar_pullback_complex(N, Δz, df::Complex, i)
Δu, Δv = reim(Δz)
du, dv = reim(df)
return Complex(Δu * partials(du, i) + Δv * partials(dv, i), Δu * partials(du, i + N) + Δv * partials(dv, i + N))
end

@inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
@inline function _broadcast_forward_complex(out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
y = broadcast(x -> _extract_value(x), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out))
unbroadcast(args[i],
broadcast((y1, o1) -> _broadcast_scalar_pullback_complex(N, y1, o1, i), ȳ, out)
)
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
Expand Down
23 changes: 23 additions & 0 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,29 @@ end
@test gradient(xs -> sum(map((x -> x<2 ? false : x^2), xs)), [1,2,3])[1][2:3] == [4, 6]
@test gradient(xs -> mapreduce((x -> x<2 ? false : x^2), +, xs), [1,2,3])[1][2:3] == [4, 6]

# https://github.com/FluxML/Zygote.jl/issues/1439
# type stable forward pass with given input, but type unstable with dualized input
# Real input, real output
f = x -> x > 1.0 ? 1.0 : x^2
@test gradient(xs -> sum(f.(xs)), [0.5, 1.0, 1.5])[1] == [1.0, 2.0, 0.0]
# Real input, complex output
f = x -> x > 1.0 ? 1.0im : (x + 1.0im)^2
@test gradient(xs -> sum(abs2, f.(xs)), [0.5, 1.0, 1.5])[1] == [2.5, 8.0, 0.0]
# Complex input, complex output
f = x -> imag(x) > 1.0 ? 1.0im : x^2
@test gradient(xs -> sum(abs2, f.(xs)), [0.5im, 1.0im, 1.5im])[1] == [
0.0 + 0.5im, 0.0 + 4.0im, 0.0 + 0.0im
]
# Complex input, real output
f = x -> imag(x) > 1.0 ? 1.0 : abs2(x)
@test gradient(xs -> sum(abs2, f.(xs)), [0.5im, 1.0im, 1.5im])[1] == [
0.0 + 0.5im, 0.0 + 4.0im, 0.0 + 0.0im
]
# Slightly more complex case that used to error
f = x -> x > 1.0 ? 1.0 : x^2
g = x -> sum(repeat(x, inner=2) .* f.(repeat(x, inner=2)))
@test gradient(g, [0.5, 1.0, 1.5])[1] == [1.5, 6.0, 2.0]

# with Ref, Val, Symbol
@test gradient(x -> sum(x .+ Ref(x[1])), [1,2,3]) == ([4,1,1],)
@test gradient(x -> sum(x .+ (x[1],)), [1,2,3]) == ([4,1,1],)
Expand Down