Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define RunMeta to help extract start and end of a run in the salting and pairing network #49

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion axidence/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import strax
import straxen

from axidence import EventsSalting, PeaksSalted
from axidence import RunMeta, EventsSalting, PeaksSalted
from axidence import (
PeakProximitySalted,
PeakShadowSalted,
Expand Down Expand Up @@ -147,6 +147,7 @@ def _salt_to_context(self):
"""Register the salted plugins to the context."""
self.register(
(
RunMeta,
EventsSalting,
PeaksSalted,
PeakProximitySalted,
Expand Down
2 changes: 2 additions & 0 deletions axidence/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

kind_colors.update(
{
"run_meta": "#ffff00",
"events_salting": "#0080ff",
"peaks_salted": "#00c0ff",
"events_salted": "#00ffff",
"peaks_paired": "#ff00ff",
"truth_paired": "#ff00ff",
"events_paired": "#ffccff",
"isolated_s1": "#80ff00",
"isolated_s2": "#80ff00",
Expand Down
3 changes: 3 additions & 0 deletions axidence/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from . import meta
from .meta import *

from . import cuts
from .cuts import *

Expand Down
2 changes: 2 additions & 0 deletions axidence/plugins/meta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import run_meta
from .run_meta import *
22 changes: 22 additions & 0 deletions axidence/plugins/meta/run_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import strax

from ...plugin import ExhaustPlugin


class RunMeta(ExhaustPlugin):
"""Plugin that provides run metadata."""

__version__ = "0.0.0"
depends_on = "event_basics"
provides = "run_meta"
data_kind = "run_meta"
save_when = strax.SaveWhen.EXPLICIT

dtype = strax.time_fields

def compute(self, events, start, end):
result = np.zeros(1, dtype=self.dtype)
result["time"] = start
result["endtime"] = end
return result
47 changes: 13 additions & 34 deletions axidence/plugins/pairing/peaks_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

from ...utils import copy_dtype
from ...dtypes import peak_positions_dtype
from ...plugin import ExhaustPlugin, RunMetaPlugin
from ...plugin import ExhaustPlugin


class PeaksPaired(ExhaustPlugin, RunMetaPlugin):
class PeaksPaired(ExhaustPlugin):
__version__ = "0.0.0"
depends_on = ("isolated_s1", "isolated_s2", "cut_event_building_salted", "event_shadow_salted")
provides = ("peaks_paired", "truth_paired")
Expand All @@ -25,24 +25,6 @@ class PeaksPaired(ExhaustPlugin, RunMetaPlugin):
help="Seed for pairing",
)

real_run_start = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

real_run_end = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

strict_real_run_time_check = straxen.URLConfig(
default=True,
type=bool,
help="Whether to strictly check the real run time is provided",
)

min_drift_length = straxen.URLConfig(
default=0,
type=(int, float),
Expand Down Expand Up @@ -135,7 +117,6 @@ def infer_dtype(self):
return dict(peaks_paired=peaks_dtype, truth_paired=truth_dtype)

def setup(self, prepare=True):
self.init_run_meta()
self.min_drift_time = int(self.min_drift_length / self.electron_drift_velocity)
self.max_drift_time = int(self.max_drift_length / self.electron_drift_velocity)
if self.pairing_seed is None:
Expand Down Expand Up @@ -211,6 +192,8 @@ def split_chunks(self, n_peaks):

def build_arrays(
self,
start,
end,
drift_time,
s1_group_number,
s2_group_number,
Expand All @@ -223,9 +206,7 @@ def build_arrays(

# set center time of S1 & S2
# paired events are separated by roughly `event_interval`
s1_center_time = (
np.arange(len(drift_time)).astype(int) * self.paring_event_interval + self.run_start
)
s1_center_time = np.arange(len(drift_time)).astype(int) * self.paring_event_interval + start
s2_center_time = s1_center_time + drift_time
# total number of isolated S1 & S2 peaks
peaks_arrays = np.zeros(n_peaks.sum(), dtype=self.dtype["peaks_paired"])
Expand Down Expand Up @@ -322,7 +303,7 @@ def build_arrays(

return peaks_arrays, truth_arrays

def compute(self, isolated_s1, isolated_s2, events_salted):
def compute(self, isolated_s1, isolated_s2, events_salted, start, end):
for i, s in enumerate([isolated_s1, isolated_s2]):
if np.any(np.diff(s["group_number"]) < 0):
raise ValueError(f"Group number is not sorted in isolated S{i}!")
Expand Down Expand Up @@ -350,7 +331,7 @@ def compute(self, isolated_s1, isolated_s2, events_salted):
paring_rate_full, s1_group_number, s2_group_number, drift_time = self.simple_pairing(
isolated_s1,
main_isolated_s2,
self.run_time,
(end - start) / units.s,
self.max_drift_time,
self.min_drift_time,
paring_rate_correction,
Expand All @@ -377,6 +358,8 @@ def compute(self, isolated_s1, isolated_s2, events_salted):
chunk_i = 0
left_i, right_i = slices[chunk_i]
peaks_arrays, truth_arrays = self.build_arrays(
start,
end,
drift_time[left_i:right_i],
s1_group_number[left_i:right_i],
s2_group_number[left_i:right_i],
Expand All @@ -389,18 +372,14 @@ def compute(self, isolated_s1, isolated_s2, events_salted):
peaks_arrays["event_number"] += left_i
truth_arrays["event_number"] += left_i

start = (
self.run_start + left_i * self.paring_event_interval - self.paring_event_interval // 2
)
end = (
self.run_start + right_i * self.paring_event_interval - self.paring_event_interval // 2
)
_start = start + left_i * self.paring_event_interval - int(self.paring_event_interval // 2)
_end = start + right_i * self.paring_event_interval - int(self.paring_event_interval // 2)
result = dict()
result["peaks_paired"] = self.chunk(
start=start, end=end, data=peaks_arrays, data_type="peaks_paired"
start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired"
)
result["truth_paired"] = self.chunk(
start=start, end=end, data=truth_arrays, data_type="truth_paired"
start=_start, end=_end, data=truth_arrays, data_type="truth_paired"
)
# chunk size should be less than default chunk size in strax
assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6
Expand Down
1 change: 1 addition & 0 deletions axidence/plugins/salting/event_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


class EventFieldsSalted(Plugin):
child_plugin = True

def compute(self, events_salted, peaks_salted, peaks):
_peaks = merge_salted_real(peaks_salted, peaks, peaks.dtype)
Expand Down
96 changes: 37 additions & 59 deletions axidence/plugins/salting/events_salting.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import Tuple
import numpy as np
import strax
from strax import DownChunkingPlugin
import straxen
from straxen import units, EventBasics, EventPositions

from ...utils import copy_dtype
from ...plugin import RunMetaPlugin
from ...plugin import ExhaustPlugin


class EventsSalting(EventPositions, EventBasics, RunMetaPlugin):
class EventsSalting(ExhaustPlugin, DownChunkingPlugin, EventPositions, EventBasics):
__version__ = "0.0.0"
depends_on: Tuple = tuple()
child_plugin = True
depends_on = "run_meta"
provides = "events_salting"
data_kind = "events_salting"
save_when = strax.SaveWhen.EXPLICIT
Expand All @@ -27,24 +28,6 @@ class EventsSalting(EventPositions, EventBasics, RunMetaPlugin):
help="Rate of salting in Hz",
)

real_run_start = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

real_run_end = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

strict_real_run_time_check = straxen.URLConfig(
default=True,
type=bool,
help="Whether to strictly check the real run time is provided",
)

s1_area_range = straxen.URLConfig(
default=(1, 150),
type=(list, tuple),
Expand Down Expand Up @@ -115,26 +98,31 @@ def infer_dtype(self):
dtype += [(("Salting number of events", "salt_number"), np.int64)]
return dtype

def setup(self):
super(EventPositions, self).setup()
super(EventsSalting, self).setup()

self.init_rng()

def init_rng(self):
"""Initialize the random number generator."""
if self.salting_seed is None:
self.rng = np.random.default_rng(seed=int(self.run_id))
else:
self.rng = np.random.default_rng(seed=self.salting_seed)

def sample_time(self):
def sample_time(self, start, end):
"""Sample the time according to the start and end of the run."""
self.event_time_interval = units.s // self.salting_rate
self.event_time_interval = int(units.s // self.salting_rate)

if units.s / self.salting_rate < self.drift_time_max * self.n_drift_time_window * 2:
raise ValueError("Salting rate is too high according the drift time window!")

time = np.arange(
self.run_start + self.veto_length_run_start,
self.run_end - self.veto_length_run_end,
start + self.veto_length_run_start,
end - self.veto_length_run_end,
self.event_time_interval,
).astype(np.int64)
self.time_left = self.event_time_interval // 2
self.time_right = self.event_time_interval - self.time_left
return time

def inverse_field_distortion(self, x, y, z):
Expand All @@ -156,15 +144,9 @@ def set_chunk_splitting(self):
self.time_left = self.event_time_interval // 2
self.time_right = self.event_time_interval - self.time_left

def setup(self):
"""Sample the features of events."""
super(EventPositions, self).setup()
super(EventsSalting, self).setup()

self.init_rng()
self.init_run_meta()

time = self.sample_time()
def sampling(self, start, end):
"""Sample the features of events, (t, x, y, z, S1, S2) et al."""
time = self.sample_time(start, end)
self.n_events = len(time)
self.events_salting = np.empty(self.n_events, dtype=self.dtype)
self.events_salting["salt_number"] = np.arange(self.n_events)
Expand Down Expand Up @@ -208,26 +190,22 @@ def setup(self):

self.set_chunk_splitting()

def compute(self, chunk_i):
def compute(self, run_meta, start, end):
"""Copy and assign the salting events into chunk."""
indices = self.slices[chunk_i]

if chunk_i == 0:
start = self.run_start
else:
start = self.events_salting["time"][indices[0]] - self.time_left

if chunk_i == len(self.slices) - 1:
end = self.run_end
else:
end = self.events_salting["time"][indices[1] - 1] + self.time_right
return self.chunk(start=start, end=end, data=self.events_salting[indices[0] : indices[1]])

def is_ready(self, chunk_i):
if chunk_i < len(self.slices):
return True
else:
return False

def source_finished(self):
return True
self.sampling(start, end)
for chunk_i in range(len(self.slices)):
indices = self.slices[chunk_i]

if chunk_i == 0:
_start = start
else:
_start = self.events_salting["time"][indices[0]] - self.time_left

if chunk_i == len(self.slices) - 1:
_end = end
else:
_end = self.events_salting["time"][indices[1] - 1] + self.time_right

yield self.chunk(
start=_start, end=_end, data=self.events_salting[indices[0] : indices[1]]
)
1 change: 1 addition & 0 deletions axidence/plugins/salting/peaks_salted.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class PeaksSalted(PeakBasics):
__version__ = "0.0.0"
child_plugin = True
depends_on = "events_salting"
provides = "peaks_salted"
data_kind = "peaks_salted"
Expand Down
Loading