Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to save first samples of peak(lets) waveform #867

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
20 changes: 19 additions & 1 deletion strax/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ def hitlet_with_data_dtype(n_samples=2):


def peak_dtype(
n_channels=100, n_sum_wv_samples=200, n_widths=11, digitize_top=True, hits_timing=True
n_channels=100,
n_sum_wv_samples=200,
n_widths=11,
digitize_top=True,
hits_timing=True,
save_waveform_start=True,
):
"""Data type for peaks - ranges across all channels in a detector
Remember to set channel to -1 (todo: make enum)
Expand Down Expand Up @@ -206,6 +211,19 @@ def peak_dtype(
n_sum_wv_samples,
)
dtype.insert(9, top_field)

if save_waveform_start:
dtype += [
(
(
"Waveform data in PE/sample (not PE/ns!), first 200 not downsampled samples",
"data_start",
),
np.float32,
n_sum_wv_samples,
)
]

return dtype


Expand Down
27 changes: 23 additions & 4 deletions strax/processing/peak_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def find_peaks(
@export
@numba.jit(nopython=True, nogil=True, cache=True)
def store_downsampled_waveform(
p, wv_buffer, store_in_data_top=False, wv_buffer_top=np.ones(1, dtype=np.float32)
p,
wv_buffer,
store_in_data_top=False,
store_waveform_start=False,
wv_buffer_top=np.ones(1, dtype=np.float32),
):
"""Downsample the waveform in buffer and store it in p['data'] and in p['data_top'] if indicated
to do so.
Expand Down Expand Up @@ -170,6 +174,14 @@ def store_downsampled_waveform(
wv_buffer[: p["length"] * downsample_factor].reshape(-1, downsample_factor).sum(axis=1)
)
p["dt"] *= downsample_factor

# If the waveform is downsampled, we can store the first samples of the waveform
if store_waveform_start & (downsample_factor <= 6):
if p["length"] > len(p["data_start"]):
p["data_start"] = wv_buffer[: len(p["data_start"])]
else:
p["data_start"][: p["length"]] = wv_buffer[: p["length"]]

else:
if store_in_data_top:
p["data_top"][: p["length"]] = wv_buffer_top[: p["length"]]
Expand Down Expand Up @@ -229,7 +241,14 @@ def _simple_summed_waveform(records, containers, touching_windows, to_pe):
@export
@numba.jit(nopython=True, nogil=True, cache=True)
def sum_waveform(
peaks, hits, records, record_links, adc_to_pe, n_top_channels=0, select_peaks_indices=None
peaks,
hits,
records,
record_links,
adc_to_pe,
n_top_channels=0,
select_peaks_indices=None,
save_waveform_start=False,
):
"""Compute sum waveforms for all peaks in peaks. Only builds summed waveform other regions in
which hits were found. This is required to avoid any bias due to zero-padding and baselining.
Expand Down Expand Up @@ -357,9 +376,9 @@ def sum_waveform(
p["area"] += area_pe

if n_top_channels > 0:
store_downsampled_waveform(p, swv_buffer, True, twv_buffer)
store_downsampled_waveform(p, swv_buffer, True, save_waveform_start, twv_buffer)
else:
store_downsampled_waveform(p, swv_buffer)
store_downsampled_waveform(p, swv_buffer, False, save_waveform_start)

p["n_saturated_channels"] = p["saturated_channel"].sum()
p["area_per_channel"][:] = area_per_channel
Expand Down
2 changes: 1 addition & 1 deletion strax/processing/peak_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def merge_peaks(peaks, start_merge_at, end_merge_at, max_buffer=int(1e5)):

# Downsample the buffers into new_p['data'], new_p['data_top'],
# and new_p['data_bot']
strax.store_downsampled_waveform(new_p, buffer, True, buffer_top)
strax.store_downsampled_waveform(new_p, buffer, True, True, buffer_top)

new_p["n_saturated_channels"] = new_p["saturated_channel"].sum()

Expand Down
22 changes: 20 additions & 2 deletions strax/processing/peak_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def split_peaks(
algorithm="local_minimum",
data_type="peaks",
n_top_channels=0,
save_waveform_start=False,
**kwargs,
):
"""Return peaks split according to algorithm, with waveforms summed and widths computed.
Expand Down Expand Up @@ -49,7 +50,15 @@ def split_peaks(
if data_type_is_not_supported:
raise TypeError(f'Data_type "{data_type}" is not supported.')
return splitter(
peaks, hits, records, rlinks, to_pe, data_type, n_top_channels=n_top_channels, **kwargs
peaks,
hits,
records,
rlinks,
to_pe,
data_type,
n_top_channels=n_top_channels,
save_waveform_start=save_waveform_start,
**kwargs,
)


Expand Down Expand Up @@ -88,6 +97,7 @@ def __call__(
do_iterations=1,
min_area=0,
n_top_channels=0,
save_waveform_start=False,
**kwargs,
):
if not len(records) or not len(peaks) or not do_iterations:
Expand Down Expand Up @@ -127,7 +137,15 @@ def __call__(
if is_split.sum() != 0:
# Found new peaks: compute basic properties
if data_type == "peaks":
strax.sum_waveform(new_peaks, hits, records, rlinks, to_pe, n_top_channels)
strax.sum_waveform(
new_peaks,
hits,
records,
rlinks,
to_pe,
n_top_channels,
save_waveform_start=save_waveform_start,
)
strax.compute_widths(new_peaks)
elif data_type == "hitlets":
# Add record fields here
Expand Down
1 change: 1 addition & 0 deletions tests/test_peak_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def test_simple_summed_waveform(pulses):
fake_event_dtype = strax.time_dt_fields + [
("data", np.float32, 200),
("data_top", np.float32, 200),
("data_start", np.float32, 200),
]

records = np.zeros(len(pulses), dtype=strax.record_dtype())
Expand Down