Skip to content

Commit

Permalink
Define RunMeta to help extract start and end of a run in the sa…
Browse files Browse the repository at this point in the history
…lting and pairing network (#49)

* Add `RunMeta` plugin

* Change base classes of `PeaksPaired`
  • Loading branch information
dachengx committed Apr 29, 2024
1 parent 85f81b8 commit 44753f1
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 94 deletions.
3 changes: 2 additions & 1 deletion axidence/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import strax
import straxen

from axidence import EventsSalting, PeaksSalted
from axidence import RunMeta, EventsSalting, PeaksSalted
from axidence import (
PeakProximitySalted,
PeakShadowSalted,
Expand Down Expand Up @@ -147,6 +147,7 @@ def _salt_to_context(self):
"""Register the salted plugins to the context."""
self.register(
(
RunMeta,
EventsSalting,
PeaksSalted,
PeakProximitySalted,
Expand Down
2 changes: 2 additions & 0 deletions axidence/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

kind_colors.update(
{
"run_meta": "#ffff00",
"events_salting": "#0080ff",
"peaks_salted": "#00c0ff",
"events_salted": "#00ffff",
"peaks_paired": "#ff00ff",
"truth_paired": "#ff00ff",
"events_paired": "#ffccff",
"isolated_s1": "#80ff00",
"isolated_s2": "#80ff00",
Expand Down
3 changes: 3 additions & 0 deletions axidence/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from . import meta
from .meta import *

from . import cuts
from .cuts import *

Expand Down
2 changes: 2 additions & 0 deletions axidence/plugins/meta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import run_meta
from .run_meta import *
22 changes: 22 additions & 0 deletions axidence/plugins/meta/run_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import strax

from ...plugin import ExhaustPlugin


class RunMeta(ExhaustPlugin):
"""Plugin that provides run metadata."""

__version__ = "0.0.0"
depends_on = "event_basics"
provides = "run_meta"
data_kind = "run_meta"
save_when = strax.SaveWhen.EXPLICIT

dtype = strax.time_fields

def compute(self, events, start, end):
result = np.zeros(1, dtype=self.dtype)
result["time"] = start
result["endtime"] = end
return result
47 changes: 13 additions & 34 deletions axidence/plugins/pairing/peaks_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

from ...utils import copy_dtype
from ...dtypes import peak_positions_dtype
from ...plugin import ExhaustPlugin, RunMetaPlugin
from ...plugin import ExhaustPlugin


class PeaksPaired(ExhaustPlugin, RunMetaPlugin):
class PeaksPaired(ExhaustPlugin):
__version__ = "0.0.0"
depends_on = ("isolated_s1", "isolated_s2", "cut_event_building_salted", "event_shadow_salted")
provides = ("peaks_paired", "truth_paired")
Expand All @@ -25,24 +25,6 @@ class PeaksPaired(ExhaustPlugin, RunMetaPlugin):
help="Seed for pairing",
)

real_run_start = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

real_run_end = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

strict_real_run_time_check = straxen.URLConfig(
default=True,
type=bool,
help="Whether to strictly check the real run time is provided",
)

min_drift_length = straxen.URLConfig(
default=0,
type=(int, float),
Expand Down Expand Up @@ -135,7 +117,6 @@ def infer_dtype(self):
return dict(peaks_paired=peaks_dtype, truth_paired=truth_dtype)

def setup(self, prepare=True):
self.init_run_meta()
self.min_drift_time = int(self.min_drift_length / self.electron_drift_velocity)
self.max_drift_time = int(self.max_drift_length / self.electron_drift_velocity)
if self.pairing_seed is None:
Expand Down Expand Up @@ -211,6 +192,8 @@ def split_chunks(self, n_peaks):

def build_arrays(
self,
start,
end,
drift_time,
s1_group_number,
s2_group_number,
Expand All @@ -223,9 +206,7 @@ def build_arrays(

# set center time of S1 & S2
# paired events are separated by roughly `event_interval`
s1_center_time = (
np.arange(len(drift_time)).astype(int) * self.paring_event_interval + self.run_start
)
s1_center_time = np.arange(len(drift_time)).astype(int) * self.paring_event_interval + start
s2_center_time = s1_center_time + drift_time
# total number of isolated S1 & S2 peaks
peaks_arrays = np.zeros(n_peaks.sum(), dtype=self.dtype["peaks_paired"])
Expand Down Expand Up @@ -322,7 +303,7 @@ def build_arrays(

return peaks_arrays, truth_arrays

def compute(self, isolated_s1, isolated_s2, events_salted):
def compute(self, isolated_s1, isolated_s2, events_salted, start, end):
for i, s in enumerate([isolated_s1, isolated_s2]):
if np.any(np.diff(s["group_number"]) < 0):
raise ValueError(f"Group number is not sorted in isolated S{i}!")
Expand Down Expand Up @@ -350,7 +331,7 @@ def compute(self, isolated_s1, isolated_s2, events_salted):
paring_rate_full, s1_group_number, s2_group_number, drift_time = self.simple_pairing(
isolated_s1,
main_isolated_s2,
self.run_time,
(end - start) / units.s,
self.max_drift_time,
self.min_drift_time,
paring_rate_correction,
Expand All @@ -377,6 +358,8 @@ def compute(self, isolated_s1, isolated_s2, events_salted):
chunk_i = 0
left_i, right_i = slices[chunk_i]
peaks_arrays, truth_arrays = self.build_arrays(
start,
end,
drift_time[left_i:right_i],
s1_group_number[left_i:right_i],
s2_group_number[left_i:right_i],
Expand All @@ -389,18 +372,14 @@ def compute(self, isolated_s1, isolated_s2, events_salted):
peaks_arrays["event_number"] += left_i
truth_arrays["event_number"] += left_i

start = (
self.run_start + left_i * self.paring_event_interval - self.paring_event_interval // 2
)
end = (
self.run_start + right_i * self.paring_event_interval - self.paring_event_interval // 2
)
_start = start + left_i * self.paring_event_interval - int(self.paring_event_interval // 2)
_end = start + right_i * self.paring_event_interval - int(self.paring_event_interval // 2)
result = dict()
result["peaks_paired"] = self.chunk(
start=start, end=end, data=peaks_arrays, data_type="peaks_paired"
start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired"
)
result["truth_paired"] = self.chunk(
start=start, end=end, data=truth_arrays, data_type="truth_paired"
start=_start, end=_end, data=truth_arrays, data_type="truth_paired"
)
# chunk size should be less than default chunk size in strax
assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6
Expand Down
1 change: 1 addition & 0 deletions axidence/plugins/salting/event_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


class EventFieldsSalted(Plugin):
child_plugin = True

def compute(self, events_salted, peaks_salted, peaks):
_peaks = merge_salted_real(peaks_salted, peaks, peaks.dtype)
Expand Down
96 changes: 37 additions & 59 deletions axidence/plugins/salting/events_salting.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import Tuple
import numpy as np
import strax
from strax import DownChunkingPlugin
import straxen
from straxen import units, EventBasics, EventPositions

from ...utils import copy_dtype
from ...plugin import RunMetaPlugin
from ...plugin import ExhaustPlugin


class EventsSalting(EventPositions, EventBasics, RunMetaPlugin):
class EventsSalting(ExhaustPlugin, DownChunkingPlugin, EventPositions, EventBasics):
__version__ = "0.0.0"
depends_on: Tuple = tuple()
child_plugin = True
depends_on = "run_meta"
provides = "events_salting"
data_kind = "events_salting"
save_when = strax.SaveWhen.EXPLICIT
Expand All @@ -27,24 +28,6 @@ class EventsSalting(EventPositions, EventBasics, RunMetaPlugin):
help="Rate of salting in Hz",
)

real_run_start = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

real_run_end = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

strict_real_run_time_check = straxen.URLConfig(
default=True,
type=bool,
help="Whether to strictly check the real run time is provided",
)

s1_area_range = straxen.URLConfig(
default=(1, 150),
type=(list, tuple),
Expand Down Expand Up @@ -115,26 +98,31 @@ def infer_dtype(self):
dtype += [(("Salting number of events", "salt_number"), np.int64)]
return dtype

def setup(self):
super(EventPositions, self).setup()
super(EventsSalting, self).setup()

self.init_rng()

def init_rng(self):
"""Initialize the random number generator."""
if self.salting_seed is None:
self.rng = np.random.default_rng(seed=int(self.run_id))
else:
self.rng = np.random.default_rng(seed=self.salting_seed)

def sample_time(self):
def sample_time(self, start, end):
"""Sample the time according to the start and end of the run."""
self.event_time_interval = units.s // self.salting_rate
self.event_time_interval = int(units.s // self.salting_rate)

if units.s / self.salting_rate < self.drift_time_max * self.n_drift_time_window * 2:
raise ValueError("Salting rate is too high according the drift time window!")

time = np.arange(
self.run_start + self.veto_length_run_start,
self.run_end - self.veto_length_run_end,
start + self.veto_length_run_start,
end - self.veto_length_run_end,
self.event_time_interval,
).astype(np.int64)
self.time_left = self.event_time_interval // 2
self.time_right = self.event_time_interval - self.time_left
return time

def inverse_field_distortion(self, x, y, z):
Expand All @@ -156,15 +144,9 @@ def set_chunk_splitting(self):
self.time_left = self.event_time_interval // 2
self.time_right = self.event_time_interval - self.time_left

def setup(self):
"""Sample the features of events."""
super(EventPositions, self).setup()
super(EventsSalting, self).setup()

self.init_rng()
self.init_run_meta()

time = self.sample_time()
def sampling(self, start, end):
"""Sample the features of events, (t, x, y, z, S1, S2) et al."""
time = self.sample_time(start, end)
self.n_events = len(time)
self.events_salting = np.empty(self.n_events, dtype=self.dtype)
self.events_salting["salt_number"] = np.arange(self.n_events)
Expand Down Expand Up @@ -208,26 +190,22 @@ def setup(self):

self.set_chunk_splitting()

def compute(self, chunk_i):
def compute(self, run_meta, start, end):
"""Copy and assign the salting events into chunk."""
indices = self.slices[chunk_i]

if chunk_i == 0:
start = self.run_start
else:
start = self.events_salting["time"][indices[0]] - self.time_left

if chunk_i == len(self.slices) - 1:
end = self.run_end
else:
end = self.events_salting["time"][indices[1] - 1] + self.time_right
return self.chunk(start=start, end=end, data=self.events_salting[indices[0] : indices[1]])

def is_ready(self, chunk_i):
if chunk_i < len(self.slices):
return True
else:
return False

def source_finished(self):
return True
self.sampling(start, end)
for chunk_i in range(len(self.slices)):
indices = self.slices[chunk_i]

if chunk_i == 0:
_start = start
else:
_start = self.events_salting["time"][indices[0]] - self.time_left

if chunk_i == len(self.slices) - 1:
_end = end
else:
_end = self.events_salting["time"][indices[1] - 1] + self.time_right

yield self.chunk(
start=_start, end=_end, data=self.events_salting[indices[0] : indices[1]]
)
1 change: 1 addition & 0 deletions axidence/plugins/salting/peaks_salted.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class PeaksSalted(PeakBasics):
__version__ = "0.0.0"
child_plugin = True
depends_on = "events_salting"
provides = "peaks_salted"
data_kind = "peaks_salted"
Expand Down

0 comments on commit 44753f1

Please sign in to comment.