Skip to content

Commit

Permalink
Add integral calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
BradyPlanden committed Nov 14, 2024
1 parent 4f45591 commit c5eef05
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
16 changes: 13 additions & 3 deletions pybop/samplers/annealed_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def transition_distribution(self, x, j):
j
] * self._log_likelihood(x)

def run(self) -> tuple[float, float, float, float]:
def run(self) -> tuple[float, float, float]:
"""
Run the annealed importance sampling algorithm.
Expand All @@ -85,6 +85,8 @@ def run(self) -> tuple[float, float, float, float]:
ValueError: If starting position has non-finite log-likelihood
"""
log_w = np.zeros(self._chains)
I = np.zeros(self._chains)
samples = np.zeros(self._num_beta)

for i in range(self._chains):
current = self._log_prior.rvs()
Expand Down Expand Up @@ -119,10 +121,18 @@ def run(self) -> tuple[float, float, float, float]:
current = proposed
current_f = proposed_f

samples[j] = current

# Sum for weights (eqn.24)
log_w[i] = (
np.sum(log_density_current - log_density_previous) / self._num_beta
)

# Return moments across chains
return np.mean(log_w), np.median(log_w), np.std(log_w), np.var(log_w)
# Compute integral using weights and samples
I[i] = np.mean(
self._log_likelihood(samples)
* np.exp((log_density_current - log_density_previous) / self._num_beta)
)

# Return log weights, integral, samples
return log_w, I, samples
6 changes: 3 additions & 3 deletions tests/unit/test_annealed_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def scaled_likelihood(x):

# Sample
sampler = pybop.AnnealedImportanceSampler(
scaled_likelihood, prior, chains=15, num_beta=200, cov0=np.eye(1) * 5e-2
scaled_likelihood, prior, chains=15, num_beta=500, cov0=np.eye(1) * 1e-2
)
mean, median, std, var = sampler.run()
log_w, I, samples = sampler.run()

# Assertions to be added
print(f"mean: {mean}, std: {std}, median: {median}, var: {var}")
print(f"Integral: {np.mean(I)}, std: {np.std(I)}")

0 comments on commit c5eef05

Please sign in to comment.