Skip to content

Commit

Permalink
Separate 2 and 3 hits pairing and add normalization factor
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx committed Apr 29, 2024
1 parent 79fd51c commit e6ef782
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 66 deletions.
2 changes: 2 additions & 0 deletions axidence/plugins/pairing/events_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def event_fields(self):
np.int8,
),
(("Event number in this dataset", "event_number"), np.int64),
(("Normalization of number of paired events", "normalization"), np.float32),
]
return dtype

Expand Down Expand Up @@ -147,6 +148,7 @@ def compute(self, events_paired, peaks_paired):
"Maybe the paired events overlap."
)
result["event_number"][i] = sp["event_number"][0]
result["normalization"][i] = sp["normalization"][0]
for idx, main_peak in zip([event["s1_index"], event["s2_index"]], ["s1_", "s2_"]):
if idx >= 0:
for n in self.peak_fields:
Expand Down
190 changes: 124 additions & 66 deletions axidence/plugins/pairing/peaks_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,15 @@ class PeaksPaired(ExhaustPlugin, DownChunkingPlugin):
# multiple factor is 100, then we will make 100 AC events
paring_rate_bootstrap_factor = straxen.URLConfig(
default=1e2,
type=(int, float),
help="Bootstrap factor for AC rate",
type=(int, float, list, tuple),
help=(
"Bootstrap factor for AC rate, "
"if list or tuple, they are the factor for 2 and 3+ hits S1"
),
)

s1_min_coincidence = straxen.URLConfig(
default=2, type=int, help="Minimum tight coincidence necessary to make an S1"
)

apply_shadow_matching = straxen.URLConfig(
Expand Down Expand Up @@ -115,6 +122,7 @@ def infer_dtype(self):
(("Original type of group", "origin_group_type"), np.int8),
(("Original s1_index in isolated S1", "origin_s1_index"), np.int32),
(("Original s2_index in isolated S2", "origin_s2_index"), np.int32),
(("Normalization of number of paired events", "normalization"), np.float32),
]
truth_dtype = [
(("Event number in this dataset", "event_number"), np.int64),
Expand All @@ -126,6 +134,7 @@ def infer_dtype(self):
),
(("Original isolated S1 group", "s1_group_number"), np.int32),
(("Original isolated S2 group", "s2_group_number"), np.int32),
(("Normalization of number of paired events", "normalization"), np.float32),
] + strax.time_fields
return dict(peaks_paired=peaks_dtype, truth_paired=truth_dtype)

Expand All @@ -138,6 +147,17 @@ def setup(self, prepare=True):
self.rng = np.random.default_rng(seed=self.pairing_seed)
self.time_left = self.paring_time_interval // 2
self.time_right = self.paring_time_interval - self.time_left
if self.s1_min_coincidence != 2:
raise NotImplementedError("Only support s1_min_coincidence = 2 now!")
if isinstance(self.paring_rate_bootstrap_factor, (list, tuple)):
if len(self.paring_rate_bootstrap_factor) != 2:
raise ValueError(
"The length of paring_rate_bootstrap_factor should be 2 "
"if provided list or tuple!"
)
self.bootstrap_factor = list(self.paring_rate_bootstrap_factor)
else:
self.bootstrap_factor = [self.paring_rate_bootstrap_factor] * 2

@staticmethod
def preprocess_isolated_s2(s2):
Expand Down Expand Up @@ -171,14 +191,18 @@ def simple_pairing(
s1_rate * s2_rate * (max_drift_time - min_drift_time) / units.s / paring_rate_correction
)
n_events = round(paring_rate_full * run_time * paring_rate_bootstrap_factor)
s1_group_number = rng.choice(len(s1), size=n_events, replace=True)
s2_group_number = rng.choice(len(s2), size=n_events, replace=True)
s1_group_index = rng.choice(len(s1), size=n_events, replace=True)
s2_group_index = rng.choice(len(s2), size=n_events, replace=True)
if fixed_drift_time is None:
drift_time = rng.uniform(min_drift_time, max_drift_time, size=n_events)
else:
warnings.warn(f"Using fixed drift time {fixed_drift_time}ns")
drift_time = np.full(n_events, fixed_drift_time)
return paring_rate_full, s1_group_number, s2_group_number, drift_time
return paring_rate_full, (
s1["group_number"][s1_group_index],
s2["group_number"][s2_group_index],
drift_time,
)

def shadow_reference_selection(self, peaks_salted):
return peaks_salted[peaks_salted["type"] == 2]
Expand Down Expand Up @@ -210,6 +234,7 @@ def preprocess_shadow(data, shadow_deltatime_exponent, delta_t=0, prefix=""):
x = np.log10(pre_s2_area * dt_s2_time_shadow**shadow_deltatime_exponent)
y = np.sqrt(np.log10(pre_s2_area) ** 2 + np.log10(dt_s2_time_shadow) ** 2)
sample = np.stack([x, y]).T
# sample = np.stack([np.log10(dt_s2_time_shadow), np.log10(pre_s2_area)]).T
return sample

@staticmethod
Expand All @@ -227,6 +252,7 @@ def shadow_matching(
paring_rate_correction,
paring_rate_bootstrap_factor,
rng,
preprocess_shadow,
onlyrate=False,
):
# perform Shadow matching technique
Expand All @@ -235,10 +261,8 @@ def shadow_matching(
# 2D equal binning
# prepare the 2D space, x is log(S2/dt), y is (log(S2)**2+log(dt)**2)**0.5
# because these 2 dimension is orthogonal
sampled_correlation = PeaksPaired.preprocess_shadow(
shadow_reference, shadow_deltatime_exponent
)
s1_sample = PeaksPaired.preprocess_shadow(s1, shadow_deltatime_exponent)
sampled_correlation = preprocess_shadow(shadow_reference, shadow_deltatime_exponent)
s1_sample = preprocess_shadow(s1, shadow_deltatime_exponent)

# use (x, y) distribution of isolated S1 as reference
# because it is more intense when shadow(S2/dt) is large
Expand Down Expand Up @@ -285,16 +309,16 @@ def shadow_matching(
if not onlyrate:
# get indices of the 2D bins
s1_digit = PeaksPaired.digitize2d(s1_sample, bin_edges, n_shadow_bins)
_s1_group_number = np.arange(len(s1))
s1_group_number_list = [
_s1_group_number[s1_digit == xd].tolist()
_s1_group_index = np.arange(len(s1))
s1_group_index_list = [
_s1_group_index[s1_digit == xd].tolist()
for xd in range(n_shadow_bins * n_shadow_bins)
]

drift_time_bins = np.linspace(min_drift_time, max_drift_time, n_drift_time_bins + 1)
drift_time_bin_center = (drift_time_bins[:-1] + drift_time_bins[1:]) / 2

group_number_list = []
group_index_list = []
_paring_rate_full = np.zeros(len(drift_time_bin_center))
for i in range(len(drift_time_bin_center)):
if shift_dt_shadow_matching:
Expand All @@ -303,9 +327,7 @@ def shadow_matching(
delta_t = drift_time_bin_center[i]
else:
delta_t = 0
data_sample = PeaksPaired.preprocess_shadow(
s2, shadow_deltatime_exponent, delta_t=delta_t
)
data_sample = preprocess_shadow(s2, shadow_deltatime_exponent, delta_t=delta_t)
ge.check_sample_sanity(data_sample)
# apply binning to (x, y)
s2_shadow_count = ge.apply_irregular_binning(
Expand All @@ -331,43 +353,48 @@ def shadow_matching(
if count_pairing.max() == 0:
count_pairing[mu_shadow.argmax()] = 1
s2_digit = PeaksPaired.digitize2d(data_sample, bin_edges, n_shadow_bins)
_s2_group_number = np.arange(len(s2))
s2_group_number_list = [
_s2_group_number[s2_digit == xd].tolist()
_s2_group_index = np.arange(len(s2))
s2_group_index_list = [
_s2_group_index[s2_digit == xd].tolist()
for xd in range(n_shadow_bins * n_shadow_bins)
]
# random sample isolated S1 and S2's group number
_s1_group_number = np.hstack(
_s1_group_index = np.hstack(
[
rng.choice(
s1_group_number_list[xd],
s1_group_index_list[xd],
size=count_pairing[xd],
)
for xd in range(n_shadow_bins * n_shadow_bins)
]
)
_s2_group_number = np.hstack(
_s2_group_index = np.hstack(
[
rng.choice(
s2_group_number_list[xd],
s2_group_index_list[xd],
size=count_pairing[xd],
)
for xd in range(n_shadow_bins * n_shadow_bins)
]
)
# sample drift time in this bin
_drift_time = rng.choice(
round(drift_time_bins[i + 1] - drift_time_bins[i]),
_drift_time = rng.uniform(
drift_time_bins[i],
drift_time_bins[i + 1],
size=count_pairing.sum(),
) + round(drift_time_bins[i])
group_number_list.append([_s1_group_number, _s2_group_number, _drift_time])
)
group_index_list.append([_s1_group_index, _s2_group_index, _drift_time])
paring_rate_full = _paring_rate_full.sum()
if not onlyrate:
s1_group_number = np.hstack([group[0] for group in group_number_list]).astype(int)
s2_group_number = np.hstack([group[1] for group in group_number_list]).astype(int)
drift_time = np.hstack([group[2] for group in group_number_list]).astype(int)
assert len(s1_group_number) == len(s2_group_number)
return paring_rate_full, s1_group_number, s2_group_number, drift_time
s1_group_index = np.hstack([group[0] for group in group_index_list]).astype(int)
s2_group_index = np.hstack([group[1] for group in group_index_list]).astype(int)
drift_time = np.hstack([group[2] for group in group_index_list]).astype(int)
assert len(s1_group_index) == len(s2_group_index)
return paring_rate_full, (
s1["group_number"][s1_group_index],
s2["group_number"][s2_group_index],
drift_time,
)

def split_chunks(self, n_peaks):
# divide results into chunks
Expand Down Expand Up @@ -416,7 +443,7 @@ def build_arrays(
# isolated S1 is assigned peak by peak
s1_index = s1_group_number[i]
for q in self.dtype["peaks_paired"].names:
if "origin" not in q and q not in ["event_number"]:
if "origin" not in q and q not in ["event_number", "normalization"]:
_array[0][q] = s1[s1_index][q]
# _array[0]["origin_run_id"] = s1["run_id"][s1_index]
_array[0]["origin_group_number"] = s1["group_number"][s1_index]
Expand All @@ -434,7 +461,7 @@ def build_arrays(
group_number = s2_group_number[i]
s2_group_i = s2[s2_group_index[group_number] : s2_group_index[group_number + 1]]
for q in self.dtype["peaks_paired"].names:
if "origin" not in q and q not in ["event_number"]:
if "origin" not in q and q not in ["event_number", "normalization"]:
_array[1:][q] = s2_group_i[q]
s2_index = s2_group_i["s2_index"]
# _array[1:]["origin_run_id"] = s2_group_i["run_id"]
Expand Down Expand Up @@ -524,38 +551,64 @@ def compute(self, isolated_s1, isolated_s2, peaks_salted, start, end):
print(f"S1 rate is {s1_rate:.3f}Hz")
print(f"There are {len(main_isolated_s2)} S2 peaks group")
print(f"S2 rate is {s2_rate * 1e3:.3f}mHz")
if self.apply_shadow_matching:
# simulate AC's drift time bin by bin
shadow_reference = self.shadow_reference_selection(peaks_salted)
paring_rate_full, s1_group_number, s2_group_number, drift_time = self.shadow_matching(
isolated_s1,
main_isolated_s2,
shadow_reference,
self.shadow_deltatime_exponent,
self.max_n_shadow_bins,
run_time,
self.max_drift_time,
self.min_drift_time,
self.n_drift_time_bins,
self.shift_dt_shadow_matching,
paring_rate_correction,
self.paring_rate_bootstrap_factor,
self.rng,
)
else:
paring_rate_full, s1_group_number, s2_group_number, drift_time = self.simple_pairing(
isolated_s1,
main_isolated_s2,
s1_rate,
s2_rate,
run_time,
self.max_drift_time,
self.min_drift_time,
paring_rate_correction,
self.paring_rate_bootstrap_factor,
self.fixed_drift_time,
self.rng,
)
n_hits_2 = isolated_s1["n_hits"] == 2
n_hits_masks = [n_hits_2, ~n_hits_2]
truths = []
for i, mask in enumerate(n_hits_masks):
if mask.sum() != 0:
if self.apply_shadow_matching:
# simulate AC's drift time bin by bin
shadow_reference = self.shadow_reference_selection(peaks_salted)
truth = self.shadow_matching(
isolated_s1[mask],
main_isolated_s2,
shadow_reference,
self.shadow_deltatime_exponent,
self.max_n_shadow_bins,
run_time,
self.max_drift_time,
self.min_drift_time,
self.n_drift_time_bins,
self.shift_dt_shadow_matching,
paring_rate_correction,
self.bootstrap_factor[i],
self.rng,
self.preprocess_shadow,
)
else:
truth = self.simple_pairing(
isolated_s1[mask],
main_isolated_s2,
s1_rate,
s2_rate,
run_time,
self.max_drift_time,
self.min_drift_time,
paring_rate_correction,
self.bootstrap_factor[i],
self.fixed_drift_time,
self.rng,
)
else:
truth = (
0.0,
(
np.empty(0, dtype=isolated_s1["group_number"].dtype),
np.empty(0, dtype=main_isolated_s2["group_number"].dtype),
[],
),
)
truths.append(truth)
paring_rate_full = truths[0][0] + truths[1][0]
s1_group_number = np.hstack([truths[0][1][0], truths[1][1][0]])
s2_group_number = np.hstack([truths[0][1][1], truths[1][1][1]])
drift_time = np.hstack([truths[0][1][2], truths[1][1][2]])
normalization = np.hstack(
[
np.full(len(truths[0][1][0]), 1 / self.bootstrap_factor[0]),
np.full(len(truths[1][1][0]), 1 / self.bootstrap_factor[1]),
]
)

print(f"AC pairing rate is {paring_rate_full * 1e3:.3f}mHz")
print(f"AC event number is {len(drift_time)}")
Expand Down Expand Up @@ -588,6 +641,11 @@ def compute(self, isolated_s1, isolated_s2, peaks_salted, start, end):
)
peaks_arrays["event_number"] += left_i
truth_arrays["event_number"] += left_i
peaks_arrays["normalization"] = np.repeat(
normalization[left_i:right_i],
n_peaks[left_i:right_i],
)
truth_arrays["normalization"] = normalization[left_i:right_i]

result = dict()
result["peaks_paired"] = self.chunk(
Expand Down

0 comments on commit e6ef782

Please sign in to comment.