Skip to content

Commit

Permalink
Merge branch 'master' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
zihaoxu98 authored Jun 18, 2024
2 parents f4dfe77 + 453f5fb commit de1e329
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 11 deletions.
19 changes: 17 additions & 2 deletions appletree/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,22 @@ def build_point(self, data):
self.coordinate_system = jnp.asarray(data["coordinate_system"], dtype=float)
self.map = jnp.asarray(data["map"], dtype=float)

setattr(self, "interpolator", interpolation.curve_interpolator)
if self.method == "IDW":
setattr(self, "interpolator", interpolation.curve_interpolator)
elif self.method == "NN":
setattr(
self,
"interpolator",
interpolation.map_interpolator_nearest_neighbor_1d,
)
elif self.method == "LERP":
setattr(
self,
"interpolator",
interpolation.map_interpolator_linear_1d,
)
else:
raise ValueError(f"Unknown method {self.method} for 1D regular binning.")
if self.coordinate_type == "log_point":
if jnp.any(self.coordinate_system <= 0):
raise ValueError(
Expand Down Expand Up @@ -262,7 +277,7 @@ def build_regbin(self, data):
)
else:
raise ValueError(f"Unknown method {self.method} for 2D regular binning.")
elif len(self.coordinate_lowers) == 3 and self.method == "IDW":
elif len(self.coordinate_lowers) == 3:
if self.method == "IDW":
setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_3d)
elif self.method == "NN":
Expand Down
31 changes: 31 additions & 0 deletions appletree/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,37 @@ def find_nearest_indices(x, y):
return indices


@export
@jit
def map_interpolator_linear_1d(pos, ref_pos, ref_val):
"""Linear 1D interpolation. Copied to prevent misuse of other arguments of jnp.interp.
Args:
pos: array with shape (N,), as the points to be interpolated.
ref_pos: array with shape (M,), as the reference points.
ref_val: array with shape (M,), as the reference values.
"""
return jnp.interp(pos, ref_pos, ref_val)


@export
@jit
def map_interpolator_nearest_neighbor_1d(pos, ref_pos, ref_val):
"""Nearest neighbor 1D interpolation.
Args:
pos: array with shape (N,), as the points to be interpolated.
ref_pos: array with shape (M,), as the reference points.
ref_val: array with shape (M,), as the reference values.
"""
ind = find_nearest_indices(pos, ref_pos)

val = ref_val[ind]
return val


@export
@jit
def map_interpolator_regular_binning_nearest_neighbor_2d(
Expand Down
7 changes: 6 additions & 1 deletion appletree/plugins/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ def simulate(self, key, parameters, batch_size):

@export
@takes_config(
Map(name="energy_spectrum", default="_nr_spectrum.json", help="Recoil energy spectrum"),
Map(
name="energy_spectrum",
method="LERP",
default="_nr_spectrum.json",
help="Recoil energy spectrum",
),
)
class FixedEnergySpectra(Plugin):
depends_on = ["batch_size"]
Expand Down
7 changes: 6 additions & 1 deletion appletree/plugins/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ def simulate(self, key, parameters, num_s1_phd):

@export
@takes_config(
Map(name="elife", default="_elife.json", help="Electron lifetime correction"),
Map(
name="elife",
method="LERP",
default="_elife.json",
help="Electron lifetime correction",
),
)
class DriftLoss(Plugin):
depends_on = ["z"]
Expand Down
3 changes: 3 additions & 0 deletions appletree/plugins/efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def simulate(self, key, parameters, s2_area):
@takes_config(
SigmaMap(
name="s1_eff_3f",
method="NN",
default="_3fold_recon_eff.json",
help="3fold S1 reconstruction efficiency",
),
Expand All @@ -44,6 +45,7 @@ def simulate(self, key, parameters, num_s1_phd):
@takes_config(
SigmaMap(
name="s1_cut_acc",
method="LERP",
default=["_s1_cut_acc.json", "_s1_cut_acc.json", "_s1_cut_acc.json"],
help="S1 cut acceptance",
),
Expand All @@ -64,6 +66,7 @@ def simulate(self, key, parameters, s1_area):
@takes_config(
SigmaMap(
name="s2_cut_acc",
method="LERP",
default=["_s2_cut_acc.json", "_s2_cut_acc.json", "_s2_cut_acc.json", "s2_cut_acc_sigma"],
help="S2 cut acceptance",
),
Expand Down
14 changes: 12 additions & 2 deletions appletree/plugins/lyqy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

@export
@takes_config(
Map(name="ly_median", default="_nr_ly.json", help="Light yield curve from NESTv2"),
Map(
name="ly_median",
method="LERP",
default="_nr_ly.json",
help="Light yield curve from NESTv2",
),
)
class LightYield(Plugin):
depends_on = ["energy"]
Expand All @@ -39,7 +44,12 @@ def simulate(self, key, parameters, energy, light_yield):

@export
@takes_config(
Map(name="qy_median", default="_nr_qy.json", help="Charge yield curve from NESTv2"),
Map(
name="qy_median",
method="LERP",
default="_nr_qy.json",
help="Charge yield curve from NESTv2",
),
)
class ChargeYield(Plugin):
depends_on = ["energy"]
Expand Down
35 changes: 30 additions & 5 deletions appletree/plugins/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

@export
@takes_config(
Map(name="posrec_reso", default="_posrec_reso.json", help="Position reconstruction resolution"),
Map(
name="posrec_reso",
method="LERP",
default="_posrec_reso.json",
help="Position reconstruction resolution",
),
)
class PositionRecon(Plugin):
depends_on = ["x", "y", "z", "num_electron_drifted"]
Expand All @@ -35,8 +40,18 @@ def simulate(self, key, parameters, x, y, z, num_electron_drifted):

@export
@takes_config(
Map(name="s1_bias_3f", default="_s1_bias.json", help="3fold S1 reconstruction bias"),
Map(name="s1_smear_3f", default="_s1_smearing.json", help="3fold S1 reconstruction smearing"),
Map(
name="s1_bias_3f",
method="LERP",
default="_s1_bias.json",
help="3fold S1 reconstruction bias",
),
Map(
name="s1_smear_3f",
method="LERP",
default="_s1_smearing.json",
help="3fold S1 reconstruction smearing",
),
)
class S1(Plugin):
depends_on = ["num_s1_phd", "num_s1_pe"]
Expand All @@ -53,8 +68,18 @@ def simulate(self, key, parameters, num_s1_phd, num_s1_pe):

@export
@takes_config(
Map(name="s2_bias", default="_s2_bias.json", help="S2 reconstruction bias"),
Map(name="s2_smear", default="_s2_smearing.json", help="S2 reconstruction smearing"),
Map(
name="s2_bias",
method="LERP",
default="_s2_bias.json",
help="S2 reconstruction bias",
),
Map(
name="s2_smear",
method="LERP",
default="_s2_smearing.json",
help="S2 reconstruction smearing",
),
)
class S2(Plugin):
depends_on = ["num_s2_pe"]
Expand Down

0 comments on commit de1e329

Please sign in to comment.