Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate 2 and 3 hits pairing and add normalization factor #54

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions axidence/plugins/pairing/events_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def event_fields(self):
np.int8,
),
(("Event number in this dataset", "event_number"), np.int64),
(("Normalization of number of paired events", "normalization"), np.float32),
]
return dtype

Expand Down Expand Up @@ -147,6 +148,7 @@ def compute(self, events_paired, peaks_paired):
"Maybe the paired events overlap."
)
result["event_number"][i] = sp["event_number"][0]
result["normalization"][i] = sp["normalization"][0]
for idx, main_peak in zip([event["s1_index"], event["s2_index"]], ["s1_", "s2_"]):
if idx >= 0:
for n in self.peak_fields:
Expand Down
190 changes: 124 additions & 66 deletions axidence/plugins/pairing/peaks_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,15 @@ class PeaksPaired(ExhaustPlugin, DownChunkingPlugin):
# multiple factor is 100, then we will make 100 AC events
paring_rate_bootstrap_factor = straxen.URLConfig(
default=1e2,
type=(int, float),
help="Bootstrap factor for AC rate",
type=(int, float, list, tuple),
help=(
"Bootstrap factor for AC rate, "
"if list or tuple, they are the factor for 2 and 3+ hits S1"
),
)

s1_min_coincidence = straxen.URLConfig(
default=2, type=int, help="Minimum tight coincidence necessary to make an S1"
)

apply_shadow_matching = straxen.URLConfig(
Expand Down Expand Up @@ -115,6 +122,7 @@ def infer_dtype(self):
(("Original type of group", "origin_group_type"), np.int8),
(("Original s1_index in isolated S1", "origin_s1_index"), np.int32),
(("Original s2_index in isolated S2", "origin_s2_index"), np.int32),
(("Normalization of number of paired events", "normalization"), np.float32),
]
truth_dtype = [
(("Event number in this dataset", "event_number"), np.int64),
Expand All @@ -126,6 +134,7 @@ def infer_dtype(self):
),
(("Original isolated S1 group", "s1_group_number"), np.int32),
(("Original isolated S2 group", "s2_group_number"), np.int32),
(("Normalization of number of paired events", "normalization"), np.float32),
] + strax.time_fields
return dict(peaks_paired=peaks_dtype, truth_paired=truth_dtype)

Expand All @@ -138,6 +147,17 @@ def setup(self, prepare=True):
self.rng = np.random.default_rng(seed=self.pairing_seed)
self.time_left = self.paring_time_interval // 2
self.time_right = self.paring_time_interval - self.time_left
if self.s1_min_coincidence != 2:
raise NotImplementedError("Only support s1_min_coincidence = 2 now!")
if isinstance(self.paring_rate_bootstrap_factor, (list, tuple)):
if len(self.paring_rate_bootstrap_factor) != 2:
raise ValueError(
"The length of paring_rate_bootstrap_factor should be 2 "
"if provided list or tuple!"
)
self.bootstrap_factor = list(self.paring_rate_bootstrap_factor)
else:
self.bootstrap_factor = [self.paring_rate_bootstrap_factor] * 2

@staticmethod
def preprocess_isolated_s2(s2):
Expand Down Expand Up @@ -171,14 +191,18 @@ def simple_pairing(
s1_rate * s2_rate * (max_drift_time - min_drift_time) / units.s / paring_rate_correction
)
n_events = round(paring_rate_full * run_time * paring_rate_bootstrap_factor)
s1_group_number = rng.choice(len(s1), size=n_events, replace=True)
s2_group_number = rng.choice(len(s2), size=n_events, replace=True)
s1_group_index = rng.choice(len(s1), size=n_events, replace=True)
s2_group_index = rng.choice(len(s2), size=n_events, replace=True)
if fixed_drift_time is None:
drift_time = rng.uniform(min_drift_time, max_drift_time, size=n_events)
else:
warnings.warn(f"Using fixed drift time {fixed_drift_time}ns")
drift_time = np.full(n_events, fixed_drift_time)
return paring_rate_full, s1_group_number, s2_group_number, drift_time
return paring_rate_full, (
s1["group_number"][s1_group_index],
s2["group_number"][s2_group_index],
drift_time,
)

def shadow_reference_selection(self, peaks_salted):
return peaks_salted[peaks_salted["type"] == 2]
Expand Down Expand Up @@ -210,6 +234,7 @@ def preprocess_shadow(data, shadow_deltatime_exponent, delta_t=0, prefix=""):
x = np.log10(pre_s2_area * dt_s2_time_shadow**shadow_deltatime_exponent)
y = np.sqrt(np.log10(pre_s2_area) ** 2 + np.log10(dt_s2_time_shadow) ** 2)
sample = np.stack([x, y]).T
# sample = np.stack([np.log10(dt_s2_time_shadow), np.log10(pre_s2_area)]).T
return sample

@staticmethod
Expand All @@ -227,6 +252,7 @@ def shadow_matching(
paring_rate_correction,
paring_rate_bootstrap_factor,
rng,
preprocess_shadow,
onlyrate=False,
):
# perform Shadow matching technique
Expand All @@ -235,10 +261,8 @@ def shadow_matching(
# 2D equal binning
# prepare the 2D space, x is log(S2/dt), y is (log(S2)**2+log(dt)**2)**0.5
# because these 2 dimension is orthogonal
sampled_correlation = PeaksPaired.preprocess_shadow(
shadow_reference, shadow_deltatime_exponent
)
s1_sample = PeaksPaired.preprocess_shadow(s1, shadow_deltatime_exponent)
sampled_correlation = preprocess_shadow(shadow_reference, shadow_deltatime_exponent)
s1_sample = preprocess_shadow(s1, shadow_deltatime_exponent)

# use (x, y) distribution of isolated S1 as reference
# because it is more intense when shadow(S2/dt) is large
Expand Down Expand Up @@ -285,16 +309,16 @@ def shadow_matching(
if not onlyrate:
# get indices of the 2D bins
s1_digit = PeaksPaired.digitize2d(s1_sample, bin_edges, n_shadow_bins)
_s1_group_number = np.arange(len(s1))
s1_group_number_list = [
_s1_group_number[s1_digit == xd].tolist()
_s1_group_index = np.arange(len(s1))
s1_group_index_list = [
_s1_group_index[s1_digit == xd].tolist()
for xd in range(n_shadow_bins * n_shadow_bins)
]

drift_time_bins = np.linspace(min_drift_time, max_drift_time, n_drift_time_bins + 1)
drift_time_bin_center = (drift_time_bins[:-1] + drift_time_bins[1:]) / 2

group_number_list = []
group_index_list = []
_paring_rate_full = np.zeros(len(drift_time_bin_center))
for i in range(len(drift_time_bin_center)):
if shift_dt_shadow_matching:
Expand All @@ -303,9 +327,7 @@ def shadow_matching(
delta_t = drift_time_bin_center[i]
else:
delta_t = 0
data_sample = PeaksPaired.preprocess_shadow(
s2, shadow_deltatime_exponent, delta_t=delta_t
)
data_sample = preprocess_shadow(s2, shadow_deltatime_exponent, delta_t=delta_t)
ge.check_sample_sanity(data_sample)
# apply binning to (x, y)
s2_shadow_count = ge.apply_irregular_binning(
Expand All @@ -331,43 +353,48 @@ def shadow_matching(
if count_pairing.max() == 0:
count_pairing[mu_shadow.argmax()] = 1
s2_digit = PeaksPaired.digitize2d(data_sample, bin_edges, n_shadow_bins)
_s2_group_number = np.arange(len(s2))
s2_group_number_list = [
_s2_group_number[s2_digit == xd].tolist()
_s2_group_index = np.arange(len(s2))
s2_group_index_list = [
_s2_group_index[s2_digit == xd].tolist()
for xd in range(n_shadow_bins * n_shadow_bins)
]
# random sample isolated S1 and S2's group number
_s1_group_number = np.hstack(
_s1_group_index = np.hstack(
[
rng.choice(
s1_group_number_list[xd],
s1_group_index_list[xd],
size=count_pairing[xd],
)
for xd in range(n_shadow_bins * n_shadow_bins)
]
)
_s2_group_number = np.hstack(
_s2_group_index = np.hstack(
[
rng.choice(
s2_group_number_list[xd],
s2_group_index_list[xd],
size=count_pairing[xd],
)
for xd in range(n_shadow_bins * n_shadow_bins)
]
)
# sample drift time in this bin
_drift_time = rng.choice(
round(drift_time_bins[i + 1] - drift_time_bins[i]),
_drift_time = rng.uniform(
drift_time_bins[i],
drift_time_bins[i + 1],
size=count_pairing.sum(),
) + round(drift_time_bins[i])
group_number_list.append([_s1_group_number, _s2_group_number, _drift_time])
)
group_index_list.append([_s1_group_index, _s2_group_index, _drift_time])
paring_rate_full = _paring_rate_full.sum()
if not onlyrate:
s1_group_number = np.hstack([group[0] for group in group_number_list]).astype(int)
s2_group_number = np.hstack([group[1] for group in group_number_list]).astype(int)
drift_time = np.hstack([group[2] for group in group_number_list]).astype(int)
assert len(s1_group_number) == len(s2_group_number)
return paring_rate_full, s1_group_number, s2_group_number, drift_time
s1_group_index = np.hstack([group[0] for group in group_index_list]).astype(int)
s2_group_index = np.hstack([group[1] for group in group_index_list]).astype(int)
drift_time = np.hstack([group[2] for group in group_index_list]).astype(int)
assert len(s1_group_index) == len(s2_group_index)
return paring_rate_full, (
s1["group_number"][s1_group_index],
s2["group_number"][s2_group_index],
drift_time,
)

def split_chunks(self, n_peaks):
# divide results into chunks
Expand Down Expand Up @@ -416,7 +443,7 @@ def build_arrays(
# isolated S1 is assigned peak by peak
s1_index = s1_group_number[i]
for q in self.dtype["peaks_paired"].names:
if "origin" not in q and q not in ["event_number"]:
if "origin" not in q and q not in ["event_number", "normalization"]:
_array[0][q] = s1[s1_index][q]
# _array[0]["origin_run_id"] = s1["run_id"][s1_index]
_array[0]["origin_group_number"] = s1["group_number"][s1_index]
Expand All @@ -434,7 +461,7 @@ def build_arrays(
group_number = s2_group_number[i]
s2_group_i = s2[s2_group_index[group_number] : s2_group_index[group_number + 1]]
for q in self.dtype["peaks_paired"].names:
if "origin" not in q and q not in ["event_number"]:
if "origin" not in q and q not in ["event_number", "normalization"]:
_array[1:][q] = s2_group_i[q]
s2_index = s2_group_i["s2_index"]
# _array[1:]["origin_run_id"] = s2_group_i["run_id"]
Expand Down Expand Up @@ -524,38 +551,64 @@ def compute(self, isolated_s1, isolated_s2, peaks_salted, start, end):
print(f"S1 rate is {s1_rate:.3f}Hz")
print(f"There are {len(main_isolated_s2)} S2 peaks group")
print(f"S2 rate is {s2_rate * 1e3:.3f}mHz")
if self.apply_shadow_matching:
# simulate AC's drift time bin by bin
shadow_reference = self.shadow_reference_selection(peaks_salted)
paring_rate_full, s1_group_number, s2_group_number, drift_time = self.shadow_matching(
isolated_s1,
main_isolated_s2,
shadow_reference,
self.shadow_deltatime_exponent,
self.max_n_shadow_bins,
run_time,
self.max_drift_time,
self.min_drift_time,
self.n_drift_time_bins,
self.shift_dt_shadow_matching,
paring_rate_correction,
self.paring_rate_bootstrap_factor,
self.rng,
)
else:
paring_rate_full, s1_group_number, s2_group_number, drift_time = self.simple_pairing(
isolated_s1,
main_isolated_s2,
s1_rate,
s2_rate,
run_time,
self.max_drift_time,
self.min_drift_time,
paring_rate_correction,
self.paring_rate_bootstrap_factor,
self.fixed_drift_time,
self.rng,
)
n_hits_2 = isolated_s1["n_hits"] == 2
n_hits_masks = [n_hits_2, ~n_hits_2]
truths = []
for i, mask in enumerate(n_hits_masks):
if mask.sum() != 0:
if self.apply_shadow_matching:
# simulate AC's drift time bin by bin
shadow_reference = self.shadow_reference_selection(peaks_salted)
truth = self.shadow_matching(
isolated_s1[mask],
main_isolated_s2,
shadow_reference,
self.shadow_deltatime_exponent,
self.max_n_shadow_bins,
run_time,
self.max_drift_time,
self.min_drift_time,
self.n_drift_time_bins,
self.shift_dt_shadow_matching,
paring_rate_correction,
self.bootstrap_factor[i],
self.rng,
self.preprocess_shadow,
)
else:
truth = self.simple_pairing(
isolated_s1[mask],
main_isolated_s2,
s1_rate,
s2_rate,
run_time,
self.max_drift_time,
self.min_drift_time,
paring_rate_correction,
self.bootstrap_factor[i],
self.fixed_drift_time,
self.rng,
)
else:
truth = (
0.0,
(
np.empty(0, dtype=isolated_s1["group_number"].dtype),
np.empty(0, dtype=main_isolated_s2["group_number"].dtype),
[],
),
)
truths.append(truth)
paring_rate_full = truths[0][0] + truths[1][0]
s1_group_number = np.hstack([truths[0][1][0], truths[1][1][0]])
s2_group_number = np.hstack([truths[0][1][1], truths[1][1][1]])
drift_time = np.hstack([truths[0][1][2], truths[1][1][2]])
normalization = np.hstack(
[
np.full(len(truths[0][1][0]), 1 / self.bootstrap_factor[0]),
np.full(len(truths[1][1][0]), 1 / self.bootstrap_factor[1]),
]
)

print(f"AC pairing rate is {paring_rate_full * 1e3:.3f}mHz")
print(f"AC event number is {len(drift_time)}")
Expand Down Expand Up @@ -588,6 +641,11 @@ def compute(self, isolated_s1, isolated_s2, peaks_salted, start, end):
)
peaks_arrays["event_number"] += left_i
truth_arrays["event_number"] += left_i
peaks_arrays["normalization"] = np.repeat(
normalization[left_i:right_i],
n_peaks[left_i:right_i],
)
truth_arrays["normalization"] = normalization[left_i:right_i]

result = dict()
result["peaks_paired"] = self.chunk(
Expand Down
Loading