Skip to content

Commit

Permalink
feat(models.py): add nonlinear timing to models
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Kaiser <[email protected]>
  • Loading branch information
2 people authored and davecwright3 committed Aug 20, 2024
1 parent 40cfcf0 commit 1061ed2
Showing 1 changed file with 121 additions and 21 deletions.
142 changes: 121 additions & 21 deletions enterprise_extensions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,21 @@
dm_noise_block, red_noise_block,
white_noise_block)
from enterprise_extensions.chromatic.solar_wind import solar_wind_block
from enterprise_extensions.timing import timing_block
from enterprise_extensions.timing import timing_block, dm_block

# from enterprise.signals.signal_base import LookupLikelihood


def model_singlepsr_noise(psr, tm_var=False, tm_linear=False,
tmparam_list=None,
tm_param_list=None,
ltm_list=None,
tm_param_dict=None,
tm_prior="uniform",
tm_prior_sigma=2.0,
normalize_prior_bound=5.0,
fit_remaining_pars=True,
dm_quadfit_prior='normal', dm_quadfit_sigma=2.0,
dm_quadfit_prior_bound=5.0, dmepoch=None,
red_var=True, psd='powerlaw', red_select=None,
noisedict=None, tm_svd=False, tm_norm=True,
white_vary=True, components=30, upper_limit=False,
Expand Down Expand Up @@ -65,7 +73,17 @@ def model_singlepsr_noise(psr, tm_var=False, tm_linear=False,
:param psr: enterprise pulsar object
:param tm_var: explicitly vary the timing model parameters
:param tm_linear: vary the timing model in the linear approximation
:param tmparam_list: an explicit list of timing model parameters to vary
:param tm_param_list: an explicit list of timing model parameters to vary
:param ltm_list: a list of parameters that will linearly varied, default is to vary anything not in tm_param_list
:param tm_param_dict: a nested dictionary of parameters to vary in the model and their user defined values and priors
:param tm_prior: prior type on varied timing model parameters {'uniform', 'normal', bounded-normal'}
:param normalize_prior_bound: scaling value for parameter errors that sets the upper and lower bounds on the nonlinear timing model priors (e.g. Uniform(-5.,5.) as default)
:param fit_remaining_pars: boolean to switch combined non-linear + linear timing models on, only works for tm_var True
:param dm_quadfit_prior: the function used for the priors ['uniform', 'normal', 'bounded-normal']
:param dm_quadfit_sigma: the sigma for the prior if ``prior_type`` is 'bounded-normal'
:param dm_quadfit_prior_bound: scaling value for parameter errors that sets the upper
and lower bounds on the quadratic DM model priors. Only used when ``dm_quadfit_prior`` is 'uniform' or 'bounded-normal', not used as default.
:param dmepoch: the reference epoch for DM [days]
:param red_var: include red noise in the model
:param psd: red noise psd model
:param noisedict: dictionary of noise parameters
Expand Down Expand Up @@ -152,9 +170,42 @@ def model_singlepsr_noise(psr, tm_var=False, tm_linear=False,
:return s: single pulsar noise model
"""

if tm_param_list is None:
tm_param_list = []

if ltm_list is None:
ltm_list = []

if tm_param_dict is None:
tm_param_dict = {}

amp_prior = 'uniform' if upper_limit else 'log-uniform'

# timing model
wideband_kwargs = {}
if is_wideband and use_dmdata:
if dmjump_var:
wideband_kwargs["dmjump"] = parameter.Uniform(pmin=-0.005, pmax=0.005)
else:
wideband_kwargs["dmjump"] = parameter.Constant()
if white_vary:
wideband_kwargs["dmefac"] = parameter.Uniform(pmin=0.1, pmax=10.0)
wideband_kwargs["log10_dmequad"] = parameter.Uniform(pmin=-7.0, pmax=0.0)
# dmjump = parameter.Uniform(pmin=-0.005, pmax=0.005)
else:
wideband_kwargs["dmefac"] = parameter.Constant()
wideband_kwargs["log10_dmequad"] = parameter.Constant()
# dmjump = parameter.Constant()
wideband_kwargs["dmefac_selection"] = selections.Selection(
selections.by_backend
)
wideband_kwargs["log10_dmequad_selection"] = selections.Selection(
selections.by_backend
)
wideband_kwargs["dmjump_selection"] = selections.Selection(
selections.by_frontend
)
if not tm_var:
if (is_wideband and use_dmdata):
if dmjump_var:
Expand Down Expand Up @@ -185,15 +236,26 @@ def model_singlepsr_noise(psr, tm_var=False, tm_linear=False,
s = gp_signals.TimingModel(use_svd=tm_svd, normed=tm_norm,
coefficients=coefficients)
else:
# create new attribute for enterprise pulsar object
psr.tmparams_orig = OrderedDict.fromkeys(psr.t2pulsar.pars())
for key in psr.tmparams_orig:
psr.tmparams_orig[key] = (psr.t2pulsar[key].val,
psr.t2pulsar[key].err)
if not tm_linear:
s = timing_block(tmparam_list=tmparam_list)
else:
pass
if tm_linear:
# create new attribute for enterprise pulsar object
# UNSURE IF NECESSARY
psr.tm_params_orig = OrderedDict.fromkeys(psr.t2pulsar.pars())
for key in psr.tm_params_orig:
psr.tm_params_orig[key] = (psr.t2pulsar[key].val, psr.t2pulsar[key].err)
s = gp_signals.TimingModel(use_svd=tm_svd, normed=tm_norm, coefficients=coefficients)
else:
s = timing_block(
psr,
tm_param_list=tm_param_list,
ltm_list=ltm_list,
prior_type=tm_prior,
prior_sigma=tm_prior_sigma,
prior_lower_bound=-normalize_prior_bound,
prior_upper_bound=normalize_prior_bound,
tm_param_dict=tm_param_dict,
fit_remaining_pars=fit_remaining_pars,
wideband_kwargs=wideband_kwargs,
)

# red noise and common process
if factorized_like:
Expand Down Expand Up @@ -313,6 +375,11 @@ def model_singlepsr_noise(psr, tm_var=False, tm_linear=False,
swgp_prior=swgp_prior, swgp_basis=swgp_basis,
Tspan=Tspan)

if tm_var and not tm_linear:
s += dm_block(psr, dmepoch=dmepoch, prior_type=dm_quadfit_prior, prior_sigma=dm_quadfit_sigma,
prior_lower_bound=dm_quadfit_prior_bound, prior_upper_bound=dm_quadfit_prior_bound,
dmx_data=dmx_data)

if extra_sigs is not None:
s += extra_sigs

Expand Down Expand Up @@ -616,7 +683,8 @@ def model_2a(psrs, psd='powerlaw', noisedict=None, components=30,
return pta


def model_general(psrs, tm_var=False, tm_linear=False, tmparam_list=None,
def model_general(psrs, tm_var=False, tm_linear=False, tm_param_list=None, ltm_list=None,
tm_param_dict=None, tm_prior="uniform", normalize_prior_bound=5.0, fit_remaining_pars=True,
tm_svd=False, tm_norm=True, noisedict=None, white_vary=False,
Tspan=None, modes=None, wgts=None, logfreq=False, nmodes_log=10,
common_psd='powerlaw', common_components=30, tnequad=False,
Expand All @@ -640,7 +708,7 @@ def model_general(psrs, tm_var=False, tm_linear=False, tmparam_list=None,
[default = False]
:param tm_linear: boolean to vary timing model under linear approximation.
[default = False]
:param tmparam_list: list of timing model parameters to vary.
:param tm_param_list: list of timing model parameters to vary.
[default = None]
:param tm_svd: stabilize timing model designmatrix with SVD.
[default = False]
Expand Down Expand Up @@ -773,6 +841,15 @@ def model_general(psrs, tm_var=False, tm_linear=False, tmparam_list=None,
30 sampling frequencies. (global)
"""

if tm_param_list is None:
tm_param_list = []

if ltm_list is None:
ltm_list = []

if tm_param_dict is None:
tm_param_dict = {}

amp_prior = 'uniform' if upper_limit else 'log-uniform'
gp_priors = [upper_limit_red, upper_limit_dm, upper_limit_common]
if all(ii is None for ii in gp_priors):
Expand Down Expand Up @@ -808,14 +885,37 @@ def model_general(psrs, tm_var=False, tm_linear=False, tmparam_list=None,
else:
# create new attribute for enterprise pulsar object
for p in psrs:
p.tmparams_orig = OrderedDict.fromkeys(p.t2pulsar.pars())
for key in p.tmparams_orig:
p.tmparams_orig[key] = (p.t2pulsar[key].val,
p.t2pulsar[key].err)
p.tm_params_orig = OrderedDict.fromkeys(p.t2pulsar.pars())
for key in p.tm_params_orig:
p.tm_params_orig[key] = (p.t2pulsar[key].val, p.t2pulsar[key].err)
if not tm_linear:
s = timing_block(tmparam_list=tmparam_list)
else:
pass
s = timing_block(tm_param_list=tm_param_list)
else:
for i, p in enumerate(psrs):
if i == 0:
s = timing_block(
psrs,
tm_param_list=tm_param_list,
ltm_list=ltm_list,
prior_type=tm_prior,
prior_sigma=2.0,
prior_lower_bound=-5.0,
prior_upper_bound=5.0,
tm_param_dict=tm_param_dict,
fit_remaining_pars=fit_remaining_pars,
)
else:
s += timing_block(
psrs,
tm_param_list=tm_param_list,
ltm_list=ltm_list,
prior_type=tm_prior,
prior_sigma=2.0,
prior_lower_bound=-5.0,
prior_upper_bound=5.0,
tm_param_dict=tm_param_dict,
fit_remaining_pars=fit_remaining_pars,
)

# find the maximum time span to set GW frequency sampling
if Tspan is not None:
Expand Down

0 comments on commit 1061ed2

Please sign in to comment.