diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index f0c788e6..a0c7b738 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -37,7 +37,7 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) - x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) + x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ typeof(x)(scale), d, dims)) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() end return y, rfft_pullback @@ -79,7 +79,7 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) + x̄ = project_x(typeof(x)(scale) .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() end return y, irfft_pullback