-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implemented Dense Hessian Laplace! * Docs fix * Fix tests * Updated build function and documentation in accordance with feedback
- Loading branch information
Showing
6 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
--- | ||
title: Laplace Dense Hessian | ||
--- | ||
|
||
# Laplace Dense Hessian | ||
|
||
::: posteriors.laplace.dense_hessian |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
from typing import Any, NamedTuple | ||
from functools import partial | ||
import torch | ||
from optree import tree_map | ||
from optree.integration.torch import tree_ravel | ||
|
||
from posteriors.types import TensorTree, Transform, LogProbFn | ||
from posteriors.tree_utils import tree_size | ||
from posteriors.utils import ( | ||
is_scalar, | ||
CatchAuxError, | ||
) | ||
|
||
from torch.func import jacrev, jacfwd | ||
|
||
|
||
def build( | ||
log_posterior: LogProbFn, | ||
init_prec: torch.Tensor | float = 0.0, | ||
epsilon: float = 0.0, | ||
rescale: float = 1.0, | ||
) -> Transform: | ||
"""Builds a transform for dense Hessian Laplace. | ||
**Warning:** | ||
The Hessian is not guaranteed to be positive definite, | ||
so setting epsilon > 0 ought to be considered. | ||
Args: | ||
log_posterior: Function that takes parameters and input batch and | ||
returns the log posterior value (which can be unnormalised) | ||
as well as auxiliary information, e.g. from the model call. | ||
init_prec: Initial precision matrix. | ||
If it is a float, it is defined as an identity matrix | ||
scaled by that float. | ||
epsilon: Added to the diagonal of the Hessian | ||
for numerical stability. | ||
rescale: Value to multiply the Hessian by | ||
(i.e. to normalize by batch size) | ||
Returns: | ||
Hessian Laplace transform instance. | ||
""" | ||
init_fn = partial(init, init_prec=init_prec) | ||
update_fn = partial( | ||
update, | ||
log_posterior=log_posterior, | ||
epsilon=epsilon, | ||
rescale=rescale, | ||
) | ||
return Transform(init_fn, update_fn) | ||
|
||
|
||
class DenseLaplaceState(NamedTuple): | ||
"""State encoding a Normal distribution over parameters, | ||
with a dense precision matrix | ||
Attributes: | ||
params: Mean of the Normal distribution. | ||
prec: Precision matrix of the Normal distribution. | ||
aux: Auxiliary information from the log_posterior call. | ||
""" | ||
|
||
params: TensorTree | ||
prec: torch.Tensor | ||
aux: Any = None | ||
|
||
|
||
def init( | ||
params: TensorTree, | ||
init_prec: torch.Tensor | float = 0.0, | ||
) -> DenseLaplaceState: | ||
"""Initialise Normal distribution over parameters | ||
with a dense precision matrix. | ||
Args: | ||
params: Mean of the Normal distribution. | ||
init_prec: Initial precision matrix. | ||
If it is a float, it is defined as an identity matrix | ||
scaled by that float. | ||
Returns: | ||
Initial DenseLaplaceState. | ||
""" | ||
|
||
if is_scalar(init_prec): | ||
num_params = tree_size(params) | ||
init_prec = init_prec * torch.eye(num_params, requires_grad=False) | ||
|
||
return DenseLaplaceState(params, init_prec) | ||
|
||
|
||
def update( | ||
state: DenseLaplaceState, | ||
batch: Any, | ||
log_posterior: LogProbFn, | ||
epsilon: float = 0.0, | ||
rescale: float = 1.0, | ||
inplace: bool = False, | ||
) -> DenseLaplaceState: | ||
"""Adds the Hessian of the negative log-posterior over given batch. | ||
**Warning:** | ||
The Hessian is not guaranteed to be positive definite, | ||
so setting epsilon > 0 ought to be considered. | ||
Args: | ||
state: Current state. | ||
batch: Input data to log_posterior. | ||
log_posterior: Function that takes parameters and input batch and | ||
returns the log posterior value (which can be unnormalised) | ||
epsilon: Added to the diagonal of the Hessian | ||
for numerical stability. | ||
rescale: Value to multiply the Hessian by | ||
(i.e. to normalize by batch size) | ||
inplace: If True, the state is updated in place. Otherwise, a new | ||
state is returned. | ||
Returns: | ||
Updated DenseLaplaceState. | ||
""" | ||
with torch.no_grad(), CatchAuxError(): | ||
flat_params, params_unravel = tree_ravel(state.params) | ||
num_params = flat_params.numel() | ||
|
||
def neg_log_p(p_flat): | ||
value, aux = log_posterior(params_unravel(p_flat), batch) | ||
return -value, aux | ||
|
||
hess, aux = jacfwd(jacrev(neg_log_p, has_aux=True), has_aux=True)(flat_params) | ||
hess = hess * rescale + epsilon * torch.eye(num_params) | ||
|
||
if inplace: | ||
state.prec.data += hess | ||
return state._replace(aux=aux) | ||
else: | ||
return DenseLaplaceState(state.params, state.prec + hess, aux) | ||
|
||
|
||
def sample( | ||
state: DenseLaplaceState, | ||
sample_shape: torch.Size = torch.Size([]), | ||
) -> TensorTree: | ||
"""Sample from Normal distribution over parameters. | ||
Args: | ||
state: State encoding mean and precision matrix. | ||
sample_shape: Shape of the desired samples. | ||
Returns: | ||
Sample(s) from the Normal distribution. | ||
""" | ||
samples = torch.distributions.MultivariateNormal( | ||
loc=torch.zeros(state.prec.shape[0], device=state.prec.device), | ||
precision_matrix=state.prec, | ||
validate_args=False, | ||
).sample(sample_shape) | ||
samples = samples.flatten(end_dim=-2) # ensure samples is 2D | ||
mean_flat, unravel_func = tree_ravel(state.params) | ||
samples += mean_flat | ||
samples = torch.vmap(unravel_func)(samples) | ||
samples = tree_map(lambda x: x.reshape(sample_shape + x.shape[-1:]), samples) | ||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import torch | ||
from torch.distributions import Normal | ||
from torch.utils.data import DataLoader, TensorDataset | ||
from torch.func import functional_call, hessian | ||
from optree import tree_map | ||
from optree.integration.torch import tree_ravel | ||
|
||
from posteriors import tree_size, diag_normal_log_prob | ||
from posteriors.laplace import dense_hessian | ||
|
||
from tests.scenarios import TestModel | ||
|
||
|
||
def normal_log_likelihood(y, y_pred): | ||
return ( | ||
Normal(y_pred, 1, validate_args=False).log_prob(y).sum(dim=-1) | ||
) # validate args introduces control flows not yet supported in torch.func.vmap | ||
|
||
|
||
def log_posterior_n(params, batch, model, n_data): | ||
y_pred = functional_call(model, params, batch[0]) | ||
return diag_normal_log_prob(params, mean=0.0, sd_diag=1.0) + normal_log_likelihood( | ||
batch[1], y_pred | ||
) * n_data, torch.tensor([]) | ||
|
||
|
||
def test_dense_hessian_vmap(): | ||
torch.manual_seed(42) | ||
model = TestModel() | ||
|
||
xs = torch.randn(100, 10) | ||
ys = model(xs) | ||
|
||
batch_size = 2 | ||
|
||
dataloader = DataLoader( | ||
TensorDataset(xs, ys), | ||
batch_size=batch_size, | ||
) | ||
|
||
def log_posterior(p, b): | ||
return log_posterior_n(p, b, model, len(xs))[0].mean(), torch.tensor([]) | ||
|
||
params = dict(model.named_parameters()) | ||
|
||
# Test inplace = False | ||
transform = dense_hessian.build(log_posterior) | ||
laplace_state = transform.init(params) | ||
laplace_state_prec_init = laplace_state.prec | ||
for batch in dataloader: | ||
laplace_state = transform.update( | ||
laplace_state, batch, rescale=batch_size / xs.size()[0], inplace=False | ||
) | ||
|
||
flat_params, params_unravel = tree_ravel(params) | ||
|
||
num_params = tree_size(params) | ||
expected = torch.zeros((num_params, num_params)) | ||
for x, y in zip(xs, ys): | ||
with torch.no_grad(): | ||
hess = hessian(lambda p: -log_posterior(params_unravel(p), (x, y))[0])( | ||
flat_params | ||
) | ||
expected += hess / xs.size()[0] | ||
|
||
assert torch.allclose(expected, laplace_state.prec, atol=1e-5) | ||
assert not torch.allclose(laplace_state.prec, laplace_state_prec_init) | ||
|
||
# Also check full batch | ||
laplace_state_fb = transform.init(params) | ||
laplace_state_fb = transform.update(laplace_state_fb, (xs, ys)) | ||
|
||
assert torch.allclose(expected, laplace_state_fb.prec, atol=1e-5) | ||
|
||
# Test inplace = True | ||
transform = dense_hessian.build(log_posterior) | ||
laplace_state = transform.init(params) | ||
laplace_state_prec_diag_init = laplace_state.prec | ||
for batch in dataloader: | ||
laplace_state = transform.update( | ||
laplace_state, batch, rescale=batch_size / xs.size()[0], inplace=True | ||
) | ||
|
||
assert torch.allclose(expected, laplace_state.prec, atol=1e-5) | ||
assert torch.allclose(laplace_state.prec, laplace_state_prec_diag_init, atol=1e-5) | ||
|
||
# Test sampling | ||
num_samples = 10000 | ||
laplace_state.prec.data += 0.1 * torch.eye( | ||
num_params | ||
) # regularize to ensure PSD and reduce variance | ||
|
||
mean_copy = tree_map(lambda x: x.clone(), laplace_state.params) | ||
sd_flat = torch.diag(torch.linalg.inv(laplace_state.prec)).sqrt() | ||
|
||
samples = dense_hessian.sample(laplace_state, (num_samples,)) | ||
|
||
samples_mean = tree_map(lambda x: x.mean(dim=0), samples) | ||
samples_sd = tree_map(lambda x: x.std(dim=0), samples) | ||
samples_sd_flat = tree_ravel(samples_sd)[0] | ||
|
||
for key in samples_mean: | ||
assert samples[key].shape[0] == num_samples | ||
assert samples[key].shape[1:] == samples_mean[key].shape | ||
assert torch.allclose(samples_mean[key], laplace_state.params[key], atol=1e-1) | ||
assert torch.allclose(mean_copy[key], laplace_state.params[key]) | ||
|
||
assert torch.allclose(sd_flat, samples_sd_flat, atol=1e-1) |