Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement only-S1/2 salting #57

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 35 additions & 30 deletions axidence/plugins/pairing/peaks_paired.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings
from immutabledict import immutabledict
import numpy as np
from scipy.stats import poisson
import strax
from strax import Plugin, DownChunkingPlugin
import straxen
Expand Down Expand Up @@ -111,6 +110,12 @@ class PeaksPaired(ExhaustPlugin, DownChunkingPlugin):
help="Whether shift drift time when performing shadow matching",
)

only_salt_s1 = straxen.URLConfig(
default=False,
type=bool,
help="Whether only salt S1",
)

apply_shadow_reweight = straxen.URLConfig(
default=True,
type=bool,
Expand Down Expand Up @@ -240,6 +245,10 @@ def shadow_reference_selection(self, events_salted, s2):
"""Select the reference events for shadow matching, also return
weights."""
reference = events_salted[events_salted["cut_main_s2_trigger_salted"]]

if self.only_salt_s1:
raise ValueError("Cannot only salt S1 when performing shadow matching!")

if self.apply_shadow_reweight:
sampler = SAMPLERS[self.s2_distribution](
self.s2_area_range, self.shadow_reweight_n_bins
Expand All @@ -248,6 +257,7 @@ def shadow_reference_selection(self, events_salted, s2):
else:
weight = np.ones(len(reference))
weight /= weight.sum()

if np.any(np.isnan(weight)):
raise ValueError("Some weights are NaN!")
dtype = np.dtype(
Expand Down Expand Up @@ -351,14 +361,12 @@ def shadow_matching(
],
n_partitions=[n_shadow_bins, n_shadow_bins],
)
if np.any(
ge.apply_irregular_binning(
data_sample=sampled_correlation,
bin_edges=bin_edges,
data_sample_weights=shadow_reference["weight"],
)
<= 0
):
ns = ge.apply_irregular_binning(
data_sample=sampled_correlation,
bin_edges=bin_edges,
data_sample_weights=shadow_reference["weight"],
)
if np.any(ns <= 0):
raise ValueError(
f"Weird! Find empty bin when the bin number is {n_shadow_bins}!"
)
Expand Down Expand Up @@ -414,15 +422,10 @@ def shadow_matching(
_paring_rate_full[i] = ac_rate_conditional.sum()
if not onlyrate:
# expectation of AC in each bin in this run
mu_shadow = ac_rate_conditional * run_time * paring_rate_bootstrap_factor
count_pairing = np.zeros_like(mu_shadow, dtype=int)
for ii in range(mu_shadow.shape[0]):
for jj in range(mu_shadow.shape[1]):
count_pairing[ii, jj] = poisson.rvs(mu=mu_shadow[ii, jj])
count_pairing = count_pairing.flatten()
# count_pairing = poisson.rvs(mu=mu_shadow).flatten()
lam_shadow = ac_rate_conditional * run_time * paring_rate_bootstrap_factor
count_pairing = rng.poisson(lam=lam_shadow).flatten()
if count_pairing.max() == 0:
count_pairing[mu_shadow.argmax()] = 1
count_pairing[lam_shadow.argmax()] = 1
s2_digit = PeaksPaired.digitize2d(data_sample, bin_edges, n_shadow_bins)
_s2_group_index = np.arange(len(s2))
s2_group_index_list = [
Expand Down Expand Up @@ -560,6 +563,12 @@ def build_arrays(
peaks_arrays[peaks_count : peaks_count + len(_array)] = _array
peaks_count += len(_array)

if peaks_count != len(peaks_arrays):
raise ValueError(
"Mismatch in total number of peaks in the chunk, "
f"expected {peaks_count}, got {len(peaks_arrays)}!"
)

# assign truth
truth_arrays = np.zeros(len(n_peaks), dtype=self.dtype["truth_paired"])
truth_arrays["time"] = peaks_arrays["time"][
Expand All @@ -584,19 +593,6 @@ def build_arrays(
):
raise ValueError("Some paired events overlap!")

peaks_arrays = np.sort(peaks_arrays, order=("time", "event_number"))

if peaks_count != len(peaks_arrays):
raise ValueError(
"Mismatch in total number of peaks in the chunk, "
f"expected {peaks_count}, got {len(peaks_arrays)}!"
)

# check overlap of peaks
n_overlap = (peaks_arrays["time"][1:] - peaks_arrays["endtime"][:-1] < 0).sum()
if n_overlap:
warnings.warn(f"{n_overlap} peaks overlap")

return peaks_arrays, truth_arrays

def compute(self, isolated_s1, isolated_s2, events_salted, start, end):
Expand Down Expand Up @@ -712,6 +708,7 @@ def compute(self, isolated_s1, isolated_s2, events_salted, start, end):
main_isolated_s2,
s2_group_index,
)

peaks_arrays["event_number"] += left_i
truth_arrays["event_number"] += left_i
peaks_arrays["normalization"] = np.repeat(
Expand All @@ -720,6 +717,14 @@ def compute(self, isolated_s1, isolated_s2, events_salted, start, end):
)
truth_arrays["normalization"] = normalization[left_i:right_i]

# becareful with all fields assignment after sorting
peaks_arrays = np.sort(peaks_arrays, order=("time", "event_number"))

# check overlap of peaks
n_overlap = (peaks_arrays["time"][1:] - peaks_arrays["endtime"][:-1] < 0).sum()
if n_overlap:
warnings.warn(f"{n_overlap} peaks overlap")

result = dict()
result["peaks_paired"] = self.chunk(
start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired"
Expand Down
44 changes: 34 additions & 10 deletions axidence/plugins/salting/event_building.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Tuple
import numpy as np
import strax
Expand All @@ -22,6 +23,18 @@ class EventsSalted(Events, ExhaustPlugin):
help="How many max drift time will the event builder extend",
)

only_salt_s1 = straxen.URLConfig(
default=False,
type=bool,
help="Whether only salt S1",
)

only_salt_s2 = straxen.URLConfig(
default=False,
type=bool,
help="Whether only salt S2",
)

def __init__(self):
super().__init__()
self.dtype = super().dtype + [
Expand Down Expand Up @@ -59,20 +72,30 @@ def compute(self, peaks_salted, peaks, start, end):

_peaks = merge_salted_real(peaks_salted, peaks, self._peaks_dtype)

# use S2s as anchors
anchor_peaks = peaks_salted[1::2]
if self.only_salt_s1 or self.only_salt_s2:
anchor_peaks = peaks_salted
else:
# use S2s as anchors by default
anchor_peaks = peaks_salted[1::2]

# check if the salting anchor can trigger
if self.only_salt_s1:
is_triggering = np.full(len(anchor_peaks), False)
else:
is_triggering = np.full(len(anchor_peaks), False)

if np.unique(anchor_peaks["type"]).size != 1:
raise ValueError("Expected only one type of anchor peaks!")

# initial the final result
n_events = len(peaks_salted) // 2
if self.only_salt_s1 or self.only_salt_s2:
n_events = len(peaks_salted)
else:
n_events = len(peaks_salted) // 2
if np.unique(peaks_salted["salt_number"]).size != n_events:
raise ValueError("Expected salt_number to be half of the input peaks number!")
result = np.empty(n_events, self.dtype)

# check if the salting anchor can trigger
is_triggering = self._is_triggering(anchor_peaks)

# prepare for an empty event
empty_events = np.empty(len(anchor_peaks), dtype=self.dtype)
empty_events["time"] = anchor_peaks["time"]
Expand Down Expand Up @@ -102,8 +125,8 @@ def compute(self, peaks_salted, peaks, start, end):

# assign the most important parameters
result["is_triggering"] = is_triggering
result["salt_number"] = peaks_salted["salt_number"][::2]
result["event_number"] = peaks_salted["salt_number"][::2]
result["salt_number"] = np.unique(peaks_salted["salt_number"])
result["event_number"] = result["salt_number"]

if np.any(np.diff(result["time"]) < 0):
raise ValueError("Expected time to be sorted!")
Expand Down Expand Up @@ -170,6 +193,7 @@ def compute(self, events_salted, peaks_salted, peaks):
self.fill_events(result, events_salted, split_peaks)
result["is_triggering"] = events_salted["is_triggering"]

if np.all(result["s1_salt_number"] < 0) or np.all(result["s2_salt_number"] < 0):
raise ValueError("Found zero triggered salted peaks!")
for i in [1, 2]:
if np.all(result[f"s{i}_salt_number"] < 0):
warnings.warn(f"Found zero triggered salted S{i}!")
return result
3 changes: 3 additions & 0 deletions axidence/plugins/salting/events_salting.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def sampling(self, start, end):
self.events_salting["s1_area"] = np.clip(self.events_salting["s1_area"], *s1_area_range)
self.events_salting["s2_area"] = np.clip(self.events_salting["s2_area"], *s2_area_range)

if np.any(np.diff(self.events_salting["time"]) <= 0):
raise ValueError("The time is not strictly increasing!")

self.set_chunk_splitting()

def compute(self, run_meta, start, end):
Expand Down
11 changes: 11 additions & 0 deletions axidence/plugins/salting/peaks_salted.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def infer_dtype(self):
]
return dtype

def setup(self):
super().setup()
if self.only_salt_s1 and self.only_salt_s2:
raise ValueError("Cannot only salt both S1 and S2.")

def compute(self, events_salting):
"""Copy features of events_salting into peaks_salted."""
peaks_salted = np.empty(len(events_salting) * 2, dtype=self.dtype)
Expand Down Expand Up @@ -76,4 +81,10 @@ def compute(self, events_salting):
]
).T.flatten()
peaks_salted["salt_number"] = np.repeat(events_salting["salt_number"], 2)

# Filter out peaks that are not S1 or S2
if self.only_salt_s1:
peaks_salted = peaks_salted[peaks_salted["type"] == 1]
if self.only_salt_s2:
peaks_salted = peaks_salted[peaks_salted["type"] == 2]
return peaks_salted
Loading