Skip to content

Commit

Permalink
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
Browse files Browse the repository at this point in the history
…rax into set_default_as_mergesort
  • Loading branch information
yuema137 committed Nov 13, 2024
2 parents 507a6cd + a9006f1 commit 93e4b5b
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions strax/processing/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
@register_jitable
def _compute_hdr_core(data, fractions_desired, only_upper_part=False, _buffer_size=10):
"""Core computation for highest density region.
Returns the data needed for interval computation and the result arrays.
"""
fi = 0 # number of fractions seen
res = np.zeros((len(fractions_desired), 2, _buffer_size), dtype=np.int32)
Expand All @@ -29,49 +31,54 @@ def _compute_hdr_core(data, fractions_desired, only_upper_part=False, _buffer_si
max_to_min = stable_argsort(data)[::-1]
return max_to_min, area_tot, res, res_amp, fi


@export
def _process_hdr_intervals(ind, gaps, fi, res, g0, _buffer_size):
"""Process the intervals for highest density region.
This function handles the stable sorting part outside of numba.
"""
if len(gaps) > _buffer_size:
res[fi, 0, :] = -1
res[fi, 1, :] = -1
return fi + 1, res

g_ind = -1
for g_ind, g in enumerate(gaps):
interval = ind[g0:g]
res[fi, 0, g_ind] = interval[0]
res[fi, 1, g_ind] = interval[-1] + 1
g0 = g

# Last interval
interval = ind[g0:]
res[fi, 0, g_ind + 1] = interval[0]
res[fi, 1, g_ind + 1] = interval[-1] + 1
return fi + 1, res


@export
@register_jitable
def highest_density_region(data, fractions_desired, only_upper_part=False, _buffer_size=10):
"""Computes for a given sampled distribution the highest density region of the desired
fractions. Does not assume anything on the normalisation of the data.
:param data: Sampled distribution
:param fractions_desired: numpy.array Area/probability for which
the hdr should be computed.
:param _buffer_size: Size of the result buffer. The size is
equivalent to the maximal number of allowed intervals.
:param only_upper_part: Boolean, if true only computes
area/probability between maximum and current height.
:return: two arrays: The first one stores the start and inclusive
endindex of the highest density region. The second array holds
the amplitude for which the desired fraction was reached.
:param fractions_desired: numpy.array Area/probability for which the hdr should be computed.
:param _buffer_size: Size of the result buffer. The size is equivalent to the maximal number of
allowed intervals.
:param only_upper_part: Boolean, if true only computes area/probability between maximum and
current height.
:return: two arrays: The first one stores the start and inclusive endindex of the highest
density region. The second array holds the amplitude for which the desired fraction was
reached.
"""
max_to_min, area_tot, res, res_amp, fi = _compute_hdr_core(
data, fractions_desired, only_upper_part, _buffer_size)

data, fractions_desired, only_upper_part, _buffer_size
)

lowest_sample_seen = np.inf
for j in range(1, len(data)):
if lowest_sample_seen == data[max_to_min[j]]:
Expand All @@ -92,15 +99,15 @@ def highest_density_region(data, fractions_desired, only_upper_part=False, _buff
res_amp[fi] = true_height

# This part needs stable_sort - switch to object mode
with numba.objmode(ind='int64[:]'):
with numba.objmode(ind="int64[:]"):
ind = stable_sort(max_to_min[:j])

gaps = np.arange(1, len(ind) + 1)
diff = ind[1:] - ind[:-1]
gaps = gaps[:-1][diff > 1]

# Process intervals outside numba
with numba.objmode(fi='int64', res='int32[:, :, :]'):
with numba.objmode(fi="int64", res="int32[:, :, :]"):
fi, res = _process_hdr_intervals(ind, gaps, fi, res, 0, _buffer_size)

if fi == len(fractions_desired):
Expand All @@ -112,5 +119,5 @@ def highest_density_region(data, fractions_desired, only_upper_part=False, _buff
res[fi:, 1, 0] = len(data)
for ind, fraction_desired in enumerate(fractions_desired[fi:]):
res_amp[fi + ind] = (1 - fraction_desired) * np.sum(data) / len(data)
return res, res_amp

return res, res_amp

0 comments on commit 93e4b5b

Please sign in to comment.