Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Annealed Importance Sampling #550

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions examples/scripts/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import pybamm

import pybop

# Parameter set and model definition
solver = pybamm.IDAKLUSolver()
parameter_set = pybop.ParameterSet.pybamm("Chen2020")
parameter_set.update(
{
"Negative electrode active material volume fraction": 0.66,
"Positive electrode active material volume fraction": 0.68,
}
)
synth_model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=solver)

# Fitting parameters
parameters = pybop.Parameters(
pybop.Parameter(
"Negative electrode active material volume fraction",
prior=pybop.Gaussian(0.6, 0.02),
),
pybop.Parameter(
"Positive electrode active material volume fraction",
prior=pybop.Gaussian(0.6, 0.02),
),
)

# Generate data
init_soc = 0.5
sigma = 0.002


def noise(sigma):
return np.random.normal(0, sigma, len(values["Voltage [V]"].data))


experiment = pybop.Experiment(
[
(
"Discharge at 0.5C for 1 minutes (5 second period)",
"Charge at 0.5C for 1 minutes (5 second period)",
),
]
)
values = synth_model.predict(
initial_state={"Initial SoC": init_soc}, experiment=experiment
)

# Form dataset
dataset = pybop.Dataset(
{
"Time [s]": values["Time [s]"].data,
"Current function [A]": values["Current [A]"].data,
"Voltage [V]": values["Voltage [V]"].data + noise(sigma),
}
)

# Generate problem, likelihood, and sampler
model = pybop.lithium_ion.SPMe(
parameter_set=parameter_set, solver=pybamm.IDAKLUSolver()
)
model.build(initial_state={"Initial SoC": init_soc})
problem = pybop.FittingProblem(model, parameters, dataset)
likelihood = pybop.GaussianLogLikelihoodKnownSigma(problem, sigma0=sigma)
# posterior = pybop.LogPosterior(likelihood)
prior = pybop.JointLogPrior(*parameters.priors())

sampler = pybop.AnnealedImportanceSampler(
likelihood, prior, chains=100, num_beta=100, cov0=np.eye(2) * 4e-4
)
mean, median, std, var = sampler.run()

print(f"mean: {mean}, std: {std}, median: {median}, var: {var}")
1 change: 1 addition & 0 deletions pybop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
SliceRankShrinkingMCMC, SliceStepoutMCMC,
)
from .samplers.mcmc_sampler import MCMCSampler
from .samplers.annealed_importance import AnnealedImportanceSampler

#
# Observer classes
Expand Down
35 changes: 35 additions & 0 deletions pybop/parameters/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,41 @@

return output, doutput

def rvs(self, size=1, random_state=None):
"""
Generates random variates from the joint distribution.

Parameters
----------
size : int
The number of random variates to generate.
random_state : int, optional
The random state seed for reproducibility. Default is None.

Returns
-------
array_like
An array of random variates from the distribution.

Raises
------
ValueError
If the size parameter is negative.
"""
if not isinstance(size, (int, tuple)):
raise ValueError(

Check warning on line 455 in pybop/parameters/priors.py

View check run for this annotation

Codecov / codecov/patch

pybop/parameters/priors.py#L454-L455

Added lines #L454 - L455 were not covered by tests
"size must be a positive integer or tuple of positive integers"
)
if isinstance(size, int) and size < 1:
raise ValueError("size must be a positive integer")
if isinstance(size, tuple) and any(s < 1 for s in size):
raise ValueError("size must be a tuple of positive integers")

Check warning on line 461 in pybop/parameters/priors.py

View check run for this annotation

Codecov / codecov/patch

pybop/parameters/priors.py#L458-L461

Added lines #L458 - L461 were not covered by tests

samples = []
for prior in self._priors:
samples.append(prior.rvs(size=size, random_state=random_state)[0])
return samples

Check warning on line 466 in pybop/parameters/priors.py

View check run for this annotation

Codecov / codecov/patch

pybop/parameters/priors.py#L463-L466

Added lines #L463 - L466 were not covered by tests

def __repr__(self) -> str:
priors_repr = ", ".join([repr(prior) for prior in self._priors])
return f"{self.__class__.__name__}(priors: [{priors_repr}])"
138 changes: 138 additions & 0 deletions pybop/samplers/annealed_importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Optional

import numpy as np

from pybop import BaseLikelihood, BasePrior


class AnnealedImportanceSampler:
"""
This class implements annealed importance sampling of
the posterior distribution to compute the model evidence
introduced in [1].

[1] "Annealed Importance Sampling", Radford M. Neal, 1998, Technical Report
No. 9805.
"""

def __init__(
self,
log_likelihood: BaseLikelihood,
log_prior: BasePrior,
cov0=None,
num_beta: int = 30,
chains: Optional[int] = None,
):
self._log_likelihood = log_likelihood
self._log_prior = log_prior
self._cov0 = (
cov0 if cov0 is not None else np.eye(log_likelihood.n_parameters) * 0.1
)

# Total number of iterations
self._chains = (
chains if chains is not None else log_likelihood.n_parameters * 300
)

# Number of beta divisions to consider 0 = beta_n <
# beta_n-1 < ... < beta_0 = 1
self.set_num_beta(num_beta)

@property
def chains(self) -> int:
"""Returns the total number of iterations."""
return self._chains

Check warning on line 44 in pybop/samplers/annealed_importance.py

View check run for this annotation

Codecov / codecov/patch

pybop/samplers/annealed_importance.py#L44

Added line #L44 was not covered by tests

@chains.setter
def chains(self, value: int) -> None:
"""Sets the total number of iterations."""
if not isinstance(value, (int, np.integer)):
raise TypeError("iterations must be an integer")
if value <= 0:
raise ValueError("iterations must be positive")
self._chains = int(value)

Check warning on line 53 in pybop/samplers/annealed_importance.py

View check run for this annotation

Codecov / codecov/patch

pybop/samplers/annealed_importance.py#L49-L53

Added lines #L49 - L53 were not covered by tests

@property
def num_beta(self) -> int:
"""Returns the number of beta points"""
return self._num_beta

Check warning on line 58 in pybop/samplers/annealed_importance.py

View check run for this annotation

Codecov / codecov/patch

pybop/samplers/annealed_importance.py#L58

Added line #L58 was not covered by tests

def set_num_beta(self, num_beta: int) -> None:
"""Sets the number of beta point values."""
if not isinstance(num_beta, (int, np.integer)):
raise TypeError("num_beta must be an integer")

Check warning on line 63 in pybop/samplers/annealed_importance.py

View check run for this annotation

Codecov / codecov/patch

pybop/samplers/annealed_importance.py#L63

Added line #L63 was not covered by tests
if num_beta <= 1:
raise ValueError("num_beta must be greater than 1")

Check warning on line 65 in pybop/samplers/annealed_importance.py

View check run for this annotation

Codecov / codecov/patch

pybop/samplers/annealed_importance.py#L65

Added line #L65 was not covered by tests
self._num_beta = num_beta
self._beta = np.linspace(0, 1, num_beta)

def transition_distribution(self, x, j):
"""
Transition distribution for each beta value [j] - Eqn 3.
"""
return (1.0 - self._beta[j]) * self._log_prior(x) + self._beta[
j
] * self._log_likelihood(x)

def run(self) -> tuple[float, float, float]:
"""
Run the annealed importance sampling algorithm.

Returns:
Tuple containing (mean, median, std, variance) of the log weights

Raises:
ValueError: If starting position has non-finite log-likelihood
"""
log_w = np.zeros(self._chains)
I = np.zeros(self._chains)
samples = np.zeros(self._num_beta)

for i in range(self._chains):
current = self._log_prior.rvs()
if not np.isfinite(self._log_likelihood(current)):
raise ValueError("Starting position has non-finite log-likelihood.")

Check warning on line 94 in pybop/samplers/annealed_importance.py

View check run for this annotation

Codecov / codecov/patch

pybop/samplers/annealed_importance.py#L94

Added line #L94 was not covered by tests

current_f = self._log_prior(current)

log_density_current = np.zeros(self._num_beta)
log_density_current[0] = current_f
log_density_previous = np.zeros(self._num_beta)
log_density_previous[0] = current_f
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, log_density_previous[0] should be f_{1}(x_0) from eqn 11


# Main sampling loop
for j in range(1, self._num_beta):
# Compute jth transition with current sample
log_density_current[j] = self.transition_distribution(current, j)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are doing f_j(x_j) / f_{j-1}(x_j), should be f_j(x_j) / f_{j+1}(x_j) from eqn 11


# Calculate the previous transition with current sample
log_density_previous[j] = self.transition_distribution(current, j - 1)

# Generate new sample from current (eqn.4)
proposed = np.random.multivariate_normal(current, self._cov0)

# Evaluate proposed sample
if np.isfinite(self._log_likelihood(proposed)):
proposed_f = self.transition_distribution(proposed, j)

# Metropolis acceptance
acceptance_log_prob = proposed_f - current_f
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the gaussian test case they use in the paper, a much more complicated transition T_j is used, a sequence of 3 metropolis tests repeated 5-10 times. I'm not sure if all that is neccessary however

if np.log(np.random.rand()) < acceptance_log_prob:
current = proposed
current_f = proposed_f

samples[j] = current

# Sum for weights (eqn.24)
log_w[i] = (
np.sum(log_density_current - log_density_previous) / self._num_beta
)

# Compute integral using weights and samples
I[i] = np.mean(
self._log_likelihood(samples)
* np.exp((log_density_current - log_density_previous) / self._num_beta)
)

# Return log weights, integral, samples
return log_w, I, samples
28 changes: 28 additions & 0 deletions tests/unit/test_annealed_importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy as np
import pytest

import pybop


class TestPintsSamplers:
"""
Class for unit tests of AnnealedImportanceSampler.
"""

@pytest.mark.unit
def test_annealed_importance_sampler(self):
likelihood = pybop.Gaussian(5, 0.5)

def scaled_likelihood(x):
return likelihood(x) * 2

prior = pybop.Gaussian(4.7, 2)

# Sample
sampler = pybop.AnnealedImportanceSampler(
scaled_likelihood, prior, chains=15, num_beta=500, cov0=np.eye(1) * 1e-2
)
log_w, I, samples = sampler.run()

# Assertions to be added
print(f"Integral: {np.mean(I)}, std: {np.std(I)}")
6 changes: 3 additions & 3 deletions tests/unit/test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def dataset(self, model, experiment):

@pytest.fixture
def signal(self):
return "Voltage [V]"
return ["Voltage [V]"]

@pytest.mark.unit
def test_base_problem(self, parameters, model, dataset):
Expand Down Expand Up @@ -248,8 +248,8 @@ def test_multi_fitting_problem(self, model, parameters, dataset, signal):
problem_1._dataset["Time [s]"]
) + len(problem_2._dataset["Time [s]"])
assert len(combined_problem._dataset["Combined signal"]) == len(
problem_1._dataset[signal]
) + len(problem_2._dataset[signal])
problem_1._dataset[signal[0]]
) + len(problem_2._dataset[signal[0]])

y = combined_problem.evaluate(inputs=[1e-5, 1e-5])
assert len(y["Combined signal"]) == len(
Expand Down
Loading