From 3f9a319e909ec01f262bcd7c5c535f4d3b3f7eae Mon Sep 17 00:00:00 2001 From: Dacheng Xu Date: Sat, 27 Apr 2024 10:40:34 -0400 Subject: [PATCH] Add other peak and event level pairing plugins (#39) * Add `PeakProximityPaired` * Add `PeakPositionsPaired` * Add `EventsForcePaired` * Add `EventInfosPaired` * Save more fields into isolated peaks --- axidence/context.py | 35 +++- axidence/plugins/isolated/isolated_s1.py | 4 +- axidence/plugins/isolated/isolated_s2.py | 6 +- axidence/plugins/pairing/events_paired.py | 186 ++++++++++++++++++++++ axidence/plugins/pairing/peaks_paired.py | 98 ++++++++++-- axidence/plugins/salting/peaks_salted.py | 4 +- axidence/utils.py | 11 +- 7 files changed, 314 insertions(+), 30 deletions(-) diff --git a/axidence/context.py b/axidence/context.py index 1672585..0ab7f59 100644 --- a/axidence/context.py +++ b/axidence/context.py @@ -24,16 +24,22 @@ IsolatedS1, IsolatedS2, PeaksPaired, + PeakProximityPaired, + PeakPositionsPaired, + EventInfosPaired, ) +export, __all__ = strax.exporter() + +@export def ordinary_context(**kwargs): """Return a straxen context without paring and salting.""" return straxen.contexts.xenonnt_online(_database_init=False, **kwargs) @strax.Context.add_method -def plugin_factory(st, data_type, suffixes): +def plugin_factory(st, data_type, suffixes, assign_attributes=None): """Create new plugins inheriting from the plugin which provides data_type.""" plugin = st._plugin_class_registry[data_type] @@ -64,6 +70,11 @@ def do_compute(self, chunk_i=None, **kwargs): # https://github.com/AxFoundation/strax/blob/7da9a2a6375e7614181830484b322389986cf064/strax/context.py#L324 new_plugin.__name__ = plugin.__name__ + suffix + # assign the attributes from the original plugin + if assign_attributes and plugin.__name__ in assign_attributes: + for attr in assign_attributes[plugin.__name__]: + setattr(new_plugin, attr, getattr(p, attr)) + # assign the same attributes as the original plugin if hasattr(p, "depends_on"): new_plugin.depends_on = tuple(d + snake for d in p.depends_on) @@ -105,12 +116,17 @@ def do_compute(self, chunk_i=None, **kwargs): @strax.Context.add_method -def replication_tree(st, suffixes=["Paired", "Salted"], tqdm_disable=True): +def replication_tree(st, suffixes=["Paired", "Salted"], assign_attributes=None, tqdm_disable=True): """Replicate the dependency tree. The plugins in the new tree will have the suffixed depends_on, provides and data_kind as the plugins in original tree. """ + if assign_attributes is None: + # this is due to some features are assigned in `infer_dtype` of the original plugins: + # https://github.com/XENONnT/straxen/blob/e555c7dcada2743d2ea627ea49df783e9dba40e3/straxen/plugins/events/event_basics.py#L69 + assign_attributes = {"EventBasics": ["peak_properties", "posrec_save"]} + snakes = ["_" + strax.camel_to_snake(suffix) for suffix in suffixes] for k in st._plugin_class_registry.keys(): for s in snakes: @@ -118,7 +134,7 @@ def replication_tree(st, suffixes=["Paired", "Salted"], tqdm_disable=True): raise ValueError(f"{k} with suffix {s} is already registered!") plugins_collection = [] for k in tqdm(st._plugin_class_registry.keys(), disable=tqdm_disable): - plugins_collection += st.plugin_factory(k, suffixes) + plugins_collection += st.plugin_factory(k, suffixes, assign_attributes=assign_attributes) st.register(plugins_collection) @@ -153,22 +169,27 @@ def _pair_to_context(self): IsolatedS1, IsolatedS2, PeaksPaired, + PeakProximityPaired, + PeakPositionsPaired, + EventInfosPaired, ) ) @strax.Context.add_method -def salt_to_context(st, tqdm_disable=True): +def salt_to_context(st, assign_attributes=None, tqdm_disable=True): """Register the salted plugins to the context.""" st.register((EventBuilding,)) - st.replication_tree(suffixes=["Salted"], tqdm_disable=tqdm_disable) + st.replication_tree( + suffixes=["Salted"], assign_attributes=assign_attributes, tqdm_disable=tqdm_disable + ) st._salt_to_context() @strax.Context.add_method -def salt_and_pair_to_context(st, tqdm_disable=True): +def salt_and_pair_to_context(st, assign_attributes=None, tqdm_disable=True): """Register the salted and paired plugins to the context.""" st.register((EventBuilding,)) - st.replication_tree(tqdm_disable=tqdm_disable) + st.replication_tree(assign_attributes=assign_attributes, tqdm_disable=tqdm_disable) st._salt_to_context() st._pair_to_context() diff --git a/axidence/plugins/isolated/isolated_s1.py b/axidence/plugins/isolated/isolated_s1.py index a600c27..e56cb74 100644 --- a/axidence/plugins/isolated/isolated_s1.py +++ b/axidence/plugins/isolated/isolated_s1.py @@ -4,7 +4,7 @@ import straxen from ...utils import needed_dtype, copy_dtype -from ...dtypes import positioned_peak_dtype +from ...dtypes import positioned_peak_dtype, correlation_fields class IsolatedS1(Plugin): @@ -24,7 +24,7 @@ class IsolatedS1(Plugin): data_kind = "isolated_s1" isolated_peaks_fields = straxen.URLConfig( - default=np.dtype(positioned_peak_dtype()).names, + default=list(np.dtype(positioned_peak_dtype()).names) + correlation_fields, type=(list, tuple), help="Needed fields in isolated peaks", ) diff --git a/axidence/plugins/isolated/isolated_s2.py b/axidence/plugins/isolated/isolated_s2.py index 31c53a9..488477b 100644 --- a/axidence/plugins/isolated/isolated_s2.py +++ b/axidence/plugins/isolated/isolated_s2.py @@ -4,7 +4,7 @@ import straxen from ...utils import needed_dtype, copy_dtype -from ...dtypes import positioned_peak_dtype +from ...dtypes import positioned_peak_dtype, correlation_fields, event_level_fields class IsolatedS2(Plugin): @@ -29,13 +29,13 @@ class IsolatedS2(Plugin): data_kind = "isolated_s2" isolated_peaks_fields = straxen.URLConfig( - default=np.dtype(positioned_peak_dtype()).names, + default=list(np.dtype(positioned_peak_dtype()).names) + correlation_fields, type=(list, tuple), help="Needed fields in isolated peaks", ) isolated_events_fields = straxen.URLConfig( - default=[], + default=event_level_fields, type=(list, tuple), help="Needed fields in isolated events", ) diff --git a/axidence/plugins/pairing/events_paired.py b/axidence/plugins/pairing/events_paired.py index e69de29..3a49a57 100644 --- a/axidence/plugins/pairing/events_paired.py +++ b/axidence/plugins/pairing/events_paired.py @@ -0,0 +1,186 @@ +import numpy as np +import strax +from strax import OverlapWindowPlugin +import straxen +from straxen import Events, EventBasics + +from ...utils import copy_dtype + +export, __all__ = strax.exporter() + + +class EventsForcePaired(OverlapWindowPlugin): + """Mimicking Events Force manually pairing of isolated S1 & S2 Actually NOT + used in AC simulation, but just for debug.""" + + depends_on = "peaks_paired" + provides = "events_paired" + data_kind = "events_paired" + save_when = strax.SaveWhen.EXPLICIT + + paring_event_interval = straxen.URLConfig( + default=int(1e8), + type=int, + help="The interval which separates two events S1 [ns]", + ) + + def infer_dtype(self): + dtype_reference = strax.unpack_dtype(self.deps["peaks_paired"].dtype_for("peaks_paired")) + required_names = ["time", "endtime", "event_number"] + dtype = copy_dtype(dtype_reference, required_names) + return dtype + + def get_window_size(self): + return 10 * self.paring_event_interval + + def compute(self, peaks_paired): + peaks_event_number_sorted = np.sort(peaks_paired, order=("event_number", "time")) + event_number, event_number_index, event_number_count = np.unique( + peaks_event_number_sorted["event_number"], return_index=True, return_counts=True + ) + event_number_index = np.append(event_number_index, len(peaks_event_number_sorted)) + result = np.zeros(len(event_number), self.dtype) + result["time"] = peaks_event_number_sorted["time"][event_number_index[:-1]] + result["endtime"] = strax.endtime(peaks_event_number_sorted)[event_number_index[1:] - 1] + result["event_number"] = event_number + return result + + +@export +class EventInfosPaired(Events): + """Superset to EventInfo Besides the features in EventInfo, also store the + shadow and ambience related features and the origin run_id, group_number, + time of main S1/2 peaks and the pairing-type.""" + + __version__ = "0.0.0" + depends_on = ("event_info_paired", "peaks_paired") + provides = "event_infos_paired" + data_kind = "events_paired" + save_when = strax.SaveWhen.EXPLICIT + + ambience_fields = straxen.URLConfig( + default=["lh_before", "s0_before", "s1_before", "s2_before", "s2_near"], + type=(list, tuple), + help="Needed ambience related fields", + ) + + alternative_peak_add_fields = straxen.URLConfig( + default=["se_density"], + type=(list, tuple), + help="Fields to store also for alternative peaks", + ) + + @property + def peak_fields(self): + required_names = [] + for key in ["s2_time_shadow", "s2_position_shadow"]: + required_names += [f"shadow_{key}", f"dt_{key}"] + for ambience in self.ambience_fields: + required_names += [f"n_{ambience}"] + required_names += [ + "pdf_s2_position_shadow", + "nearest_dt_s1", + "nearest_dt_s2", + "se_density", + "left_dtime", + "right_dtime", + ] + # TODO: reconsider about how to store run_id after implementing super runs + required_names += [ + # "origin_run_id", + "origin_group_number", + "origin_time", + "origin_endtime", + "origin_center_time", + ] + return required_names + + @property + def event_fields(self): + dtype_reference = strax.unpack_dtype(self.deps["peaks_paired"].dtype_for("peaks_paired")) + peaks_dtype = copy_dtype(dtype_reference, self.peak_fields) + dtype = [] + for d in peaks_dtype: + dtype += [ + (("Main S1 " + d[0][0], "s1_" + d[0][1]), d[1]), + (("Main S2 " + d[0][0], "s2_" + d[0][1]), d[1]), + ] + if d[0][1] in self.alternative_peak_add_fields: + dtype += [ + (("Alternative S1 " + d[0][0], "alt_s1_" + d[0][1]), d[1]), + (("Alternative S2 " + d[0][0], "alt_s2_" + d[0][1]), d[1]), + ] + dtype += [ + ( + ("Type of event indicating whether the isolated S1 becomes main S1", "event_type"), + np.int8, + ), + (("Event number in this dataset", "event_number"), np.int64), + ] + return dtype + + def infer_dtype(self): + return strax.merged_dtype( + [ + self.deps["event_info_paired"].dtype_for("event_info_paired"), + np.dtype(self.event_fields), + ] + ) + + def compute(self, events_paired, peaks_paired): + result = np.zeros(len(events_paired), dtype=self.dtype) + + # assign the additional fields + EventBasics.set_nan_defaults(result) + + # assign the features already in EventInfo + for q in self.deps["event_info_paired"].dtype_for("event_info_paired").names: + result[q] = events_paired[q] + + # store AC-type + split_peaks = strax.split_by_containment(peaks_paired, events_paired) + for i, (event, sp) in enumerate(zip(events_paired, split_peaks)): + if np.unique(sp["event_number"]).size != 1: + raise ValueError( + f"Event {i} has multiple event numbers: " + f"{np.unique(sp['event_number'])}. " + "Maybe the paired events overlap." + ) + result["event_number"][i] = sp["event_number"][0] + for idx, main_peak in zip([event["s1_index"], event["s2_index"]], ["s1_", "s2_"]): + if idx >= 0: + for n in self.peak_fields: + result[main_peak + n][i] = sp[n][idx] + for idx, main_peak in zip( + [event["alt_s1_index"], event["alt_s2_index"]], ["alt_s1_", "alt_s2_"] + ): + if idx >= 0: + for n in self.peak_fields: + if n in self.alternative_peak_add_fields: + result[main_peak + n][i] = sp[n][idx] + # if the AC event have S2 + if event["s2_index"] != -1: + if sp["origin_s1_index"][0] == -1: + # if isolated S2 is pure-isolated S2(w/o main S1) + if event["s1_index"] != -1: + # if successfully paired, considered as AC + result["event_type"][i] = 1 + else: + # if unsuccessfully paired, not considered as AC + result["event_type"][i] = 2 + else: + # if isolated S2 is ext-isolated S2(w/ main S1) + if event["s1_index"] != -1 and sp["origin_group_type"][event["s1_index"]] == 1: + # if successfully paired and main S1 is from isolated S1 but not isolated S2 + # considered as AC + result["event_type"][i] = 3 + else: + # otherwise, not considered as AC + result["event_type"][i] = 4 + else: + result["event_type"][i] = 5 + + result["time"] = events_paired["time"] + result["endtime"] = events_paired["endtime"] + + return result diff --git a/axidence/plugins/pairing/peaks_paired.py b/axidence/plugins/pairing/peaks_paired.py index cb5b5cc..921b695 100644 --- a/axidence/plugins/pairing/peaks_paired.py +++ b/axidence/plugins/pairing/peaks_paired.py @@ -2,10 +2,13 @@ from immutabledict import immutabledict import numpy as np import strax +from strax import Plugin import straxen from straxen import units +from straxen import PeakProximity -from ...dtypes import positioned_peak_dtype +from ...utils import copy_dtype +from ...dtypes import peak_positions_dtype from ...plugin import ExhaustPlugin, RunMetaPlugin @@ -22,18 +25,6 @@ class PeaksPaired(ExhaustPlugin, RunMetaPlugin): help="Seed for pairing", ) - isolated_peaks_fields = straxen.URLConfig( - default=np.dtype(positioned_peak_dtype()).names, - type=(list, tuple), - help="Needed fields in isolated peaks", - ) - - isolated_events_fields = straxen.URLConfig( - default=[], - type=(list, tuple), - help="Needed fields in isolated events", - ) - real_run_start = straxen.URLConfig( default=None, type=(int, None), @@ -118,7 +109,8 @@ def infer_dtype(self): dtype = strax.unpack_dtype(self.deps["isolated_s1"].dtype_for("isolated_s1")) # TODO: reconsider about how to store run_id after implementing super runs peaks_dtype = dtype + [ - (("Event number in this dataset", "event_number"), np.int32), + # since event_number is int64 in event_basics + (("Event number in this dataset", "event_number"), np.int64), # (("Original run id", "origin_run_id"), np.int32), (("Original isolated S1/S2 group", "origin_group_number"), np.int32), (("Original time of peaks", "origin_time"), np.int64), @@ -130,7 +122,7 @@ def infer_dtype(self): (("Original s2_index in isolated S2", "origin_s2_index"), np.int32), ] truth_dtype = [ - (("Event number in this dataset", "event_number"), np.int32), + (("Event number in this dataset", "event_number"), np.int64), # (("Original run id of isolated S1", "s1_run_id"), np.int32), # (("Original run id of isolated S2", "s2_run_id"), np.int32), ( @@ -414,3 +406,79 @@ def compute(self, isolated_s1, isolated_s2, events_salted): assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6 return result + + +class PeakProximityPaired(PeakProximity): + __version__ = "0.0.0" + depends_on = "peaks_paired" + provides = "peak_proximity_paired" + data_kind = "peaks_paired" + save_when = strax.SaveWhen.EXPLICIT + + use_origin_n_competing = straxen.URLConfig( + default=False, + type=bool, + help="Whether use original n_competing", + ) + + def infer_dtype(self): + dtype_reference = strax.unpack_dtype(self.deps["peaks_paired"].dtype_for("peaks_paired")) + required_names = ["time", "endtime", "n_competing"] + dtype = copy_dtype(dtype_reference, required_names) + return dtype + + def compute(self, peaks_paired): + if self.use_origin_n_competing: + warnings.warn("Using original n_competing for paired peaks") + n_competing = peaks_paired["origin_n_competing"].copy() + else: + # add `n_competing` to isolated S1 and isolated S2 because injection of peaks + # will not consider the competing window because + # that window is much larger than the max drift time + n_competing = np.zeros(len(peaks_paired), self.dtype["n_competing"]) + peaks_event_number_sorted = np.sort(peaks_paired, order=("event_number", "time")) + event_number, event_number_index, event_number_count = np.unique( + peaks_event_number_sorted["event_number"], + return_index=True, + return_counts=True, + ) + event_number_index = np.append(event_number_index, len(peaks_event_number_sorted)) + for i in range(len(event_number)): + areas = peaks_event_number_sorted["area"][ + event_number_index[i] : event_number_index[i + 1] + ].copy() + types = peaks_event_number_sorted["origin_group_type"][ + event_number_index[i] : event_number_index[i + 1] + ].copy() + n_competing_s = peaks_event_number_sorted["origin_n_competing"][ + event_number_index[i] : event_number_index[i + 1] + ].copy() + threshold = areas * self.min_area_fraction + for j in range(event_number_count[i]): + if types[j] == 1: + n_competing_s[j] += np.sum(areas[types == 2] > threshold[j]) + elif types[j] == 2: + n_competing_s[j] += np.sum(areas[types == 1] > threshold[j]) + n_competing[event_number_index[i] : event_number_index[i + 1]] = n_competing_s + + return dict( + time=peaks_paired["time"], + endtime=strax.endtime(peaks_paired), + n_competing=n_competing[peaks_event_number_sorted["time"].argsort()], + ) + + +class PeakPositionsPaired(Plugin): + __version__ = "0.0.0" + depends_on = "peaks_paired" + provides = "peak_positions_paired" + save_when = strax.SaveWhen.EXPLICIT + + def infer_dtype(self): + return peak_positions_dtype() + + def compute(self, peaks_paired): + result = np.zeros(len(peaks_paired), dtype=self.dtype) + for q in self.dtype.names: + result[q] = peaks_paired[q] + return result diff --git a/axidence/plugins/salting/peaks_salted.py b/axidence/plugins/salting/peaks_salted.py index 9d66c90..afd9d13 100644 --- a/axidence/plugins/salting/peaks_salted.py +++ b/axidence/plugins/salting/peaks_salted.py @@ -35,8 +35,8 @@ def infer_dtype(self): dtype = copy_dtype(dtype_reference, required_names) # since event_number is int64 in event_basics dtype += [ - ("x", np.float32, "Reconstructed S2 X position (cm), uncorrected"), - ("y", np.float32, "Reconstructed S2 Y position (cm), uncorrected"), + (("Reconstructed S2 X position (cm), uncorrected", "x"), np.float32), + (("Reconstructed S2 Y position (cm), uncorrected", "y"), np.float32), (("Salting number of peaks", "salt_number"), np.int64), ] return dtype diff --git a/axidence/utils.py b/axidence/utils.py index 21a01ba..670311e 100644 --- a/axidence/utils.py +++ b/axidence/utils.py @@ -11,6 +11,15 @@ def copy_dtype(dtype_reference, required_names): Returns: list: copied dtype """ + if not isinstance(required_names, (set, list, tuple)): + raise ValueError( + "required_names must be set, list or tuple, " + f"not {type(required_names)}, got {required_names}!" + ) + if not isinstance(dtype_reference, list): + raise ValueError( + f"dtype_reference must be list, not {type(dtype_reference)}, got {dtype_reference}!" + ) dtype = [] for n in required_names: for x in dtype_reference: @@ -20,7 +29,7 @@ def copy_dtype(dtype_reference, required_names): found = True break if not found: - raise ValueError(f"Could not find {n} in dtype_reference!") + raise ValueError(f"Could not find {n} in {dtype_reference}!") return dtype