Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Dec 17, 2024
1 parent e06611d commit c182bc6
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@ MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
FFTWForwardDiffExt = "ForwardDiff"

[compat]
AbstractFFTs = "1.5"
AbstractFFTs = "1.6"
FFTW_jll = "3.3.9"
ForwardDiff = "0.10"
LinearAlgebra = "<0.0.1, 1"
MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023, 2024"
Preferences = "1.2"
Expand Down
17 changes: 7 additions & 10 deletions ext/FFTWForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
module FFTWForwardDiffExt
# AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = float.(x .+ 0im)
using FFTW
using ForwardDiff
import FFTW: plan_r2r, r2r
import FFTW.AbstractFFTs: dualplan, dual2array
import ForwardDiff: Dual

plan_r2r(x::AbstractArray{D}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, plan_r2r(dual2array(x), FLAG, 1 .+ dims))
plan_r2r(x::AbstractArray{<:Complex{D}}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, plan_r2r(dual2array(x), FLAG, 1 .+ dims))

plan_r2r(x::AbstractArray{<:Dual}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims)
plan_r2r(x::AbstractArray{<:Complex{<:Dual}}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims)

for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex?
@eval begin
$plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims)
$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, dims=1:ndims(x)) = $plan(dual2array(x), d, 1 .+ dims)
end
end

r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x
r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Aqua = "0.8"
ForwardDiff = "0.10"
Test = "<0.0.1, 1"
5 changes: 4 additions & 1 deletion test/fftwforwarddiff.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using FFTW, ForwardDiff, Test
using ForwardDiff: Dual, value, partials

@testset "r2r" begin
x1 = Dual.(1:4.0, 2:5, 3:6)
t = FFTW.r2r(x1, FFTW.R2HC)
Expand All @@ -12,7 +15,7 @@
@test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC)

f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1]
@test derivative(f, 0.1) 1.0
@test ForwardDiff.derivative(f, 0.1) 1.0

@test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC)
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,5 @@ end
AbstractFFTs.TestUtils.test_real_ffts(Array; copy_input=true)
end
end

include("fftwforwarddiff.jl")

0 comments on commit c182bc6

Please sign in to comment.