Skip to content

Commit

Permalink
Add other peak and event level pairing plugins (#39)
Browse files Browse the repository at this point in the history
* Add `PeakProximityPaired`

* Add `PeakPositionsPaired`

* Add `EventsForcePaired`

* Add `EventInfosPaired`

* Save more fields into isolated peaks
  • Loading branch information
dachengx committed Apr 27, 2024
1 parent 7ab745f commit 3f9a319
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 30 deletions.
35 changes: 28 additions & 7 deletions axidence/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,22 @@
IsolatedS1,
IsolatedS2,
PeaksPaired,
PeakProximityPaired,
PeakPositionsPaired,
EventInfosPaired,
)

export, __all__ = strax.exporter()


@export
def ordinary_context(**kwargs):
"""Return a straxen context without paring and salting."""
return straxen.contexts.xenonnt_online(_database_init=False, **kwargs)


@strax.Context.add_method
def plugin_factory(st, data_type, suffixes):
def plugin_factory(st, data_type, suffixes, assign_attributes=None):
"""Create new plugins inheriting from the plugin which provides
data_type."""
plugin = st._plugin_class_registry[data_type]
Expand Down Expand Up @@ -64,6 +70,11 @@ def do_compute(self, chunk_i=None, **kwargs):
# https://github.com/AxFoundation/strax/blob/7da9a2a6375e7614181830484b322389986cf064/strax/context.py#L324
new_plugin.__name__ = plugin.__name__ + suffix

# assign the attributes from the original plugin
if assign_attributes and plugin.__name__ in assign_attributes:
for attr in assign_attributes[plugin.__name__]:
setattr(new_plugin, attr, getattr(p, attr))

# assign the same attributes as the original plugin
if hasattr(p, "depends_on"):
new_plugin.depends_on = tuple(d + snake for d in p.depends_on)
Expand Down Expand Up @@ -105,20 +116,25 @@ def do_compute(self, chunk_i=None, **kwargs):


@strax.Context.add_method
def replication_tree(st, suffixes=["Paired", "Salted"], tqdm_disable=True):
def replication_tree(st, suffixes=["Paired", "Salted"], assign_attributes=None, tqdm_disable=True):
"""Replicate the dependency tree.
The plugins in the new tree will have the suffixed depends_on,
provides and data_kind as the plugins in original tree.
"""
if assign_attributes is None:
# this is due to some features are assigned in `infer_dtype` of the original plugins:
# https://github.com/XENONnT/straxen/blob/e555c7dcada2743d2ea627ea49df783e9dba40e3/straxen/plugins/events/event_basics.py#L69
assign_attributes = {"EventBasics": ["peak_properties", "posrec_save"]}

snakes = ["_" + strax.camel_to_snake(suffix) for suffix in suffixes]
for k in st._plugin_class_registry.keys():
for s in snakes:
if s in k:
raise ValueError(f"{k} with suffix {s} is already registered!")
plugins_collection = []
for k in tqdm(st._plugin_class_registry.keys(), disable=tqdm_disable):
plugins_collection += st.plugin_factory(k, suffixes)
plugins_collection += st.plugin_factory(k, suffixes, assign_attributes=assign_attributes)

st.register(plugins_collection)

Expand Down Expand Up @@ -153,22 +169,27 @@ def _pair_to_context(self):
IsolatedS1,
IsolatedS2,
PeaksPaired,
PeakProximityPaired,
PeakPositionsPaired,
EventInfosPaired,
)
)


@strax.Context.add_method
def salt_to_context(st, tqdm_disable=True):
def salt_to_context(st, assign_attributes=None, tqdm_disable=True):
"""Register the salted plugins to the context."""
st.register((EventBuilding,))
st.replication_tree(suffixes=["Salted"], tqdm_disable=tqdm_disable)
st.replication_tree(
suffixes=["Salted"], assign_attributes=assign_attributes, tqdm_disable=tqdm_disable
)
st._salt_to_context()


@strax.Context.add_method
def salt_and_pair_to_context(st, tqdm_disable=True):
def salt_and_pair_to_context(st, assign_attributes=None, tqdm_disable=True):
"""Register the salted and paired plugins to the context."""
st.register((EventBuilding,))
st.replication_tree(tqdm_disable=tqdm_disable)
st.replication_tree(assign_attributes=assign_attributes, tqdm_disable=tqdm_disable)
st._salt_to_context()
st._pair_to_context()
4 changes: 2 additions & 2 deletions axidence/plugins/isolated/isolated_s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import straxen

from ...utils import needed_dtype, copy_dtype
from ...dtypes import positioned_peak_dtype
from ...dtypes import positioned_peak_dtype, correlation_fields


class IsolatedS1(Plugin):
Expand All @@ -24,7 +24,7 @@ class IsolatedS1(Plugin):
data_kind = "isolated_s1"

isolated_peaks_fields = straxen.URLConfig(
default=np.dtype(positioned_peak_dtype()).names,
default=list(np.dtype(positioned_peak_dtype()).names) + correlation_fields,
type=(list, tuple),
help="Needed fields in isolated peaks",
)
Expand Down
6 changes: 3 additions & 3 deletions axidence/plugins/isolated/isolated_s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import straxen

from ...utils import needed_dtype, copy_dtype
from ...dtypes import positioned_peak_dtype
from ...dtypes import positioned_peak_dtype, correlation_fields, event_level_fields


class IsolatedS2(Plugin):
Expand All @@ -29,13 +29,13 @@ class IsolatedS2(Plugin):
data_kind = "isolated_s2"

isolated_peaks_fields = straxen.URLConfig(
default=np.dtype(positioned_peak_dtype()).names,
default=list(np.dtype(positioned_peak_dtype()).names) + correlation_fields,
type=(list, tuple),
help="Needed fields in isolated peaks",
)

isolated_events_fields = straxen.URLConfig(
default=[],
default=event_level_fields,
type=(list, tuple),
help="Needed fields in isolated events",
)
Expand Down
186 changes: 186 additions & 0 deletions axidence/plugins/pairing/events_paired.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import numpy as np
import strax
from strax import OverlapWindowPlugin
import straxen
from straxen import Events, EventBasics

from ...utils import copy_dtype

export, __all__ = strax.exporter()


class EventsForcePaired(OverlapWindowPlugin):
"""Mimicking Events Force manually pairing of isolated S1 & S2 Actually NOT
used in AC simulation, but just for debug."""

depends_on = "peaks_paired"
provides = "events_paired"
data_kind = "events_paired"
save_when = strax.SaveWhen.EXPLICIT

paring_event_interval = straxen.URLConfig(
default=int(1e8),
type=int,
help="The interval which separates two events S1 [ns]",
)

def infer_dtype(self):
dtype_reference = strax.unpack_dtype(self.deps["peaks_paired"].dtype_for("peaks_paired"))
required_names = ["time", "endtime", "event_number"]
dtype = copy_dtype(dtype_reference, required_names)
return dtype

def get_window_size(self):
return 10 * self.paring_event_interval

def compute(self, peaks_paired):
peaks_event_number_sorted = np.sort(peaks_paired, order=("event_number", "time"))
event_number, event_number_index, event_number_count = np.unique(
peaks_event_number_sorted["event_number"], return_index=True, return_counts=True
)
event_number_index = np.append(event_number_index, len(peaks_event_number_sorted))
result = np.zeros(len(event_number), self.dtype)
result["time"] = peaks_event_number_sorted["time"][event_number_index[:-1]]
result["endtime"] = strax.endtime(peaks_event_number_sorted)[event_number_index[1:] - 1]
result["event_number"] = event_number
return result


@export
class EventInfosPaired(Events):
"""Superset to EventInfo Besides the features in EventInfo, also store the
shadow and ambience related features and the origin run_id, group_number,
time of main S1/2 peaks and the pairing-type."""

__version__ = "0.0.0"
depends_on = ("event_info_paired", "peaks_paired")
provides = "event_infos_paired"
data_kind = "events_paired"
save_when = strax.SaveWhen.EXPLICIT

ambience_fields = straxen.URLConfig(
default=["lh_before", "s0_before", "s1_before", "s2_before", "s2_near"],
type=(list, tuple),
help="Needed ambience related fields",
)

alternative_peak_add_fields = straxen.URLConfig(
default=["se_density"],
type=(list, tuple),
help="Fields to store also for alternative peaks",
)

@property
def peak_fields(self):
required_names = []
for key in ["s2_time_shadow", "s2_position_shadow"]:
required_names += [f"shadow_{key}", f"dt_{key}"]
for ambience in self.ambience_fields:
required_names += [f"n_{ambience}"]
required_names += [
"pdf_s2_position_shadow",
"nearest_dt_s1",
"nearest_dt_s2",
"se_density",
"left_dtime",
"right_dtime",
]
# TODO: reconsider about how to store run_id after implementing super runs
required_names += [
# "origin_run_id",
"origin_group_number",
"origin_time",
"origin_endtime",
"origin_center_time",
]
return required_names

@property
def event_fields(self):
dtype_reference = strax.unpack_dtype(self.deps["peaks_paired"].dtype_for("peaks_paired"))
peaks_dtype = copy_dtype(dtype_reference, self.peak_fields)
dtype = []
for d in peaks_dtype:
dtype += [
(("Main S1 " + d[0][0], "s1_" + d[0][1]), d[1]),
(("Main S2 " + d[0][0], "s2_" + d[0][1]), d[1]),
]
if d[0][1] in self.alternative_peak_add_fields:
dtype += [
(("Alternative S1 " + d[0][0], "alt_s1_" + d[0][1]), d[1]),
(("Alternative S2 " + d[0][0], "alt_s2_" + d[0][1]), d[1]),
]
dtype += [
(
("Type of event indicating whether the isolated S1 becomes main S1", "event_type"),
np.int8,
),
(("Event number in this dataset", "event_number"), np.int64),
]
return dtype

def infer_dtype(self):
return strax.merged_dtype(
[
self.deps["event_info_paired"].dtype_for("event_info_paired"),
np.dtype(self.event_fields),
]
)

def compute(self, events_paired, peaks_paired):
result = np.zeros(len(events_paired), dtype=self.dtype)

# assign the additional fields
EventBasics.set_nan_defaults(result)

# assign the features already in EventInfo
for q in self.deps["event_info_paired"].dtype_for("event_info_paired").names:
result[q] = events_paired[q]

# store AC-type
split_peaks = strax.split_by_containment(peaks_paired, events_paired)
for i, (event, sp) in enumerate(zip(events_paired, split_peaks)):
if np.unique(sp["event_number"]).size != 1:
raise ValueError(
f"Event {i} has multiple event numbers: "
f"{np.unique(sp['event_number'])}. "
"Maybe the paired events overlap."
)
result["event_number"][i] = sp["event_number"][0]
for idx, main_peak in zip([event["s1_index"], event["s2_index"]], ["s1_", "s2_"]):
if idx >= 0:
for n in self.peak_fields:
result[main_peak + n][i] = sp[n][idx]
for idx, main_peak in zip(
[event["alt_s1_index"], event["alt_s2_index"]], ["alt_s1_", "alt_s2_"]
):
if idx >= 0:
for n in self.peak_fields:
if n in self.alternative_peak_add_fields:
result[main_peak + n][i] = sp[n][idx]
# if the AC event have S2
if event["s2_index"] != -1:
if sp["origin_s1_index"][0] == -1:
# if isolated S2 is pure-isolated S2(w/o main S1)
if event["s1_index"] != -1:
# if successfully paired, considered as AC
result["event_type"][i] = 1
else:
# if unsuccessfully paired, not considered as AC
result["event_type"][i] = 2
else:
# if isolated S2 is ext-isolated S2(w/ main S1)
if event["s1_index"] != -1 and sp["origin_group_type"][event["s1_index"]] == 1:
# if successfully paired and main S1 is from isolated S1 but not isolated S2
# considered as AC
result["event_type"][i] = 3
else:
# otherwise, not considered as AC
result["event_type"][i] = 4
else:
result["event_type"][i] = 5

result["time"] = events_paired["time"]
result["endtime"] = events_paired["endtime"]

return result
Loading

0 comments on commit 3f9a319

Please sign in to comment.