Skip to content

Commit

Permalink
Set PeaksPaired to be subclass of DownChunkingPlugin
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx committed Apr 29, 2024
1 parent 44753f1 commit b4dd9dd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 42 deletions.
4 changes: 2 additions & 2 deletions axidence/plugins/pairing/events_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class EventsForcePaired(OverlapWindowPlugin):
data_kind = "events_paired"
save_when = strax.SaveWhen.EXPLICIT

paring_event_interval = straxen.URLConfig(
paring_time_interval = straxen.URLConfig(
default=int(1e8),
type=int,
help="The interval which separates two events S1 [ns]",
Expand All @@ -31,7 +31,7 @@ def infer_dtype(self):
return dtype

def get_window_size(self):
return 10 * self.paring_event_interval
return 10 * self.paring_time_interval

def compute(self, peaks_paired):
peaks_event_number_sorted = np.sort(peaks_paired, order=("event_number", "time"))
Expand Down
80 changes: 40 additions & 40 deletions axidence/plugins/pairing/peaks_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from immutabledict import immutabledict
import numpy as np
import strax
from strax import Plugin
from strax import Plugin, DownChunkingPlugin
import straxen
from straxen import units
from straxen import PeakProximity
Expand All @@ -12,12 +12,13 @@
from ...plugin import ExhaustPlugin


class PeaksPaired(ExhaustPlugin):
class PeaksPaired(ExhaustPlugin, DownChunkingPlugin):
__version__ = "0.0.0"
depends_on = ("isolated_s1", "isolated_s2", "cut_event_building_salted", "event_shadow_salted")
provides = ("peaks_paired", "truth_paired")
data_kind = immutabledict(zip(provides, provides))
save_when = immutabledict(zip(provides, [strax.SaveWhen.EXPLICIT, strax.SaveWhen.ALWAYS]))
rechunk_on_save = immutabledict(zip(provides, [False, True]))

pairing_seed = straxen.URLConfig(
default=None,
Expand Down Expand Up @@ -81,7 +82,7 @@ class PeaksPaired(ExhaustPlugin):
help="The fixed drift time [ns]",
)

paring_event_interval = straxen.URLConfig(
paring_time_interval = straxen.URLConfig(
default=int(1e8),
type=int,
help="The interval which separates two events S1 [ns]",
Expand Down Expand Up @@ -123,6 +124,8 @@ def setup(self, prepare=True):
self.rng = np.random.default_rng(seed=int(self.run_id))
else:
self.rng = np.random.default_rng(seed=self.pairing_seed)
self.time_left = self.paring_time_interval // 2
self.time_right = self.paring_time_interval - self.time_left

@staticmethod
def preprocess_isolated_s2(s2):
Expand All @@ -133,9 +136,9 @@ def preprocess_isolated_s2(s2):
# index of main S2s in isolated S2
s2_main_index = s2_group_index + s2["s2_index"][s2_group_index]

indices = np.append(s2_group_index, len(s2))
s2_group_index = np.append(s2_group_index, len(s2))
# time coverage of the first time and last endtime of S2s in each group
s2_length = s2["endtime"][indices[1:] - 1] - s2["time"][indices[:-1]]
s2_length = s2["endtime"][s2_group_index[1:] - 1] - s2["time"][s2_group_index[:-1]]
return s2_group_index, s2_main_index, s2_n_peaks, s2_length

@staticmethod
Expand Down Expand Up @@ -193,7 +196,6 @@ def split_chunks(self, n_peaks):
def build_arrays(
self,
start,
end,
drift_time,
s1_group_number,
s2_group_number,
Expand All @@ -206,7 +208,7 @@ def build_arrays(

# set center time of S1 & S2
# paired events are separated by roughly `event_interval`
s1_center_time = np.arange(len(drift_time)).astype(int) * self.paring_event_interval + start
s1_center_time = np.arange(len(drift_time)).astype(int) * self.paring_time_interval + start
s2_center_time = s1_center_time + drift_time
# total number of isolated S1 & S2 peaks
peaks_arrays = np.zeros(n_peaks.sum(), dtype=self.dtype["peaks_paired"])
Expand Down Expand Up @@ -344,47 +346,45 @@ def compute(self, isolated_s1, isolated_s2, events_salted, start, end):
print(f"AC event number is {len(drift_time)}")

# make sure events are not very long
assert (s2_length.max() + drift_time.max()) * 5.0 < self.paring_event_interval
assert (s2_length.max() + drift_time.max()) * 5.0 < self.paring_time_interval

# peaks number in each event
n_peaks = 1 + s2_n_peaks[s2_group_number]
slices = self.split_chunks(n_peaks)

if len(slices) > 1:
raise NotImplementedError(
f"Got {len(slices)} chunks. Multiple chunks are not implemented yet!"
print(f"Number of chunks is {len(slices)}")

for chunk_i in range(len(slices)):
left_i, right_i = slices[chunk_i]

_start = start + left_i * self.paring_time_interval
_end = start + right_i * self.paring_time_interval

peaks_arrays, truth_arrays = self.build_arrays(
_start + self.time_left,
drift_time[left_i:right_i],
s1_group_number[left_i:right_i],
s2_group_number[left_i:right_i],
n_peaks[left_i:right_i],
isolated_s1,
isolated_s2,
main_isolated_s2,
s2_group_index,
)
peaks_arrays["event_number"] += left_i
truth_arrays["event_number"] += left_i

chunk_i = 0
left_i, right_i = slices[chunk_i]
peaks_arrays, truth_arrays = self.build_arrays(
start,
end,
drift_time[left_i:right_i],
s1_group_number[left_i:right_i],
s2_group_number[left_i:right_i],
n_peaks[left_i:right_i],
isolated_s1,
isolated_s2,
main_isolated_s2,
s2_group_index,
)
peaks_arrays["event_number"] += left_i
truth_arrays["event_number"] += left_i

_start = start + left_i * self.paring_event_interval - int(self.paring_event_interval // 2)
_end = start + right_i * self.paring_event_interval - int(self.paring_event_interval // 2)
result = dict()
result["peaks_paired"] = self.chunk(
start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired"
)
result["truth_paired"] = self.chunk(
start=_start, end=_end, data=truth_arrays, data_type="truth_paired"
)
# chunk size should be less than default chunk size in strax
assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6
result = dict()
result["peaks_paired"] = self.chunk(
start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired"
)
result["truth_paired"] = self.chunk(
start=_start, end=_end, data=truth_arrays, data_type="truth_paired"
)
# chunk size should be less than default chunk size in strax
assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6

return result
yield result


class PeakProximityPaired(PeakProximity):
Expand Down

0 comments on commit b4dd9dd

Please sign in to comment.