Skip to content

Commit

Permalink
fix: minor bug fixes for whitening and ctf + pytest device fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mgiammar committed Dec 11, 2024
1 parent d39badb commit 601b7e5
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 37 deletions.
4 changes: 2 additions & 2 deletions src/torch_fourier_filter/ctf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import einops
import torch
from scipy import constants as C
from torch_grid_utils.fftfreq_grid import fftfreq_grid
from torch_grid_utils.fftfreq_grid import fftfreq_grid, fftshift_2d


def calculate_relativistic_electron_wavelength(energy: float) -> float:
Expand Down Expand Up @@ -166,7 +166,7 @@ def calculate_ctf_2d(
if k4 > 0:
ctf *= torch.exp(k4 * n4)
if fftshift is True:
ctf = torch.fft.fftshift(ctf, dim=(-2, -1))
ctf = fftshift_2d(ctf, rfft=rfft)
return ctf


Expand Down
18 changes: 10 additions & 8 deletions src/torch_fourier_filter/dft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def rotational_average_dft_2d(
image_shape: tuple[int, ...],
rfft: bool = False,
fftshifted: bool = False,
return_2d_average: bool = False,
return_1d_average: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]: # rotational_average, frequency_bins
"""
Calculate the rotational average of a 2D DFT.
Expand All @@ -30,8 +30,9 @@ def rotational_average_dft_2d(
Whether the input is from an rfft (True) or full fft (False)
fftshifted : bool
Whether the input is fftshifted
return_2d_average : bool
Whether to return the 2D rotational average and frequency bins
return_1d_average : bool
If true, return a 1D rotational average and frequency bins, otherwise
return a 2D average.
Returns
-------
Expand All @@ -54,7 +55,7 @@ def rotational_average_dft_2d(
for shell in shell_data
]
rotational_average = einops.rearrange(mean_per_shell, "shells ... -> ... shells")
if return_2d_average is True:
if not return_1d_average:
if len(dft.shape) > len(image_shape):
image_shape = (*dft.shape[:-2], *image_shape[-2:])
rotational_average = _1d_to_rotational_average_2d_dft(
Expand All @@ -77,7 +78,7 @@ def rotational_average_dft_3d(
image_shape: tuple[int, ...],
rfft: bool = False,
fftshifted: bool = False,
return_3d_average: bool = False,
return_1d_average: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]: # rotational_average, frequency_bins
"""
Calculate the rotational average of a 3D DFT.
Expand All @@ -93,8 +94,9 @@ def rotational_average_dft_3d(
Whether the input is from an rfft (True) or full fft (False)
fftshifted : bool
Whether the input is fftshifted
return_3d_average : bool
Whether to return the 3D rotational average and frequency bins
return_1d_average : bool
If true, return a 1D rotational average and frequency bins, otherwise
return a 3D average.
Returns
-------
Expand All @@ -117,7 +119,7 @@ def rotational_average_dft_3d(
for shell in shell_data
]
rotational_average = einops.rearrange(mean_per_shell, "shells ... -> ... shells")
if return_3d_average is True:
if not return_1d_average:
if len(dft.shape) > len(image_shape):
image_shape = (*dft.shape[:-3], *image_shape[-3:])
rotational_average = _1d_to_rotational_average_3d_dft(
Expand Down
46 changes: 24 additions & 22 deletions src/torch_fourier_filter/whitening.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def gaussian_smoothing(
torch.Tensor
The smoothed tensor.
"""
assert tensor.dim() in [1, 2], "Input tensor must be 1D or 2D"

# Create a 1D Gaussian kernel
x = torch.arange(
-kernel_size // 2 + 1,
Expand Down Expand Up @@ -100,44 +102,44 @@ def whitening_filter(
Whitening filter
"""
power_spectrum = torch.abs(image_dft)

if power_spec:
power_spectrum = power_spectrum**2
radial_average = None

if len(image_shape) == 2:
radial_average, _ = rotational_average_dft_2d(
dft=power_spectrum,
image_shape=image_shape,
rfft=rfft,
fftshifted=fftshift,
return_2d_average=False, # output 1D average
)
rot_avg_mth = rotational_average_dft_2d
elif len(image_shape) == 3:
radial_average, _ = rotational_average_dft_3d(
dft=power_spectrum,
image_shape=image_shape,
rfft=rfft,
fftshifted=fftshift,
return_3d_average=False, # output 1D average
)
rot_avg_mth = rotational_average_dft_3d

# Get 1-dimensional radial average
power_spectrum_1d, _ = rot_avg_mth(
dft=power_spectrum,
image_shape=image_shape,
rfft=rfft,
fftshifted=fftshift,
return_1d_average=True,
)

whitening_filter_1d = 1 / power_spectrum_1d

# Take the reciprical of the square root of the radial average
whiten_filter = 1 / (radial_average)
if power_spec:
whiten_filter = whiten_filter**0.5
whitening_filter_1d = whitening_filter_1d**0.5

# Apply Gaussian smoothing
if smoothing:
whiten_filter = gaussian_smoothing(whiten_filter)
whitening_filter_1d = gaussian_smoothing(whitening_filter_1d)

# bin or interpolate to output size
whiten_filter = bin_or_interpolate_to_output_size(whiten_filter, output_shape)
whitening_filter_1d = bin_or_interpolate_to_output_size(
whitening_filter_1d, output_shape
)

# put back to 2 or 3D if necessary
if dimensions_output == 2:
if len(power_spectrum.shape) > len(output_shape):
output_shape = (*power_spectrum.shape[:-2], *output_shape[-2:])
whiten_filter = _1d_to_rotational_average_2d_dft(
values=radial_average,
values=whitening_filter_1d,
image_shape=output_shape,
rfft=rfft,
fftshifted=fftshift,
Expand All @@ -146,7 +148,7 @@ def whitening_filter(
if len(power_spectrum.shape) > len(output_shape):
output_shape = (*power_spectrum.shape[:-3], *output_shape[-3:])
whiten_filter = _1d_to_rotational_average_3d_dft(
values=radial_average,
values=whitening_filter_1d,
image_shape=output_shape,
rfft=rfft,
fftshifted=fftshift,
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pytest
import torch


@pytest.fixture(autouse=True)
def set_default_device():
torch.set_default_device("cpu")
5 changes: 0 additions & 5 deletions tests/mtf/test_mtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ def test_make_mtf_grid():
image_shape_3d = (32, 64, 64)
mtf_frequencies = torch.linspace(0, 0.5, 10)
mtf_amplitudes = torch.linspace(1, 0, 10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move tensors to the appropriate device
mtf_frequencies = mtf_frequencies.to(device)
mtf_amplitudes = mtf_amplitudes.to(device)

mtf_grid_2d = make_mtf_grid(
image_shape_2d, mtf_frequencies, mtf_amplitudes, rfft=True, fftshift=False
Expand Down

0 comments on commit 601b7e5

Please sign in to comment.