diff --git a/axidence/context.py b/axidence/context.py index f4a80ae..135d5f2 100644 --- a/axidence/context.py +++ b/axidence/context.py @@ -8,6 +8,7 @@ PeakProximitySalted, PeakShadowSalted, PeakAmbienceSalted, + PeakNearestTriggeringSalted, PeakSEDensitySalted, ) from axidence import ( @@ -15,6 +16,7 @@ EventBasicsSalted, EventShadowSalted, EventAmbienceSalted, + EventNearestTriggeringSalted, EventSEDensitySalted, EventBuilding, ) @@ -150,11 +152,13 @@ def _salt_to_context(self): PeakProximitySalted, PeakShadowSalted, PeakAmbienceSalted, + PeakNearestTriggeringSalted, PeakSEDensitySalted, EventsSalted, EventBasicsSalted, EventShadowSalted, EventAmbienceSalted, + EventNearestTriggeringSalted, EventSEDensitySalted, ) ) diff --git a/axidence/plugins/salting/event_fields.py b/axidence/plugins/salting/event_fields.py index d7473dd..15fc857 100644 --- a/axidence/plugins/salting/event_fields.py +++ b/axidence/plugins/salting/event_fields.py @@ -1,11 +1,19 @@ from typing import Tuple import strax -from straxen import EventShadow, EventAmbience, EventSEDensity +from strax import Plugin +from straxen import EventShadow, EventAmbience, EventNearestTriggering, EventSEDensity from ...utils import merge_salted_real -class EventShadowSalted(EventShadow): +class EventFieldsSalted(Plugin): + + def compute(self, events_salted, peaks_salted, peaks): + _peaks = merge_salted_real(peaks_salted, peaks, peaks.dtype) + return super().compute(events_salted, _peaks) + + +class EventShadowSalted(EventFieldsSalted, EventShadow): __version__ = "0.0.0" depends_on = ( "event_basics_salted", @@ -18,12 +26,8 @@ class EventShadowSalted(EventShadow): data_kind = "events_salted" save_when = strax.SaveWhen.EXPLICIT - def compute(self, events_salted, peaks_salted, peaks): - _peaks = merge_salted_real(peaks_salted, peaks, peaks.dtype) - return super().compute(events_salted, _peaks) - -class EventAmbienceSalted(EventAmbience): +class EventAmbienceSalted(EventFieldsSalted, EventAmbience): __version__ = "0.0.0" depends_on = ( "event_basics_salted", @@ -36,12 +40,22 @@ class EventAmbienceSalted(EventAmbience): data_kind = "events_salted" save_when = strax.SaveWhen.EXPLICIT - def compute(self, events_salted, peaks_salted, peaks): - _peaks = merge_salted_real(peaks_salted, peaks, peaks.dtype) - return super().compute(events_salted, _peaks) + +class EventNearestTriggeringSalted(EventFieldsSalted, EventNearestTriggering): + __version__ = "0.0.0" + depends_on = ( + "event_basics_salted", + "peaks_salted", + "peak_nearest_triggering_salted", + "peak_basics", + "peak_nearest_triggering", + ) + provides = "event_nearest_triggering_salted" + data_kind = "events_salted" + save_when = strax.SaveWhen.EXPLICIT -class EventSEDensitySalted(EventSEDensity): +class EventSEDensitySalted(EventFieldsSalted, EventSEDensity): __version__ = "0.0.0" depends_on: Tuple[str, ...] = ( "event_basics_salted", @@ -53,7 +67,3 @@ class EventSEDensitySalted(EventSEDensity): provides = "event_se_density_salted" data_kind = "events_salted" save_when = strax.SaveWhen.EXPLICIT - - def compute(self, events_salted, peaks_salted, peaks): - _peaks = merge_salted_real(peaks_salted, peaks, peaks.dtype) - return super().compute(events_salted, _peaks) diff --git a/axidence/plugins/salting/peak_correlation.py b/axidence/plugins/salting/peak_correlation.py index 1e93df8..1215f65 100644 --- a/axidence/plugins/salting/peak_correlation.py +++ b/axidence/plugins/salting/peak_correlation.py @@ -1,7 +1,7 @@ import numba import numpy as np import strax -from straxen import PeakProximity, PeakShadow, PeakAmbience, PeakSEDensity +from straxen import PeakProximity, PeakShadow, PeakAmbience, PeakNearestTriggering, PeakSEDensity from ...utils import copy_dtype @@ -98,6 +98,26 @@ def compute(self, peaks_salted, lone_hits, peaks): return result +class PeakNearestTriggeringSalted(PeakNearestTriggering): + __version__ = "0.0.0" + depends_on = ("peaks_salted", "peak_proximity_salted", "peak_basics", "peak_proximity") + provides = "peak_nearest_triggering_salted" + data_kind = "peaks_salted" + save_when = strax.SaveWhen.EXPLICIT + + def infer_dtype(self): + dtype = super().infer_dtype() + dtype += [ + (("Salting number of peaks", "salt_number"), np.int64), + ] + return dtype + + def compute(self, peaks_salted, peaks): + result = self.compute_triggering(peaks, peaks_salted) + result["salt_number"] = peaks_salted["salt_number"] + return result + + class PeakSEDensitySalted(PeakSEDensity): __version__ = "0.0.0" depends_on = ("peaks_salted", "peak_basics", "peak_positions")