Skip to content

Commit

Permalink
feat: breaking change, process data as 3D zyx array
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Jan 7, 2025
1 parent c777591 commit e9a623a
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 60 deletions.
84 changes: 84 additions & 0 deletions qpretrieve/data_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np
import warnings

allowed_data_formats = [
"rgb",
"rgba",
"3d",
"2d",
]


def check_data_input_format(data_input):
"""Figure out what data input is provided."""
if len(data_input.shape) == 3:
if data_input.shape[-1] in [1, 2, 3]:
# take the first slice (we have alpha or RGB information)
data, data_format = _convert_rgb_to_3d(data_input)
elif data_input.shape[-1] == 4:
# take the first slice (we have alpha or RGB information)
data, data_format = _convert_rgba_to_3d(data_input)
else:
# we have a 3D image stack (z, y, x)
data, data_format = data_input, "3d"
elif len(data_input.shape) == 2:
# we have a 2D image (y, x). convert to (z, y, z)
data, data_format = _convert_2d_to_3d(data_input)
else:
raise ValueError(f"data_input shape must be 2d or 3d, "
f"got shape {data_input.shape}.")
return data.copy(), data_format


def revert_to_data_input_format(data_format, field):
"""Convert the outputted field shape to the original input shape,
for user convenience."""
assert data_format in allowed_data_formats
assert len(field.shape) == 3, "the field should be 3d"
field = field.copy()
if data_format == "rgb":
field = _revert_3d_to_rgb(field)
elif data_format == "rgba":
field = _revert_3d_to_rgba(field)
elif data_format == "3d":
field = field
else:
field = _revert_3d_to_2d(field)
return field


def _convert_rgb_to_3d(data_input):
data = data_input[:, :, 0]
data = data[np.newaxis, :, :]
data_format = "rgb"
warnings.warn(f"Format of input data detected as {data_format}. "
f"The first channel will be used for processing")
return data, data_format


def _convert_rgba_to_3d(data_input):
data, _ = _convert_rgb_to_3d(data_input)
data_format = "rgba"
return data, data_format


def _convert_2d_to_3d(data_input):
data = data_input[np.newaxis, :, :]
data_format = "2d"
return data, data_format


def _revert_3d_to_rgb(data_input):
data = data_input[0]
data = np.dstack((data, data, data))
return data


def _revert_3d_to_rgba(data_input):
data = data_input[0]
data = np.dstack((data, data, data, np.ones_like(data)))
return data


def _revert_3d_to_2d(data_input):
return data_input[0]
10 changes: 6 additions & 4 deletions qpretrieve/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape):
and must be between 0 and `max(fft_shape)/2`
freq_pos: tuple of floats
The position of the filter in frequency coordinates as
returned by :func:`nunpy.fft.fftfreq`.
returned by :func:`numpy.fft.fftfreq`.
fft_shape: tuple of int
The shape of the Fourier transformed image for which the
The shape of the Fourier transformed image (2d) for which the
filter will be applied. The shape must be squared (two
identical integers).
Expand Down Expand Up @@ -104,8 +104,10 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape):
# TODO: avoid the np.roll, instead use the indices directly
alpha = 0.1
rsize = int(min(fx.size, fy.size) * filter_size) * 2
tukey_window_x = signal.tukey(rsize, alpha=alpha).reshape(-1, 1)
tukey_window_y = signal.tukey(rsize, alpha=alpha).reshape(1, -1)
tukey_window_x = signal.windows.tukey(
rsize, alpha=alpha).reshape(-1, 1)
tukey_window_y = signal.windows.tukey(
rsize, alpha=alpha).reshape(1, -1)
tukey = tukey_window_x * tukey_window_y
base = np.zeros(fft_shape)
s1 = (np.array(fft_shape) - rsize) // 2
Expand Down
13 changes: 13 additions & 0 deletions qpretrieve/fourier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
PREFERRED_INTERFACE = None


def get_available_interfaces():
"""Return a list of available FFT algorithms"""
interfaces = [
FFTFilterPyFFTW,
FFTFilterNumpy,
]
interfaces_available = []
for interface in interfaces:
if interface is not None and interface.is_available:
interfaces_available.append(interface)
return interfaces_available


def get_best_interface():
"""Return the fastest refocusing interface available
Expand Down
53 changes: 36 additions & 17 deletions qpretrieve/fourier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np

from .. import filter
from ..utils import padding_3d, mean_3d
from ..data_input import check_data_input_format


class FFTCache:
Expand Down Expand Up @@ -35,12 +37,19 @@ def cleanup(key):


class FFTFilter(ABC):
def __init__(self, data, subtract_mean=True, padding=2, copy=True):
def __init__(self,
data: np.ndarray,
subtract_mean: bool = True,
padding: int = 2,
copy: bool = True):
r"""
Parameters
----------
data: 2d real-valued np.ndarray
The experimental input image
data
The experimental input real-valued image. Allowed input shapes are:
- 2d (y, x)
- 3d (z, y, x)
- 3d rgb (y, x, 3) or rgba (y, x, 4)
subtract_mean: bool
If True, subtract the mean of `data` before performing
the Fourier transform. This setting is recommended as it
Expand Down Expand Up @@ -70,9 +79,15 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True):
else:
# convert integer-arrays to floating point arrays
dtype = float
if not copy:
# numpy v2.x behaviour requires asarray with copy=False
copy = None
data_ed = np.array(data, dtype=dtype, copy=copy)
# figure out what type of data we have, change it to 3d-stack
data_ed, self.data_format = check_data_input_format(data_ed)
#: original data (with subtracted mean)
self.origin = data_ed
# for `subtract_mean` and `padding`, we could use `np.atleast_3d`
#: whether padding is enabled
self.padding = padding
#: whether the mean was subtracted
Expand All @@ -81,14 +96,13 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True):
# remove contributions of the central band
# (this affects more than one pixel in the FFT
# because of zero-padding)
data_ed -= data_ed.mean()
data_ed = mean_3d(data_ed)
if padding:
# zero padding size is next order of 2
logfact = np.log(padding * max(data_ed.shape))
order = int(2 ** np.ceil(logfact / np.log(2)))
# this is faster than np.pad
datapad = np.zeros((order, order), dtype=dtype)
datapad[:data_ed.shape[0], :data_ed.shape[1]] = data_ed

datapad = padding_3d(data_ed, order, dtype)
#: padded input data
self.origin_padded = datapad
data_ed = datapad
Expand Down Expand Up @@ -175,7 +189,7 @@ def filter(self, filter_name: str, filter_size: float,
and must be between 0 and `max(fft_shape)/2`
freq_pos: tuple of floats
The position of the filter in frequency coordinates as
returned by :func:`nunpy.fft.fftfreq`.
returned by :func:`numpy.fft.fftfreq`.
scale_to_filter: bool or float
Crop the image in Fourier space after applying the filter,
effectively removing surplus (zero-padding) data and
Expand Down Expand Up @@ -220,36 +234,41 @@ def filter(self, filter_name: str, filter_size: float,
filter_name=filter_name,
filter_size=filter_size,
freq_pos=freq_pos,
fft_shape=self.fft_origin.shape)
# only take shape of a single fft
fft_shape=self.fft_origin.shape[-2:])
fft_filtered = self.fft_origin * filt_array
px = int(freq_pos[0] * self.shape[0])
py = int(freq_pos[1] * self.shape[1])
fft_used = np.roll(np.roll(fft_filtered, -px, axis=0), -py, axis=1)
px = int(freq_pos[0] * self.shape[-2])
py = int(freq_pos[1] * self.shape[-1])
fft_used = np.roll(np.roll(
fft_filtered, -px, axis=-2), -py, axis=-1)
if scale_to_filter:
# Determine the size of the cropping region.
# We compute the "radius" of the region, so we can
# crop the data left and right from the center of the
# Fourier domain.
osize = fft_filtered.shape[0] # square shaped
osize = fft_filtered.shape[-2] # square shaped
crad = int(np.ceil(filter_size * osize * scale_to_filter))
ccent = osize // 2
cslice = slice(ccent - crad, ccent + crad)
# We now have the interesting peak already shifted to
# the first entry of our array in `shifted`.
fft_used = fft_used[cslice, cslice]
fft_used = fft_used[:, cslice, cslice]

field = self._ifft(np.fft.ifftshift(fft_used))

if self.padding:
# revert padding
sx, sy = self.origin.shape
sx, sy = self.origin.shape[-2:]
if scale_to_filter:
sx = int(np.ceil(sx * 2 * crad / osize))
sy = int(np.ceil(sy * 2 * crad / osize))
field = field[:sx, :sy]

field = field[:, :sx, :sy]

if scale_to_filter:
# Scale the absolute value of the field. This does not
# have any influence on the phase, but on the amplitude.
field *= (2 * crad / osize)**2
field *= (2 * crad / osize) ** 2
# Add FFT to cache
# (The cache will only be cleared if this instance is deleted)
FFTCache.add_item(weakref_key, self.fft_origin,
Expand Down
46 changes: 32 additions & 14 deletions qpretrieve/interfere/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import numpy as np

from ..fourier import get_best_interface
from ..fourier import get_best_interface, get_available_interfaces
from ..fourier.base import FFTFilter


class BaseInterferogram(ABC):
Expand All @@ -15,11 +16,19 @@ class BaseInterferogram(ABC):
"invert_phase": False,
}

def __init__(self, data, subtract_mean=True, padding=2, copy=True,
def __init__(self, data, fft_interface: FFTFilter = None,
subtract_mean=True, padding=2, copy=True,
**pipeline_kws):
"""
Parameters
----------
fft_interface: FFTFilter
A Fourier transform interface.
See :func:`qpretrieve.fourier.get_available_interfaces`
to get a list of implemented interfaces.
Default is None, which will use
:func:`qpretrieve.fourier.get_best_interface`. This is in line
with old behaviour.
subtract_mean: bool
If True, remove the mean of the hologram before performing
the Fourier transform. This setting is recommended as it
Expand All @@ -38,15 +47,24 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True,
Any additional keyword arguments for :func:`run_pipeline`
as defined in :const:`default_pipeline_kws`.
"""
ff_iface = get_best_interface()
if len(data.shape) == 3:
# take the first slice (we have alpha or RGB information)
data = data[:, :, 0]
if fft_interface == 'auto' or fft_interface is None:
self.ff_iface = get_best_interface()
else:
if fft_interface in get_available_interfaces():
self.ff_iface = fft_interface
else:
raise ValueError(
f"User-chosen FFT Interface '{fft_interface}' is not "
f"available. The available interfaces are: "
f"{get_available_interfaces()}.\n"
f"You can use `fft_interface='auto'` to get the best "
f"available interface.")

#: qpretrieve Fourier transform interface class
self.fft = ff_iface(data=data,
subtract_mean=subtract_mean,
padding=padding,
copy=copy)
self.fft = self.ff_iface(data=data,
subtract_mean=subtract_mean,
padding=padding,
copy=copy)
#: originally computed Fourier transform
self.fft_origin = self.fft.fft_origin
#: filtered Fourier data from last run of `run_pipeline`
Expand Down Expand Up @@ -94,18 +112,18 @@ def compute_filter_size(self, filter_size, filter_size_interpretation,
raise ValueError("For sideband distance interpretation, "
"`filter_size` must be between 0 and 1; "
f"got '{filter_size}'!")
fsize = np.sqrt(np.sum(np.array(sideband_freq)**2)) * filter_size
fsize = np.sqrt(np.sum(np.array(sideband_freq) ** 2)) * filter_size
elif filter_size_interpretation == "frequency index":
# filter size given in Fourier index (number of Fourier pixels)
# The user probably does not know that we are padding in
# Fourier space, so we use the unpadded size and translate it.
if filter_size <= 0 or filter_size >= self.fft.shape[0] / 2:
if filter_size <= 0 or filter_size >= self.fft.shape[-2] / 2:
raise ValueError("For frequency index interpretation, "
+ "`filter_size` must be between 0 and "
+ f"{self.fft.shape[0] / 2}, got "
+ f"{self.fft.shape[-2] / 2}, got "
+ f"'{filter_size}'!")
# convert to frequencies (compatible with fx and fy)
fsize = filter_size / self.fft.shape[0]
fsize = filter_size / self.fft.shape[-2]
else:
raise ValueError("Invalid value for `filter_size_interpretation`: "
+ f"'{filter_size_interpretation}'")
Expand Down
6 changes: 4 additions & 2 deletions qpretrieve/interfere/if_oah.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from .base import BaseInterferogram
from ..data_input import revert_to_data_input_format


class OffAxisHologram(BaseInterferogram):
Expand Down Expand Up @@ -73,7 +74,7 @@ def run_pipeline(self, **pipeline_kws):

if pipeline_kws["sideband_freq"] is None:
pipeline_kws["sideband_freq"] = find_peak_cosine(
self.fft.fft_origin)
self.fft.fft_origin[0])

# convert filter_size to frequency coordinates
fsize = self.compute_filter_size(
Expand All @@ -92,6 +93,7 @@ def run_pipeline(self, **pipeline_kws):
if pipeline_kws["invert_phase"]:
field.imag *= -1

field = revert_to_data_input_format(self.fft.data_format, field)
self._field = field
self._phase = None
self._amplitude = None
Expand All @@ -101,7 +103,7 @@ def run_pipeline(self, **pipeline_kws):


def find_peak_cosine(ft_data, copy=True):
"""Find the side band position of a regular off-axis hologram
"""Find the side band position of a 2d regular off-axis hologram
The Fourier transform of a cosine function (known as the
striped fringe pattern in off-axis holography) results in
Expand Down
Loading

0 comments on commit e9a623a

Please sign in to comment.