Skip to content

Commit

Permalink
add functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zihaoxu98 committed Feb 29, 2024
1 parent 138bb56 commit 287dd84
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
15 changes: 9 additions & 6 deletions appletree/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from appletree.randgen import TwoHalfNorm
from appletree.utils import errors_to_two_half_norm_sigmas


class Parameter:
Expand Down Expand Up @@ -81,10 +82,12 @@ def sample_prior(self):
val = np.random.normal(**kwargs)
self._parameter_dict[par_name] = np.clip(val, *setting["allowed_range"])
elif prior_type == "twohalfnorm":
sigmas = errors_to_two_half_norm_sigmas([args["sigma_pos"],
args["sigma_neg"]])
kwargs = {
"mu": args["mu"],
"sigma_pos": args["sigma_pos"],
"sigma_neg": args["sigma_neg"],
"sigma_pos": sigmas[0],
"sigma_neg": sigmas[1],
}
val = TwoHalfNorm.rvs(**kwargs)
self._parameter_dict[par_name] = np.clip(val, *setting["allowed_range"])
Expand Down Expand Up @@ -150,14 +153,14 @@ def log_prior(self):
std = args["std"]
log_prior += -((val - mean) ** 2) / 2 / std**2
elif prior_type == "twohalfnorm":
sigmas = errors_to_two_half_norm_sigmas([args["sigma_pos"],
args["sigma_neg"]])
mu = args["mu"]
sigma_pos = args["sigma_pos"]
sigma_neg = args["sigma_neg"]
log_prior += TwoHalfNorm.logpdf(
x=val,
mu=mu,
sigma_pos=sigma_pos,
sigma_neg=sigma_neg,
sigma_pos=sigmas[0],
sigma_neg=sigmas[1],
)
elif prior_type == "free":
pass
Expand Down
14 changes: 14 additions & 0 deletions appletree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import matplotlib as mpl
from matplotlib.patches import Rectangle
from matplotlib import pyplot as plt
from scipy.special import erf
from scipy.optimize import root
from scipy.stats import chi2

import GOFevaluation
from appletree.share import _cached_configs
Expand Down Expand Up @@ -576,3 +579,14 @@ def check_unused_configs():
unused_configs = set(_cached_configs.keys()) - _cached_configs.accessed_keys
if unused_configs:
warn(f"Detected unused configs: {unused_configs}, you might set the configs incorrectly.")


def errors_to_two_half_norm_sigmas(errors):
"""This function solves the sigmas for a two-half-norm distribution,
such that the 16 and 84 percentile corresponds to the given errors."""
def _to_solve(x, errors, p):
return [x[0] / (x[0] + x[1]) * (1 - erf(errors[0] / x[0] / np.sqrt(2))) - p / 2,
x[1] / (x[0] + x[1]) * (1 - erf(errors[1] / x[1] / np.sqrt(2))) - p / 2]
res = root(_to_solve, errors, args=(errors, 1 - chi2.cdf(1, 1)))
assert res.success, f"Cannot solve sigmas of TwoHalfNorm for errors {errors}!"
return res.x

0 comments on commit 287dd84

Please sign in to comment.