Skip to content

Commit

Permalink
perf: move things into so3g and other speedups
Browse files Browse the repository at this point in the history
* so3g block_moment
* matched filter in so3g
* 2pi jumps in so3g
* Faster diff buffed
* Switch to mean sub
* Do more things inplace
* Remove unnneeded std check
  • Loading branch information
skhrg committed Aug 27, 2024
1 parent f080944 commit ce75039
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 94 deletions.
170 changes: 76 additions & 94 deletions sotodlib/tod_ops/jumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import numpy as np
import scipy.ndimage as simg
import scipy.signal as sig
import scipy.stats as ss
from numpy.typing import NDArray
from pixell.utils import block_expand, block_mean_filter, block_reduce
from pixell.utils import block_expand, block_reduce, moveaxis
from scipy.sparse import csr_array
from skimage.restoration import denoise_tv_chambolle
from so3g import matched_jumps, matched_jumps64, clean_flag, find_quantized_jumps, find_quantized_jumps64
from so3g.proj import Ranges, RangesMatrix
from sotodlib.core import AxisManager

Expand All @@ -21,6 +21,7 @@
def std_est(
x: NDArray[np.floating],
ds: int = 1,
win_size: int = 20,
axis: int = -1,
method: str = "median_unbiased",
) -> NDArray[np.floating]:
Expand All @@ -32,7 +33,9 @@ def std_est(
x: Data to compute standard deviation of.
ds: Downsample factor to use, does a naive slicing.
ds: Downsample factor to use, does a naive slicing in blocks of ``win_size``.
win_size: Window size to downsample by.
axis: The axis to compute along.
Expand All @@ -44,12 +47,22 @@ def std_est(
"""
if ds > 2 * x.shape[axis]:
ds = 1
sl = [slice(None)] * len(x.shape)
if ds > 1:
sl[axis] = slice(None, None, ds)
x = np.moveaxis(x, axis, -1)
x = x[..., : -1 * (x.shape[-1] % win_size)]
shape = list(x.shape) + [win_size]
shape[-2] = -1
x = x.reshape(tuple(shape))
x = np.moveaxis(x, -2, 0)
diff = np.diff(x[::ds], axis=-1)
diff = moveaxis(diff, 0, -2)
diff = diff.reshape(shape[:-1])
diff = np.moveaxis(diff, -1, axis)
else:
diff = np.diff(x, axis=axis)
# Find ~1 sigma limits of differenced data
lims = np.quantile(
np.diff(x, axis=axis)[tuple(sl)],
diff,
np.array([0.159, 0.841]),
axis=axis,
method=method,
Expand Down Expand Up @@ -90,73 +103,45 @@ def _jumpfinder(
# and in the case of 2d data we find jumps along rows
orig_shape = x.shape
x = np.atleast_2d(x)
dtype = x.dtype.name
if len(x.shape) > 2:
raise ValueError("x may not have more than 2 dimensions")
if dtype == "float32":
matched_filt = matched_jumps
elif dtype == "float64":
matched_filt = matched_jumps64
else:
raise TypeError("x must be float32 or float64")

jumps = np.zeros(x.shape, dtype=bool)
jumps = np.zeros(x.shape, dtype=bool, order="C")
if x.shape[-1] < win_size:
return jumps.reshape(orig_shape)

size_msk = (np.max(x, axis=-1) - np.min(x, axis=-1)) < min_size
if np.all(size_msk):
return jumps.reshape(orig_shape)

# If std is basically 0 no need to check for jumps
std = np.std(x, axis=-1)
std_msk = np.isclose(std, 0.0) + np.isclose(std_est(x, ds=win_size, axis=-1), std)

msk = ~(size_msk + std_msk)
msk = np.ptp(x, axis=-1) > min_size
if not np.any(msk):
return jumps.reshape(orig_shape)

# Build a mean filter
# Flag with a matched filter
win_size += win_size % 2 # Odd win size adds a wierd phasing issue
half_win = int(win_size / 2)
x_br = block_mean_filter(x[msk], win_size)
diff = x[msk] - x_br
# Take cumulative sum, this is equivalent to convolving with a step
x_step = np.abs(np.cumsum(diff, axis=-1))

# If the jump is at a multiple of win_size we will miss it so also do a shifted filter
x_br_shift = block_mean_filter(x[msk, half_win:], win_size)
x_step[:, half_win:] = np.maximum(
x_step[:, half_win:],
np.abs(np.cumsum(x[msk, half_win:] - x_br_shift, axis=-1)),
) # TODO: Is there something better than using the max for this?

# Because of the shift the closest to a window edge we can be in win_size/4
# In this case the slope of the shorter segment is ~3*height/4
# So the peaks should be at least (3*win_size*min_size)/16
# We want to include peaks of that size so we use a denominator of 32
# Note that after this filtering we are left with at least win_size/4 width
_x = np.ascontiguousarray(x[msk])
_jumps = np.ascontiguousarray(np.empty_like(_x), "int32")
if isinstance(min_size, np.ndarray):
_min_size = (3 * win_size * min_size / 32)[..., None]
_min_size = min_size[msk].astype(_x.dtype)
elif min_size is None:
raise TypeError("min_size is None")
else:
_min_size = (3 * win_size * min_size / 32) * np.ones((1, len(x)))

peak_msk = np.zeros_like(jumps, dtype=bool)
peak_msk[msk] = x_step > _min_size[msk]
has_peaks = np.any(peak_msk, -1)
if not np.any(has_peaks):
return jumps.reshape(orig_shape)

quarter_win = int(half_win / 2)
# This is equivalent to this convolution
# jumps[has_peaks] = (
# sig.fftconvolve(np.ones((1, half_win)), peak_msk[has_peaks], axes=-1)[
# :, : -1 * (half_win - 1)
# ]
# > quarter_win
# )
peak_msk = peak_msk.astype(float)
peak_msk[has_peaks, half_win:] -= peak_msk[has_peaks, : -1 * half_win]
jumps[has_peaks] = np.cumsum(peak_msk[has_peaks], axis=-1) > quarter_win

# Recall that we set _min_size to be half the actual peak min above
jumps[has_peaks] *= x_step[has_peaks[msk]] >= 2 * _min_size[has_peaks]
_min_size = (min_size * np.ones(len(_x))).astype(_x.dtype)
matched_filt(_x, _jumps, _min_size, win_size)
jumps[msk] = _jumps > 0

if exact:
structure = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]])
labels, _ = simg.label(jumps, structure)
peak_idx = np.array(simg.maximum_position(x_step, labels))
peak_idx = np.array(
simg.maximum_position(
np.diff(_x, axis=-1, prepend=np.zeros(len(_x))), labels
)
)
jump_rows = [peak_idx[:, 0]]
jump_cols = peak_idx[:, 1]
jumps[:] = False
Expand Down Expand Up @@ -254,13 +239,15 @@ def _diff_buffed(
make_step: bool,
) -> NDArray[np.floating]:
win_size = int(win_size + win_size % 2)
pad = np.zeros((len(signal.shape), 2), dtype=int)
half_win = int(win_size / 2)
pad[-1, :] = half_win
if jumps is not None and make_step:
signal = _make_step(signal, jumps)
padded = np.pad(signal, pad, mode="edge")
diff_buffed = padded[..., win_size:] - padded[..., : (-1 * win_size)]
diff_buffed = np.empty_like(signal)
diff_buffed[..., :win_size] = 0
diff_buffed[..., win_size:] = np.subtract(
signal[..., win_size:],
signal[..., : (-1 * win_size)],
out=diff_buffed[..., win_size:],
)

return diff_buffed

Expand Down Expand Up @@ -431,34 +418,32 @@ def twopi_jumps(
if not isinstance(signal, np.ndarray):
raise TypeError("Signal is not an array")
if atol is None:
atol = nsigma * std_est(signal.astype(float), ds=win_size)
atol = nsigma * std_est(
signal.astype(float), ds=win_size * 10, win_size=win_size
)
np.clip(atol, 1e-8, 1e-2)

_signal = _filter(signal, **filter_pars)
diff_buffed = _diff_buffed(_signal, None, win_size, False)

if isinstance(atol, int):
atol = float(atol)
if isinstance(atol, float):
ratio = np.abs(diff_buffed) / (2 * np.pi)
jumps = (np.abs(ratio - np.round(ratio, 0)) <= atol) & (ratio >= 0.5)
jumps[..., :win_size] = False
elif isinstance(atol, np.ndarray):
jumps = np.atleast_2d(np.zeros_like(signal, dtype=bool))
diff_buffed = np.atleast_2d(diff_buffed)
if len(atol) != len(jumps):
raise ValueError(f"Non-scalar atol provided with length {len(atol)}")
ratio = np.abs(diff_buffed / (2 * np.pi))
jumps = (np.abs(ratio - np.round(ratio, 0)) <= atol[..., None]) & (ratio >= 0.5)
jumps.reshape(signal.shape)
else:
_signal = np.atleast_2d(_signal)
if isinstance(atol, int) or isinstance(atol, float):
atol = np.ones(len(_signal), float) * float(atol)
elif np.isscalar(atol):
raise TypeError(f"Invalid atol type: {type(atol)}")
if len(atol) != len(signal):
raise ValueError(f"Non-scalar atol provided with length {len(atol)}")

_signal = np.ascontiguousarray(_signal)
heights = np.empty_like(_signal)
atol = np.ascontiguousarray(atol, dtype=_signal.dtype)
if _signal.dtype.name == "float32":
find_quantized_jumps(_signal, heights, atol, win_size, 2 * np.pi)
elif _signal.dtype.name == "float64":
find_quantized_jumps64(_signal, heights, atol, win_size, 2 * np.pi)
else:
raise TypeError("signal must be float32 or float64")

jumps = heights != 0
jump_ranges = RangesMatrix.from_mask(jumps).buffer(int(win_size / 2))
jumps = jump_ranges.mask()
heights = estimate_heights(
signal, jumps, win_size=win_size, twopi=True, diff_buffed=diff_buffed
)

if merge:
_merge(aman, jump_ranges, name, overwrite)
Expand Down Expand Up @@ -595,7 +580,6 @@ def slow_jumps(
def find_jumps(
aman,
signal=...,
max_iters=...,
min_sigma=...,
min_size=...,
win_size=...,
Expand All @@ -614,7 +598,6 @@ def find_jumps(
def find_jumps(
aman,
signal=...,
max_iters=...,
min_sigma=...,
min_size=...,
win_size=...,
Expand All @@ -632,7 +615,6 @@ def find_jumps(
def find_jumps(
aman: AxisManager,
signal: Optional[NDArray[np.floating]] = None,
max_iters: int = 1,
min_sigma: Optional[float] = None,
min_size: Optional[Union[float, NDArray[np.floating]]] = None,
win_size: int = 20,
Expand Down Expand Up @@ -706,7 +688,9 @@ def find_jumps(
raise ValueError("Jumpfinder only works on 1D or 2D data")

if min_size is None and min_sigma is not None:
min_size = min_sigma * std_est(signal, ds=win_size, axis=-1)
min_size = min_sigma * std_est(
signal, ds=win_size * 10, win_size=win_size, axis=-1
)
if min_size is None:
raise ValueError("min_size is somehow still None")
if isinstance(min_size, np.ndarray) and np.ndim(min_size) > 1: # type: ignore
Expand All @@ -715,12 +699,10 @@ def find_jumps(
min_size = float(min_size) * np.ones(len(signal))

_signal = _filter(signal, **filter_pars)
if max_iters > 1:
_signal = signal.copy()
_signal = np.atleast_2d(_signal)
# Median subtract, if we don't do this then when we cumsum we get floats
# Mean subtract, if we don't do this then when we cumsum we get floats
# that are too big and lack the precicion to find jumps well
_signal -= np.median(_signal, axis=-1)[..., None]
_signal -= np.mean(_signal, axis=-1)[..., None]

nfuture = min(len(_signal), NFUTURE)
slice_size = len(_signal) // nfuture
Expand Down
73 changes: 73 additions & 0 deletions sotodlib/tod_ops/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Generically useful utility functions.
"""
from typing import Optional
import numpy as np
from numpy.typing import NDArray
from so3g import block_moment, block_moment64


def get_block_moment(
tod: NDArray[np.floating],
block_size: int,
moment: int = 1,
central: bool = True,
shift: int = 0,
output: Optional[NDArray[np.floating]] = None,
) -> NDArray[np.floating]:
"""
Compute the n'th moment of data in blocks along each row.
Note that the blocks are made to be exclusive,
so any samples left at the end will be in a smaller standalone block.
This is a wrapper around ``so3g.block_moment``.
Arguments:
tod: Data to compute the moment of.
Should be (ndet, nsamp) or (nsamp).
Must be float32 or float64.
block_size: Size of block to use.
moment: Which moment to compute.
Must be >= 1.
central: If True compute the mean centered moment.
shift: Sample to start the blocks at, will be 0 before this.
output: Array to put the blocked moment into.
If provided must be the same shape as tod.
If None, will be intialized from tod.
Returns:
block_moment: The blocked moment.
Will have the same shape as tod.
If output is provided it is modified in place and retured here.
"""
if not np.any(np.isfinite(tod)):
raise ValueError("Only finite values allowed in tod")
orig_shape = tod.shape
dtype = tod.dtype.name
tod = np.atleast_2d(tod)
if len(tod.shape) > 2:
raise ValueError("tod may not have more than 2 dimensions")
if dtype not in ["float32", "float64"]:
raise TypeError("tod must be float32 or float64")

if output is None:
output = np.ascontiguousarray(np.empty_like(tod))
if output.shape != tod.shape:
raise ValueError("output shape does not match tod")
if output.dtype.name != dtype:
raise TypeError("output type does not match tod")

if moment < 1:
raise ValueError("moment must be at least 1")

if dtype == "float32":
block_moment(tod, output, block_size, moment, central, shift)
else:
block_moment64(tod, output, block_size, moment, central, shift)

return output.reshape(orig_shape)

0 comments on commit ce75039

Please sign in to comment.