Skip to content

Commit

Permalink
add docs, try getting validate, NamedTuple
Browse files Browse the repository at this point in the history
  • Loading branch information
AFg6K7h4fhy2 committed Aug 7, 2024
1 parent ccb3925 commit 3200f51
Show file tree
Hide file tree
Showing 3 changed files with 652 additions and 140 deletions.
8 changes: 4 additions & 4 deletions scratch/config/params_2024-01-20.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ weekly_rw_prior_scale = 0.25
first_fitting_date = "2023-09-15"


adapt_delta = 0.85
max_treedepth = 9
adapt_delta = 0.80
max_treedepth = 12
n_chains = 1
n_warmup = 1000
n_iter = 2000
n_warmup = 150
n_iter = 200
283 changes: 283 additions & 0 deletions scratch/paste_bin.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,289 @@ observation: renewal process in progress, alpha ≠ 0, and observed hosps have v
Look to Admission Model for inspiration for sample().


REMOVE undoc'd Model Proc


class CFAEPIM_Model(Model): # numpydoc ignore=GL08,PR01
def __init__(
self,
config: dict[str, any],
population: int,
week_indices: ArrayLike,
first_week_hosp: int,
predictors: list[int],
data_observed_hosp_admissions: pl.DataFrame,
): # numpydoc ignore=GL08
self.population = population
self.week_indices = week_indices
self.first_week_hosp = first_week_hosp
self.predictors = predictors
self.data_observed_hosp_admissions = data_observed_hosp_admissions

self.config = config
for key, value in config.items():
setattr(self, key, value)

# transmission: generation time distribution
self.pmf_array = jnp.array(self.generation_time_dist)
self.gen_int = DeterministicPMF(name="gen_int", value=self.pmf_array)

# transmission: prior for RW intercept
self.intercept_RW_prior = dist.Normal(
self.rt_intercept_prior_mode, self.rt_intercept_prior_scale
)

# transmission: Rt process
self.Rt_process = CFAEPIM_Rt(
intercept_RW_prior=self.intercept_RW_prior,
max_rt=self.max_rt,
gamma_RW_prior_scale=self.weekly_rw_prior_scale,
week_indices=self.week_indices,
)

# infections: get value rate for infection seeding (initialization)
self.mean_inf_val = (
self.inf_model_prior_infections_per_capita * self.population
) + (self.first_week_hosp / (self.ihr_intercept_prior_mode * 7))

# infections: initial infections
self.I0 = InfectionInitializationProcess(
name="I0_initialization",
I_pre_init_rv=DistributionalRV(
name="I0",
dist=dist.Exponential(rate=1 / self.mean_inf_val).expand(
[self.inf_model_seed_days]
),
),
infection_init_method=InitializeInfectionsFromVec(
n_timepoints=self.inf_model_seed_days
),
t_unit=1,
)

# infections: susceptibility depletion prior
self.susceptibility_prior = dist.Normal(
self.susceptible_fraction_prior_mode,
self.susceptible_fraction_prior_scale,
)

# infections component
self.infections = CFAEPIM_Infections(
I0=self.I0, susceptibility_prior=self.susceptibility_prior
)
# update: check that post-instantiation, changing
# sus_prior changes CFAEPIM_Infections values, believe
# does, but check

# observations component
self.nb_concentration_prior = dist.Normal(
self.reciprocal_dispersion_prior_mode,
self.reciprocal_dispersion_prior_scale,
)
self.alpha_prior_dist = dist.Normal(
self.ihr_intercept_prior_mode, self.ihr_intercept_prior_scale
)
self.coefficient_priors = dist.Normal(
loc=jnp.array(
self.day_of_week_effect_prior_modes
+ [
self.holiday_eff_prior_mode,
self.post_holiday_eff_prior_mode,
self.non_obs_effect_prior_mode,
]
),
scale=jnp.array(
self.day_of_week_effect_prior_scales
+ [
self.holiday_eff_prior_scale,
self.post_holiday_eff_prior_scale,
self.non_obs_effect_prior_scale,
]
),
)
self.obs_process = CFAEPIM_Observation(
predictors=self.predictors,
alpha_prior_dist=self.alpha_prior_dist,
coefficient_priors=self.coefficient_priors,
max_rt=self.max_rt,
nb_concentration_prior=self.nb_concentration_prior,
)

@staticmethod
def validate() -> None: # numpydoc ignore=GL08
pass

def sample(
self,
n_steps: int,
**kwargs,
) -> tuple: # numpydoc ignore=GL08
sampled_Rts = self.Rt_process.sample(n_steps=n_steps)
sampled_gen_int = self.gen_int.sample()
all_I_t, all_S_t = self.infections.sample(
Rt=sampled_Rts,
gen_int=sampled_gen_int[0].value,
P=self.population,
)
sampled_alphas, expected_hosps = self.obs_process.sample(
infections=all_I_t,
inf_to_hosp_dist=jnp.array(self.inf_to_hosp_dist),
)
observed_hosp_admissions = self.obs_process.nb_observation.sample(
mu=expected_hosps,
obs=self.data_observed_hosp_admissions,
**kwargs,
)
numpyro.deterministic("Rts", sampled_Rts)
numpyro.deterministic("latent_infections", all_I_t)
numpyro.deterministic("susceptibles", all_S_t)
numpyro.deterministic("alphas", sampled_alphas)
numpyro.deterministic("expected_hospitalizations", expected_hosps)
numpyro.deterministic(
"observed_hospitalizations", observed_hosp_admissions[0].value
)
return CFAEPIM_Model_Sample(
Rts=sampled_Rts,
latent_infections=all_I_t,
susceptibles=all_S_t,
ascertainment_rates=sampled_alphas,
expected_hospitalizations=expected_hosps,
observed_hospital_admissions=observed_hosp_admissions[0].value,
)


REMOVE undoc'd Obs Proc

class CFAEPIM_Observation(RandomVariable): # numpydoc ignore=GL08
def __init__(
self,
predictors,
alpha_prior_dist,
coefficient_priors,
max_rt,
nb_concentration_prior,
): # numpydoc ignore=GL08
self.predictors = predictors
self.alpha_prior_dist = alpha_prior_dist
self.coefficient_priors = coefficient_priors
self.max_rt = max_rt
self.nb_concentration_prior = nb_concentration_prior

self._init_alpha_t()
self._init_negative_binomial()

def _init_alpha_t(self): # numpydoc ignore=GL08
self.alpha_process = GLMPrediction(
name="alpha_t",
fixed_predictor_values=self.predictors,
intercept_prior=self.alpha_prior_dist,
coefficient_priors=self.coefficient_priors,
transform=t.SigmoidTransform().inv,
)
# MAKE ISSUE where inversion happens (which is g, which is g_{-1})
# just escape underscores & minus

def _init_negative_binomial(self): # numpydoc ignore=GL08
self.nb_observation = NegativeBinomialObservation(
name="negbinom_rv",
concentration_rv=DistributionalRV(
name="nb_concentration",
dist=self.nb_concentration_prior,
),
)

@staticmethod
def validate() -> None: # numpydoc ignore=GL08
pass

def sample(
self,
infections: ArrayLike,
inf_to_hosp_dist: ArrayLike,
**kwargs,
) -> tuple: # numpydoc ignore=GL08
alpha_samples = self.alpha_process.sample()["prediction"]
alpha_samples = alpha_samples[: infections.shape[0]]
expected_hosp = (
alpha_samples
* jnp.convolve(infections, inf_to_hosp_dist, mode="full")[
: infections.shape[0]
]
)
return alpha_samples, expected_hosp

# update: explore this further;
# would be unobserved discrete site if not used
# nb_samples = self.nb_observation.sample(mu=expected_hosp, **kwargs)


REMOVE undoc'd Rt

class CFAEPIM_Rt(RandomVariable): # numpydoc ignore=GL08
def __init__(
self,
intercept_RW_prior: numpyro.distributions,
max_rt: float,
gamma_RW_prior_scale: float,
week_indices: ArrayLike,
): # numpydoc ignore=GL08
self.intercept_RW_prior = intercept_RW_prior
self.max_rt = max_rt
self.gamma_RW_prior_scale = gamma_RW_prior_scale
self.week_indices = week_indices

@staticmethod
def validate() -> None: # numpydoc ignore=GL08
pass

def sample(self, n_steps: int, **kwargs) -> tuple: # numpydoc ignore=GL08
sd_wt = numpyro.sample(
"Wt_rw_sd", dist.HalfNormal(self.gamma_RW_prior_scale)
)
wt_rv = SimpleRandomWalkProcess(
name="Wt",
step_rv=DistributionalRV(
name="rw_step_rv",
dist=dist.Normal(0, sd_wt),
reparam=LocScaleReparam(0),
),
init_rv=DistributionalRV(
name="init_Wt_rv",
dist=self.intercept_RW_prior,
),
)
transformed_rt_samples = TransformedRandomVariable(
name="transformed_rt_rw",
base_rv=wt_rv,
transforms=t.ScaledLogitTransform(x_max=self.max_rt).inv,
).sample(n_steps=n_steps, **kwargs)
broadcasted_rt_samples = transformed_rt_samples[0].value[
self.week_indices
]
return broadcasted_rt_samples


REMOVE various text:

# instantiate MSR-cfaepim model
# simulate data
# run the model for NY
# print summary (print_summary)
# visualize prior predictive (prior_predictive)
# visualize posterior predictive (posterior_predictive)
# spread draws (spread_draws)

# you have a dataset and a configuration file
# you can generate 3 reports, one on the priors, one on
# the model, one on the forecasts; these are intelligently
# looked for and then concatenated; file naming is done
# intelligently.
# you can choose, using argparse, which states are run
# as well as the report date and target end date;
# data is used differently when available
# comparison is done with --historical
#

REMOVE various removals:

# rt_samples = self.Rt_process.sample(n_steps=n_steps, **kwargs)["value"]
Expand Down
Loading

0 comments on commit 3200f51

Please sign in to comment.