From e0e667f5de1ba7b20fdb4e71e81988f27cca5802 Mon Sep 17 00:00:00 2001 From: Ziyi Yin Date: Mon, 27 Mar 2023 22:50:00 -0400 Subject: [PATCH 1/4] fix https://github.com/JuliaMath/AbstractFFTs.jl/issues/95 --- ext/AbstractFFTsChainRulesCoreExt.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index f0c788e6..953c4b80 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -30,10 +30,10 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) halfdim = first(dims) d = size(x, halfdim) n = size(y, halfdim) - scale = reshape( + scale = typeof(y)(reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) + )) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) @@ -72,10 +72,10 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) n = size(x, halfdim) invN = AbstractFFTs.normalization(y, dims) twoinvN = 2 * invN - scale = reshape( + scale = typeof(y)(reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) + )) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) @@ -111,10 +111,10 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) # compute scaling factors halfdim = first(dims) n = size(x, halfdim) - scale = reshape( + scale = typeof(y)(reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) + )) project_x = ChainRulesCore.ProjectTo(x) function brfft_pullback(ȳ) From eac7da49f1eec175a586e7a03c3e7d6ec58298bc Mon Sep 17 00:00:00 2001 From: Ziyi Yin Date: Mon, 10 Apr 2023 20:31:15 -0400 Subject: [PATCH 2/4] scaling in place instead of type conversion --- ext/AbstractFFTsChainRulesCoreExt.jl | 34 +++++++++++++++------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index 953c4b80..bbb0cf49 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -30,14 +30,15 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) halfdim = first(dims) d = size(x, halfdim) n = size(y, halfdim) - scale = typeof(y)(reshape( - [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - )) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) - x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) + dY = ChainRulesCore.unthunk(ȳ) ./ 2 + selectdim(dY, halfdim, 1) .*= 2 + if 2 * (n - 1) == d + selectdim(dY, halfdim, n) .*= 2 + end + x̄ = project_x(brfft(dY, d, dims)) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() end return y, rfft_pullback @@ -71,15 +72,15 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) halfdim = first(dims) n = size(x, halfdim) invN = AbstractFFTs.normalization(y, dims) - twoinvN = 2 * invN - scale = typeof(y)(reshape( - [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - )) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) + dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) .* invN .* 2 + selectdim(dX, halfdim, 1) ./= 2 + if 2 * (n - 1) == d + selectdim(dX, halfdim, n) ./= 2 + end + x̄ = project_x(dX) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() end return y, irfft_pullback @@ -111,14 +112,15 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) # compute scaling factors halfdim = first(dims) n = size(x, halfdim) - scale = typeof(y)(reshape( - [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - )) project_x = ChainRulesCore.ProjectTo(x) function brfft_pullback(ȳ) - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) + dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) .* 2 + selectdim(dX, halfdim, 1) ./= 2 + if 2 * (n - 1) == d + selectdim(dX, halfdim, n) ./= 2 + end + x̄ = project_x(dX) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() end return y, brfft_pullback From dc2cfab69b68c16ce289b2db64891960a7ed7858 Mon Sep 17 00:00:00 2001 From: Ziyi Yin Date: Wed, 12 Apr 2023 10:14:20 -0400 Subject: [PATCH 3/4] fix immutable array by similar --- ext/AbstractFFTsChainRulesCoreExt.jl | 45 ++++++++++++++++++---------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index bbb0cf49..119c079a 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -26,19 +26,24 @@ end function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) y = rfft(x, dims) - # compute scaling factors halfdim = first(dims) d = size(x, halfdim) n = size(y, halfdim) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) - dY = ChainRulesCore.unthunk(ȳ) ./ 2 - selectdim(dY, halfdim, 1) .*= 2 + dY = ChainRulesCore.unthunk(ȳ) + # apply scaling + dY_scaled = similar(dY) + dY_scaled .= dY + dY_scaled ./= 2 + v = selectdim(dY_scaled, halfdim, 1) + v .*= 2 if 2 * (n - 1) == d - selectdim(dY, halfdim, n) .*= 2 + v = selectdim(dY_scaled, halfdim, n) + v .*= 2 end - x̄ = project_x(brfft(dY, d, dims)) + x̄ = project_x(brfft(dY_scaled, d, dims)) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() end return y, rfft_pullback @@ -68,19 +73,24 @@ end function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) y = irfft(x, d, dims) - # compute scaling factors halfdim = first(dims) n = size(x, halfdim) invN = AbstractFFTs.normalization(y, dims) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) - dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) .* invN .* 2 - selectdim(dX, halfdim, 1) ./= 2 + dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) + # apply scaling + dX_scaled = similar(dX) + dX_scaled .= dX + dX_scaled .*= invN .* 2 + v = selectdim(dX_scaled, halfdim, 1) + v ./= 2 if 2 * (n - 1) == d - selectdim(dX, halfdim, n) ./= 2 + v = selectdim(dX_scaled, halfdim, n) + v ./= 2 end - x̄ = project_x(dX) + x̄ = project_x(dX_scaled) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() end return y, irfft_pullback @@ -109,18 +119,23 @@ end function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) y = brfft(x, d, dims) - # compute scaling factors halfdim = first(dims) n = size(x, halfdim) project_x = ChainRulesCore.ProjectTo(x) function brfft_pullback(ȳ) - dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) .* 2 - selectdim(dX, halfdim, 1) ./= 2 + dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) + # apply scaling + dX_scaled = similar(dX) + dX_scaled .= dX + dX_scaled .*= 2 + v = selectdim(dX_scaled, halfdim, 1) + v ./= 2 if 2 * (n - 1) == d - selectdim(dX, halfdim, n) ./= 2 + v = selectdim(dX_scaled, halfdim, n) + v ./= 2 end - x̄ = project_x(dX) + x̄ = project_x(dX_scaled) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() end return y, brfft_pullback From a22c1689b949e0ac8df04485890129326dd66252 Mon Sep 17 00:00:00 2001 From: Ziyi Yin Date: Wed, 12 Apr 2023 11:25:29 -0400 Subject: [PATCH 4/4] add more comments --- ext/AbstractFFTsChainRulesCoreExt.jl | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index 119c079a..7dc6862f 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -33,10 +33,11 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) dY = ChainRulesCore.unthunk(ȳ) - # apply scaling + # apply scaling; below approach is for GPU CuArray compatibility, see PR #96 dY_scaled = similar(dY) - dY_scaled .= dY - dY_scaled ./= 2 + dY_scaled .= dY ./ 2 + # assign view to a separate variable before assignment, to support Julia <1.2 + # see https://github.com/JuliaLang/julia/issues/31295 v = selectdim(dY_scaled, halfdim, 1) v .*= 2 if 2 * (n - 1) == d @@ -80,10 +81,11 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) - # apply scaling + # apply scaling; below approach is for GPU CuArray compatibility, see PR #96 dX_scaled = similar(dX) - dX_scaled .= dX - dX_scaled .*= invN .* 2 + dX_scaled .= dX .* invN .* 2 + # assign view to a separate variable before assignment, to support Julia <1.2 + # see https://github.com/JuliaLang/julia/issues/31295 v = selectdim(dX_scaled, halfdim, 1) v ./= 2 if 2 * (n - 1) == d @@ -125,10 +127,11 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) project_x = ChainRulesCore.ProjectTo(x) function brfft_pullback(ȳ) dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) - # apply scaling + # apply scaling; below approach is for GPU CuArray compatibility, see PR #96 dX_scaled = similar(dX) - dX_scaled .= dX - dX_scaled .*= 2 + dX_scaled .= dX .* 2 + # assign view to a separate variable before assignment, to support Julia <1.2 + # see https://github.com/JuliaLang/julia/issues/31295 v = selectdim(dX_scaled, halfdim, 1) v ./= 2 if 2 * (n - 1) == d