Skip to content

Commit

Permalink
add more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyiyin97 committed Apr 13, 2023
1 parent dc2cfab commit a22c168
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions ext/AbstractFFTsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
project_x = ChainRulesCore.ProjectTo(x)
function rfft_pullback(ȳ)
dY = ChainRulesCore.unthunk(ȳ)
# apply scaling
# apply scaling; below approach is for GPU CuArray compatibility, see PR #96
dY_scaled = similar(dY)
dY_scaled .= dY
dY_scaled ./= 2
dY_scaled .= dY ./ 2
# assign view to a separate variable before assignment, to support Julia <1.2
# see https://github.com/JuliaLang/julia/issues/31295
v = selectdim(dY_scaled, halfdim, 1)
v .*= 2
if 2 * (n - 1) == d
Expand Down Expand Up @@ -80,10 +81,11 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
project_x = ChainRulesCore.ProjectTo(x)
function irfft_pullback(ȳ)
dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)
# apply scaling
# apply scaling; below approach is for GPU CuArray compatibility, see PR #96
dX_scaled = similar(dX)
dX_scaled .= dX
dX_scaled .*= invN .* 2
dX_scaled .= dX .* invN .* 2
# assign view to a separate variable before assignment, to support Julia <1.2
# see https://github.com/JuliaLang/julia/issues/31295
v = selectdim(dX_scaled, halfdim, 1)
v ./= 2
if 2 * (n - 1) == d
Expand Down Expand Up @@ -125,10 +127,11 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
project_x = ChainRulesCore.ProjectTo(x)
function brfft_pullback(ȳ)
dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)
# apply scaling
# apply scaling; below approach is for GPU CuArray compatibility, see PR #96
dX_scaled = similar(dX)
dX_scaled .= dX
dX_scaled .*= 2
dX_scaled .= dX .* 2
# assign view to a separate variable before assignment, to support Julia <1.2
# see https://github.com/JuliaLang/julia/issues/31295
v = selectdim(dX_scaled, halfdim, 1)
v ./= 2
if 2 * (n - 1) == d
Expand Down

0 comments on commit a22c168

Please sign in to comment.