-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from normal-computing/ekf
EKF (diagonal)
- Loading branch information
Showing
14 changed files
with
914 additions
and
27 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from functools import partial | ||
from typing import Any | ||
import torch | ||
from optree import tree_map | ||
|
||
from uqlib import ekf | ||
from uqlib.utils import diag_normal_log_prob | ||
|
||
|
||
def batch_normal_log_prob( | ||
p: dict, batch: Any, mean: dict, sd_diag: dict | ||
) -> torch.Tensor: | ||
return diag_normal_log_prob(p, mean, sd_diag) | ||
|
||
|
||
def test_ekf_diag(): | ||
torch.manual_seed(42) | ||
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)} | ||
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean) | ||
|
||
batch_normal_log_prob_spec = partial( | ||
batch_normal_log_prob, mean=target_mean, sd_diag=target_sds | ||
) | ||
|
||
init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) | ||
|
||
batch = torch.arange(3).reshape(-1, 1) | ||
|
||
n_steps = 1000 | ||
transform = ekf.diag_fisher.build(batch_normal_log_prob_spec, lr=1e-3) | ||
|
||
state = transform.init(init_mean) | ||
|
||
log_liks = [] | ||
|
||
for _ in range(n_steps): | ||
state = transform.update(state, batch) | ||
log_liks.append(state.log_likelihood) | ||
|
||
for key in state.mean: | ||
assert torch.allclose(state.mean[key], target_mean[key], atol=1e-1) | ||
|
||
# Test inplace | ||
state_ip = transform.init(init_mean) | ||
state_ip2 = transform.update( | ||
state_ip, | ||
batch, | ||
inplace=True, | ||
) | ||
|
||
for key in state_ip2.mean: | ||
assert torch.allclose(state_ip2.mean[key], state_ip.mean[key], atol=1e-8) | ||
assert torch.allclose(state_ip2.sd_diag[key], state_ip.sd_diag[key], atol=1e-8) | ||
|
||
# Test not inplace | ||
state_ip_false = transform.init( | ||
tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) | ||
) | ||
state_ip_false2 = transform.update( | ||
state_ip_false, | ||
batch, | ||
inplace=False, | ||
) | ||
|
||
for key in state_ip.mean: | ||
assert not torch.allclose( | ||
state_ip_false2.mean[key], state_ip_false.mean[key], atol=1e-8 | ||
) | ||
assert not torch.allclose( | ||
state_ip_false2.sd_diag[key], state_ip_false.sd_diag[key], atol=1e-8 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from functools import partial | ||
from typing import Any | ||
import torch | ||
from optree import tree_map | ||
|
||
from uqlib import ekf | ||
from uqlib.utils import diag_normal_log_prob | ||
|
||
|
||
def batch_normal_log_prob( | ||
p: dict, batch: Any, mean: dict, sd_diag: dict | ||
) -> torch.Tensor: | ||
return diag_normal_log_prob(p, mean, sd_diag) | ||
|
||
|
||
def test_ekf_diag(): | ||
torch.manual_seed(42) | ||
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)} | ||
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean) | ||
|
||
batch_normal_log_prob_spec = partial( | ||
batch_normal_log_prob, mean=target_mean, sd_diag=target_sds | ||
) | ||
|
||
init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) | ||
|
||
batch = torch.arange(3).reshape(-1, 1) | ||
|
||
n_steps = 1000 | ||
transform = ekf.diag_hessian.build(batch_normal_log_prob_spec, lr=1e-3) | ||
|
||
state = transform.init(init_mean) | ||
|
||
log_liks = [] | ||
|
||
for _ in range(n_steps): | ||
state = transform.update(state, batch) | ||
log_liks.append(state.log_likelihood) | ||
|
||
for key in state.mean: | ||
assert torch.allclose(state.mean[key], target_mean[key], atol=1e-1) | ||
|
||
# Test inplace | ||
state_ip = transform.init(init_mean) | ||
state_ip2 = transform.update( | ||
state_ip, | ||
batch, | ||
inplace=True, | ||
) | ||
|
||
for key in state_ip2.mean: | ||
assert torch.allclose(state_ip2.mean[key], state_ip.mean[key], atol=1e-8) | ||
assert torch.allclose(state_ip2.sd_diag[key], state_ip.sd_diag[key], atol=1e-8) | ||
|
||
# Test not inplace | ||
state_ip_false = transform.init( | ||
tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) | ||
) | ||
state_ip_false2 = transform.update( | ||
state_ip_false, | ||
batch, | ||
inplace=False, | ||
) | ||
|
||
for key in state_ip.mean: | ||
assert not torch.allclose( | ||
state_ip_false2.mean[key], state_ip_false.mean[key], atol=1e-8 | ||
) | ||
assert not torch.allclose( | ||
state_ip_false2.sd_diag[key], state_ip_false.sd_diag[key], atol=1e-8 | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from uqlib.ekf import diag_fisher | ||
from uqlib.ekf import diag_hessian |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from typing import Callable, Any, NamedTuple | ||
from functools import partial | ||
import torch | ||
from torch.func import vmap, jacrev | ||
from optree import tree_map | ||
|
||
from uqlib.types import TensorTree, Transform | ||
from uqlib.utils import diag_normal_sample, flexi_tree_map | ||
|
||
|
||
class EKFDiagState(NamedTuple): | ||
"""State encoding a diagonal Normal distribution over parameters. | ||
Args: | ||
mean: Mean of the Normal distribution. | ||
sd_diag: Square-root diagonal of the covariance matrix of the | ||
Normal distribution. | ||
log_likelihood: Log likelihood of the data given the parameters. | ||
""" | ||
|
||
mean: TensorTree | ||
sd_diag: TensorTree | ||
log_likelihood: float = 0 | ||
|
||
|
||
def init( | ||
params: TensorTree, | ||
init_sds: TensorTree | None = None, | ||
) -> EKFDiagState: | ||
"""Initialise diagonal Normal distribution over parameters. | ||
Args: | ||
params: Initial mean of the variational distribution. | ||
init_sds: Initial square-root diagonal of the covariance matrix | ||
of the variational distribution. Defaults to ones. | ||
Returns: | ||
Initial EKFDiagState. | ||
""" | ||
if init_sds is None: | ||
init_sds = tree_map( | ||
lambda x: torch.ones_like(x, requires_grad=x.requires_grad), params | ||
) | ||
|
||
return EKFDiagState(params, init_sds) | ||
|
||
|
||
def update( | ||
state: EKFDiagState, | ||
batch: Any, | ||
log_likelihood: Callable[[TensorTree, Any], float], | ||
lr: float, | ||
transition_sd: float = 0.0, | ||
per_sample: bool = False, | ||
inplace: bool = True, | ||
) -> EKFDiagState: | ||
"""Applies an extended Kalman Filter update to the diagonal Normal distribution. | ||
The update is first order, i.e. the likelihood is approximated by a | ||
log p(y | x, p) ≈ log p(y | x, μ) + lr * g(μ)ᵀ(p - μ) | ||
+ lr * 1/2 (p - μ)ᵀ F_d(μ) (p - μ) T⁻¹ | ||
where μ is the mean of the variational distribution, lr is the learning rate | ||
(likelihood inverse temperature), whilst g(μ) is the gradient and F_d(μ) the | ||
negative diagonal empirical Fisher of the log-likelihood with respect to the | ||
parameters. | ||
Args: | ||
state: Current state. | ||
batch: Input data to log_likelihood. | ||
log_likelihood: Function that takes parameters and input batch and | ||
returns the log-likelihood. | ||
lr: Inverse temperature of the update, which behaves like a learning rate. | ||
see https://arxiv.org/abs/1703.00209 for details. | ||
transition_sd: Standard deviation of the transition noise, to additively | ||
inflate the diagonal covariance before the update. Defaults to zero. | ||
per_sample: If True, then log_likelihood is assumed to return a vector of | ||
log likelihoods for each sample in the batch. If False, then log_likelihood | ||
is assumed to return a scalar log likelihood for the whole batch, in this | ||
case torch.func.vmap will be called, this is typically slower than | ||
directly writing log_likelihood to be per sample. | ||
inplace: Whether to update the state parameters in-place. | ||
Returns: | ||
Updated EKFDiagState. | ||
""" | ||
|
||
if per_sample: | ||
log_likelihood_per_sample = log_likelihood | ||
else: | ||
# per-sample gradients following https://pytorch.org/tutorials/intermediate/per_sample_grads.html | ||
@partial(vmap, in_dims=(None, 0)) | ||
def log_likelihood_per_sample(params, batch): | ||
batch = tree_map(lambda x: x.unsqueeze(0), batch) | ||
return log_likelihood(params, batch) | ||
|
||
predict_sd_diag = flexi_tree_map( | ||
lambda x: (x**2 + transition_sd**2) ** 0.5, state.sd_diag, inplace=inplace | ||
) | ||
with torch.no_grad(): | ||
log_lik = log_likelihood_per_sample(state.mean, batch).mean() | ||
jac = jacrev(log_likelihood_per_sample)(state.mean, batch) | ||
grad = tree_map(lambda x: x.mean(0), jac) | ||
diag_lik_hessian_approx = tree_map(lambda x: -(x**2).mean(0), jac) | ||
|
||
update_sd_diag = flexi_tree_map( | ||
lambda sig, h: (sig**-2 - lr * h) ** -0.5, | ||
predict_sd_diag, | ||
diag_lik_hessian_approx, | ||
inplace=inplace, | ||
) | ||
update_mean = flexi_tree_map( | ||
lambda mu, sig, g: mu + sig**2 * lr * g, | ||
state.mean, | ||
update_sd_diag, | ||
grad, | ||
inplace=inplace, | ||
) | ||
return EKFDiagState(update_mean, update_sd_diag, log_lik.item()) | ||
|
||
|
||
def build( | ||
log_likelihood: Callable[[TensorTree, Any], float], | ||
lr: float, | ||
transition_sd: float = 0.0, | ||
per_sample: bool = False, | ||
init_sds: TensorTree | None = None, | ||
) -> Transform: | ||
"""Builds a transform for variational inference with a diagonal Normal | ||
distribution over parameters. | ||
Args: | ||
log_likelihood: Function that takes parameters and input batch and | ||
returns the log-likelihood. | ||
lr: Inverse temperature of the update, which behaves like a learning rate. | ||
see https://arxiv.org/abs/1703.00209 for details. | ||
transition_sd: Standard deviation of the transition noise, to additively | ||
inflate the diagonal covariance before the update. Defaults to zero. | ||
per_sample: If True, then log_likelihood is assumed to return a vector of | ||
log likelihoods for each sample in the batch. If False, then log_likelihood | ||
is assumed to return a scalar log likelihood for the whole batch, in this | ||
case torch.func.vmap will be called, this is typically slower than | ||
directly writing log_likelihood to be per sample. | ||
init_sds: Initial square-root diagonal of the covariance matrix | ||
of the variational distribution. Defaults to ones. | ||
Returns: | ||
Diagonal EKF transform (uqlib.types.Transform instance). | ||
""" | ||
init_fn = partial(init, init_sds=init_sds) | ||
update_fn = partial( | ||
update, | ||
log_likelihood=log_likelihood, | ||
lr=lr, | ||
transition_sd=transition_sd, | ||
per_sample=per_sample, | ||
) | ||
return Transform(init_fn, update_fn) | ||
|
||
|
||
def sample(state: EKFDiagState, sample_shape: torch.Size = torch.Size([])): | ||
"""Single sample from diagonal Normal distribution over parameters. | ||
Args: | ||
state: State encoding mean and standard deviations. | ||
Returns: | ||
Sample from Normal distribution. | ||
""" | ||
return diag_normal_sample(state.mean, state.sd_diag, sample_shape=sample_shape) |
Oops, something went wrong.