Skip to content

Commit

Permalink
Merge pull request #921 from HERA-Team/fix-inpaint-nonfiring
Browse files Browse the repository at this point in the history
Fix inpaint nonfiring
  • Loading branch information
steven-murray authored Nov 29, 2023
2 parents d89aaff + 7de3c49 commit 7005da8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 16 deletions.
58 changes: 49 additions & 9 deletions hera_cal/lstbin_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,9 @@ def reduce_lst_bins(

if d.size:
d, f = get_masked_data(
d, n, f, inpainted=inpf, inpainted_mode=inpainted_mode
d, n, f, inpainted=inpf,
inpainted_mode=inpainted_mode, flag_thresh=flag_thresh
)
f = threshold_flags(f, inplace=True, flag_thresh=flag_thresh)
d.mask |= f

(
out_data[:, lstbin],
Expand Down Expand Up @@ -417,6 +416,7 @@ def get_masked_data(
flags: np.ndarray,
inpainted: np.ndarray | None = None,
inpainted_mode: bool = False,
flag_thresh: float = 0.7,
) -> np.ma.MaskedArray:
if not inpainted_mode:
# Act like nothing is inpainted.
Expand All @@ -429,6 +429,12 @@ def get_masked_data(
inpainted = flags.copy() * (~allf)

flags = flags | np.isnan(data) | np.isinf(data) | (nsamples == 0)

# Threshold flags over time here, because we want the new flags to be treated on the
# same footing as the inpainted flags.
threshold_flags(flags, inplace=True, flag_thresh=flag_thresh)

logger.info(f"In inpainted_mode: {inpainted_mode}. Got {np.sum(inpainted)} inpainted samples, {np.sum(flags)} total flags, {np.sum(flags & ~inpainted)} non-inpainted flags.")
data = np.ma.masked_array(data, mask=(flags & ~inpainted))
return data, flags

Expand All @@ -453,6 +459,31 @@ def threshold_flags(
inplace: bool = False,
flag_thresh: float = 0.7,
):
"""
Thresholds the input flags array based on the flag fraction.
Parameters
----------
flags : numpy.ndarray
A numpy array of shape (Nnights, ...) representing the flags.
inplace : bool, optional
If True, modifies the input flags array in place. If False, creates a copy of
the flags array.
flag_thresh : float, optional
The threshold value for the flag fraction.
Returns
-------
numpy.ndarray
A numpy array of shape (N, ...) with the thresholded flags.
Examples
--------
>>> flags = np.array([[True, False, True], [False, True, False]])
>>> threshold_flags(flags, inplace=True, flag_thresh=0.5)
array([[ True, False, True],
[False, True, False]])
"""
if not inplace:
flags = flags.copy()

Expand Down Expand Up @@ -568,7 +599,12 @@ def lst_average(

nsamples = np.ma.masked_array(nsamples, mask=data.mask)

# Norm is the total number of samples over the nights. In the in-painted case,
# it *counts* in-painted data as samples. In the non-inpainted case, it does not.
norm = np.sum(nsamples, axis=0)

# Ndays binned is the number of days that count towards the mean. This is the same
# in in-painted and flagged mode.
ndays_binned = np.sum((~flags).astype(int), axis=0)

logger.info("Calculating mean")
Expand All @@ -580,10 +616,18 @@ def lst_average(
lstbin_flagged[ndays_binned < sigma_clip_min_N] = True

normalizable = norm > 0

meandata[normalizable] /= norm[normalizable]
# Multiply by nan instead of just setting as nan, so both real and imag parts are nan
meandata[~normalizable] *= np.nan

# While the previous nsamples is different for in-painted and flagged mode, which is
# what we want for the mean, for the std and nsamples we want to treat flags as really
# flagged.
nsamples.mask = flags
norm = np.sum(nsamples, axis=0)
normalizable = norm > 0

# get other stats
logger.info("Calculating std")
with warnings.catch_warnings():
Expand All @@ -597,12 +641,8 @@ def lst_average(

std[~normalizable] = np.inf

# While the previous norm is correct for normalizing the mean, we now
# calculate nsamples as the unflagged samples in each LST bin.
nsamples.mask = flags
nsamples = np.sum(nsamples, axis=0)

return meandata.data, lstbin_flagged, std.data, nsamples.data, ndays_binned
logger.info(f"Mean of meandata: {np.mean(meandata)}. Mean of std: {np.mean(std)}. Total nsamples: {np.sum(norm)}")
return meandata.data, lstbin_flagged, std.data, norm.data, ndays_binned


def adjust_lst_bin_edges(lst_bin_edges: np.ndarray) -> np.ndarray:
Expand Down
12 changes: 5 additions & 7 deletions hera_cal/tests/test_lstbin_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,11 @@ def test_flag_below_min_N(self):
)

assert np.all(flg_n)
# just because we're flagging it, doesn't mean we need to set nsamples=0
# or the std to inf. We have info there, we're just choosing not to use it.
assert np.all(norm_n == 2)
assert not np.any(np.isinf(std_n))
# nsamples is zero because all are flagged.
assert np.all(norm_n == 0)
assert np.all(np.isinf(std_n))

# this time, only one column is flagged too much...
# this time, there's enough samples, but too many are flagged...
flags[:] = False
flags[:5, 0] = True
data_n, flg_n, std_n, norm_n, db = lstbin_simple.lst_average(
Expand All @@ -642,8 +640,8 @@ def test_flag_below_min_N(self):
)

assert np.all(flg_n[0])
assert np.all(norm_n[0] == 2)
assert not np.any(np.isinf(std_n[0]))
assert np.all(norm_n[0] == 0)
assert np.all(np.isinf(std_n[0]))

assert not np.any(flg_n[1:])
assert np.all(norm_n[1:] == 7)
Expand Down

0 comments on commit 7005da8

Please sign in to comment.