Skip to content

Commit

Permalink
Merge pull request #9 from normal-computing/ekf
Browse files Browse the repository at this point in the history
EKF (diagonal)
  • Loading branch information
SamDuffield authored Feb 1, 2024
2 parents a08c55d + 0012b63 commit 2784e73
Show file tree
Hide file tree
Showing 14 changed files with 914 additions and 27 deletions.
337 changes: 337 additions & 0 deletions examples/yelp/yelp_subspace_ekf_diag_hessian.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/yelp/yelp_subspace_vi_diag.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@
}
],
"source": [
"# Visualize the standard deviations of the Laplace approximation\n",
"# Visualize the standard deviations of the final Normal distribution\n",
"sd_diag = torch.cat([v.exp().detach().cpu().flatten() for v in vi_state.log_sd_diag.values()]).numpy()\n",
"\n",
"plt.hist(sd_diag, bins=100, density=True);"
Expand Down
Empty file added tests/ekf/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions tests/ekf/test_diag_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from functools import partial
from typing import Any
import torch
from optree import tree_map

from uqlib import ekf
from uqlib.utils import diag_normal_log_prob


def batch_normal_log_prob(
p: dict, batch: Any, mean: dict, sd_diag: dict
) -> torch.Tensor:
return diag_normal_log_prob(p, mean, sd_diag)


def test_ekf_diag():
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean)

batch_normal_log_prob_spec = partial(
batch_normal_log_prob, mean=target_mean, sd_diag=target_sds
)

init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)

batch = torch.arange(3).reshape(-1, 1)

n_steps = 1000
transform = ekf.diag_fisher.build(batch_normal_log_prob_spec, lr=1e-3)

state = transform.init(init_mean)

log_liks = []

for _ in range(n_steps):
state = transform.update(state, batch)
log_liks.append(state.log_likelihood)

for key in state.mean:
assert torch.allclose(state.mean[key], target_mean[key], atol=1e-1)

# Test inplace
state_ip = transform.init(init_mean)
state_ip2 = transform.update(
state_ip,
batch,
inplace=True,
)

for key in state_ip2.mean:
assert torch.allclose(state_ip2.mean[key], state_ip.mean[key], atol=1e-8)
assert torch.allclose(state_ip2.sd_diag[key], state_ip.sd_diag[key], atol=1e-8)

# Test not inplace
state_ip_false = transform.init(
tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)
)
state_ip_false2 = transform.update(
state_ip_false,
batch,
inplace=False,
)

for key in state_ip.mean:
assert not torch.allclose(
state_ip_false2.mean[key], state_ip_false.mean[key], atol=1e-8
)
assert not torch.allclose(
state_ip_false2.sd_diag[key], state_ip_false.sd_diag[key], atol=1e-8
)
71 changes: 71 additions & 0 deletions tests/ekf/test_diag_hessian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from functools import partial
from typing import Any
import torch
from optree import tree_map

from uqlib import ekf
from uqlib.utils import diag_normal_log_prob


def batch_normal_log_prob(
p: dict, batch: Any, mean: dict, sd_diag: dict
) -> torch.Tensor:
return diag_normal_log_prob(p, mean, sd_diag)


def test_ekf_diag():
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean)

batch_normal_log_prob_spec = partial(
batch_normal_log_prob, mean=target_mean, sd_diag=target_sds
)

init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)

batch = torch.arange(3).reshape(-1, 1)

n_steps = 1000
transform = ekf.diag_hessian.build(batch_normal_log_prob_spec, lr=1e-3)

state = transform.init(init_mean)

log_liks = []

for _ in range(n_steps):
state = transform.update(state, batch)
log_liks.append(state.log_likelihood)

for key in state.mean:
assert torch.allclose(state.mean[key], target_mean[key], atol=1e-1)

# Test inplace
state_ip = transform.init(init_mean)
state_ip2 = transform.update(
state_ip,
batch,
inplace=True,
)

for key in state_ip2.mean:
assert torch.allclose(state_ip2.mean[key], state_ip.mean[key], atol=1e-8)
assert torch.allclose(state_ip2.sd_diag[key], state_ip.sd_diag[key], atol=1e-8)

# Test not inplace
state_ip_false = transform.init(
tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)
)
state_ip_false2 = transform.update(
state_ip_false,
batch,
inplace=False,
)

for key in state_ip.mean:
assert not torch.allclose(
state_ip_false2.mean[key], state_ip_false.mean[key], atol=1e-8
)
assert not torch.allclose(
state_ip_false2.sd_diag[key], state_ip_false.sd_diag[key], atol=1e-8
)
Empty file added tests/laplace/__init__.py
Empty file.
5 changes: 4 additions & 1 deletion uqlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from uqlib import ekf
from uqlib import laplace
from uqlib import vi
from uqlib import sgmcmc
from uqlib import types
from uqlib import vi

from uqlib.utils import model_to_function
from uqlib.utils import hvp
Expand All @@ -16,3 +17,5 @@
from uqlib.utils import insert_requires_grad_
from uqlib.utils import extract_requires_grad_and_func
from uqlib.utils import inplacify
from uqlib.utils import tree_map_inplacify_
from uqlib.utils import flexi_tree_map
2 changes: 2 additions & 0 deletions uqlib/ekf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from uqlib.ekf import diag_fisher
from uqlib.ekf import diag_hessian
170 changes: 170 additions & 0 deletions uqlib/ekf/diag_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Callable, Any, NamedTuple
from functools import partial
import torch
from torch.func import vmap, jacrev
from optree import tree_map

from uqlib.types import TensorTree, Transform
from uqlib.utils import diag_normal_sample, flexi_tree_map


class EKFDiagState(NamedTuple):
"""State encoding a diagonal Normal distribution over parameters.
Args:
mean: Mean of the Normal distribution.
sd_diag: Square-root diagonal of the covariance matrix of the
Normal distribution.
log_likelihood: Log likelihood of the data given the parameters.
"""

mean: TensorTree
sd_diag: TensorTree
log_likelihood: float = 0


def init(
params: TensorTree,
init_sds: TensorTree | None = None,
) -> EKFDiagState:
"""Initialise diagonal Normal distribution over parameters.
Args:
params: Initial mean of the variational distribution.
init_sds: Initial square-root diagonal of the covariance matrix
of the variational distribution. Defaults to ones.
Returns:
Initial EKFDiagState.
"""
if init_sds is None:
init_sds = tree_map(
lambda x: torch.ones_like(x, requires_grad=x.requires_grad), params
)

return EKFDiagState(params, init_sds)


def update(
state: EKFDiagState,
batch: Any,
log_likelihood: Callable[[TensorTree, Any], float],
lr: float,
transition_sd: float = 0.0,
per_sample: bool = False,
inplace: bool = True,
) -> EKFDiagState:
"""Applies an extended Kalman Filter update to the diagonal Normal distribution.
The update is first order, i.e. the likelihood is approximated by a
log p(y | x, p) ≈ log p(y | x, μ) + lr * g(μ)ᵀ(p - μ)
+ lr * 1/2 (p - μ)ᵀ F_d(μ) (p - μ) T⁻¹
where μ is the mean of the variational distribution, lr is the learning rate
(likelihood inverse temperature), whilst g(μ) is the gradient and F_d(μ) the
negative diagonal empirical Fisher of the log-likelihood with respect to the
parameters.
Args:
state: Current state.
batch: Input data to log_likelihood.
log_likelihood: Function that takes parameters and input batch and
returns the log-likelihood.
lr: Inverse temperature of the update, which behaves like a learning rate.
see https://arxiv.org/abs/1703.00209 for details.
transition_sd: Standard deviation of the transition noise, to additively
inflate the diagonal covariance before the update. Defaults to zero.
per_sample: If True, then log_likelihood is assumed to return a vector of
log likelihoods for each sample in the batch. If False, then log_likelihood
is assumed to return a scalar log likelihood for the whole batch, in this
case torch.func.vmap will be called, this is typically slower than
directly writing log_likelihood to be per sample.
inplace: Whether to update the state parameters in-place.
Returns:
Updated EKFDiagState.
"""

if per_sample:
log_likelihood_per_sample = log_likelihood
else:
# per-sample gradients following https://pytorch.org/tutorials/intermediate/per_sample_grads.html
@partial(vmap, in_dims=(None, 0))
def log_likelihood_per_sample(params, batch):
batch = tree_map(lambda x: x.unsqueeze(0), batch)
return log_likelihood(params, batch)

predict_sd_diag = flexi_tree_map(
lambda x: (x**2 + transition_sd**2) ** 0.5, state.sd_diag, inplace=inplace
)
with torch.no_grad():
log_lik = log_likelihood_per_sample(state.mean, batch).mean()
jac = jacrev(log_likelihood_per_sample)(state.mean, batch)
grad = tree_map(lambda x: x.mean(0), jac)
diag_lik_hessian_approx = tree_map(lambda x: -(x**2).mean(0), jac)

update_sd_diag = flexi_tree_map(
lambda sig, h: (sig**-2 - lr * h) ** -0.5,
predict_sd_diag,
diag_lik_hessian_approx,
inplace=inplace,
)
update_mean = flexi_tree_map(
lambda mu, sig, g: mu + sig**2 * lr * g,
state.mean,
update_sd_diag,
grad,
inplace=inplace,
)
return EKFDiagState(update_mean, update_sd_diag, log_lik.item())


def build(
log_likelihood: Callable[[TensorTree, Any], float],
lr: float,
transition_sd: float = 0.0,
per_sample: bool = False,
init_sds: TensorTree | None = None,
) -> Transform:
"""Builds a transform for variational inference with a diagonal Normal
distribution over parameters.
Args:
log_likelihood: Function that takes parameters and input batch and
returns the log-likelihood.
lr: Inverse temperature of the update, which behaves like a learning rate.
see https://arxiv.org/abs/1703.00209 for details.
transition_sd: Standard deviation of the transition noise, to additively
inflate the diagonal covariance before the update. Defaults to zero.
per_sample: If True, then log_likelihood is assumed to return a vector of
log likelihoods for each sample in the batch. If False, then log_likelihood
is assumed to return a scalar log likelihood for the whole batch, in this
case torch.func.vmap will be called, this is typically slower than
directly writing log_likelihood to be per sample.
init_sds: Initial square-root diagonal of the covariance matrix
of the variational distribution. Defaults to ones.
Returns:
Diagonal EKF transform (uqlib.types.Transform instance).
"""
init_fn = partial(init, init_sds=init_sds)
update_fn = partial(
update,
log_likelihood=log_likelihood,
lr=lr,
transition_sd=transition_sd,
per_sample=per_sample,
)
return Transform(init_fn, update_fn)


def sample(state: EKFDiagState, sample_shape: torch.Size = torch.Size([])):
"""Single sample from diagonal Normal distribution over parameters.
Args:
state: State encoding mean and standard deviations.
Returns:
Sample from Normal distribution.
"""
return diag_normal_sample(state.mean, state.sd_diag, sample_shape=sample_shape)
Loading

0 comments on commit 2784e73

Please sign in to comment.