diff --git a/axidence/plugins/pairing/events_paired.py b/axidence/plugins/pairing/events_paired.py index 3a49a57..8b18b34 100644 --- a/axidence/plugins/pairing/events_paired.py +++ b/axidence/plugins/pairing/events_paired.py @@ -18,7 +18,7 @@ class EventsForcePaired(OverlapWindowPlugin): data_kind = "events_paired" save_when = strax.SaveWhen.EXPLICIT - paring_event_interval = straxen.URLConfig( + paring_time_interval = straxen.URLConfig( default=int(1e8), type=int, help="The interval which separates two events S1 [ns]", @@ -31,7 +31,7 @@ def infer_dtype(self): return dtype def get_window_size(self): - return 10 * self.paring_event_interval + return 10 * self.paring_time_interval def compute(self, peaks_paired): peaks_event_number_sorted = np.sort(peaks_paired, order=("event_number", "time")) diff --git a/axidence/plugins/pairing/peaks_paired.py b/axidence/plugins/pairing/peaks_paired.py index 53c650d..f17db38 100644 --- a/axidence/plugins/pairing/peaks_paired.py +++ b/axidence/plugins/pairing/peaks_paired.py @@ -2,7 +2,7 @@ from immutabledict import immutabledict import numpy as np import strax -from strax import Plugin +from strax import Plugin, DownChunkingPlugin import straxen from straxen import units from straxen import PeakProximity @@ -12,12 +12,13 @@ from ...plugin import ExhaustPlugin -class PeaksPaired(ExhaustPlugin): +class PeaksPaired(ExhaustPlugin, DownChunkingPlugin): __version__ = "0.0.0" depends_on = ("isolated_s1", "isolated_s2", "cut_event_building_salted", "event_shadow_salted") provides = ("peaks_paired", "truth_paired") data_kind = immutabledict(zip(provides, provides)) save_when = immutabledict(zip(provides, [strax.SaveWhen.EXPLICIT, strax.SaveWhen.ALWAYS])) + rechunk_on_save = immutabledict(zip(provides, [False, True])) pairing_seed = straxen.URLConfig( default=None, @@ -81,7 +82,7 @@ class PeaksPaired(ExhaustPlugin): help="The fixed drift time [ns]", ) - paring_event_interval = straxen.URLConfig( + paring_time_interval = straxen.URLConfig( default=int(1e8), type=int, help="The interval which separates two events S1 [ns]", @@ -123,6 +124,8 @@ def setup(self, prepare=True): self.rng = np.random.default_rng(seed=int(self.run_id)) else: self.rng = np.random.default_rng(seed=self.pairing_seed) + self.time_left = self.paring_time_interval // 2 + self.time_right = self.paring_time_interval - self.time_left @staticmethod def preprocess_isolated_s2(s2): @@ -133,9 +136,9 @@ def preprocess_isolated_s2(s2): # index of main S2s in isolated S2 s2_main_index = s2_group_index + s2["s2_index"][s2_group_index] - indices = np.append(s2_group_index, len(s2)) + s2_group_index = np.append(s2_group_index, len(s2)) # time coverage of the first time and last endtime of S2s in each group - s2_length = s2["endtime"][indices[1:] - 1] - s2["time"][indices[:-1]] + s2_length = s2["endtime"][s2_group_index[1:] - 1] - s2["time"][s2_group_index[:-1]] return s2_group_index, s2_main_index, s2_n_peaks, s2_length @staticmethod @@ -193,7 +196,6 @@ def split_chunks(self, n_peaks): def build_arrays( self, start, - end, drift_time, s1_group_number, s2_group_number, @@ -206,7 +208,7 @@ def build_arrays( # set center time of S1 & S2 # paired events are separated by roughly `event_interval` - s1_center_time = np.arange(len(drift_time)).astype(int) * self.paring_event_interval + start + s1_center_time = np.arange(len(drift_time)).astype(int) * self.paring_time_interval + start s2_center_time = s1_center_time + drift_time # total number of isolated S1 & S2 peaks peaks_arrays = np.zeros(n_peaks.sum(), dtype=self.dtype["peaks_paired"]) @@ -344,47 +346,45 @@ def compute(self, isolated_s1, isolated_s2, events_salted, start, end): print(f"AC event number is {len(drift_time)}") # make sure events are not very long - assert (s2_length.max() + drift_time.max()) * 5.0 < self.paring_event_interval + assert (s2_length.max() + drift_time.max()) * 5.0 < self.paring_time_interval # peaks number in each event n_peaks = 1 + s2_n_peaks[s2_group_number] slices = self.split_chunks(n_peaks) - if len(slices) > 1: - raise NotImplementedError( - f"Got {len(slices)} chunks. Multiple chunks are not implemented yet!" + print(f"Number of chunks is {len(slices)}") + + for chunk_i in range(len(slices)): + left_i, right_i = slices[chunk_i] + + _start = start + left_i * self.paring_time_interval + _end = start + right_i * self.paring_time_interval + + peaks_arrays, truth_arrays = self.build_arrays( + _start + self.time_left, + drift_time[left_i:right_i], + s1_group_number[left_i:right_i], + s2_group_number[left_i:right_i], + n_peaks[left_i:right_i], + isolated_s1, + isolated_s2, + main_isolated_s2, + s2_group_index, ) + peaks_arrays["event_number"] += left_i + truth_arrays["event_number"] += left_i - chunk_i = 0 - left_i, right_i = slices[chunk_i] - peaks_arrays, truth_arrays = self.build_arrays( - start, - end, - drift_time[left_i:right_i], - s1_group_number[left_i:right_i], - s2_group_number[left_i:right_i], - n_peaks[left_i:right_i], - isolated_s1, - isolated_s2, - main_isolated_s2, - s2_group_index, - ) - peaks_arrays["event_number"] += left_i - truth_arrays["event_number"] += left_i - - _start = start + left_i * self.paring_event_interval - int(self.paring_event_interval // 2) - _end = start + right_i * self.paring_event_interval - int(self.paring_event_interval // 2) - result = dict() - result["peaks_paired"] = self.chunk( - start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired" - ) - result["truth_paired"] = self.chunk( - start=_start, end=_end, data=truth_arrays, data_type="truth_paired" - ) - # chunk size should be less than default chunk size in strax - assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6 + result = dict() + result["peaks_paired"] = self.chunk( + start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired" + ) + result["truth_paired"] = self.chunk( + start=_start, end=_end, data=truth_arrays, data_type="truth_paired" + ) + # chunk size should be less than default chunk size in strax + assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6 - return result + yield result class PeakProximityPaired(PeakProximity):