From 601b7e52477ad44804498927cf2edf41ae05f6af Mon Sep 17 00:00:00 2001 From: Matthew Giammmar Date: Wed, 11 Dec 2024 14:57:45 -0800 Subject: [PATCH] fix: minor bug fixes for whitening and ctf + pytest device fix --- src/torch_fourier_filter/ctf.py | 4 +-- src/torch_fourier_filter/dft_utils.py | 18 ++++++----- src/torch_fourier_filter/whitening.py | 46 ++++++++++++++------------- tests/conftest.py | 7 ++++ tests/mtf/test_mtf.py | 5 --- 5 files changed, 43 insertions(+), 37 deletions(-) create mode 100644 tests/conftest.py diff --git a/src/torch_fourier_filter/ctf.py b/src/torch_fourier_filter/ctf.py index 936b3f0..9aace52 100644 --- a/src/torch_fourier_filter/ctf.py +++ b/src/torch_fourier_filter/ctf.py @@ -3,7 +3,7 @@ import einops import torch from scipy import constants as C -from torch_grid_utils.fftfreq_grid import fftfreq_grid +from torch_grid_utils.fftfreq_grid import fftfreq_grid, fftshift_2d def calculate_relativistic_electron_wavelength(energy: float) -> float: @@ -166,7 +166,7 @@ def calculate_ctf_2d( if k4 > 0: ctf *= torch.exp(k4 * n4) if fftshift is True: - ctf = torch.fft.fftshift(ctf, dim=(-2, -1)) + ctf = fftshift_2d(ctf, rfft=rfft) return ctf diff --git a/src/torch_fourier_filter/dft_utils.py b/src/torch_fourier_filter/dft_utils.py index d5980fb..940209d 100644 --- a/src/torch_fourier_filter/dft_utils.py +++ b/src/torch_fourier_filter/dft_utils.py @@ -14,7 +14,7 @@ def rotational_average_dft_2d( image_shape: tuple[int, ...], rfft: bool = False, fftshifted: bool = False, - return_2d_average: bool = False, + return_1d_average: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: # rotational_average, frequency_bins """ Calculate the rotational average of a 2D DFT. @@ -30,8 +30,9 @@ def rotational_average_dft_2d( Whether the input is from an rfft (True) or full fft (False) fftshifted : bool Whether the input is fftshifted - return_2d_average : bool - Whether to return the 2D rotational average and frequency bins + return_1d_average : bool + If true, return a 1D rotational average and frequency bins, otherwise + return a 2D average. Returns ------- @@ -54,7 +55,7 @@ def rotational_average_dft_2d( for shell in shell_data ] rotational_average = einops.rearrange(mean_per_shell, "shells ... -> ... shells") - if return_2d_average is True: + if not return_1d_average: if len(dft.shape) > len(image_shape): image_shape = (*dft.shape[:-2], *image_shape[-2:]) rotational_average = _1d_to_rotational_average_2d_dft( @@ -77,7 +78,7 @@ def rotational_average_dft_3d( image_shape: tuple[int, ...], rfft: bool = False, fftshifted: bool = False, - return_3d_average: bool = False, + return_1d_average: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: # rotational_average, frequency_bins """ Calculate the rotational average of a 3D DFT. @@ -93,8 +94,9 @@ def rotational_average_dft_3d( Whether the input is from an rfft (True) or full fft (False) fftshifted : bool Whether the input is fftshifted - return_3d_average : bool - Whether to return the 3D rotational average and frequency bins + return_1d_average : bool + If true, return a 1D rotational average and frequency bins, otherwise + return a 3D average. Returns ------- @@ -117,7 +119,7 @@ def rotational_average_dft_3d( for shell in shell_data ] rotational_average = einops.rearrange(mean_per_shell, "shells ... -> ... shells") - if return_3d_average is True: + if not return_1d_average: if len(dft.shape) > len(image_shape): image_shape = (*dft.shape[:-3], *image_shape[-3:]) rotational_average = _1d_to_rotational_average_3d_dft( diff --git a/src/torch_fourier_filter/whitening.py b/src/torch_fourier_filter/whitening.py index 2fe2a6a..1ad314f 100644 --- a/src/torch_fourier_filter/whitening.py +++ b/src/torch_fourier_filter/whitening.py @@ -33,6 +33,8 @@ def gaussian_smoothing( torch.Tensor The smoothed tensor. """ + assert tensor.dim() in [1, 2], "Input tensor must be 1D or 2D" + # Create a 1D Gaussian kernel x = torch.arange( -kernel_size // 2 + 1, @@ -100,44 +102,44 @@ def whitening_filter( Whitening filter """ power_spectrum = torch.abs(image_dft) + if power_spec: power_spectrum = power_spectrum**2 - radial_average = None + if len(image_shape) == 2: - radial_average, _ = rotational_average_dft_2d( - dft=power_spectrum, - image_shape=image_shape, - rfft=rfft, - fftshifted=fftshift, - return_2d_average=False, # output 1D average - ) + rot_avg_mth = rotational_average_dft_2d elif len(image_shape) == 3: - radial_average, _ = rotational_average_dft_3d( - dft=power_spectrum, - image_shape=image_shape, - rfft=rfft, - fftshifted=fftshift, - return_3d_average=False, # output 1D average - ) + rot_avg_mth = rotational_average_dft_3d + + # Get 1-dimensional radial average + power_spectrum_1d, _ = rot_avg_mth( + dft=power_spectrum, + image_shape=image_shape, + rfft=rfft, + fftshifted=fftshift, + return_1d_average=True, + ) + + whitening_filter_1d = 1 / power_spectrum_1d - # Take the reciprical of the square root of the radial average - whiten_filter = 1 / (radial_average) if power_spec: - whiten_filter = whiten_filter**0.5 + whitening_filter_1d = whitening_filter_1d**0.5 # Apply Gaussian smoothing if smoothing: - whiten_filter = gaussian_smoothing(whiten_filter) + whitening_filter_1d = gaussian_smoothing(whitening_filter_1d) # bin or interpolate to output size - whiten_filter = bin_or_interpolate_to_output_size(whiten_filter, output_shape) + whitening_filter_1d = bin_or_interpolate_to_output_size( + whitening_filter_1d, output_shape + ) # put back to 2 or 3D if necessary if dimensions_output == 2: if len(power_spectrum.shape) > len(output_shape): output_shape = (*power_spectrum.shape[:-2], *output_shape[-2:]) whiten_filter = _1d_to_rotational_average_2d_dft( - values=radial_average, + values=whitening_filter_1d, image_shape=output_shape, rfft=rfft, fftshifted=fftshift, @@ -146,7 +148,7 @@ def whitening_filter( if len(power_spectrum.shape) > len(output_shape): output_shape = (*power_spectrum.shape[:-3], *output_shape[-3:]) whiten_filter = _1d_to_rotational_average_3d_dft( - values=radial_average, + values=whitening_filter_1d, image_shape=output_shape, rfft=rfft, fftshifted=fftshift, diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5cc8126 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +import pytest +import torch + + +@pytest.fixture(autouse=True) +def set_default_device(): + torch.set_default_device("cpu") diff --git a/tests/mtf/test_mtf.py b/tests/mtf/test_mtf.py index d150407..35c1d34 100644 --- a/tests/mtf/test_mtf.py +++ b/tests/mtf/test_mtf.py @@ -11,11 +11,6 @@ def test_make_mtf_grid(): image_shape_3d = (32, 64, 64) mtf_frequencies = torch.linspace(0, 0.5, 10) mtf_amplitudes = torch.linspace(1, 0, 10) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Move tensors to the appropriate device - mtf_frequencies = mtf_frequencies.to(device) - mtf_amplitudes = mtf_amplitudes.to(device) mtf_grid_2d = make_mtf_grid( image_shape_2d, mtf_frequencies, mtf_amplitudes, rfft=True, fftshift=False