Skip to content

Commit

Permalink
Dense Hessian Laplace (#112)
Browse files Browse the repository at this point in the history
* Implemented Dense Hessian Laplace!

* Docs fix

* Fix tests

* Updated build function and documentation in accordance with feedback
  • Loading branch information
jcqcai authored Oct 10, 2024
1 parent 87e0af0 commit 3505484
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ approximation](https://arxiv.org/abs/2106.14806).
- [`laplace.dense_ggn`](laplace/dense_ggn.md) calculates the Generalised
Gauss-Newton matrix which is equivalent to the non-empirical Fisher in most
neural network settings - see [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).
- [`laplace.dense_hessian`](laplace/dense_hessian.md) calculates the Hessian of the negative
log posterior.
- [`laplace.diag_fisher`](laplace/diag_fisher.md) same as `laplace.dense_fisher` but
uses the diagonal of the empirical Fisher information matrix instead.
- [`laplace.diag_ggn`](laplace/diag_ggn.md) same as `laplace.dense_ggn` but
Expand Down
7 changes: 7 additions & 0 deletions docs/api/laplace/dense_hessian.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
title: Laplace Dense Hessian
---

# Laplace Dense Hessian

::: posteriors.laplace.dense_hessian
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ nav:
- Laplace:
- Dense Fisher: api/laplace/dense_fisher.md
- Dense GGN: api/laplace/dense_ggn.md
- Dense Hessian: api/laplace/dense_hessian.md
- Diagonal Fisher: api/laplace/diag_fisher.md
- Diagonal GGN: api/laplace/diag_ggn.md
- SGMCMC:
Expand Down
1 change: 1 addition & 0 deletions posteriors/laplace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from posteriors.laplace import diag_fisher
from posteriors.laplace import dense_ggn
from posteriors.laplace import diag_ggn
from posteriors.laplace import dense_hessian
163 changes: 163 additions & 0 deletions posteriors/laplace/dense_hessian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from typing import Any, NamedTuple
from functools import partial
import torch
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.tree_utils import tree_size
from posteriors.utils import (
is_scalar,
CatchAuxError,
)

from torch.func import jacrev, jacfwd


def build(
log_posterior: LogProbFn,
init_prec: torch.Tensor | float = 0.0,
epsilon: float = 0.0,
rescale: float = 1.0,
) -> Transform:
"""Builds a transform for dense Hessian Laplace.
**Warning:**
The Hessian is not guaranteed to be positive definite,
so setting epsilon > 0 ought to be considered.
Args:
log_posterior: Function that takes parameters and input batch and
returns the log posterior value (which can be unnormalised)
as well as auxiliary information, e.g. from the model call.
init_prec: Initial precision matrix.
If it is a float, it is defined as an identity matrix
scaled by that float.
epsilon: Added to the diagonal of the Hessian
for numerical stability.
rescale: Value to multiply the Hessian by
(i.e. to normalize by batch size)
Returns:
Hessian Laplace transform instance.
"""
init_fn = partial(init, init_prec=init_prec)
update_fn = partial(
update,
log_posterior=log_posterior,
epsilon=epsilon,
rescale=rescale,
)
return Transform(init_fn, update_fn)


class DenseLaplaceState(NamedTuple):
"""State encoding a Normal distribution over parameters,
with a dense precision matrix
Attributes:
params: Mean of the Normal distribution.
prec: Precision matrix of the Normal distribution.
aux: Auxiliary information from the log_posterior call.
"""

params: TensorTree
prec: torch.Tensor
aux: Any = None


def init(
params: TensorTree,
init_prec: torch.Tensor | float = 0.0,
) -> DenseLaplaceState:
"""Initialise Normal distribution over parameters
with a dense precision matrix.
Args:
params: Mean of the Normal distribution.
init_prec: Initial precision matrix.
If it is a float, it is defined as an identity matrix
scaled by that float.
Returns:
Initial DenseLaplaceState.
"""

if is_scalar(init_prec):
num_params = tree_size(params)
init_prec = init_prec * torch.eye(num_params, requires_grad=False)

return DenseLaplaceState(params, init_prec)


def update(
state: DenseLaplaceState,
batch: Any,
log_posterior: LogProbFn,
epsilon: float = 0.0,
rescale: float = 1.0,
inplace: bool = False,
) -> DenseLaplaceState:
"""Adds the Hessian of the negative log-posterior over given batch.
**Warning:**
The Hessian is not guaranteed to be positive definite,
so setting epsilon > 0 ought to be considered.
Args:
state: Current state.
batch: Input data to log_posterior.
log_posterior: Function that takes parameters and input batch and
returns the log posterior value (which can be unnormalised)
epsilon: Added to the diagonal of the Hessian
for numerical stability.
rescale: Value to multiply the Hessian by
(i.e. to normalize by batch size)
inplace: If True, the state is updated in place. Otherwise, a new
state is returned.
Returns:
Updated DenseLaplaceState.
"""
with torch.no_grad(), CatchAuxError():
flat_params, params_unravel = tree_ravel(state.params)
num_params = flat_params.numel()

def neg_log_p(p_flat):
value, aux = log_posterior(params_unravel(p_flat), batch)
return -value, aux

hess, aux = jacfwd(jacrev(neg_log_p, has_aux=True), has_aux=True)(flat_params)
hess = hess * rescale + epsilon * torch.eye(num_params)

if inplace:
state.prec.data += hess
return state._replace(aux=aux)
else:
return DenseLaplaceState(state.params, state.prec + hess, aux)


def sample(
state: DenseLaplaceState,
sample_shape: torch.Size = torch.Size([]),
) -> TensorTree:
"""Sample from Normal distribution over parameters.
Args:
state: State encoding mean and precision matrix.
sample_shape: Shape of the desired samples.
Returns:
Sample(s) from the Normal distribution.
"""
samples = torch.distributions.MultivariateNormal(
loc=torch.zeros(state.prec.shape[0], device=state.prec.device),
precision_matrix=state.prec,
validate_args=False,
).sample(sample_shape)
samples = samples.flatten(end_dim=-2) # ensure samples is 2D
mean_flat, unravel_func = tree_ravel(state.params)
samples += mean_flat
samples = torch.vmap(unravel_func)(samples)
samples = tree_map(lambda x: x.reshape(sample_shape + x.shape[-1:]), samples)
return samples
108 changes: 108 additions & 0 deletions tests/laplace/test_dense_hessian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset
from torch.func import functional_call, hessian
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors import tree_size, diag_normal_log_prob
from posteriors.laplace import dense_hessian

from tests.scenarios import TestModel


def normal_log_likelihood(y, y_pred):
return (
Normal(y_pred, 1, validate_args=False).log_prob(y).sum(dim=-1)
) # validate args introduces control flows not yet supported in torch.func.vmap


def log_posterior_n(params, batch, model, n_data):
y_pred = functional_call(model, params, batch[0])
return diag_normal_log_prob(params, mean=0.0, sd_diag=1.0) + normal_log_likelihood(
batch[1], y_pred
) * n_data, torch.tensor([])


def test_dense_hessian_vmap():
torch.manual_seed(42)
model = TestModel()

xs = torch.randn(100, 10)
ys = model(xs)

batch_size = 2

dataloader = DataLoader(
TensorDataset(xs, ys),
batch_size=batch_size,
)

def log_posterior(p, b):
return log_posterior_n(p, b, model, len(xs))[0].mean(), torch.tensor([])

params = dict(model.named_parameters())

# Test inplace = False
transform = dense_hessian.build(log_posterior)
laplace_state = transform.init(params)
laplace_state_prec_init = laplace_state.prec
for batch in dataloader:
laplace_state = transform.update(
laplace_state, batch, rescale=batch_size / xs.size()[0], inplace=False
)

flat_params, params_unravel = tree_ravel(params)

num_params = tree_size(params)
expected = torch.zeros((num_params, num_params))
for x, y in zip(xs, ys):
with torch.no_grad():
hess = hessian(lambda p: -log_posterior(params_unravel(p), (x, y))[0])(
flat_params
)
expected += hess / xs.size()[0]

assert torch.allclose(expected, laplace_state.prec, atol=1e-5)
assert not torch.allclose(laplace_state.prec, laplace_state_prec_init)

# Also check full batch
laplace_state_fb = transform.init(params)
laplace_state_fb = transform.update(laplace_state_fb, (xs, ys))

assert torch.allclose(expected, laplace_state_fb.prec, atol=1e-5)

# Test inplace = True
transform = dense_hessian.build(log_posterior)
laplace_state = transform.init(params)
laplace_state_prec_diag_init = laplace_state.prec
for batch in dataloader:
laplace_state = transform.update(
laplace_state, batch, rescale=batch_size / xs.size()[0], inplace=True
)

assert torch.allclose(expected, laplace_state.prec, atol=1e-5)
assert torch.allclose(laplace_state.prec, laplace_state_prec_diag_init, atol=1e-5)

# Test sampling
num_samples = 10000
laplace_state.prec.data += 0.1 * torch.eye(
num_params
) # regularize to ensure PSD and reduce variance

mean_copy = tree_map(lambda x: x.clone(), laplace_state.params)
sd_flat = torch.diag(torch.linalg.inv(laplace_state.prec)).sqrt()

samples = dense_hessian.sample(laplace_state, (num_samples,))

samples_mean = tree_map(lambda x: x.mean(dim=0), samples)
samples_sd = tree_map(lambda x: x.std(dim=0), samples)
samples_sd_flat = tree_ravel(samples_sd)[0]

for key in samples_mean:
assert samples[key].shape[0] == num_samples
assert samples[key].shape[1:] == samples_mean[key].shape
assert torch.allclose(samples_mean[key], laplace_state.params[key], atol=1e-1)
assert torch.allclose(mean_copy[key], laplace_state.params[key])

assert torch.allclose(sd_flat, samples_sd_flat, atol=1e-1)

0 comments on commit 3505484

Please sign in to comment.