Skip to content

Commit

Permalink
Change TransformState to NamedTuple (#106)
Browse files Browse the repository at this point in the history
* Change TransformState to NamedTuple

* Change class type docstring from Args to Attributes

* Update inplace gotcha

* Update docstring
  • Loading branch information
SamDuffield authored Jul 24, 2024
1 parent a89667f commit e08e729
Show file tree
Hide file tree
Showing 22 changed files with 183 additions and 183 deletions.
2 changes: 1 addition & 1 deletion docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Here:
- `build` is a function that loads `config_args` into the `init` and `update` functions
and stores them within the `transform` instance. The `init` and `update`
functions then conform to a preset signature allowing for easy switching between algorithms.
- `state` is a [`dataclass`](https://docs.python.org/3/library/dataclasses.html)
- `state` is a [`NamedTuple`](https://docs.python.org/3/library/typing.html#typing.NamedTuple)
encoding the state of the algorithm, including `params` and `aux` attributes.
- `init` constructs the iteration-varying `state` based on the model parameters `params`.
- `update` updates the `state` based on a new `batch` of data.
Expand Down
23 changes: 23 additions & 0 deletions docs/gotchas.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,29 @@ state2 = transform.update(state, batch, inplace=True)
# state is updated and state2 is a pointer to state
```

When adding a new algorithm, in-place support can be achieved by modifying `TensorTree`s
via the [`flexi_tree_map`](https://normal-computing.github.io/posteriors/api/tree_utils/#posteriors.tree_utils.flexi_tree_map) function:

```python
from posteriors.tree_utils import flexi_tree_map

new_state = flexi_tree_map(lambda x: x + 1, state, inplace=True)
```

As `posteriors` transform states are immutable `NamedTuple`s, in-place modification of
`TensorTree` leaves can be achieved by modifying the data of the tensor directly with [`tree_insert_`](https://normal-computing.github.io/posteriors/api/tree_utils/#posteriors.tree_utils.tree_insert_):

```python
from posteriors.tree_utils import tree_insert_

tree_insert_(state.log_posterior, log_post.detach())
```

However, the `aux` component of the `TransformState` is not guaranteed to be a `TensorTree`,
and so in-place modification of `aux` is not supported. Using `state._replace(aux=aux)`
will return a state with all `TensorTree` pointing to the same memory as input `state`,
but with a new `aux` component (`aux` is not modified in the input `state` object).


## `torch.tensor` with autograd

Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/lightning_autoencoder.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
import torchopt
from dataclasses import asdict

import posteriors

Expand Down Expand Up @@ -100,7 +99,7 @@ class LitAutoEncoderUQ(L.LightningModule):
# it is independent of forward
self.state = self.transform.update(self.state, batch, inplace=True)
# Logging to TensorBoard (if installed) by default
for k, v in asdict(self.state).items():
for k, v in self.state._asdict().items():
if isinstance(v, float) or (isinstance(v, torch.Tensor) and v.numel() == 1):
self.log(k, v)

Expand Down
3 changes: 1 addition & 2 deletions examples/lightning_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torchvision.transforms import ToTensor
import lightning as L
import torchopt
from dataclasses import asdict

import posteriors

Expand Down Expand Up @@ -54,7 +53,7 @@ def training_step(self, batch, batch_idx):
# it is independent of forward
self.state = self.transform.update(self.state, batch, inplace=True)
# Logging to TensorBoard (if installed) by default
for k, v in asdict(self.state).items():
for k, v in self.state._asdict().items():
if isinstance(v, float) or (isinstance(v, torch.Tensor) and v.numel() == 1):
self.log(k, v)

Expand Down
24 changes: 11 additions & 13 deletions posteriors/ekf/dense_fisher.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Any
from typing import Any, NamedTuple
from functools import partial
import torch
from torch.func import grad_and_value
from dataclasses import dataclass
from optree.integration.torch import tree_ravel

from posteriors.tree_utils import tree_size
from posteriors.tree_utils import tree_size, tree_insert_

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.utils import (
per_samplify,
empirical_fisher,
Expand Down Expand Up @@ -67,11 +66,10 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class EKFDenseState(TransformState):
class EKFDenseState(NamedTuple):
"""State encoding a Normal distribution over parameters.
Args:
Attributes:
params: Mean of the Normal distribution.
cov: Covariance matrix of the
Normal distribution.
Expand All @@ -81,7 +79,7 @@ class EKFDenseState(TransformState):

params: TensorTree
cov: torch.Tensor
log_likelihood: float = 0
log_likelihood: torch.Tensor = torch.tensor([])
aux: Any = None


Expand Down Expand Up @@ -170,11 +168,11 @@ def log_likelihood_reduced(params, batch):
update_mean = mu_unravel_f(update_mean)

if inplace:
state.params = update_mean
state.cov = update_cov
state.log_likelihood = log_liks.mean().detach()
state.aux = aux
return state
tree_insert_(state.params, update_mean)
tree_insert_(state.cov, update_cov)
tree_insert_(state.log_likelihood, log_liks.mean().detach())
return state._replace(aux=aux)

return EKFDenseState(update_mean, update_cov, log_liks.mean().detach(), aux)


Expand Down
20 changes: 9 additions & 11 deletions posteriors/ekf/diag_fisher.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Any
from typing import Any, NamedTuple
from functools import partial
import torch
from torch.func import jacrev
from optree import tree_map
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.tree_utils import flexi_tree_map
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.tree_utils import flexi_tree_map, tree_insert_
from posteriors.utils import (
diag_normal_sample,
per_samplify,
Expand Down Expand Up @@ -68,11 +67,10 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class EKFDiagState(TransformState):
class EKFDiagState(NamedTuple):
"""State encoding a diagonal Normal distribution over parameters.
Args:
Attributes:
params: Mean of the Normal distribution.
sd_diag: Square-root diagonal of the covariance matrix of the
Normal distribution.
Expand All @@ -82,7 +80,7 @@ class EKFDiagState(TransformState):

params: TensorTree
sd_diag: TensorTree
log_likelihood: float = 0
log_likelihood: torch.Tensor = torch.tensor([])
aux: Any = None


Expand Down Expand Up @@ -176,9 +174,9 @@ def update(
)

if inplace:
state.log_likelihood = log_liks.mean().detach()
state.aux = aux
return state
tree_insert_(state.log_likelihood, log_liks.mean().detach())
return state._replace(aux=aux)

return EKFDiagState(update_mean, update_sd_diag, log_liks.mean().detach(), aux)


Expand Down
15 changes: 6 additions & 9 deletions posteriors/laplace/dense_fisher.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any
from dataclasses import dataclass
from typing import Any, NamedTuple
from functools import partial
import torch
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.tree_utils import tree_size
from posteriors.utils import (
per_samplify,
Expand Down Expand Up @@ -55,12 +54,11 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class DenseLaplaceState(TransformState):
class DenseLaplaceState(NamedTuple):
"""State encoding a Normal distribution over parameters,
with a dense precision matrix
Args:
Attributes:
params: Mean of the Normal distribution.
prec: Precision matrix of the Normal distribution.
aux: Auxiliary information from the log_posterior call.
Expand Down Expand Up @@ -130,9 +128,8 @@ def update(
)(state.params)

if inplace:
state.prec += fisher
state.aux = aux
return state
state.prec.data += fisher
return state._replace(aux=aux)
else:
return DenseLaplaceState(state.params, state.prec + fisher, aux)

Expand Down
14 changes: 5 additions & 9 deletions posteriors/laplace/dense_ggn.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from functools import partial
from typing import Any
from typing import Any, NamedTuple
import torch
from optree import tree_map
from dataclasses import dataclass
from optree.integration.torch import tree_ravel

from posteriors.types import (
TensorTree,
Transform,
ForwardFn,
OuterLogProbFn,
TransformState,
)
from posteriors.utils import (
tree_size,
Expand Down Expand Up @@ -67,12 +65,11 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class DenseLaplaceState(TransformState):
class DenseLaplaceState(NamedTuple):
"""State encoding a Normal distribution over parameters,
with a dense precision matrix
Args:
Attributes:
params: Mean of the Normal distribution.
prec: Precision matrix of the Normal distribution.
aux: Auxiliary information from the log_posterior call.
Expand Down Expand Up @@ -145,9 +142,8 @@ def outer_loss(z, batch):
)(state.params)

if inplace:
state.prec += ggn_batch
state.aux = aux
return state
state.prec.data += ggn_batch
return state._replace(aux=aux)
else:
return DenseLaplaceState(state.params, state.prec + ggn_batch, aux)

Expand Down
13 changes: 5 additions & 8 deletions posteriors/laplace/diag_fisher.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from functools import partial
from typing import Any
from typing import Any, NamedTuple
import torch
from torch.func import jacrev
from optree import tree_map
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.tree_utils import flexi_tree_map
from posteriors.utils import (
diag_normal_sample,
Expand Down Expand Up @@ -54,11 +53,10 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class DiagLaplaceState(TransformState):
class DiagLaplaceState(NamedTuple):
"""State encoding a diagonal Normal distribution over parameters.
Args:
Attributes:
params: Mean of the Normal distribution.
prec_diag: Diagonal of the precision matrix of the Normal distribution.
aux: Auxiliary information from the log_posterior call.
Expand Down Expand Up @@ -134,8 +132,7 @@ def update_func(x, y):
)

if inplace:
state.aux = aux
return state
return state._replace(aux=aux)
return DiagLaplaceState(state.params, prec_diag, aux)


Expand Down
12 changes: 4 additions & 8 deletions posteriors/laplace/diag_ggn.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from functools import partial
from typing import Any
from typing import Any, NamedTuple
import torch
from optree import tree_map
from dataclasses import dataclass

from posteriors.types import (
TensorTree,
Transform,
ForwardFn,
OuterLogProbFn,
TransformState,
)
from posteriors.tree_utils import flexi_tree_map
from posteriors.utils import (
Expand Down Expand Up @@ -66,11 +64,10 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class DiagLaplaceState(TransformState):
class DiagLaplaceState(NamedTuple):
"""State encoding a diagonal Normal distribution over parameters.
Args:
Attributes:
params: Mean of the Normal distribution.
prec_diag: Diagonal of the precision matrix of the Normal distribution.
aux: Auxiliary information from the log_posterior call.
Expand Down Expand Up @@ -149,8 +146,7 @@ def update_func(x, y):
)

if inplace:
state.aux = aux
return state
return state._replace(aux=aux)
return DiagLaplaceState(state.params, prec_diag, aux)


Expand Down
18 changes: 8 additions & 10 deletions posteriors/optim.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Type, Any
from typing import Type, Any, NamedTuple
from functools import partial
import torch
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.utils import CatchAuxError
from posteriors.tree_utils import tree_insert_


def build(
Expand Down Expand Up @@ -36,11 +36,10 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class OptimState(TransformState):
class OptimState(NamedTuple):
"""State of an optimizer from [torch.optim](https://pytorch.org/docs/stable/optim.html).
Args:
Attributes:
params: Parameters to be optimized.
optimizer: torch.optim optimizer instance.
loss: Loss value.
Expand All @@ -49,7 +48,7 @@ class OptimState(TransformState):

params: TensorTree
optimizer: torch.optim.Optimizer
loss: torch.tensor = None
loss: torch.tensor = torch.tensor([])
aux: Any = None


Expand Down Expand Up @@ -104,6 +103,5 @@ def update(
loss, aux = loss_fn(state.params, batch)
loss.backward()
state.optimizer.step()
state.loss = loss
state.aux = aux
return state
tree_insert_(state.loss, loss.detach())
return state._replace(aux=aux)
Loading

0 comments on commit e08e729

Please sign in to comment.