Skip to content

Commit

Permalink
fix bug finding focus in stack with only one slice (#162)
Browse files Browse the repository at this point in the history
* fix bug finding focus in stack with only one slice

* refactor for clarify

* formatting

* print -> warnings.warn

* test single-slice case

* fix test bugs

---------

Co-authored-by: Talon Chandler <[email protected]>
  • Loading branch information
ieivanov and talonchandler authored Mar 22, 2024
1 parent af3afeb commit d08f296
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 21 deletions.
15 changes: 12 additions & 3 deletions tests/test_focus_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,21 @@ def test_focus_estimator(tmp_path):
plot_path = tmp_path.joinpath("test.pdf")
data3D = np.random.random((11, 256, 256))
slice = focus.focus_from_transverse_band(
data3D, ps, lambda_ill, NA_det, plot_path=str(plot_path)
data3D, NA_det, lambda_ill, ps, plot_path=str(plot_path)
)
assert slice >= 0
assert slice <= data3D.shape[0]
assert plot_path.exists()

# Check single slice
slice = focus.focus_from_transverse_band(
np.random.random((1, 10, 10)),
NA_det,
lambda_ill,
ps,
)
assert slice == 0


def test_focus_estimator_snr(tmp_path):
ps = 6.5 / 100
Expand All @@ -66,9 +75,9 @@ def test_focus_estimator_snr(tmp_path):
plot_path = tmp_path / f"test-{snr}.pdf"
slice = focus.focus_from_transverse_band(
data,
ps,
lambda_ill,
NA_det,
lambda_ill,
ps,
plot_path=plot_path,
threshold_FWHM=5,
)
Expand Down
54 changes: 36 additions & 18 deletions waveorder/focus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from waveorder import util
import matplotlib.pyplot as plt
import numpy as np
import warnings


def focus_from_transverse_band(
Expand Down Expand Up @@ -60,10 +61,19 @@ def focus_from_transverse_band(
>>> slice = focus_from_transverse_band(zyx_array, NA_det=0.55, lambda_ill=0.532, pixel_size=6.5/20)
>>> in_focus_data = data[slice,:,:]
"""
minmaxfunc = _check_focus_inputs(
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions, mode
minmaxfunc = _mode_to_minmaxfunc(mode)

_check_focus_inputs(
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
)

# Check for single slice
if zyx_array.shape[0] == 1:
warnings.warn(
"The dataset only contained a single slice. Returning trivial slice index = 0."
)
return 0

# Calculate coordinates
_, Y, X = zyx_array.shape
_, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size)
Expand Down Expand Up @@ -94,25 +104,35 @@ def focus_from_transverse_band(
# Plot
if plot_path is not None:
_plot_focus_metric(
plot_path, midband_sum, peak_index, in_focus_index, peak_results, threshold_FWHM
plot_path,
midband_sum,
peak_index,
in_focus_index,
peak_results,
threshold_FWHM,
)

return in_focus_index


def _mode_to_minmaxfunc(mode):
if mode == "min":
minmaxfunc = np.argmin
elif mode == "max":
minmaxfunc = np.argmax
else:
raise ValueError("mode must be either `min` or `max`")
return minmaxfunc


def _check_focus_inputs(
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions, mode
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
):
N = len(zyx_array.shape)
if N != 3:
raise ValueError(
f"{N}D array supplied. `focus_from_transverse_band` only accepts 3D arrays."
)
if zyx_array.shape[0] == 1:
print(
"WARNING: The dataset only contained a single slice. Returning trivial slice index = 0."
)
return 0

if NA_det < 0:
raise ValueError("NA must be > 0")
Expand All @@ -121,7 +141,7 @@ def _check_focus_inputs(
if pixel_size < 0:
raise ValueError("pixel_size must be > 0")
if not 0.4 < lambda_ill / pixel_size < 10:
print(
warnings.warn(
f"WARNING: lambda_ill/pixel_size = {lambda_ill/pixel_size}."
f"Did you use the same units?"
f"Did you enter the pixel size in (demagnified) object-space units?"
Expand All @@ -134,17 +154,15 @@ def _check_focus_inputs(
raise ValueError("midband_fractions[0] must be between 0 and 1")
if not (0 <= midband_fractions[1] <= 1):
raise ValueError("midband_fractions[1] must be between 0 and 1")
if mode == "min":
minmaxfunc = np.argmin
elif mode == "max":
minmaxfunc = np.argmax
else:
raise ValueError("mode must be either `min` or `max`")
return minmaxfunc


def _plot_focus_metric(
plot_path, midband_sum, peak_index, in_focus_index, peak_results, threshold_FWHM
plot_path,
midband_sum,
peak_index,
in_focus_index,
peak_results,
threshold_FWHM,
):
_, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.plot(midband_sum, "-k")
Expand Down

0 comments on commit d08f296

Please sign in to comment.