Skip to content

Commit

Permalink
clean up and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
talonchandler committed Jul 11, 2023
1 parent 372eda3 commit 2b50aea
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 65 deletions.
6 changes: 4 additions & 2 deletions tests/test_focus_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ def test_focus_estimator_snr(tmp_path):
phantom * np.sqrt(snr), size=phantom.shape
) + np.random.normal(loc=0, scale=3, size=phantom.shape)

plot_path = tmp_path / f"test-{snr}.pdf"
slice = focus.focus_from_transverse_band(
data,
ps,
lambda_ill,
NA_det,
plot_path=f"./test-{snr}.pdf",#tmp_path /
peak_width_threshold=5,
plot_path=plot_path,
threshold_FWHM=5,
)
assert plot_path.exists()
if slice is not None:
assert np.abs(slice - 10) <= 2
141 changes: 78 additions & 63 deletions waveorder/focus.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def focus_from_transverse_band(
midband_fractions=(0.125, 0.25),
mode: Literal["min" "max"] = "max",
plot_path: Optional[str] = None,
peak_width_threshold: float = 0,
threshold_FWHM: float = 0,
):
"""Estimates the in-focus slice from a 3D stack by optimizing a transverse spatial frequency band.
Expand All @@ -38,10 +38,10 @@ def focus_from_transverse_band(
plot_path: str or None, optional
File name for a diagnostic plot (supports matplotlib filetypes .png, .pdf, .svg, etc.).
Use None to skip.
peak_width_threshold: float, optional
Threshold width for a peak to be considered in focus.
The default value, 0, applies no threshold, and the maximum midband power is considered in focus.
For values >0, the peak's FWHM must be greater than the threshold for the slice to be considered in focus.
threshold_FWHM: float, optional
Threshold full-width half max for a peak to be considered in focus.
The default value, 0, applies no threshold, and the maximum midband power is always considered in focus.
For values > 0, the peak's FWHM must be greater than the threshold for the slice to be considered in focus.
If the peak does not meet this threshold, the function returns None.
Returns
Expand All @@ -60,8 +60,48 @@ def focus_from_transverse_band(
>>> slice = focus_from_transverse_band(zyx_array, NA_det=0.55, lambda_ill=0.532, pixel_size=6.5/20)
>>> in_focus_data = data[slice,:,:]
"""
minmaxfunc = _check_focus_inputs(
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions, mode
)

# Calculate coordinates
_, Y, X = zyx_array.shape
_, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size)
frr = np.sqrt(fxx**2 + fyy**2)

# Calculate fft
xy_abs_fft = np.abs(np.fft.fftn(zyx_array, axes=(1, 2)))

# Calculate midband mask
cutoff = 2 * NA_det / lambda_ill
midband_mask = np.logical_and(
frr > cutoff * midband_fractions[0],
frr < cutoff * midband_fractions[1],
)

# Find slice index with min/max power in midband
midband_sum = np.sum(xy_abs_fft[:, midband_mask], axis=1)
peak_index = minmaxfunc(midband_sum)

peak_results = peak_widths(midband_sum, [peak_index])
FWHM = peak_results[0][0]
if FWHM > threshold_FWHM:
in_focus_index = peak_index
else:
in_focus_index = None

# Plot
if plot_path is not None:
_plot_focus_metric(
plot_path, midband_sum, peak_index, in_focus_index, peak_results, threshold_FWHM
)

return in_focus_index


# Check inputs
def _check_focus_inputs(
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions, mode
):
N = len(zyx_array.shape)
if N != 3:
raise ValueError(
Expand Down Expand Up @@ -99,65 +139,40 @@ def focus_from_transverse_band(
minmaxfunc = np.argmax
else:
raise ValueError("mode must be either `min` or `max`")
return minmaxfunc

# Calculate coordinates
_, Y, X = zyx_array.shape
_, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size)
frr = np.sqrt(fxx**2 + fyy**2)

# Calculate fft
xy_abs_fft = np.abs(np.fft.fftn(zyx_array, axes=(1, 2)))

# Calculate midband mask
cutoff = 2 * NA_det / lambda_ill
midband_mask = np.logical_and(
frr > cutoff * midband_fractions[0],
frr < cutoff * midband_fractions[1],
def _plot_focus_metric(
plot_path, midband_sum, peak_index, in_focus_index, peak_results, threshold_FWHM
):
_, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.plot(midband_sum, "-k")
ax.plot(
peak_index,
midband_sum[peak_index],
"go" if in_focus_index is not None else "ro",
)
ax.hlines(*peak_results[1:], color="k", linestyles="dashed")

ax.set_xlabel("Slice index")
ax.set_ylabel("Midband power")

ax.annotate(
f"In-focus slice = {in_focus_index}\n Peak width = {peak_results[0][0]:.2f}\n Peak width threshold = {threshold_FWHM}",
xy=(1, 1),
xytext=(1.0, 1.1),
textcoords="axes fraction",
xycoords="axes fraction",
ha="right",
va="center",
annotation_clip=False,
)

# Find slice index with min/max power in midband
midband_sum = np.sum(xy_abs_fft[:, midband_mask], axis=1)
peak_index = minmaxfunc(midband_sum)

peak_results = peak_widths(midband_sum, [peak_index])
width = peak_results[0][0]
if width > peak_width_threshold:
in_focus_index = peak_index
else:
in_focus_index = None

# Plot
if plot_path is not None:
_, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.plot(midband_sum, "-k")
ax.plot(
peak_index,
midband_sum[peak_index],
"go" if in_focus_index is not None else "ro",
)
ax.hlines(*peak_results[1:], color="k", linestyles="dashed")

ax.set_xlabel("Slice index")
ax.set_ylabel("Midband power")

ax.annotate(
f"In-focus slice = {in_focus_index}\n Peak width = {width:.2f}\n Peak width threshold = {peak_width_threshold}",
xy=(1, 1),
xytext=(1.0, 1.1),
textcoords="axes fraction",
xycoords="axes fraction",
ha="right",
va="center",
annotation_clip=False,
)

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_position(("outward", 10))
ax.spines["bottom"].set_position(("outward", 10))
ax.ticklabel_format(style="sci", scilimits=(-2, 2))

print(f"Saving plot to {plot_path}")
plt.savefig(plot_path, bbox_inches="tight", dpi=300)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_position(("outward", 10))
ax.spines["bottom"].set_position(("outward", 10))
ax.ticklabel_format(style="sci", scilimits=(-2, 2))

return in_focus_index
print(f"Saving plot to {plot_path}")
plt.savefig(plot_path, bbox_inches="tight", dpi=300)

0 comments on commit 2b50aea

Please sign in to comment.