From 44753f108120e094090341dbf07600b227c47d45 Mon Sep 17 00:00:00 2001 From: Dacheng Xu Date: Sun, 28 Apr 2024 21:48:48 -0400 Subject: [PATCH] Define `RunMeta` to help extract `start` and `end` of a run in the salting and pairing network (#49) * Add `RunMeta` plugin * Change base classes of `PeaksPaired` --- axidence/context.py | 3 +- axidence/dtypes.py | 2 + axidence/plugins/__init__.py | 3 + axidence/plugins/meta/__init__.py | 2 + axidence/plugins/meta/run_meta.py | 22 +++++ axidence/plugins/pairing/peaks_paired.py | 47 +++-------- axidence/plugins/salting/event_fields.py | 1 + axidence/plugins/salting/events_salting.py | 96 +++++++++------------- axidence/plugins/salting/peaks_salted.py | 1 + 9 files changed, 83 insertions(+), 94 deletions(-) create mode 100644 axidence/plugins/meta/__init__.py create mode 100644 axidence/plugins/meta/run_meta.py diff --git a/axidence/context.py b/axidence/context.py index 135d5f2..f517d6a 100644 --- a/axidence/context.py +++ b/axidence/context.py @@ -3,7 +3,7 @@ import strax import straxen -from axidence import EventsSalting, PeaksSalted +from axidence import RunMeta, EventsSalting, PeaksSalted from axidence import ( PeakProximitySalted, PeakShadowSalted, @@ -147,6 +147,7 @@ def _salt_to_context(self): """Register the salted plugins to the context.""" self.register( ( + RunMeta, EventsSalting, PeaksSalted, PeakProximitySalted, diff --git a/axidence/dtypes.py b/axidence/dtypes.py index 568ad63..3dea0e0 100644 --- a/axidence/dtypes.py +++ b/axidence/dtypes.py @@ -7,10 +7,12 @@ kind_colors.update( { + "run_meta": "#ffff00", "events_salting": "#0080ff", "peaks_salted": "#00c0ff", "events_salted": "#00ffff", "peaks_paired": "#ff00ff", + "truth_paired": "#ff00ff", "events_paired": "#ffccff", "isolated_s1": "#80ff00", "isolated_s2": "#80ff00", diff --git a/axidence/plugins/__init__.py b/axidence/plugins/__init__.py index 0098fd7..61d2ad3 100644 --- a/axidence/plugins/__init__.py +++ b/axidence/plugins/__init__.py @@ -1,3 +1,6 @@ +from . import meta +from .meta import * + from . import cuts from .cuts import * diff --git a/axidence/plugins/meta/__init__.py b/axidence/plugins/meta/__init__.py new file mode 100644 index 0000000..ca514d7 --- /dev/null +++ b/axidence/plugins/meta/__init__.py @@ -0,0 +1,2 @@ +from . import run_meta +from .run_meta import * diff --git a/axidence/plugins/meta/run_meta.py b/axidence/plugins/meta/run_meta.py new file mode 100644 index 0000000..ee7d374 --- /dev/null +++ b/axidence/plugins/meta/run_meta.py @@ -0,0 +1,22 @@ +import numpy as np +import strax + +from ...plugin import ExhaustPlugin + + +class RunMeta(ExhaustPlugin): + """Plugin that provides run metadata.""" + + __version__ = "0.0.0" + depends_on = "event_basics" + provides = "run_meta" + data_kind = "run_meta" + save_when = strax.SaveWhen.EXPLICIT + + dtype = strax.time_fields + + def compute(self, events, start, end): + result = np.zeros(1, dtype=self.dtype) + result["time"] = start + result["endtime"] = end + return result diff --git a/axidence/plugins/pairing/peaks_paired.py b/axidence/plugins/pairing/peaks_paired.py index 921b695..53c650d 100644 --- a/axidence/plugins/pairing/peaks_paired.py +++ b/axidence/plugins/pairing/peaks_paired.py @@ -9,10 +9,10 @@ from ...utils import copy_dtype from ...dtypes import peak_positions_dtype -from ...plugin import ExhaustPlugin, RunMetaPlugin +from ...plugin import ExhaustPlugin -class PeaksPaired(ExhaustPlugin, RunMetaPlugin): +class PeaksPaired(ExhaustPlugin): __version__ = "0.0.0" depends_on = ("isolated_s1", "isolated_s2", "cut_event_building_salted", "event_shadow_salted") provides = ("peaks_paired", "truth_paired") @@ -25,24 +25,6 @@ class PeaksPaired(ExhaustPlugin, RunMetaPlugin): help="Seed for pairing", ) - real_run_start = straxen.URLConfig( - default=None, - type=(int, None), - help="Real start time of run [ns]", - ) - - real_run_end = straxen.URLConfig( - default=None, - type=(int, None), - help="Real start time of run [ns]", - ) - - strict_real_run_time_check = straxen.URLConfig( - default=True, - type=bool, - help="Whether to strictly check the real run time is provided", - ) - min_drift_length = straxen.URLConfig( default=0, type=(int, float), @@ -135,7 +117,6 @@ def infer_dtype(self): return dict(peaks_paired=peaks_dtype, truth_paired=truth_dtype) def setup(self, prepare=True): - self.init_run_meta() self.min_drift_time = int(self.min_drift_length / self.electron_drift_velocity) self.max_drift_time = int(self.max_drift_length / self.electron_drift_velocity) if self.pairing_seed is None: @@ -211,6 +192,8 @@ def split_chunks(self, n_peaks): def build_arrays( self, + start, + end, drift_time, s1_group_number, s2_group_number, @@ -223,9 +206,7 @@ def build_arrays( # set center time of S1 & S2 # paired events are separated by roughly `event_interval` - s1_center_time = ( - np.arange(len(drift_time)).astype(int) * self.paring_event_interval + self.run_start - ) + s1_center_time = np.arange(len(drift_time)).astype(int) * self.paring_event_interval + start s2_center_time = s1_center_time + drift_time # total number of isolated S1 & S2 peaks peaks_arrays = np.zeros(n_peaks.sum(), dtype=self.dtype["peaks_paired"]) @@ -322,7 +303,7 @@ def build_arrays( return peaks_arrays, truth_arrays - def compute(self, isolated_s1, isolated_s2, events_salted): + def compute(self, isolated_s1, isolated_s2, events_salted, start, end): for i, s in enumerate([isolated_s1, isolated_s2]): if np.any(np.diff(s["group_number"]) < 0): raise ValueError(f"Group number is not sorted in isolated S{i}!") @@ -350,7 +331,7 @@ def compute(self, isolated_s1, isolated_s2, events_salted): paring_rate_full, s1_group_number, s2_group_number, drift_time = self.simple_pairing( isolated_s1, main_isolated_s2, - self.run_time, + (end - start) / units.s, self.max_drift_time, self.min_drift_time, paring_rate_correction, @@ -377,6 +358,8 @@ def compute(self, isolated_s1, isolated_s2, events_salted): chunk_i = 0 left_i, right_i = slices[chunk_i] peaks_arrays, truth_arrays = self.build_arrays( + start, + end, drift_time[left_i:right_i], s1_group_number[left_i:right_i], s2_group_number[left_i:right_i], @@ -389,18 +372,14 @@ def compute(self, isolated_s1, isolated_s2, events_salted): peaks_arrays["event_number"] += left_i truth_arrays["event_number"] += left_i - start = ( - self.run_start + left_i * self.paring_event_interval - self.paring_event_interval // 2 - ) - end = ( - self.run_start + right_i * self.paring_event_interval - self.paring_event_interval // 2 - ) + _start = start + left_i * self.paring_event_interval - int(self.paring_event_interval // 2) + _end = start + right_i * self.paring_event_interval - int(self.paring_event_interval // 2) result = dict() result["peaks_paired"] = self.chunk( - start=start, end=end, data=peaks_arrays, data_type="peaks_paired" + start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired" ) result["truth_paired"] = self.chunk( - start=start, end=end, data=truth_arrays, data_type="truth_paired" + start=_start, end=_end, data=truth_arrays, data_type="truth_paired" ) # chunk size should be less than default chunk size in strax assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6 diff --git a/axidence/plugins/salting/event_fields.py b/axidence/plugins/salting/event_fields.py index 15fc857..5f86bf2 100644 --- a/axidence/plugins/salting/event_fields.py +++ b/axidence/plugins/salting/event_fields.py @@ -7,6 +7,7 @@ class EventFieldsSalted(Plugin): + child_plugin = True def compute(self, events_salted, peaks_salted, peaks): _peaks = merge_salted_real(peaks_salted, peaks, peaks.dtype) diff --git a/axidence/plugins/salting/events_salting.py b/axidence/plugins/salting/events_salting.py index 913cde5..06c326a 100644 --- a/axidence/plugins/salting/events_salting.py +++ b/axidence/plugins/salting/events_salting.py @@ -1,16 +1,17 @@ -from typing import Tuple import numpy as np import strax +from strax import DownChunkingPlugin import straxen from straxen import units, EventBasics, EventPositions from ...utils import copy_dtype -from ...plugin import RunMetaPlugin +from ...plugin import ExhaustPlugin -class EventsSalting(EventPositions, EventBasics, RunMetaPlugin): +class EventsSalting(ExhaustPlugin, DownChunkingPlugin, EventPositions, EventBasics): __version__ = "0.0.0" - depends_on: Tuple = tuple() + child_plugin = True + depends_on = "run_meta" provides = "events_salting" data_kind = "events_salting" save_when = strax.SaveWhen.EXPLICIT @@ -27,24 +28,6 @@ class EventsSalting(EventPositions, EventBasics, RunMetaPlugin): help="Rate of salting in Hz", ) - real_run_start = straxen.URLConfig( - default=None, - type=(int, None), - help="Real start time of run [ns]", - ) - - real_run_end = straxen.URLConfig( - default=None, - type=(int, None), - help="Real start time of run [ns]", - ) - - strict_real_run_time_check = straxen.URLConfig( - default=True, - type=bool, - help="Whether to strictly check the real run time is provided", - ) - s1_area_range = straxen.URLConfig( default=(1, 150), type=(list, tuple), @@ -115,26 +98,31 @@ def infer_dtype(self): dtype += [(("Salting number of events", "salt_number"), np.int64)] return dtype + def setup(self): + super(EventPositions, self).setup() + super(EventsSalting, self).setup() + + self.init_rng() + def init_rng(self): + """Initialize the random number generator.""" if self.salting_seed is None: self.rng = np.random.default_rng(seed=int(self.run_id)) else: self.rng = np.random.default_rng(seed=self.salting_seed) - def sample_time(self): + def sample_time(self, start, end): """Sample the time according to the start and end of the run.""" - self.event_time_interval = units.s // self.salting_rate + self.event_time_interval = int(units.s // self.salting_rate) if units.s / self.salting_rate < self.drift_time_max * self.n_drift_time_window * 2: raise ValueError("Salting rate is too high according the drift time window!") time = np.arange( - self.run_start + self.veto_length_run_start, - self.run_end - self.veto_length_run_end, + start + self.veto_length_run_start, + end - self.veto_length_run_end, self.event_time_interval, ).astype(np.int64) - self.time_left = self.event_time_interval // 2 - self.time_right = self.event_time_interval - self.time_left return time def inverse_field_distortion(self, x, y, z): @@ -156,15 +144,9 @@ def set_chunk_splitting(self): self.time_left = self.event_time_interval // 2 self.time_right = self.event_time_interval - self.time_left - def setup(self): - """Sample the features of events.""" - super(EventPositions, self).setup() - super(EventsSalting, self).setup() - - self.init_rng() - self.init_run_meta() - - time = self.sample_time() + def sampling(self, start, end): + """Sample the features of events, (t, x, y, z, S1, S2) et al.""" + time = self.sample_time(start, end) self.n_events = len(time) self.events_salting = np.empty(self.n_events, dtype=self.dtype) self.events_salting["salt_number"] = np.arange(self.n_events) @@ -208,26 +190,22 @@ def setup(self): self.set_chunk_splitting() - def compute(self, chunk_i): + def compute(self, run_meta, start, end): """Copy and assign the salting events into chunk.""" - indices = self.slices[chunk_i] - - if chunk_i == 0: - start = self.run_start - else: - start = self.events_salting["time"][indices[0]] - self.time_left - - if chunk_i == len(self.slices) - 1: - end = self.run_end - else: - end = self.events_salting["time"][indices[1] - 1] + self.time_right - return self.chunk(start=start, end=end, data=self.events_salting[indices[0] : indices[1]]) - - def is_ready(self, chunk_i): - if chunk_i < len(self.slices): - return True - else: - return False - - def source_finished(self): - return True + self.sampling(start, end) + for chunk_i in range(len(self.slices)): + indices = self.slices[chunk_i] + + if chunk_i == 0: + _start = start + else: + _start = self.events_salting["time"][indices[0]] - self.time_left + + if chunk_i == len(self.slices) - 1: + _end = end + else: + _end = self.events_salting["time"][indices[1] - 1] + self.time_right + + yield self.chunk( + start=_start, end=_end, data=self.events_salting[indices[0] : indices[1]] + ) diff --git a/axidence/plugins/salting/peaks_salted.py b/axidence/plugins/salting/peaks_salted.py index afd9d13..f9f5dbb 100644 --- a/axidence/plugins/salting/peaks_salted.py +++ b/axidence/plugins/salting/peaks_salted.py @@ -8,6 +8,7 @@ class PeaksSalted(PeakBasics): __version__ = "0.0.0" + child_plugin = True depends_on = "events_salting" provides = "peaks_salted" data_kind = "peaks_salted"