Skip to content

Commit

Permalink
Debug do_compute in plugin factory (#37)
Browse files Browse the repository at this point in the history
* Modify `do_compute` in plugin factory

* Minor change

* Bug fix
  • Loading branch information
dachengx authored Apr 26, 2024
1 parent 7b576b7 commit cfe18d0
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
6 changes: 6 additions & 0 deletions axidence/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def infer_dtype(self):
# so we assign the dtype manually and raise error in infer_dtype method
raise RuntimeError

def do_compute(self, chunk_i=None, **kwargs):
# remove the suffix from the keys
new_keys = [k.replace(self.suffix, "") for k in kwargs.keys()]
new_kwargs = dict(zip(new_keys, kwargs.values()))
return super().do_compute(chunk_i=chunk_i, **new_kwargs)

# need to be compatible with strax.camel_to_snake
# https://github.com/AxFoundation/strax/blob/7da9a2a6375e7614181830484b322389986cf064/strax/context.py#L324
new_plugin.__name__ = plugin.__name__ + suffix
Expand Down
9 changes: 2 additions & 7 deletions axidence/plugins/pairing/peaks_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,9 @@ def simple_pairing(
print(f"S2 rate is {s2_rate * 1e3:.3f}mHz")

paring_rate_full = (
s1_rate
* s2_rate
* (max_drift_time - min_drift_time)
/ units.s
* run_time
/ paring_rate_correction
s1_rate * s2_rate * (max_drift_time - min_drift_time) / units.s / paring_rate_correction
)
n_events = round(paring_rate_full * paring_rate_bootstrap_factor)
n_events = round(paring_rate_full * run_time * paring_rate_bootstrap_factor)
s1_group_number = rng.choice(len(s1), size=n_events, replace=True)
s2_group_number = rng.choice(len(s2), size=n_events, replace=True)
if fixed_drift_time is None:
Expand Down
6 changes: 3 additions & 3 deletions axidence/plugins/salting/event_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class EventShadowSalted(EventShadow):
__version__ = "0.0.0"
depends_on = (
"event_basics",
"event_basics_salted",
"peaks_salted",
"peak_shadow_salted",
"peak_basics",
Expand All @@ -26,7 +26,7 @@ def compute(self, events_salted, peaks_salted, peaks):
class EventAmbienceSalted(EventAmbience):
__version__ = "0.0.0"
depends_on = (
"event_basics",
"event_basics_salted",
"peaks_salted",
"peak_ambience_salted",
"peak_basics",
Expand All @@ -44,7 +44,7 @@ def compute(self, events_salted, peaks_salted, peaks):
class EventSEDensitySalted(EventSEDensity):
__version__ = "0.0.0"
depends_on: Tuple[str, ...] = (
"event_basics",
"event_basics_salted",
"peaks_salted",
"peak_se_density_salted",
"peak_basics",
Expand Down

0 comments on commit cfe18d0

Please sign in to comment.