From 722277542dd9ea5027e17f58863f5bcfe16aa975 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Thu, 13 Jul 2023 13:23:33 +0100 Subject: [PATCH 1/7] Initial edits for real inputs --- src/lib/broadcast.jl | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 504ef614d..b55de8c5a 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -281,40 +281,38 @@ end @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) T = eltype(out) - T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing) - if any(eltype(a) <: Complex for a in args) - _broadcast_forward_complex(T, out, args...) + if !isconcretetype(T) || T <: Union{Dual, Complex{<:Dual}} + if any(eltype(a) <: Complex for a in args) + return _broadcast_forward_complex(out, args...) + else + return _broadcast_forward(out, args...) + end else - _broadcast_forward(T, out, args...) + return (out, _ -> nothing) end end -# Real input and real output pullback -@inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} +# Real input +@inline _extract_value(x) = value(x) +@inline _extract_value(x::Complex) = Complex(value(real(x)), value(imag(x))) +@inline _broadcast_scalar_pullback(ȳ, out, i) = ȳ * partials(out, i) +@inline function _broadcast_scalar_pullback(ȳ::Complex, out, i) + return Complex(real(ȳ) * partials(real(out), i), imag(ȳ) * partials(imag(out), i)) +end +@inline function _broadcast_forward(out, args::Vararg{Any, N}) where {N} valN = Val(N) - y = broadcast(x -> value(x), out) + y = broadcast(x -> _extract_value(x), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out)) + unbroadcast(args[i], + broadcast((y1, o1) -> _broadcast_scalar_pullback(y1, o1, i), ȳ, out) + ) end (nothing, nothing, dargs...) # nothings for broadcasted & f end return y, bc_fwd_back end -# This handles the complex output and real input pullback -@inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} - valN = Val(N) - y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) - function bc_fwd_back(ȳ) - dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out)) - end - (nothing, nothing, dargs...) # nothings for broadcasted & f - end - return y, bc_fwd_back - end - # This handles complex input and real output. We use the gradient definition from ChainRules here # since it agrees with what Zygote did for real(x). @inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} From e3eaaf53d33a43bf8bcdb3ebcd048985c7881842 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Thu, 13 Jul 2023 13:30:26 +0100 Subject: [PATCH 2/7] Fix real input complex output --- src/lib/broadcast.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index b55de8c5a..992e74915 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -283,7 +283,7 @@ end T = eltype(out) if !isconcretetype(T) || T <: Union{Dual, Complex{<:Dual}} if any(eltype(a) <: Complex for a in args) - return _broadcast_forward_complex(out, args...) + return _broadcast_forward_complex(T, out, args...) else return _broadcast_forward(out, args...) end @@ -297,7 +297,7 @@ end @inline _extract_value(x::Complex) = Complex(value(real(x)), value(imag(x))) @inline _broadcast_scalar_pullback(ȳ, out, i) = ȳ * partials(out, i) @inline function _broadcast_scalar_pullback(ȳ::Complex, out, i) - return Complex(real(ȳ) * partials(real(out), i), imag(ȳ) * partials(imag(out), i)) + return real(ȳ) * partials(real(out), i) + imag(ȳ) * partials(imag(out), i) end @inline function _broadcast_forward(out, args::Vararg{Any, N}) where {N} valN = Val(N) From ccd63b2a41967bb69b9fccf4d684b71b3885e670 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Thu, 13 Jul 2023 14:21:33 +0100 Subject: [PATCH 3/7] Update complex input scalar broadcast --- src/lib/broadcast.jl | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 992e74915..5e90e49b5 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -281,9 +281,10 @@ end @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) T = eltype(out) + println(T) if !isconcretetype(T) || T <: Union{Dual, Complex{<:Dual}} if any(eltype(a) <: Complex for a in args) - return _broadcast_forward_complex(T, out, args...) + return _broadcast_forward_complex(out, args...) else return _broadcast_forward(out, args...) end @@ -315,36 +316,28 @@ end # This handles complex input and real output. We use the gradient definition from ChainRules here # since it agrees with what Zygote did for real(x). -@inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} - valN = Val(N) - y = broadcast(x -> value(x), out) - function bc_fwd_back(ȳ) - dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(partials(o1, i), partials(o1, i+N)), ȳ, out)) - end - (nothing, nothing, dargs...) # nothings for broadcasted & f - end - return y, bc_fwd_back +@inline function _broadcast_scalar_pullback_complex(N, Δz, df::Dual, i) + return Δz * Complex(partials(df, i), partials(df, i + N)) end - # # # This is for complex input and complex output # If we assume that # f(x + iy) = u(x,y) + iv(x,y) # then we do the following for the adjoint # Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y ) # this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html -function _adjoint_complex(N, Δz, df, i) - Δu, Δv = reim(Δz) - du, dv = reim(df) - return Complex(Δu*partials(du, i) + Δv*partials(dv, i), Δu*partials(du, i+N) + Δv*partials(dv, i+N)) +function _broadcast_scalar_pullback_complex(N, Δz, df::Complex, i) + Δu, Δv = reim(Δz) + du, dv = reim(df) + return Complex(Δu * partials(du, i) + Δv * partials(dv, i), Δu * partials(du, i + N) + Δv * partials(dv, i + N)) end - -@inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} +@inline function _broadcast_forward_complex(out, args::Vararg{Any, N}) where {N} valN = Val(N) - y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) + y = broadcast(x -> _extract_value(x), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out)) + unbroadcast(args[i], + broadcast((y1, o1) -> _broadcast_scalar_pullback_complex(N, y1, o1, i), ȳ, out) + ) end (nothing, nothing, dargs...) # nothings for broadcasted & f end From 695f78acf436702794d68a38b9c94ca56e0bc0ec Mon Sep 17 00:00:00 2001 From: DomCRose Date: Thu, 13 Jul 2023 15:13:22 +0100 Subject: [PATCH 4/7] Merge all scalar real/complex forward passes --- src/lib/broadcast.jl | 59 ++++++++++++++++++-------------------------- 1 file changed, 24 insertions(+), 35 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 5e90e49b5..ffb41d2db 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -281,42 +281,26 @@ end @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) T = eltype(out) - println(T) if !isconcretetype(T) || T <: Union{Dual, Complex{<:Dual}} - if any(eltype(a) <: Complex for a in args) - return _broadcast_forward_complex(out, args...) - else - return _broadcast_forward(out, args...) - end + return _broadcast_forward(out, args...) else return (out, _ -> nothing) end end -# Real input @inline _extract_value(x) = value(x) @inline _extract_value(x::Complex) = Complex(value(real(x)), value(imag(x))) -@inline _broadcast_scalar_pullback(ȳ, out, i) = ȳ * partials(out, i) -@inline function _broadcast_scalar_pullback(ȳ::Complex, out, i) - return real(ȳ) * partials(real(out), i) + imag(ȳ) * partials(imag(out), i) +# Real input, real output +@inline function _broadcast_scalar_pullback(::Type{<:Real}, N, Δz::Real, df, i) + return Δz * partials(df, i) end -@inline function _broadcast_forward(out, args::Vararg{Any, N}) where {N} - valN = Val(N) - y = broadcast(x -> _extract_value(x), out) - function bc_fwd_back(ȳ) - dargs = ntuple(valN) do i - unbroadcast(args[i], - broadcast((y1, o1) -> _broadcast_scalar_pullback(y1, o1, i), ȳ, out) - ) - end - (nothing, nothing, dargs...) # nothings for broadcasted & f - end - return y, bc_fwd_back +# Real input, complex output +@inline function _broadcast_scalar_pullback(::Type{<:Real}, N, Δz::Complex, df, i) + return real(Δz) * partials(real(df), i) + imag(Δz) * partials(imag(df), i) end - # This handles complex input and real output. We use the gradient definition from ChainRules here # since it agrees with what Zygote did for real(x). -@inline function _broadcast_scalar_pullback_complex(N, Δz, df::Dual, i) +@inline function _broadcast_scalar_pullback(::Type{<:Complex}, N, Δz::Real, df, i) return Δz * Complex(partials(df, i), partials(df, i + N)) end # # # This is for complex input and complex output @@ -325,25 +309,30 @@ end # then we do the following for the adjoint # Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y ) # this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html -function _broadcast_scalar_pullback_complex(N, Δz, df::Complex, i) +@inline function _broadcast_scalar_pullback(::Type{<:Complex}, N, Δz::Complex, df, i) Δu, Δv = reim(Δz) du, dv = reim(df) return Complex(Δu * partials(du, i) + Δv * partials(dv, i), Δu * partials(du, i + N) + Δv * partials(dv, i + N)) end -@inline function _broadcast_forward_complex(out, args::Vararg{Any, N}) where {N} - valN = Val(N) - y = broadcast(x -> _extract_value(x), out) - function bc_fwd_back(ȳ) - dargs = ntuple(valN) do i - unbroadcast(args[i], - broadcast((y1, o1) -> _broadcast_scalar_pullback_complex(N, y1, o1, i), ȳ, out) +@inline function _broadcast_forward(out, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> _extract_value(x), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], + broadcast( + (y1, o1) -> _broadcast_scalar_pullback(eltype(args[i]), N, y1, o1, i), + ȳ, + out ) - end - (nothing, nothing, dargs...) # nothings for broadcasted & f + ) end - return y, bc_fwd_back + (nothing, nothing, dargs...) # nothings for broadcasted & f + end + return y, bc_fwd_back end + using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe, From f4016741bc9619ffb7d2a46a40489ac593ea64f1 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Thu, 13 Jul 2023 17:01:37 +0100 Subject: [PATCH 5/7] Break up broadcast code to help GPU compiler --- src/lib/broadcast.jl | 58 ++++++++++++++++++++++++++------------------ test/features.jl | 3 +++ 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index ffb41d2db..52da0dad7 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -282,25 +282,40 @@ end out = dual_function(f).(args...) T = eltype(out) if !isconcretetype(T) || T <: Union{Dual, Complex{<:Dual}} - return _broadcast_forward(out, args...) + if any(eltype(a) <: Complex for a in args) + return _broadcast_forward_complex(out, args...) + else + return _broadcast_forward(out, args...) + end else return (out, _ -> nothing) end end +# Real input @inline _extract_value(x) = value(x) @inline _extract_value(x::Complex) = Complex(value(real(x)), value(imag(x))) -# Real input, real output -@inline function _broadcast_scalar_pullback(::Type{<:Real}, N, Δz::Real, df, i) - return Δz * partials(df, i) +@inline _broadcast_scalar_pullback(ȳ, out, i) = ȳ * partials(out, i) +@inline function _broadcast_scalar_pullback(ȳ, out::Complex, i) + return real(ȳ) * partials(real(out), i) + imag(ȳ) * partials(imag(out), i) end -# Real input, complex output -@inline function _broadcast_scalar_pullback(::Type{<:Real}, N, Δz::Complex, df, i) - return real(Δz) * partials(real(df), i) + imag(Δz) * partials(imag(df), i) +@inline function _broadcast_forward(out, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> _extract_value(x), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], + broadcast((y1, o1) -> _broadcast_scalar_pullback(y1, o1, i), ȳ, out) + ) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f + end + return y, bc_fwd_back end + # This handles complex input and real output. We use the gradient definition from ChainRules here # since it agrees with what Zygote did for real(x). -@inline function _broadcast_scalar_pullback(::Type{<:Complex}, N, Δz::Real, df, i) +@inline function _broadcast_scalar_pullback_complex(N, Δz, df, i) return Δz * Complex(partials(df, i), partials(df, i + N)) end # # # This is for complex input and complex output @@ -309,30 +324,25 @@ end # then we do the following for the adjoint # Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y ) # this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html -@inline function _broadcast_scalar_pullback(::Type{<:Complex}, N, Δz::Complex, df, i) +@inline function _broadcast_scalar_pullback_complex(N, Δz, df::Complex, i) Δu, Δv = reim(Δz) du, dv = reim(df) return Complex(Δu * partials(du, i) + Δv * partials(dv, i), Δu * partials(du, i + N) + Δv * partials(dv, i + N)) end -@inline function _broadcast_forward(out, args::Vararg{Any, N}) where {N} - valN = Val(N) - y = broadcast(x -> _extract_value(x), out) - function bc_fwd_back(ȳ) - dargs = ntuple(valN) do i - unbroadcast(args[i], - broadcast( - (y1, o1) -> _broadcast_scalar_pullback(eltype(args[i]), N, y1, o1, i), - ȳ, - out +@inline function _broadcast_forward_complex(out, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> _extract_value(x), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], + broadcast((y1, o1) -> _broadcast_scalar_pullback_complex(N, y1, o1, i), ȳ, out) ) - ) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f end - (nothing, nothing, dargs...) # nothings for broadcasted & f - end - return y, bc_fwd_back + return y, bc_fwd_back end - using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe, diff --git a/test/features.jl b/test/features.jl index 908ae5815..7fe9ea247 100644 --- a/test/features.jl +++ b/test/features.jl @@ -798,6 +798,9 @@ end @test gradient(xs -> sum(map((x -> x<2 ? false : x^2), xs)), [1,2,3])[1][2:3] == [4, 6] @test gradient(xs -> mapreduce((x -> x<2 ? false : x^2), +, xs), [1,2,3])[1][2:3] == [4, 6] + # type stable forward pass with input, but type unstable with dualized input + + # with Ref, Val, Symbol @test gradient(x -> sum(x .+ Ref(x[1])), [1,2,3]) == ([4,1,1],) @test gradient(x -> sum(x .+ (x[1],)), [1,2,3]) == ([4,1,1],) From 0c23770ddcdf9addf6710ad32945dc0f41de6be4 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Thu, 13 Jul 2023 17:29:33 +0100 Subject: [PATCH 6/7] Add some tests --- test/features.jl | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/test/features.jl b/test/features.jl index 7fe9ea247..10d08ffa9 100644 --- a/test/features.jl +++ b/test/features.jl @@ -798,8 +798,28 @@ end @test gradient(xs -> sum(map((x -> x<2 ? false : x^2), xs)), [1,2,3])[1][2:3] == [4, 6] @test gradient(xs -> mapreduce((x -> x<2 ? false : x^2), +, xs), [1,2,3])[1][2:3] == [4, 6] - # type stable forward pass with input, but type unstable with dualized input - + # https://github.com/FluxML/Zygote.jl/issues/1439 + # type stable forward pass with given input, but type unstable with dualized input + # Real input, real output + f = x -> x > 1.0 ? 1.0 : x^2 + @test gradient(xs -> sum(f.(xs)), [0.5, 1.0, 1.5])[1] == [1.0, 2.0, 0.0] + # Real input, complex output + f = x -> x > 1.0 ? 1.0im : (x + 1.0im)^2 + @test gradient(xs -> sum(f.(xs)), [0.5, 1.0, 1.5])[1] == [2.5, 8.0, 0.0] + # Complex input, complex output + f = x -> imag(x) > 1.0 ? 1.0im : x^2 + @test gradient(xs -> sum(abs2, f.(xs)), [0.5im, 1.0im, 1.5im])[1] == [ + 0.0 + 0.5im, 0.0 + 4.0im, 0.0 + 0.0im + ] + # Complex input, real output + f = x -> imag(x) > 1.0 ? 1.0 : abs2(x) + @test gradient(xs -> sum(abs2, f.(xs)), [0.5im, 1.0im, 1.5im])[1] == [ + 0.0 + 0.5im, 0.0 + 4.0im, 0.0 + 0.0im + ] + # Slightly more complex case that used to error + f = x -> x > 1.0 ? 1.0 : x^2 + g = x -> sum(repeat(x, inner=2) .* f.(repeat(x, inner=2))) + @test gradient(g, [0.5, 1.0, 1.5])[1] == [1.5, 6.0, 2.0] # with Ref, Val, Symbol @test gradient(x -> sum(x .+ Ref(x[1])), [1,2,3]) == ([4,1,1],) From 17c07682a25383ead4822088976a208ebdb0829e Mon Sep 17 00:00:00 2001 From: DomCRose Date: Thu, 13 Jul 2023 17:46:58 +0100 Subject: [PATCH 7/7] Fix new real -> complex broadcasting test --- test/features.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/features.jl b/test/features.jl index 10d08ffa9..25f6661b4 100644 --- a/test/features.jl +++ b/test/features.jl @@ -805,7 +805,7 @@ end @test gradient(xs -> sum(f.(xs)), [0.5, 1.0, 1.5])[1] == [1.0, 2.0, 0.0] # Real input, complex output f = x -> x > 1.0 ? 1.0im : (x + 1.0im)^2 - @test gradient(xs -> sum(f.(xs)), [0.5, 1.0, 1.5])[1] == [2.5, 8.0, 0.0] + @test gradient(xs -> sum(abs2, f.(xs)), [0.5, 1.0, 1.5])[1] == [2.5, 8.0, 0.0] # Complex input, complex output f = x -> imag(x) > 1.0 ? 1.0im : x^2 @test gradient(xs -> sum(abs2, f.(xs)), [0.5im, 1.0im, 1.5im])[1] == [