Skip to content

Commit

Permalink
feat(hypermodel.py): add nonlinear timing to hypermodels
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Kaiser <[email protected]>
  • Loading branch information
2 people authored and davecwright3 committed Aug 20, 2024
1 parent 1061ed2 commit f4cfa17
Showing 1 changed file with 111 additions and 6 deletions.
117 changes: 111 additions & 6 deletions enterprise_extensions/hypermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enterprise import constants as const
from PTMCMCSampler.PTMCMCSampler import PTSampler as ptmcmc

from .sampler import JumpProposal, get_parameter_groups, save_runtime_info
from .sampler import JumpProposal, get_parameter_groups, save_runtime_info, get_timing_groups, group_from_params


class HyperModel(object):
Expand Down Expand Up @@ -71,6 +71,27 @@ def __init__(self, models, log_weights=None):
for q in uniq_params]].tolist()
#########

#########
# Timing Model
self.tm_groups = []
self.special_idxs = []
for i, x in enumerate(self.params):
if "timing_model" in str(x):
self.tm_groups.append(i)
if "Uniform" in str(x):
pmin = float(str(x).split("Uniform")[-1].split("pmin=")[1].split(",")[0])
pmax = float(str(x).split("Uniform")[-1].split("pmax=")[-1].split(")")[0])
if pmin + pmax != 0.0:
self.special_idxs.append(i)
elif "BoundedNormal" in str(x):
pmin = float(str(x).split("BoundedNormal")[-1].split("[")[-1].split(",")[0])
pmax = float(str(x).split("BoundedNormal")[-1].split("[")[-1].split(",")[1].split(']')[0])
if pmin + pmax != 0.0:
self.special_idxs.append(i)
else:
self.special_idxs.append(i)
#########

def get_lnlikelihood(self, x):

# find model index variable
Expand Down Expand Up @@ -115,6 +136,18 @@ def get_parameter_groups(self):
unique_groups = []
for p in self.models.values():
groups = get_parameter_groups(p)
if self.tm_groups:
groups.extend(get_timing_groups(p))
groups.append(
group_from_params(
p,
[
x
for x in p.param_names
if any(y in x for y in ["timing_model", "ecorr"])
],
)
)
# check for any duplicate groups
# e.g. the GWB may have different indices in model 1 and model 2
for group in groups:
Expand All @@ -127,12 +160,39 @@ def get_parameter_groups(self):
unique_groups.extend([[len(self.param_names) - 1]])
return unique_groups

def initial_sample(self):
def initial_sample(self, tm_params_orig=None, tm_param_dict=None, zero_start=True):
"""
Draw an initial sample from within the hyper-model prior space.
:param tm_params_orig: dictionary of timing model parameter tuples, (val, err)
:param tm_param_dict: a nested dictionary of parameters to vary in the model and their user defined values and priors
:param zero_start: start all timing parameters at their parfile value (in tm_params_orig), or their refit values (tm_param_dict)
"""

x0 = [np.array(p.sample()).ravel().tolist() for p in self.models[0].params]
if zero_start and tm_params_orig:
x0 = []
for xx, p in enumerate(self.models[0].params):
if "timing" in p.name:
if "DMX" in p.name:
p_name = ("_").join(p.name.split("_")[-2:])
else:
p_name = p.name.split("_")[-1]
if tm_params_orig[p_name][-1] == "normalized":
x0.append([np.double(0.0)])
else:
if p_name in tm_param_dict.keys():
x0.append([np.double(tm_param_dict[p_name]["prior_mu"])])
else:
x0.append([np.double(tm_params_orig[p_name][0])])
elif "dm_model" in p.name:
if "mu" in str(p):
x0.append([float(str(p).split("(")[1].split(",")[0].split("=")[-1])])
else:
x0.append(np.array(p.sample()).ravel().tolist())
else:
x0.append(np.array(p.sample()).ravel().tolist())
else:
x0 = [np.array(p.sample()).ravel().tolist() for p in self.models[0].params]

uniq_params = [str(p) for p in self.models[0].params]

for model in self.models.values():
Expand Down Expand Up @@ -161,8 +221,9 @@ def draw_from_nmodel_prior(self, x, iter, beta):
return q, float(lqxy)

def setup_sampler(self, outdir='chains', resume=False, sample_nmodel=True,
empirical_distr=None, groups=None, human=None,
loglkwargs={}, logpkwargs={}):
empirical_distr=None, groups=None, timing=False, psr=None, human=None,
restrict_mass=True,
loglkwargs=None, logpkwargs=None):
"""
Sets up an instance of PTMCMC sampler.
Expand All @@ -182,6 +243,12 @@ def setup_sampler(self, outdir='chains', resume=False, sample_nmodel=True,
draws from uniform distributions.
"""

if loglkwargs is None:
loglkwargs = {}

if logpkwargs is None:
logpkwargs = {}

# dimension of parameter space
ndim = len(self.param_names)

Expand Down Expand Up @@ -212,7 +279,9 @@ def setup_sampler(self, outdir='chains', resume=False, sample_nmodel=True,
save_runtime_info(self, sampler.outDir, human)

# additional jump proposals
jp = JumpProposal(self, self.snames, empirical_distr=empirical_distr)
jp = JumpProposal(self, self.snames, empirical_distr=empirical_distr,
timing=timing, psr=psr, sampler=sampler,
restrict_mass=restrict_mass)
sampler.jp = jp

# always add draw from prior
Expand Down Expand Up @@ -268,6 +337,16 @@ def setup_sampler(self, outdir='chains', resume=False, sample_nmodel=True,
print('Adding Chromatic GP noise prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_chrom_gp_prior, 10)

# SW prior draw
if "gp_sw" in jp.snames:
print("Adding Solar Wind DM GP prior draws...\n")
sampler.addProposalToCycle(jp.draw_from_dm_sw_prior, 10)

# Chromatic GP noise prior draw
if 'chrom_gp' in self.snames:
print('Adding Chromatic GP noise prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_chrom_gp_prior, 10)

# Ephemeris prior draw
if 'd_jupiter_mass' in self.param_names:
print('Adding ephemeris model prior draws...\n')
Expand Down Expand Up @@ -316,6 +395,32 @@ def setup_sampler(self, outdir='chains', resume=False, sample_nmodel=True,
p in list(self.params)
if 'gw' in str(p)]), 10)

# Non Linear Timing Draws
if "timing_model" in jp.snames:
print("Adding timing model jump proposal...\n")
sampler.addProposalToCycle(jp.draw_from_timing_model, 25)
if "timing_model" in jp.snames:
print("Adding timing model prior draw...\n")
sampler.addProposalToCycle(jp.draw_from_timing_model_prior, 25)

# DM Model Draws
if "dm_model" in jp.snames and len(jp.snames["dm_model"]):
print("Adding dm model prior draw...\n")
sampler.addProposalToCycle(jp.draw_from_signal("dm_model"), 10)

if timing:
if jp.restrict_mass:
# SCAM and AM Draws
# add SCAM
print("Adding SCAM Jump Proposal...\n")
sampler.addProposalToCycle(jp.covarianceJumpProposalSCAM, 20)

# add AM
print("Adding AM Jump Proposal...\n")
sampler.addProposalToCycle(jp.covarianceJumpProposalAM, 20)

# DE does not work well with restricting the pulsar mass

# Model index distribution draw
if sample_nmodel:
if 'nmodel' in self.param_names:
Expand Down

0 comments on commit f4cfa17

Please sign in to comment.