Skip to content

Commit

Permalink
Add run_meta as the dependency of PeaksPaired and store run_id
Browse files Browse the repository at this point in the history
…into `peaks_paired` (#63)
  • Loading branch information
dachengx authored May 2, 2024
1 parent 198bcf1 commit 4394f39
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 14 deletions.
4 changes: 3 additions & 1 deletion axidence/plugins/isolated/isolated_s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def infer_dtype(self):
dtype_reference = self.refer_dtype("peaks")
dtype = copy_dtype(dtype_reference, self.isolated_peaks_fields)
dtype += [
(("Run id", "run_id"), np.int32),
(("Group number of peaks", "group_number"), np.int64),
]
return dtype
Expand All @@ -49,9 +50,10 @@ def compute(self, peaks):
_result = peaks[peaks["cut_isolated_s1"]]
result = np.empty(len(_result), dtype=self.dtype)
for n in result.dtype.names:
if n not in ["group_number"]:
if n not in ["run_id", "group_number"]:
result[n] = _result[n]

result["run_id"] = int(self.run_id)
result["group_number"] = np.arange(len(result)) + self.groups_seen

self.groups_seen += len(result)
Expand Down
4 changes: 3 additions & 1 deletion axidence/plugins/isolated/isolated_s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def infer_dtype(self):
dtype = copy_dtype(self.refer_dtype("peaks"), self.isolated_peaks_fields)
dtype += copy_dtype(self.refer_dtype("events"), self.isolated_events_fields)
dtype += [
(("Run id", "run_id"), np.int32),
(("Group number of peaks", "group_number"), np.int64),
]
return dtype
Expand All @@ -69,14 +70,15 @@ def compute(self, events, peaks):

result = np.empty(_events["n_peaks"].sum(), dtype=self.dtype)
for n in result.dtype.names:
if n not in ["group_number"]:
if n not in ["run_id", "group_number"]:
if n in self.isolated_peaks_fields:
result[n] = _peaks[n]
elif n in self.isolated_events_fields:
result[n] = np.repeat(_events[n], _events["n_peaks"])
else:
raise ValueError(f"Field {n} not found in peaks or events!")

result["run_id"] = int(self.run_id)
result["group_number"] = (
np.repeat(np.arange(len(_events)), _events["n_peaks"]) + self.groups_seen
)
Expand Down
3 changes: 1 addition & 2 deletions axidence/plugins/pairing/events_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ def peak_fields(self):
"left_dtime",
"right_dtime",
]
# TODO: reconsider about how to store run_id after implementing super runs
required_names += [
# "origin_run_id",
"origin_run_id",
"origin_group_number",
"origin_time",
"origin_endtime",
Expand Down
33 changes: 23 additions & 10 deletions axidence/plugins/pairing/peaks_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class PeaksPaired(ExhaustPlugin, DownChunkingPlugin):
__version__ = "0.0.0"
depends_on = (
"run_meta",
"isolated_s1",
"isolated_s2",
"peaks_salted",
Expand Down Expand Up @@ -140,11 +141,10 @@ class PeaksPaired(ExhaustPlugin, DownChunkingPlugin):

def infer_dtype(self):
dtype = strax.unpack_dtype(self.deps["isolated_s1"].dtype_for("isolated_s1"))
# TODO: reconsider about how to store run_id after implementing super runs
peaks_dtype = dtype + [
# since event_number is int64 in event_basics
(("Event number in this dataset", "event_number"), np.int64),
# (("Original run id", "origin_run_id"), np.int32),
(("Original run id", "origin_run_id"), np.int32),
(("Original isolated S1/S2 group", "origin_group_number"), np.int32),
(("Original time of peaks", "origin_time"), np.int64),
(("Original endtime of peaks", "origin_endtime"), np.int64),
Expand All @@ -157,8 +157,8 @@ def infer_dtype(self):
]
truth_dtype = [
(("Event number in this dataset", "event_number"), np.int64),
# (("Original run id of isolated S1", "s1_run_id"), np.int32),
# (("Original run id of isolated S2", "s2_run_id"), np.int32),
(("Original run id of isolated S1", "s1_run_id"), np.int32),
(("Original run id of isolated S2", "s2_run_id"), np.int32),
(
("Drift time between isolated S1 and main isolated S2 [ns]", "drift_time"),
np.float32,
Expand Down Expand Up @@ -190,6 +190,17 @@ def setup(self, prepare=True):
else:
self.bootstrap_factor = [self.paring_rate_bootstrap_factor] * 2

@staticmethod
def update_group_number(isolated, run_meta):
result = isolated.copy()
windows = strax.touching_windows(isolated, run_meta)
n_groups = np.array(
[np.unique(isolated["group_number"][w[0] : w[1]]).size for w in windows]
)
cumsum = np.cumsum(np.hstack([[0], n_groups])[:-1])
result["group_number"] += np.repeat(cumsum, windows[:, 1] - windows[:, 0])
return result

@staticmethod
def preprocess_isolated_s2(s2):
# index of isolated S2 groups
Expand Down Expand Up @@ -514,7 +525,7 @@ def build_arrays(
for q in self.dtype["peaks_paired"].names:
if "origin" not in q and q not in ["event_number", "normalization"]:
_array[0][q] = s1[s1_index][q]
# _array[0]["origin_run_id"] = s1["run_id"][s1_index]
_array[0]["origin_run_id"] = s1["run_id"][s1_index]
_array[0]["origin_group_number"] = s1["group_number"][s1_index]
_array[0]["origin_time"] = s1["time"][s1_index]
_array[0]["origin_endtime"] = strax.endtime(s1)[s1_index]
Expand All @@ -533,7 +544,7 @@ def build_arrays(
if "origin" not in q and q not in ["event_number", "normalization"]:
_array[1:][q] = s2_group_i[q]
s2_index = s2_group_i["s2_index"]
# _array[1:]["origin_run_id"] = s2_group_i["run_id"]
_array[1:]["origin_run_id"] = s2_group_i["run_id"]
_array[1:]["origin_group_number"] = s2_group_i["group_number"]
_array[1:]["origin_time"] = s2_group_i["time"]
_array[1:]["origin_endtime"] = strax.endtime(s2_group_i)
Expand Down Expand Up @@ -574,8 +585,8 @@ def build_arrays(
]
truth_arrays["event_number"] = np.arange(len(n_peaks))
truth_arrays["drift_time"] = drift_time
# truth_arrays["s1_run_id"] = s1["run_id"][s1_group_number]
# truth_arrays["s2_run_id"] = main_s2["run_id"][s2_group_number]
truth_arrays["s1_run_id"] = s1["run_id"][s1_group_number]
truth_arrays["s2_run_id"] = main_s2["run_id"][s2_group_number]
truth_arrays["s1_group_number"] = s1["group_number"][s1_group_number]
truth_arrays["s2_group_number"] = main_s2["group_number"][s2_group_number]

Expand All @@ -590,7 +601,9 @@ def build_arrays(

return peaks_arrays, truth_arrays

def compute(self, isolated_s1, isolated_s2, peaks_salted, events_salted, start, end):
def compute(self, run_meta, isolated_s1, isolated_s2, peaks_salted, events_salted, start, end):
isolated_s1 = self.update_group_number(isolated_s1, run_meta)
isolated_s2 = self.update_group_number(isolated_s2, run_meta)
for i, s in enumerate([isolated_s1, isolated_s2]):
if np.any(np.diff(s["group_number"]) < 0):
raise ValueError(f"Group number is not sorted in isolated S{i}!")
Expand All @@ -603,7 +616,7 @@ def compute(self, isolated_s1, isolated_s2, peaks_salted, events_salted, start,
paring_rate_correction = self.get_paring_rate_correction(peaks_salted)
print(f"Isolated S1 correction factor is {paring_rate_correction:.3f}")

run_time = (end - start) / units.s
run_time = (run_meta["endtime"] - run_meta["time"]).sum() / units.s
s1_rate = len(isolated_s1) / run_time
s2_rate = len(main_isolated_s2) / run_time
print(f"There are {len(isolated_s1)} S1 peaks group")
Expand Down

0 comments on commit 4394f39

Please sign in to comment.