Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coupling layers #92

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
12 changes: 8 additions & 4 deletions flowtorch/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from flowtorch.bijectors.autoregressive import Autoregressive
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.compose import Compose
from flowtorch.bijectors.coupling import ConvCouplingBijector
from flowtorch.bijectors.coupling import CouplingBijector
from flowtorch.bijectors.elementwise import Elementwise
from flowtorch.bijectors.elu import ELU
from flowtorch.bijectors.exp import Exp
Expand All @@ -35,26 +37,28 @@
("Affine", Affine),
("AffineAutoregressive", AffineAutoregressive),
("AffineFixed", AffineFixed),
("Fixed", Fixed),
("ConvCouplingBijector", ConvCouplingBijector),
("CouplingBijector", CouplingBijector),
("ELU", ELU),
("Exp", Exp),
("LeakyReLU", LeakyReLU),
("Permute", Permute),
("VolumePreserving", VolumePreserving),
vmoens marked this conversation as resolved.
Show resolved Hide resolved
("Power", Power),
("Sigmoid", Sigmoid),
("Softplus", Softplus),
("Spline", Spline),
("SplineAutoregressive", SplineAutoregressive),
("Tanh", Tanh),
]

meta_bijectors = [
("Elementwise", Elementwise),
("Autoregressive", Autoregressive),
("Fixed", Fixed),
("Bijector", Bijector),
("Compose", Compose),
("Invert", Invert),
("VolumePreserving", VolumePreserving),
("Permute", Permute),
("SplineAutoregressive", SplineAutoregressive),
]


Expand Down
21 changes: 17 additions & 4 deletions flowtorch/bijectors/affine_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,28 @@ def __init__(
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:
super().__init__(
AffineOp.__init__(
vmoens marked this conversation as resolved.
Show resolved Hide resolved
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)
Autoregressive.__init__(
vmoens marked this conversation as resolved.
Show resolved Hide resolved
self,
params_fn,
shape=shape,
context_shape=context_shape,
)
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid_bias = sigmoid_bias
2 changes: 1 addition & 1 deletion flowtorch/bijectors/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def inverse(
# TODO: Make permutation, inverse work for other event shapes
log_detJ: Optional[torch.Tensor] = None
for idx in cast(torch.LongTensor, permutation):
_params = self._params_fn(x_new.clone(), context=context)
_params = self._params_fn(x_new.clone(), inverse=False, context=context)
x_temp, log_detJ = self._inverse(y, params=_params)
x_new[..., idx] = x_temp[..., idx]
# _log_detJ = out[1]
Expand Down
16 changes: 12 additions & 4 deletions flowtorch/bijectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def forward(
assert isinstance(x, BijectiveTensor)
return x.get_parent_from_bijector(self)

params = self._params_fn(x, context) if self._params_fn is not None else None
params = (
self._params_fn(x, inverse=False, context=context)
if self._params_fn is not None
else None
)
y, log_detJ = self._forward(x, params)
if (
is_record_flow_graph_enabled()
Expand Down Expand Up @@ -117,7 +121,11 @@ def inverse(
return y.get_parent_from_bijector(self)

# TODO: What to do in this line?
params = self._params_fn(x, context) if self._params_fn is not None else None
params = (
self._params_fn(y, inverse=True, context=context)
if self._params_fn is not None
else None
)
x, log_detJ = self._inverse(y, params)

if (
Expand Down Expand Up @@ -170,10 +178,10 @@ def log_abs_det_jacobian(
if ladj is None:
if is_record_flow_graph_enabled():
warnings.warn(
"Computing _log_abs_det_jacobian from values and not " "from cache."
"Computing _log_abs_det_jacobian from values and not from cache."
)
params = (
self._params_fn(x, context) if self._params_fn is not None else None
self._params_fn(x, y, context) if self._params_fn is not None else None
)
return self._log_abs_det_jacobian(x, y, params)
return ladj
Expand Down
131 changes: 131 additions & 0 deletions flowtorch/bijectors/coupling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc
from copy import deepcopy
from typing import Optional, Sequence, Tuple

import flowtorch.parameters
import torch
from flowtorch.bijectors.ops.affine import Affine as AffineOp
from flowtorch.parameters import ConvCoupling, DenseCoupling
from torch.distributions import constraints


_REAL3d = deepcopy(constraints.real)
_REAL3d.event_dim = 3

_REAL1d = deepcopy(constraints.real)
_REAL1d.event_dim = 1


class CouplingBijector(AffineOp):
"""
Examples:
>>> params = DenseCoupling()
>>> bij = CouplingBijector(params)
>>> bij = bij(shape=torch.Size([32,]))
>>> for p in bij.parameters():
... p.data += torch.randn_like(p)/10
>>> x = torch.randn(1, 32,requires_grad=True)
>>> y = bij.forward(x).detach_from_flow()
>>> x_bis = bij.inverse(y)
>>> torch.testing.assert_allclose(x, x_bis)
"""

domain: constraints.Constraint = _REAL1d
codomain: constraints.Constraint = _REAL1d

def __init__(
self,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:

if params_fn is None:
params_fn = DenseCoupling() # type: ignore

AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)

def _forward(
self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert self._params_fn is not None

y, ldj = super()._forward(x, params)
return y, ldj

def _inverse(
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert self._params_fn is not None

x, ldj = super()._inverse(y, params)
return x, ldj


class ConvCouplingBijector(CouplingBijector):
"""
Examples:
>>> params = ConvCoupling()
>>> bij = ConvCouplingBijector(params)
>>> bij = bij(shape=torch.Size([3,16,16]))
>>> for p in bij.parameters():
... p.data += torch.randn_like(p)/10
>>> x = torch.randn(4, 3, 16, 16)
>>> y = bij.forward(x)
>>> x_bis = bij.inverse(y.detach_from_flow())
>>> torch.testing.assert_allclose(x, x_bis)
"""

domain: constraints.Constraint = _REAL3d
codomain: constraints.Constraint = _REAL3d

def __init__(
self,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:

if not len(shape) == 3:
raise ValueError(f"Expected a 3d-tensor shape, got {shape}")

if params_fn is None:
params_fn = ConvCoupling() # type: ignore

AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)
76 changes: 59 additions & 17 deletions flowtorch/bijectors/ops/affine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
# Copyright (c) Meta Platforms, Inc

from typing import Optional, Sequence, Tuple
from typing import Callable, Dict, Optional, Sequence, Tuple

import flowtorch
import torch
from flowtorch.bijectors.base import Bijector
from flowtorch.ops import clamp_preserve_gradients
from torch.distributions.utils import _sum_rightmost

_DEFAULT_POSITIVE_BIASES = {
"softplus": 0.5413248538970947,
"exp": 0.0,
}

_POSITIVE_MAPS: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
"softplus": torch.nn.functional.softplus,
"sigmoid": torch.sigmoid,
"exp": torch.exp,
}


class Affine(Bijector):
r"""
Expand All @@ -22,38 +33,63 @@ def __init__(
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:
super().__init__(params_fn, shape=shape, context_shape=context_shape)
self.clamp_values = clamp_values
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid_bias = sigmoid_bias
if positive_bias is None:
positive_bias = _DEFAULT_POSITIVE_BIASES[positive_map]
self.positive_bias = positive_bias
if positive_map not in _POSITIVE_MAPS:
raise RuntimeError(f"Unknwon positive map {positive_map}")
self._positive_map = _POSITIVE_MAPS[positive_map]
self._exp_map = self._positive_map is torch.exp and self.positive_bias == 0

def positive_map(self, x: torch.Tensor) -> torch.Tensor:
return self._positive_map(x + self.positive_bias)

def _forward(
self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert params is not None

mean, log_scale = params
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
scale = torch.exp(log_scale)
mean, unbounded_scale = params
if self.clamp_values:
unbounded_scale = clamp_preserve_gradients(
unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
scale = self.positive_map(unbounded_scale)
log_scale = scale.log() if not self._exp_map else unbounded_scale
y = scale * x + mean
return y, _sum_rightmost(log_scale, self.domain.event_dim)

def _inverse(
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
assert params is not None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (
params is not None
), f"{self.__class__.__name__}._inverse got no parameters"

mean, log_scale = params
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
inverse_scale = torch.exp(-log_scale)
mean, unbounded_scale = params
if self.clamp_values:
unbounded_scale = clamp_preserve_gradients(
unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip
)

if not self._exp_map:
inverse_scale = self.positive_map(unbounded_scale).reciprocal()
log_scale = -inverse_scale.log()
else:
inverse_scale = torch.exp(-unbounded_scale)
log_scale = unbounded_scale
x_new = (y - mean) * inverse_scale
return x_new, _sum_rightmost(log_scale, self.domain.event_dim)

Expand All @@ -65,9 +101,15 @@ def _log_abs_det_jacobian(
) -> torch.Tensor:
assert params is not None

_, log_scale = params
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
_, unbounded_scale = params
if self.clamp_values:
unbounded_scale = clamp_preserve_gradients(
unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
log_scale = (
self.positive_map(unbounded_scale).log()
if not self._exp_map
else unbounded_scale
)
return _sum_rightmost(log_scale, self.domain.event_dim)

Expand Down
2 changes: 1 addition & 1 deletion flowtorch/distributions/flow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc

from typing import Any, Dict, Optional, Union, Iterator
from typing import Any, Dict, Iterator, Optional, Union

import flowtorch
import torch
Expand Down
Loading