Skip to content

Commit

Permalink
Jaxify legacy model evidence and diagnostics.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed May 1, 2024
1 parent 2a1af73 commit d1b9869
Showing 1 changed file with 23 additions and 71 deletions.
94 changes: 23 additions & 71 deletions harmonic/evidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,16 +269,16 @@ def add_chains(self, chains):

for i, i_samples in enumerate(range(i_samples_start, i_samples_end)):
lnpredict = self.model.predict(X[i_samples, :])
# lnpred[i_samples] = lnpredict
lnpred[i_samples] = lnpredict

lnprob = Y[i_samples]
lnargs[i_samples] = lnpredict - lnprob

if np.isinf(lnargs[i_samples]):
lnargs[i_samples] = np.nan

# if np.isinf(lnpred[i_samples]):
# lnpred[i_samples] = np.nan
if np.isinf(lnpred[i_samples]):
lnpred[i_samples] = np.nan

# The following performs a shift in log-space to avoid overflow or float
# rounding errors in realspace.
Expand All @@ -304,78 +304,30 @@ def get_nans_per_chain(lnargs, mask):
nans_num = jnp.sum(jnp.where(mask, jnp.isnan(lnargs), 0.0))
return nans_num

if self.batch_calculation:
lnargs += self.shift_value
lnargs += self.shift_value

masks = self.get_masks(jnp.array(chains.start_indices))
masks = self.get_masks(jnp.array(chains.start_indices))

running_sum_val = jax.vmap(get_running_sum, in_axes=(None, 0))(
lnargs, masks
)
self.running_sum += running_sum_val
running_sum_val = jax.vmap(get_running_sum, in_axes=(None, 0))(lnargs, masks)
self.running_sum += running_sum_val

# Count added number of samples per chain
added_nsamples_per_chain = np.diff(jnp.array(chains.start_indices))
self.nsamples_per_chain += added_nsamples_per_chain
# Count added number of samples per chain
added_nsamples_per_chain = np.diff(jnp.array(chains.start_indices))
self.nsamples_per_chain += added_nsamples_per_chain

# Count number of NaN values per chain and subtract to get effective
# number of added samples per chain
nan_count_per_chain = jax.vmap(get_nans_per_chain, in_axes=(None, 0))(
lnargs, masks
)
self.nsamples_eff_per_chain += (
added_nsamples_per_chain - nan_count_per_chain
)

self.lnargmax = jnp.nanmax(lnargs)
self.lnargmin = jnp.nanmin(lnargs)
self.lnprobmax = jnp.nanmax(Y)
self.lnprobmin = jnp.nanmin(Y)
self.lnpredictmax = jnp.nanmax(lnpred)
self.lnpredictmin = jnp.nanmin(lnpred)

else:
for i_chains in range(nchains):
i_samples_start = chains.start_indices[i_chains]
i_samples_end = chains.start_indices[i_chains + 1]

for i, i_samples in enumerate(range(i_samples_start, i_samples_end)):
# Apply shifting term to avoid overflow.
lnarg = lnargs[i_samples] + self.shift_value
# Store realspace or logspace sum depending on choice.
term = np.exp(lnarg)
nsamples_per_chain[i_chains] += 1

if not np.isnan(lnargs[i_samples]):
# Count number of samples used.
nsamples_eff_per_chain[i_chains] += 1

# Add contribution to running sum.
running_sum[i_chains] += term

# Log diagnostic terms.
self.lnargmax = (
lnarg if lnarg > self.lnargmax else self.lnargmax
)
self.lnargmin = (
lnarg if lnarg < self.lnargmin else self.lnargmin
)
self.lnprobmax = (
lnprob if lnprob > self.lnprobmax else self.lnprobmax
)
self.lnprobmin = (
lnprob if lnprob < self.lnprobmin else self.lnprobmin
)
self.lnpredictmax = (
lnpredict
if lnpredict > self.lnpredictmax
else self.lnpredictmax
)
self.lnpredictmin = (
lnpredict
if lnpredict < self.lnpredictmin
else self.lnpredictmin
)
# Count number of NaN values per chain and subtract to get effective
# number of added samples per chain
nan_count_per_chain = jax.vmap(get_nans_per_chain, in_axes=(None, 0))(
lnargs, masks
)
self.nsamples_eff_per_chain += added_nsamples_per_chain - nan_count_per_chain

self.lnargmax = jnp.nanmax(lnargs)
self.lnargmin = jnp.nanmin(lnargs)
self.lnprobmax = jnp.nanmax(Y)
self.lnprobmin = jnp.nanmin(Y)
self.lnpredictmax = jnp.nanmax(lnpred)
self.lnpredictmin = jnp.nanmin(lnpred)

self.process_run()
self.chains_added = True
Expand Down

0 comments on commit d1b9869

Please sign in to comment.