Skip to content

Commit

Permalink
rework to use negative sig_thresh values
Browse files Browse the repository at this point in the history
  • Loading branch information
PyxieLouStar committed Feb 1, 2024
1 parent 4209b36 commit f40ee33
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 21 deletions.
45 changes: 24 additions & 21 deletions SSINS/match_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,38 +138,41 @@ def match_test(self, INS):
shape_max = None
for shape in self.slice_dict:
if shape == 'narrow':
t, f, p = np.unravel_index(np.absolute(INS.metric_ms).argmax(),
if self.sig_thresh[shape] < 0:
t, f, p = np.unravel_index((INS.metric_ms).argmin(),
INS.metric_ms.shape)
sig = np.absolute(INS.metric_ms[t, f, p])
sig = INS.metric_ms[t, f, p]
if sig < 0:
sig = np.absolute(sig)
else:
continue
else:
t, f, p = np.unravel_index(np.absolute(INS.metric_ms).argmax(),
INS.metric_ms.shape)
sig = np.absolute(INS.metric_ms[t, f, p])
t = slice(t, t + 1)
f = slice(f, f + 1)
elif shape == 'center_packet_loss':
N = np.count_nonzero(np.logical_not(INS.metric_ms[:, self.slice_dict[shape]].mask),
axis=1)
sliced_arr = (INS.metric_ms[:, self.slice_dict[shape]].mean(axis=1)) * np.sqrt(N)
t, p = np.unravel_index((sliced_arr / self.sig_thresh[shape]).argmin(),
sliced_arr.shape)
t = slice(t, t + 1)
f = self.slice_dict[shape]
# Pull out the number instead of a sliced arr
sig = sliced_arr[t, p][0]
if sig < 0:
print('found packet loss')
print(sig)
sig = np.absolute(sig)
else:
continue
else:
N = np.count_nonzero(np.logical_not(INS.metric_ms[:, self.slice_dict[shape]].mask),
axis=1)
sliced_arr = np.absolute(INS.metric_ms[:, self.slice_dict[shape]].mean(axis=1)) * np.sqrt(N)
t, p = np.unravel_index((sliced_arr / self.sig_thresh[shape]).argmax(),
if self.sig_thresh[shape] < 0:
sliced_arr = (INS.metric_ms[:, self.slice_dict[shape]].mean(axis=1)) * np.sqrt(N)
t, p = np.unravel_index((sliced_arr / np.abs(self.sig_thresh[shape])).argmin(),
sliced_arr.shape)
else:
sliced_arr = np.absolute(INS.metric_ms[:, self.slice_dict[shape]].mean(axis=1)) * np.sqrt(N)
t, p = np.unravel_index((sliced_arr / self.sig_thresh[shape]).argmax(),
sliced_arr.shape)
t = slice(t, t + 1)
f = self.slice_dict[shape]
# Pull out the number instead of a sliced arr
sig = sliced_arr[t, p][0]
if sig > self.sig_thresh[shape]:
if self.sig_thresh[shape] < 0:
if sig < 0:
sig = np.absolute(sig)
else:
continue
if sig > np.absolute(self.sig_thresh[shape]):
if sig > sig_max:
t_max, f_max, shape_max, sig_max = (t, f, shape, sig)

Expand Down
67 changes: 67 additions & 0 deletions SSINS/tests/test_MF.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,70 @@ def test_MF_write(tmp_path):

with pytest.raises(ValueError, match="matchfilter file with prefix"):
mf.write(f"{prefix}_test", clobber=False)

def test_negative_sig_thresh():
obs = '1061313128_99bl_1pol_half_time'
insfile = os.path.join(DATA_PATH, '%s_SSINS.h5' % obs)

ins = INS(insfile)

# Mock a simple metric_array and freq_array
ins.metric_array[:] = np.ones_like(ins.metric_array)
ins.weights_array = np.copy(ins.metric_array)
ins.weights_square_array = np.copy(ins.weights_array)

# Make a shape dictionary for a shape that will be injected later
ch_wid = ins.freq_array[1] - ins.freq_array[0]
shape = [ins.freq_array[127] - 0.2 * ch_wid, ins.freq_array[255] + 0.2 * ch_wid]
shape_dict = {'neg_shape': shape}
sig_thresh = {'neg_shape': -5, 'narrow': 5, 'streak': 5}
mf = MF(ins.freq_array, sig_thresh, shape_dict=shape_dict)

# Inject a packet loss event and a streak event
ins.metric_array[7, 127:256] = -10
ins.metric_array[13, 127:256] = 10
ins.metric_ms = ins.mean_subtract()

mf.apply_match_test(ins, event_record=True)

# Check that the right events are flagged
test_mask = np.zeros(ins.metric_array.shape, dtype=bool)
test_mask[7, 127:256] = 1
test_mask[13, :] = 1

assert np.all(test_mask == ins.metric_array.mask), "Flags are incorrect"

test_match_events_slc = [(slice(7, 8, None), slice(127, 256, None), 'neg_shape'),
(slice(13, 14, None), slice(0, 384, None), 'streak')]

for i, event in enumerate(test_match_events_slc):
assert ins.match_events[i][:-1] == test_match_events_slc[i], f"{i}th event is wrong"

def test_all_negative_sig_thresh():
obs = '1061313128_99bl_1pol_half_time'
insfile = os.path.join(DATA_PATH, '%s_SSINS.h5' % obs)

ins = INS(insfile)

# Mock a simple metric_array and freq_array
ins.metric_array[:] = np.ones_like(ins.metric_array)
ins.weights_array = np.copy(ins.metric_array)
ins.weights_square_array = np.copy(ins.weights_array)

# Make a shape dictionary
ch_wid = ins.freq_array[1] - ins.freq_array[0]
shape = [ins.freq_array[127] - 0.2 * ch_wid, ins.freq_array[255] + 0.2 * ch_wid]
shape_dict = {'neg_shape': shape}
sig_thresh = {'neg_shape': -5, 'narrow': -5, 'streak': -5}
mf = MF(ins.freq_array, sig_thresh, shape_dict=shape_dict)

# Inject some positive rfi events
ins.metric_array[13, 299] = 10
ins.metric_array[7, :125] = 10
ins.metric_ms = ins.mean_subtract()

mf.apply_match_test(ins, event_record=True)

# No events should be flagged
assert np.all(~ins.metric_array.mask), "Flags are incorrect"
assert not ins.match_events, "Match events are incorrect"

0 comments on commit f40ee33

Please sign in to comment.