diff --git a/harmonic/evidence.py b/harmonic/evidence.py index 29523b4..1a6e20f 100644 --- a/harmonic/evidence.py +++ b/harmonic/evidence.py @@ -269,7 +269,7 @@ def add_chains(self, chains): for i, i_samples in enumerate(range(i_samples_start, i_samples_end)): lnpredict = self.model.predict(X[i_samples, :]) - # lnpred[i_samples] = lnpredict + lnpred[i_samples] = lnpredict lnprob = Y[i_samples] lnargs[i_samples] = lnpredict - lnprob @@ -277,8 +277,8 @@ def add_chains(self, chains): if np.isinf(lnargs[i_samples]): lnargs[i_samples] = np.nan - # if np.isinf(lnpred[i_samples]): - # lnpred[i_samples] = np.nan + if np.isinf(lnpred[i_samples]): + lnpred[i_samples] = np.nan # The following performs a shift in log-space to avoid overflow or float # rounding errors in realspace. @@ -304,78 +304,30 @@ def get_nans_per_chain(lnargs, mask): nans_num = jnp.sum(jnp.where(mask, jnp.isnan(lnargs), 0.0)) return nans_num - if self.batch_calculation: - lnargs += self.shift_value + lnargs += self.shift_value - masks = self.get_masks(jnp.array(chains.start_indices)) + masks = self.get_masks(jnp.array(chains.start_indices)) - running_sum_val = jax.vmap(get_running_sum, in_axes=(None, 0))( - lnargs, masks - ) - self.running_sum += running_sum_val + running_sum_val = jax.vmap(get_running_sum, in_axes=(None, 0))(lnargs, masks) + self.running_sum += running_sum_val - # Count added number of samples per chain - added_nsamples_per_chain = np.diff(jnp.array(chains.start_indices)) - self.nsamples_per_chain += added_nsamples_per_chain + # Count added number of samples per chain + added_nsamples_per_chain = np.diff(jnp.array(chains.start_indices)) + self.nsamples_per_chain += added_nsamples_per_chain - # Count number of NaN values per chain and subtract to get effective - # number of added samples per chain - nan_count_per_chain = jax.vmap(get_nans_per_chain, in_axes=(None, 0))( - lnargs, masks - ) - self.nsamples_eff_per_chain += ( - added_nsamples_per_chain - nan_count_per_chain - ) - - self.lnargmax = jnp.nanmax(lnargs) - self.lnargmin = jnp.nanmin(lnargs) - self.lnprobmax = jnp.nanmax(Y) - self.lnprobmin = jnp.nanmin(Y) - self.lnpredictmax = jnp.nanmax(lnpred) - self.lnpredictmin = jnp.nanmin(lnpred) - - else: - for i_chains in range(nchains): - i_samples_start = chains.start_indices[i_chains] - i_samples_end = chains.start_indices[i_chains + 1] - - for i, i_samples in enumerate(range(i_samples_start, i_samples_end)): - # Apply shifting term to avoid overflow. - lnarg = lnargs[i_samples] + self.shift_value - # Store realspace or logspace sum depending on choice. - term = np.exp(lnarg) - nsamples_per_chain[i_chains] += 1 - - if not np.isnan(lnargs[i_samples]): - # Count number of samples used. - nsamples_eff_per_chain[i_chains] += 1 - - # Add contribution to running sum. - running_sum[i_chains] += term - - # Log diagnostic terms. - self.lnargmax = ( - lnarg if lnarg > self.lnargmax else self.lnargmax - ) - self.lnargmin = ( - lnarg if lnarg < self.lnargmin else self.lnargmin - ) - self.lnprobmax = ( - lnprob if lnprob > self.lnprobmax else self.lnprobmax - ) - self.lnprobmin = ( - lnprob if lnprob < self.lnprobmin else self.lnprobmin - ) - self.lnpredictmax = ( - lnpredict - if lnpredict > self.lnpredictmax - else self.lnpredictmax - ) - self.lnpredictmin = ( - lnpredict - if lnpredict < self.lnpredictmin - else self.lnpredictmin - ) + # Count number of NaN values per chain and subtract to get effective + # number of added samples per chain + nan_count_per_chain = jax.vmap(get_nans_per_chain, in_axes=(None, 0))( + lnargs, masks + ) + self.nsamples_eff_per_chain += added_nsamples_per_chain - nan_count_per_chain + + self.lnargmax = jnp.nanmax(lnargs) + self.lnargmin = jnp.nanmin(lnargs) + self.lnprobmax = jnp.nanmax(Y) + self.lnprobmin = jnp.nanmin(Y) + self.lnpredictmax = jnp.nanmax(lnpred) + self.lnpredictmin = jnp.nanmin(lnpred) self.process_run() self.chains_added = True