Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coupling layers #92

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions flowtorch/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from flowtorch.bijectors.autoregressive import Autoregressive
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.compose import Compose
from flowtorch.bijectors.coupling import ConvCouplingBijector
from flowtorch.bijectors.coupling import CouplingBijector
from flowtorch.bijectors.elementwise import Elementwise
from flowtorch.bijectors.elu import ELU
from flowtorch.bijectors.exp import Exp
Expand All @@ -35,6 +37,8 @@
("Affine", Affine),
("AffineAutoregressive", AffineAutoregressive),
("AffineFixed", AffineFixed),
("ConvCouplingBijector", ConvCouplingBijector),
("CouplingBijector", CouplingBijector),
("ELU", ELU),
("Exp", Exp),
("LeakyReLU", LeakyReLU),
Expand Down
2 changes: 1 addition & 1 deletion flowtorch/bijectors/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def inverse(
# TODO: Make permutation, inverse work for other event shapes
log_detJ: Optional[torch.Tensor] = None
for idx in cast(torch.LongTensor, permutation):
_params = self._params_fn(x_new.clone(), context=context)
_params = self._params_fn(x_new.clone(), inverse=False, context=context)
x_temp, log_detJ = self._inverse(y, params=_params)
x_new[..., idx] = x_temp[..., idx]
# _log_detJ = out[1]
Expand Down
16 changes: 12 additions & 4 deletions flowtorch/bijectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def forward(
assert isinstance(x, BijectiveTensor)
return x.get_parent_from_bijector(self)

params = self._params_fn(x, context) if self._params_fn is not None else None
params = (
self._params_fn(x, inverse=False, context=context)
if self._params_fn is not None
else None
)
y, log_detJ = self._forward(x, params)
if (
is_record_flow_graph_enabled()
Expand Down Expand Up @@ -117,7 +121,11 @@ def inverse(
return y.get_parent_from_bijector(self)

# TODO: What to do in this line?
params = self._params_fn(x, context) if self._params_fn is not None else None
params = (
self._params_fn(y, inverse=True, context=context)
if self._params_fn is not None
else None
)
x, log_detJ = self._inverse(y, params)

if (
Expand Down Expand Up @@ -170,10 +178,10 @@ def log_abs_det_jacobian(
if ladj is None:
if is_record_flow_graph_enabled():
warnings.warn(
"Computing _log_abs_det_jacobian from values and not " "from cache."
"Computing _log_abs_det_jacobian from values and not from cache."
)
params = (
self._params_fn(x, context) if self._params_fn is not None else None
self._params_fn(x, y, context) if self._params_fn is not None else None
)
return self._log_abs_det_jacobian(x, y, params)
return ladj
Expand Down
131 changes: 131 additions & 0 deletions flowtorch/bijectors/coupling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc
from copy import deepcopy
from typing import Optional, Sequence, Tuple

import flowtorch.parameters
import torch
from flowtorch.bijectors.ops.affine import Affine as AffineOp
from flowtorch.parameters import ConvCoupling, DenseCoupling
from torch.distributions import constraints


_REAL3d = deepcopy(constraints.real)
_REAL3d.event_dim = 3

_REAL1d = deepcopy(constraints.real)
_REAL1d.event_dim = 1


class CouplingBijector(AffineOp):
"""
Examples:
>>> params = DenseCoupling()
>>> bij = CouplingBijector(params)
>>> bij = bij(shape=torch.Size([32,]))
>>> for p in bij.parameters():
... p.data += torch.randn_like(p)/10
>>> x = torch.randn(1, 32,requires_grad=True)
>>> y = bij.forward(x).detach_from_flow()
>>> x_bis = bij.inverse(y)
>>> torch.testing.assert_allclose(x, x_bis)
"""

domain: constraints.Constraint = _REAL1d
codomain: constraints.Constraint = _REAL1d

def __init__(
self,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:

if params_fn is None:
params_fn = DenseCoupling() # type: ignore

AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)

def _forward(
self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert self._params_fn is not None

y, ldj = super()._forward(x, params)
return y, ldj

def _inverse(
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert self._params_fn is not None

x, ldj = super()._inverse(y, params)
return x, ldj


class ConvCouplingBijector(CouplingBijector):
"""
Examples:
>>> params = ConvCoupling()
>>> bij = ConvCouplingBijector(params)
>>> bij = bij(shape=torch.Size([3,16,16]))
>>> for p in bij.parameters():
... p.data += torch.randn_like(p)/10
>>> x = torch.randn(4, 3, 16, 16)
>>> y = bij.forward(x)
>>> x_bis = bij.inverse(y.detach_from_flow())
>>> torch.testing.assert_allclose(x, x_bis)
"""

domain: constraints.Constraint = _REAL3d
codomain: constraints.Constraint = _REAL3d

def __init__(
self,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:

if not len(shape) == 3:
raise ValueError(f"Expected a 3d-tensor shape, got {shape}")

if params_fn is None:
params_fn = ConvCoupling() # type: ignore

AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)
10 changes: 9 additions & 1 deletion flowtorch/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
"""

from flowtorch.parameters.base import Parameters
from flowtorch.parameters.coupling import ConvCoupling
from flowtorch.parameters.coupling import DenseCoupling
from flowtorch.parameters.dense_autoregressive import DenseAutoregressive
from flowtorch.parameters.tensor import Tensor

__all__ = ["Parameters", "DenseAutoregressive", "Tensor"]
__all__ = [
"Parameters",
"ConvCoupling",
"DenseCoupling",
"DenseAutoregressive",
"Tensor",
]
8 changes: 5 additions & 3 deletions flowtorch/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@ def __init__(

def forward(
self,
x: Optional[torch.Tensor] = None,
input: torch.Tensor,
inverse: bool,
context: Optional[torch.Tensor] = None,
) -> Optional[Sequence[torch.Tensor]]:
# TODO: Caching etc.
return self._forward(x, context)
return self._forward(input, inverse=inverse, context=context)

def _forward(
self,
x: Optional[torch.Tensor] = None,
input: torch.Tensor,
inverse: bool,
context: Optional[torch.Tensor] = None,
) -> Optional[Sequence[torch.Tensor]]:
# I raise an exception rather than using @abstractmethod and
Expand Down
Loading