Skip to content

Commit

Permalink
FFT filter changes with Numpy v2.0 (#592)
Browse files Browse the repository at this point in the history
* Include accelerate base tests by default

* Change expected output in FFT filter tests to reflect support for single precision in Numpy > 2.0

* Adjust tolerance in pycuda fft filter tests
  • Loading branch information
daurer authored and kahntm committed Jan 8, 2025
1 parent aa34bba commit e3f812c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 77 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ testpaths = [
"test/ptyscan_tests",
"test/template_tests",
"test/util_tests",
"test/accelerate_tests/base_tests"
]

# this is all BETA according to setuptools
Expand Down
142 changes: 71 additions & 71 deletions test/accelerate_tests/base_tests/array_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,42 +305,42 @@ def test_fft_filter(self):
kernel = np.fft.fftn(rk)

output = au.fft_filter(data, kernel, prefactor, postfactor)

known_test_output = np.array([-0.0000000e+00+0.00000000e+00j, -0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
0.0000000e+00-0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
-0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
-0.0000000e+00+0.00000000e+00j, -0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00-0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00+0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 0.0000000e+00-0.00000000e+00j,
0.0000000e+00+0.00000000e+00j, 6.1097220e-05+2.92563982e-05j,
4.0044695e-05+2.52102855e-05j, 8.9999994e+02+9.00000000e+02j,
8.9999988e+02+8.99999939e+02j, 5.0999994e+02+5.09999939e+02j,
1.9365043e-05+5.84280206e-05j, 3.0681291e-05+2.31116355e-05j,
1.2552022e-05-1.01537153e-05j, -1.4034913e-05-1.17988075e-05j,
-1.9193330e-05+2.72889110e-07j, 1.3895768e-05+1.64778357e-05j,
6.5228807e-05+2.45708943e-05j, 3.8999994e+02+3.89999939e+02j,
8.9999988e+02+8.99999878e+02j, 8.9999982e+02+8.99999878e+02j,
3.0000015e+01+3.00000248e+01j, 3.8863189e-05+3.26705631e-05j,
2.8768281e-06-1.62116921e-05j, -3.2418033e-05-1.97073969e-05j,
-6.6843757e-05+7.19546824e-06j, 6.5036993e-06+3.95851657e-06j,
-2.4053887e-05+9.88548163e-06j, 1.5231475e-05+1.31202614e-06j,
8.7000000e+01+8.70000305e+01j, 6.1035156e-05+0.00000000e+00j,
6.1035156e-05+0.00000000e+00j, -2.4943074e-07+6.62429193e-06j,
1.6712515e-06-2.97475322e-06j, 1.9025241e-05+2.97752194e-07j,
-9.2436176e-07-3.86252796e-05j, -8.8145862e-06-9.89961700e-06j,
-1.5782407e-06+1.01533060e-05j, -4.7593076e-06+2.96332291e-05j])

known_test_output = np.array([-0.00000000e+00+0.00000000e+00j, 0.00000000e+00-0.00000000e+00j,
-0.00000000e+00 + 0.00000000e+00j, 0.00000000e+00 - 0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, -0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
-0.00000000e+00+0.00000000e+00j, -0.00000000e+00+0.00000000e+00j,
-0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, 0.00000000e+00-0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00-0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, 0.00000000e+00-0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, 8.66422277e-14+4.86768828e-14j,
7.23113320e-14+2.82331542e-14j, 9.00000000e+02+9.00000000e+02j,
9.00000000e+02+9.00000000e+02j, 5.10000000e+02+5.10000000e+02j,
1.41172830e-14+3.62223425e-14j, 2.61684238e-14-4.13866575e-14j,
2.16691314e-14-1.95102733e-14j, -1.36536942e-13-9.94589021e-14j,
-1.42905371e-13-5.77964697e-14j, -5.00005072e-14+4.08620637e-14j,
6.38160272e-14+7.61753583e-14j, 3.90000000e+02+3.90000000e+02j,
9.00000000e+02+9.00000000e+02j, 9.00000000e+02+9.00000000e+02j,
3.00000000e+01+3.00000000e+01j, 8.63255773e-14+7.08532924e-14j,
1.80941313e-14-3.85517154e-14j, 7.84277340e-14-1.32008745e-14j,
-6.57025196e-14-1.72739350e-14j, -6.69570857e-15+6.49622898e-14j,
6.27436466e-15+7.57162569e-14j, 2.01150157e-15+3.65538558e-14j,
8.70000000e+01+8.70000000e+01j, -1.13686838e-13-1.70530257e-13j,
0.00000000e+00-2.27373675e-13j, -1.84492121e-14-9.21502853e-14j,
2.12418687e-14-8.62209232e-14j, 1.20880692e-13+3.86522371e-14j,
1.03754734e-13+9.19851759e-14j, 5.50926123e-14+1.17150422e-13j,
-5.47869215e-14+5.87176511e-14j, -3.52652980e-14+8.44455504e-15j])

np.testing.assert_array_almost_equal(output.flat[::2000], known_test_output)
np.testing.assert_array_almost_equal(output.flat[::2000], known_test_output, decimal=5)

def test_fft_filter_batched(self):
data = np.zeros((2,256, 512), dtype=COMPLEX_TYPE)
Expand All @@ -356,42 +356,42 @@ def test_fft_filter_batched(self):
kernel = np.fft.fftn(rk)

output = au.fft_filter(data, kernel, prefactor, postfactor)
known_test_output = np.array([-0.00000000e+00+0.00000000e+00j, 0.00000000e+00-0.00000000e+00j,
-0.00000000e+00 + 0.00000000e+00j, 0.00000000e+00 - 0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, -0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
-0.00000000e+00+0.00000000e+00j, -0.00000000e+00+0.00000000e+00j,
-0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, 0.00000000e+00-0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00-0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, 0.00000000e+00-0.00000000e+00j,
0.00000000e+00-0.00000000e+00j, 8.66422277e-14+4.86768828e-14j,
7.23113320e-14+2.82331542e-14j, 9.00000000e+02+9.00000000e+02j,
9.00000000e+02+9.00000000e+02j, 5.10000000e+02+5.10000000e+02j,
1.41172830e-14+3.62223425e-14j, 2.61684238e-14-4.13866575e-14j,
2.16691314e-14-1.95102733e-14j, -1.36536942e-13-9.94589021e-14j,
-1.42905371e-13-5.77964697e-14j, -5.00005072e-14+4.08620637e-14j,
6.38160272e-14+7.61753583e-14j, 3.90000000e+02+3.90000000e+02j,
9.00000000e+02+9.00000000e+02j, 9.00000000e+02+9.00000000e+02j,
3.00000000e+01+3.00000000e+01j, 8.63255773e-14+7.08532924e-14j,
1.80941313e-14-3.85517154e-14j, 7.84277340e-14-1.32008745e-14j,
-6.57025196e-14-1.72739350e-14j, -6.69570857e-15+6.49622898e-14j,
6.27436466e-15+7.57162569e-14j, 2.01150157e-15+3.65538558e-14j,
8.70000000e+01+8.70000000e+01j, -1.13686838e-13-1.70530257e-13j,
0.00000000e+00-2.27373675e-13j, -1.84492121e-14-9.21502853e-14j,
2.12418687e-14-8.62209232e-14j, 1.20880692e-13+3.86522371e-14j,
1.03754734e-13+9.19851759e-14j, 5.50926123e-14+1.17150422e-13j,
-5.47869215e-14+5.87176511e-14j, -3.52652980e-14+8.44455504e-15j])

np.testing.assert_array_almost_equal(output[1].flat[::2000], known_test_output)

known_test_output = np.array([ 0.00000000e+00-0.0000000e+00j, 0.00000000e+00-0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00+0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00+0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00+0.0000000e+00j,
-0.00000000e+00+0.0000000e+00j, 0.00000000e+00-0.0000000e+00j,
0.00000000e+00-0.0000000e+00j, 0.00000000e+00-0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00+0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00+0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00+0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00-0.0000000e+00j,
0.00000000e+00-0.0000000e+00j, 0.00000000e+00-0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00+0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00+0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00+0.0000000e+00j,
0.00000000e+00+0.0000000e+00j, 0.00000000e+00-0.0000000e+00j,
0.00000000e+00-0.0000000e+00j, 0.00000000e+00-0.0000000e+00j,
0.00000000e+00-0.0000000e+00j, 4.86995195e-05-9.1511911e-06j,
5.89395277e-05+3.6706428e-05j, 8.99999817e+02+9.0000000e+02j,
8.99999817e+02+9.0000000e+02j, 5.09999969e+02+5.0999997e+02j,
6.86399580e-05+5.5245564e-05j, -2.15578075e-06-8.0761157e-07j,
-5.99612467e-05-3.7489859e-05j, -2.08058154e-05-1.7001423e-05j,
-3.15661709e-05-2.0192698e-05j, -1.17410173e-05-2.3929812e-05j,
8.41844594e-05+4.9635066e-05j, 3.90000031e+02+3.9000003e+02j,
8.99999817e+02+8.9999994e+02j, 8.99999817e+02+8.9999994e+02j,
3.00000153e+01+3.0000000e+01j, 4.75842753e-05+1.7961407e-05j,
-1.28229876e-05-3.3492659e-05j, -1.50405585e-05+3.0159079e-05j,
-1.00799960e-04-6.6932058e-05j, -4.90295024e-05-3.6601130e-05j,
-4.48861247e-05-1.4717044e-05j, 2.60417364e-05-8.3221821e-06j,
8.69999847e+01+8.7000046e+01j, 4.31583721e-05+4.3158372e-05j,
4.31583721e-05+4.3158372e-05j, 4.04649109e-06-1.6836095e-05j,
1.37377283e-05+5.2577798e-06j, -2.30404657e-05-3.4596611e-05j,
-1.33214944e-05-3.2517899e-05j, 2.45428764e-05-3.5186855e-07j,
-1.85950885e-05-2.1921931e-05j, -1.65030433e-05-8.0249208e-07j])
np.testing.assert_array_almost_equal(output[1].flat[::2000], known_test_output, decimal=5)


def test_complex_gaussian_filter_fft(self):
Expand Down
12 changes: 6 additions & 6 deletions test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def test_fft_filter_UNITY(self):

output = au.fft_filter(data, kernel, prefactor, postfactor)

np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-5)

def test_fft_filter_batched_UNITY(self):
sh = (2,16, 35)
Expand All @@ -607,7 +607,7 @@ def test_fft_filter_batched_UNITY(self):
output = au.fft_filter(data, kernel, prefactor, postfactor)
print(data_dev.get())

np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-5)

def test_complex_gaussian_filter_fft_little_blurring_UNITY(self):
# Arrange
Expand All @@ -624,7 +624,7 @@ def test_complex_gaussian_filter_fft_little_blurring_UNITY(self):
out_exp = au.complex_gaussian_filter_fft(data, mfs)
out = data_dev.get()

np.testing.assert_allclose(out_exp, out, atol=1e-6)
np.testing.assert_allclose(out_exp, out, atol=1e-5)

def test_complex_gaussian_filter_fft_more_blurring_UNITY(self):
# Arrange
Expand All @@ -641,7 +641,7 @@ def test_complex_gaussian_filter_fft_more_blurring_UNITY(self):
out_exp = au.complex_gaussian_filter_fft(data, mfs)
out = data_dev.get()

np.testing.assert_allclose(out_exp, out, atol=1e-6)
np.testing.assert_allclose(out_exp, out, atol=1e-5)

def test_complex_gaussian_filter_fft_nonsquare_UNITY(self):
# Arrange
Expand All @@ -660,7 +660,7 @@ def test_complex_gaussian_filter_fft_nonsquare_UNITY(self):
out_exp = au.complex_gaussian_filter_fft(data, mfs)
out = data_dev.get()

np.testing.assert_allclose(out_exp, out, atol=1e-6)
np.testing.assert_allclose(out_exp, out, atol=1e-5)

def test_complex_gaussian_filter_fft_batched(self):
# Arrange
Expand All @@ -680,4 +680,4 @@ def test_complex_gaussian_filter_fft_batched(self):
out_exp = au.complex_gaussian_filter_fft(data, mfs)
out = data_dev.get()

np.testing.assert_allclose(out_exp, out, atol=1e-6)
np.testing.assert_allclose(out_exp, out, atol=1e-5)

0 comments on commit e3f812c

Please sign in to comment.