diff --git a/pybop/samplers/annealed_importance.py b/pybop/samplers/annealed_importance.py index 3f0f301e..026412ab 100644 --- a/pybop/samplers/annealed_importance.py +++ b/pybop/samplers/annealed_importance.py @@ -74,7 +74,7 @@ def transition_distribution(self, x, j): j ] * self._log_likelihood(x) - def run(self) -> tuple[float, float, float, float]: + def run(self) -> tuple[float, float, float]: """ Run the annealed importance sampling algorithm. @@ -85,6 +85,8 @@ def run(self) -> tuple[float, float, float, float]: ValueError: If starting position has non-finite log-likelihood """ log_w = np.zeros(self._chains) + I = np.zeros(self._chains) + samples = np.zeros(self._num_beta) for i in range(self._chains): current = self._log_prior.rvs() @@ -119,10 +121,18 @@ def run(self) -> tuple[float, float, float, float]: current = proposed current_f = proposed_f + samples[j] = current + # Sum for weights (eqn.24) log_w[i] = ( np.sum(log_density_current - log_density_previous) / self._num_beta ) - # Return moments across chains - return np.mean(log_w), np.median(log_w), np.std(log_w), np.var(log_w) + # Compute integral using weights and samples + I[i] = np.mean( + self._log_likelihood(samples) + * np.exp((log_density_current - log_density_previous) / self._num_beta) + ) + + # Return log weights, integral, samples + return log_w, I, samples diff --git a/tests/unit/test_annealed_importance.py b/tests/unit/test_annealed_importance.py index 43954dd0..08287fb5 100644 --- a/tests/unit/test_annealed_importance.py +++ b/tests/unit/test_annealed_importance.py @@ -20,9 +20,9 @@ def scaled_likelihood(x): # Sample sampler = pybop.AnnealedImportanceSampler( - scaled_likelihood, prior, chains=15, num_beta=200, cov0=np.eye(1) * 5e-2 + scaled_likelihood, prior, chains=15, num_beta=500, cov0=np.eye(1) * 1e-2 ) - mean, median, std, var = sampler.run() + log_w, I, samples = sampler.run() # Assertions to be added - print(f"mean: {mean}, std: {std}, median: {median}, var: {var}") + print(f"Integral: {np.mean(I)}, std: {np.std(I)}")