Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set PeaksPaired to be subclass of DownChunkingPlugin #50

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions axidence/plugins/pairing/events_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class EventsForcePaired(OverlapWindowPlugin):
data_kind = "events_paired"
save_when = strax.SaveWhen.EXPLICIT

paring_event_interval = straxen.URLConfig(
paring_time_interval = straxen.URLConfig(
default=int(1e8),
type=int,
help="The interval which separates two events S1 [ns]",
Expand All @@ -31,7 +31,7 @@ def infer_dtype(self):
return dtype

def get_window_size(self):
return 10 * self.paring_event_interval
return 10 * self.paring_time_interval

def compute(self, peaks_paired):
peaks_event_number_sorted = np.sort(peaks_paired, order=("event_number", "time"))
Expand Down
80 changes: 40 additions & 40 deletions axidence/plugins/pairing/peaks_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from immutabledict import immutabledict
import numpy as np
import strax
from strax import Plugin
from strax import Plugin, DownChunkingPlugin
import straxen
from straxen import units
from straxen import PeakProximity
Expand All @@ -12,12 +12,13 @@
from ...plugin import ExhaustPlugin


class PeaksPaired(ExhaustPlugin):
class PeaksPaired(ExhaustPlugin, DownChunkingPlugin):
__version__ = "0.0.0"
depends_on = ("isolated_s1", "isolated_s2", "cut_event_building_salted", "event_shadow_salted")
provides = ("peaks_paired", "truth_paired")
data_kind = immutabledict(zip(provides, provides))
save_when = immutabledict(zip(provides, [strax.SaveWhen.EXPLICIT, strax.SaveWhen.ALWAYS]))
rechunk_on_save = immutabledict(zip(provides, [False, True]))

pairing_seed = straxen.URLConfig(
default=None,
Expand Down Expand Up @@ -81,7 +82,7 @@ class PeaksPaired(ExhaustPlugin):
help="The fixed drift time [ns]",
)

paring_event_interval = straxen.URLConfig(
paring_time_interval = straxen.URLConfig(
default=int(1e8),
type=int,
help="The interval which separates two events S1 [ns]",
Expand Down Expand Up @@ -123,6 +124,8 @@ def setup(self, prepare=True):
self.rng = np.random.default_rng(seed=int(self.run_id))
else:
self.rng = np.random.default_rng(seed=self.pairing_seed)
self.time_left = self.paring_time_interval // 2
self.time_right = self.paring_time_interval - self.time_left

@staticmethod
def preprocess_isolated_s2(s2):
Expand All @@ -133,9 +136,9 @@ def preprocess_isolated_s2(s2):
# index of main S2s in isolated S2
s2_main_index = s2_group_index + s2["s2_index"][s2_group_index]

indices = np.append(s2_group_index, len(s2))
s2_group_index = np.append(s2_group_index, len(s2))
# time coverage of the first time and last endtime of S2s in each group
s2_length = s2["endtime"][indices[1:] - 1] - s2["time"][indices[:-1]]
s2_length = s2["endtime"][s2_group_index[1:] - 1] - s2["time"][s2_group_index[:-1]]
return s2_group_index, s2_main_index, s2_n_peaks, s2_length

@staticmethod
Expand Down Expand Up @@ -193,7 +196,6 @@ def split_chunks(self, n_peaks):
def build_arrays(
self,
start,
end,
drift_time,
s1_group_number,
s2_group_number,
Expand All @@ -206,7 +208,7 @@ def build_arrays(

# set center time of S1 & S2
# paired events are separated by roughly `event_interval`
s1_center_time = np.arange(len(drift_time)).astype(int) * self.paring_event_interval + start
s1_center_time = np.arange(len(drift_time)).astype(int) * self.paring_time_interval + start
s2_center_time = s1_center_time + drift_time
# total number of isolated S1 & S2 peaks
peaks_arrays = np.zeros(n_peaks.sum(), dtype=self.dtype["peaks_paired"])
Expand Down Expand Up @@ -344,47 +346,45 @@ def compute(self, isolated_s1, isolated_s2, events_salted, start, end):
print(f"AC event number is {len(drift_time)}")

# make sure events are not very long
assert (s2_length.max() + drift_time.max()) * 5.0 < self.paring_event_interval
assert (s2_length.max() + drift_time.max()) * 5.0 < self.paring_time_interval

# peaks number in each event
n_peaks = 1 + s2_n_peaks[s2_group_number]
slices = self.split_chunks(n_peaks)

if len(slices) > 1:
raise NotImplementedError(
f"Got {len(slices)} chunks. Multiple chunks are not implemented yet!"
print(f"Number of chunks is {len(slices)}")

for chunk_i in range(len(slices)):
left_i, right_i = slices[chunk_i]

_start = start + left_i * self.paring_time_interval
_end = start + right_i * self.paring_time_interval

peaks_arrays, truth_arrays = self.build_arrays(
_start + self.time_left,
drift_time[left_i:right_i],
s1_group_number[left_i:right_i],
s2_group_number[left_i:right_i],
n_peaks[left_i:right_i],
isolated_s1,
isolated_s2,
main_isolated_s2,
s2_group_index,
)
peaks_arrays["event_number"] += left_i
truth_arrays["event_number"] += left_i

chunk_i = 0
left_i, right_i = slices[chunk_i]
peaks_arrays, truth_arrays = self.build_arrays(
start,
end,
drift_time[left_i:right_i],
s1_group_number[left_i:right_i],
s2_group_number[left_i:right_i],
n_peaks[left_i:right_i],
isolated_s1,
isolated_s2,
main_isolated_s2,
s2_group_index,
)
peaks_arrays["event_number"] += left_i
truth_arrays["event_number"] += left_i

_start = start + left_i * self.paring_event_interval - int(self.paring_event_interval // 2)
_end = start + right_i * self.paring_event_interval - int(self.paring_event_interval // 2)
result = dict()
result["peaks_paired"] = self.chunk(
start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired"
)
result["truth_paired"] = self.chunk(
start=_start, end=_end, data=truth_arrays, data_type="truth_paired"
)
# chunk size should be less than default chunk size in strax
assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6
result = dict()
result["peaks_paired"] = self.chunk(
start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired"
)
result["truth_paired"] = self.chunk(
start=_start, end=_end, data=truth_arrays, data_type="truth_paired"
)
# chunk size should be less than default chunk size in strax
assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6

return result
yield result


class PeakProximityPaired(PeakProximity):
Expand Down
Loading