From 85ae279be3d327bf793a3815eadadd6678045172 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 9 Feb 2022 16:38:43 +0000 Subject: [PATCH 01/14] init --- flowtorch/bijectors/__init__.py | 2 + flowtorch/bijectors/affine_autoregressive.py | 21 +- flowtorch/bijectors/autoregressive.py | 2 +- flowtorch/bijectors/base.py | 12 +- flowtorch/bijectors/bijective_tensor.py | 2 +- flowtorch/bijectors/compose.py | 4 +- flowtorch/bijectors/coupling.py | 61 ++++++ flowtorch/bijectors/ops/affine.py | 73 +++++-- flowtorch/distributions/flow.py | 2 +- flowtorch/parameters/__init__.py | 3 +- flowtorch/parameters/base.py | 4 +- flowtorch/parameters/coupling.py | 204 +++++++++++++++++++ flowtorch/parameters/dense_autoregressive.py | 5 + flowtorch/parameters/tensor.py | 5 +- 14 files changed, 368 insertions(+), 32 deletions(-) create mode 100644 flowtorch/bijectors/coupling.py create mode 100644 flowtorch/parameters/coupling.py diff --git a/flowtorch/bijectors/__init__.py b/flowtorch/bijectors/__init__.py index e57fc400..c4c4d6c5 100644 --- a/flowtorch/bijectors/__init__.py +++ b/flowtorch/bijectors/__init__.py @@ -16,6 +16,7 @@ from flowtorch.bijectors.autoregressive import Autoregressive from flowtorch.bijectors.base import Bijector from flowtorch.bijectors.compose import Compose +from flowtorch.bijectors.coupling import Coupling from flowtorch.bijectors.elementwise import Elementwise from flowtorch.bijectors.elu import ELU from flowtorch.bijectors.exp import Exp @@ -33,6 +34,7 @@ standard_bijectors = [ ("Affine", Affine), ("AffineAutoregressive", AffineAutoregressive), + ("Coupling", Coupling), ("AffineFixed", AffineFixed), ("ELU", ELU), ("Exp", Exp), diff --git a/flowtorch/bijectors/affine_autoregressive.py b/flowtorch/bijectors/affine_autoregressive.py index 610e5477..a855cf5d 100644 --- a/flowtorch/bijectors/affine_autoregressive.py +++ b/flowtorch/bijectors/affine_autoregressive.py @@ -16,15 +16,28 @@ def __init__( *, shape: torch.Size, context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, log_scale_min_clip: float = -5.0, log_scale_max_clip: float = 3.0, sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, ) -> None: - super().__init__( + AffineOp.__init__( + self, + params_fn, + shape=shape, + context_shape=context_shape, + clamp_values=clamp_values, + log_scale_min_clip=log_scale_min_clip, + log_scale_max_clip=log_scale_max_clip, + sigmoid_bias=sigmoid_bias, + positive_map=positive_map, + positive_bias=positive_bias, + ) + Autoregressive.__init__( + self, params_fn, shape=shape, context_shape=context_shape, ) - self.log_scale_min_clip = log_scale_min_clip - self.log_scale_max_clip = log_scale_max_clip - self.sigmoid_bias = sigmoid_bias diff --git a/flowtorch/bijectors/autoregressive.py b/flowtorch/bijectors/autoregressive.py index 8367b51b..4d8c2d7a 100644 --- a/flowtorch/bijectors/autoregressive.py +++ b/flowtorch/bijectors/autoregressive.py @@ -60,7 +60,7 @@ def inverse( # TODO: Make permutation, inverse work for other event shapes log_detJ: Optional[torch.Tensor] = None for idx in cast(torch.LongTensor, permutation): - _params = self._params_fn(x_new.clone(), context=context) + _params = self._params_fn(x_new.clone(), None, context=context) x_temp, log_detJ = self._inverse(y, params=_params) x_new[..., idx] = x_temp[..., idx] # _log_detJ = out[1] diff --git a/flowtorch/bijectors/base.py b/flowtorch/bijectors/base.py index 6a388c83..55867640 100644 --- a/flowtorch/bijectors/base.py +++ b/flowtorch/bijectors/base.py @@ -1,11 +1,11 @@ # Copyright (c) Meta Platforms, Inc import warnings -from typing import Optional, Sequence, Tuple, Union, Callable, Iterator +from typing import Callable, Iterator, Optional, Sequence, Tuple, Union import flowtorch.parameters import torch import torch.distributions -from flowtorch.bijectors.bijective_tensor import to_bijective_tensor, BijectiveTensor +from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor from flowtorch.bijectors.utils import is_record_flow_graph_enabled from flowtorch.parameters import Parameters from torch.distributions import constraints @@ -75,7 +75,9 @@ def forward( assert isinstance(x, BijectiveTensor) return x.get_parent_from_bijector(self) - params = self._params_fn(x, context) if self._params_fn is not None else None + params = ( + self._params_fn(x, None, context) if self._params_fn is not None else None + ) y, log_detJ = self._forward(x, params) if ( is_record_flow_graph_enabled() @@ -119,7 +121,7 @@ def inverse( return y.get_parent_from_bijector(self) # TODO: What to do in this line? - params = self._params_fn(x, context) if self._params_fn is not None else None + params = self._params_fn(x, y, context) if self._params_fn is not None else None x, log_detJ = self._inverse(y, params) if ( @@ -173,7 +175,7 @@ def log_abs_det_jacobian( "Computing _log_abs_det_jacobian from values and not from cache." ) params = ( - self._params_fn(x, context) if self._params_fn is not None else None + self._params_fn(x, y, context) if self._params_fn is not None else None ) return self._log_abs_det_jacobian(x, y, params) return ladj diff --git a/flowtorch/bijectors/bijective_tensor.py b/flowtorch/bijectors/bijective_tensor.py index 8a3e6338..46a1b171 100644 --- a/flowtorch/bijectors/bijective_tensor.py +++ b/flowtorch/bijectors/bijective_tensor.py @@ -1,5 +1,5 @@ # Copyright (c) Meta Platforms, Inc -from typing import Any, Optional, Iterator, Type, TYPE_CHECKING, Union +from typing import Any, Iterator, Optional, Type, TYPE_CHECKING, Union if TYPE_CHECKING: from flowtorch.bijectors.base import Bijector diff --git a/flowtorch/bijectors/compose.py b/flowtorch/bijectors/compose.py index 5bc13317..a3caa285 100644 --- a/flowtorch/bijectors/compose.py +++ b/flowtorch/bijectors/compose.py @@ -1,11 +1,11 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional, Sequence, Iterator +from typing import Iterator, Optional, Sequence import flowtorch.parameters import torch import torch.distributions from flowtorch.bijectors.base import Bijector -from flowtorch.bijectors.bijective_tensor import to_bijective_tensor, BijectiveTensor +from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor from flowtorch.bijectors.utils import is_record_flow_graph_enabled, requires_log_detJ from torch.distributions.utils import _sum_rightmost diff --git a/flowtorch/bijectors/coupling.py b/flowtorch/bijectors/coupling.py new file mode 100644 index 00000000..c9a0e3d2 --- /dev/null +++ b/flowtorch/bijectors/coupling.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc + +from typing import Optional, Sequence, Tuple + +import flowtorch.parameters + +import torch +from flowtorch.bijectors.ops.affine import Affine as AffineOp +from flowtorch.parameters import DenseCoupling + + +class Coupling(AffineOp): + def __init__( + self, + params_fn: Optional[flowtorch.Lazy] = None, + *, + shape: torch.Size, + context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, + log_scale_min_clip: float = -5.0, + log_scale_max_clip: float = 3.0, + sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, + ) -> None: + + if params_fn is None: + params_fn = DenseCoupling() # type: ignore + + AffineOp.__init__( + self, + params_fn, + shape=shape, + context_shape=context_shape, + clamp_values=clamp_values, + log_scale_min_clip=log_scale_min_clip, + log_scale_max_clip=log_scale_max_clip, + sigmoid_bias=sigmoid_bias, + positive_map=positive_map, + positive_bias=positive_bias, + ) + + def _forward( + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert self._params_fn is not None + + x = x[..., self._params_fn.permutation] + y, ldj = super()._forward(x, params) + y = y[..., self._params_fn.inv_permutation] + return y, ldj + + def _inverse( + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert self._params_fn is not None + + y = y[..., self._params_fn.inv_permutation] + x, ldj = super()._inverse(y, params) + x = x[..., self._params_fn.permutation] + return x, ldj diff --git a/flowtorch/bijectors/ops/affine.py b/flowtorch/bijectors/ops/affine.py index d9cdf56f..4cee536c 100644 --- a/flowtorch/bijectors/ops/affine.py +++ b/flowtorch/bijectors/ops/affine.py @@ -1,13 +1,24 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple, Dict, Callable import flowtorch + import torch from flowtorch.bijectors.base import Bijector from flowtorch.ops import clamp_preserve_gradients from torch.distributions.utils import _sum_rightmost +_DEFAULT_POSITIVE_BIASES = { + "softplus": torch.expm1(torch.ones(1)).log().item(), + "exp": 0.0, +} +_POSITIVE_MAPS: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = { + "softplus": torch.nn.functional.softplus, + "sigmoid": torch.sigmoid, + "exp": torch.exp, +} + class Affine(Bijector): r""" @@ -22,38 +33,64 @@ def __init__( *, shape: torch.Size, context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, log_scale_min_clip: float = -5.0, log_scale_max_clip: float = 3.0, sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, ) -> None: super().__init__(params_fn, shape=shape, context_shape=context_shape) + self.clamp_values = clamp_values self.log_scale_min_clip = log_scale_min_clip self.log_scale_max_clip = log_scale_max_clip self.sigmoid_bias = sigmoid_bias + if positive_bias is None: + positive_bias = _DEFAULT_POSITIVE_BIASES[positive_map] + self.positive_bias = positive_bias + if positive_map not in _POSITIVE_MAPS: + raise RuntimeError(f"Unknwon positive map {positive_map}") + self._positive_map = _POSITIVE_MAPS[positive_map] + self._exp_map = self._positive_map is torch.exp and self.positive_bias == 0 + + def positive_map(self, x: torch.Tensor) -> torch.Tensor: + return self._positive_map(x + self.positive_bias) def _forward( self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor]: assert params is not None - mean, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) - scale = torch.exp(log_scale) + mean, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + scale = self.positive_map(unbounded_scale) + log_scale = scale.log() if not self._exp_map else unbounded_scale y = scale * x + mean return y, _sum_rightmost(log_scale, self.domain.event_dim) def _inverse( self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor]: - assert params is not None + assert ( + params is not None + ), f"{self.__class__.__name__}._inverse got no parameters" + + mean, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + + if not self._exp_map: + inverse_scale = self.positive_map(unbounded_scale).reciprocal() + log_scale = inverse_scale.log() + else: + inverse_scale = torch.exp(-unbounded_scale) + log_scale = unbounded_scale - mean, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) - inverse_scale = torch.exp(-log_scale) x_new = (y - mean) * inverse_scale return x_new, _sum_rightmost(log_scale, self.domain.event_dim) @@ -65,9 +102,15 @@ def _log_abs_det_jacobian( ) -> torch.Tensor: assert params is not None - _, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip + _, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + log_scale = ( + self.positive_map(unbounded_scale).log() + if not self._exp_map + else unbounded_scale ) return _sum_rightmost(log_scale, self.domain.event_dim) diff --git a/flowtorch/distributions/flow.py b/flowtorch/distributions/flow.py index bfb0e97d..6b7d6488 100644 --- a/flowtorch/distributions/flow.py +++ b/flowtorch/distributions/flow.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Any, Dict, Optional, Union, Iterator +from typing import Any, Dict, Iterator, Optional, Union import flowtorch import torch diff --git a/flowtorch/parameters/__init__.py b/flowtorch/parameters/__init__.py index 86f8045c..6ac0d040 100644 --- a/flowtorch/parameters/__init__.py +++ b/flowtorch/parameters/__init__.py @@ -7,7 +7,8 @@ """ from flowtorch.parameters.base import Parameters +from flowtorch.parameters.coupling import DenseCoupling from flowtorch.parameters.dense_autoregressive import DenseAutoregressive from flowtorch.parameters.tensor import Tensor -__all__ = ["Parameters", "DenseAutoregressive", "Tensor"] +__all__ = ["Parameters", "DenseAutoregressive", "Tensor", "DenseCoupling"] diff --git a/flowtorch/parameters/base.py b/flowtorch/parameters/base.py index 72e4b69f..62391825 100644 --- a/flowtorch/parameters/base.py +++ b/flowtorch/parameters/base.py @@ -25,14 +25,16 @@ def __init__( def forward( self, x: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: # TODO: Caching etc. - return self._forward(x, context) + return self._forward(x, y, context) def _forward( self, x: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: # I raise an exception rather than using @abstractmethod and diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py new file mode 100644 index 00000000..a4e3413c --- /dev/null +++ b/flowtorch/parameters/coupling.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc + +import warnings +from typing import Callable, Optional, Sequence + +import torch +import torch.nn as nn +from flowtorch.nn.made import create_mask, MaskedLinear +from flowtorch.parameters.base import Parameters + + +class DenseCoupling(Parameters): + autoregressive = True + + def __init__( + self, + param_shapes: Sequence[torch.Size], + input_shape: torch.Size, + context_shape: Optional[torch.Size], + *, + hidden_dims: Sequence[int] = (256, 256), + nonlinearity: Callable[[], nn.Module] = nn.ReLU, + permutation: Optional[torch.LongTensor] = None, + skip_connections: bool = False, + ) -> None: + super().__init__(param_shapes, input_shape, context_shape) + + # Check consistency of input_shape with param_shapes + # We need each param_shapes to match input_shape in + # its leftmost dimensions + for s in param_shapes: + assert len(s) >= len(input_shape) and s[: len(input_shape)] == input_shape + + self.hidden_dims = hidden_dims + self.nonlinearity = nonlinearity + self.skip_connections = skip_connections + self._build(input_shape, param_shapes, context_shape, permutation) + + def _build( + self, + input_shape: torch.Size, + param_shapes: Sequence[torch.Size], + context_shape: Optional[torch.Size], + permutation: Optional[torch.LongTensor], + ) -> None: + # Work out flattened input and output shapes + param_shapes_ = list(param_shapes) + input_dims = int(torch.sum(torch.tensor(input_shape)).int().item()) + self.input_dims = input_dims + if input_dims == 0: + input_dims = 1 # scalars represented by torch.Size([]) + if permutation is None: + # By default set a random permutation of variables, which is + # important for performance with multiple steps + permutation = torch.LongTensor( + torch.randperm(input_dims, device="cpu").to( + torch.LongTensor((1,)).device + ) + ) + else: + # The permutation is chosen by the user + permutation = torch.LongTensor(permutation) + + self.param_dims = [ + int(max(torch.prod(torch.tensor(s[len(input_shape) :])).item(), 1)) + for s in param_shapes_ + ] + + self.output_multiplier = sum(self.param_dims) + + if input_dims == 1: + warnings.warn( + "DenseAutoregressive input_dim = 1. " + "Consider using an affine transformation instead." + ) + + # Calculate the indices on the output corresponding to each parameter + # TODO: Is this logic correct??? + # ends = torch.cumsum( + # torch.tensor( + # [max(torch.prod(torch.tensor(s)).item(), 1) for s in param_shapes_] + # ), + # dim=0, + # ) + # starts = torch.cat((torch.zeros(1).type_as(ends), ends[:-1])) + # self.param_slices = [slice(s.item(), e.item()) for s, e in zip(starts, ends)] + + # Hidden dimension must be not less than the input otherwise it isn't + # possible to connect to the outputs correctly + for h in self.hidden_dims: + if h < input_dims: + raise ValueError( + "Hidden dimension must not be less than input dimension." + ) + + # TODO: Check that the permutation is valid for the input dimension! + # Implement ispermutation() that sorts permutation and checks whether it + # has all integers from 0, 1, ..., self.input_dims - 1 + self.register_buffer("permutation", permutation) + self.register_buffer("inv_permutation", permutation.argsort()) + + # Create masks + hidden_dims = self.hidden_dims + + # Create masked layers: + # input is [x1 ; 0] + # output is [0 ; mu2], [0 ; sig2] + mask_input = torch.ones(hidden_dims[0], input_dims) + self.x1_dim = x1_dim = input_dims // 2 + mask_input[:, x1_dim:] = 0.0 + + out_dims = input_dims * self.output_multiplier + mask_output = torch.ones(self.output_multiplier, input_dims, hidden_dims[-1]) + mask_output[:x1_dim] = 0.0 + mask_output = mask_output.view(-1, hidden_dims[-1]) + self._bias = nn.Parameter( + torch.zeros(self.output_multiplier, x1_dim, requires_grad=True) + ) + + layers = [ + MaskedLinear( + input_dims, # + context_dims, + hidden_dims[0], + mask_input, + ), + self.nonlinearity(), + ] + for i in range(1, len(hidden_dims)): + layers.extend( + [ + nn.Linear(hidden_dims[i - 1], hidden_dims[i]), + self.nonlinearity(), + ] + ) + layers.append( + MaskedLinear( + hidden_dims[-1], + out_dims, + mask_output, + bias=False, + ) + ) + + for l in layers[::2]: + l.weight.data.normal_(0.0, 1e-3) # type: ignore + if l.bias is not None: + l.bias.data.fill_(0.0) # type: ignore + + if self.skip_connections: + mask_skip = torch.ones(out_dims, input_dims) + mask_skip[:, input_dims // 2 :] = 0.0 + mask_skip[: mask_output // 2] = 0.0 + self.skip_layer = MaskedLinear( + input_dims, # + context_dims, + out_dims, + mask_skip, + bias=False, + ) + + self.layers = nn.Sequential(*layers) + + @property + def bias(self) -> torch.Tensor: + z = torch.zeros( + self.output_multiplier, + self.input_dims - self.x1_dim, + device=self._bias.device, + dtype=self._bias.dtype, + ) + return torch.cat([z, self._bias], -1).view(-1) + + def _forward( + self, + x: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ) -> Optional[Sequence[torch.Tensor]]: + if (x is None) and (y is None): + raise RuntimeError("Either x or y must be provided.") + elif x is None: + x = y[..., self.inv_permutation] # type: ignore + inverse = True + else: + x = x[..., self.permutation] # type: ignore + inverse = False + + if context is not None: + x_aug = torch.cat([context.expand((*x.shape[:-1], -1)), x], dim=-1) + else: + x_aug = x + + h = self.layers(x_aug) + self.bias + + # TODO: Get skip_layers working again! + if self.skip_connections: + h = h + self.skip_layer(x_aug) + + # Shape the output + h = h.view(*x.shape[:-1], self.output_multiplier, -1) + + result = h.unbind(-2) + perm = self.inv_permutation if inverse else self.permutation + result = tuple(r[..., perm] for r in result) + return result diff --git a/flowtorch/parameters/dense_autoregressive.py b/flowtorch/parameters/dense_autoregressive.py index 8110e5a6..e9043f23 100644 --- a/flowtorch/parameters/dense_autoregressive.py +++ b/flowtorch/parameters/dense_autoregressive.py @@ -45,6 +45,7 @@ def _build( ) -> None: # Work out flattened input and output shapes param_shapes_ = list(param_shapes) + # Why not just (sum(input_shape))? input_dims = int(torch.sum(torch.tensor(input_shape)).int().item()) if input_dims == 0: input_dims = 1 # scalars represented by torch.Size([]) @@ -60,6 +61,7 @@ def _build( # The permutation is chosen by the user permutation = torch.LongTensor(permutation) + # why not math.pod(s[len(input_shape):]), where math.prod([])=1? self.param_dims = [ int(max(torch.prod(torch.tensor(s[len(input_shape) :])).item(), 1)) for s in param_shapes_ @@ -141,11 +143,13 @@ def _build( ) ) + # Why not using regular sequential? self.layers = nn.ModuleList(layers) def _forward( self, x: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: assert x is not None @@ -161,6 +165,7 @@ def _forward( else: h = x + # Why not using regular sequential? for idx in range(len(self.layers) // 2): h = self.layers[2 * idx + 1](self.layers[2 * idx](h)) h = self.layers[-1](h) diff --git a/flowtorch/parameters/tensor.py b/flowtorch/parameters/tensor.py index 3de8680a..4eba7de0 100644 --- a/flowtorch/parameters/tensor.py +++ b/flowtorch/parameters/tensor.py @@ -22,6 +22,9 @@ def __init__( ) def _forward( - self, x: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None + self, + x: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: return list(self.params) From 96c039916e68a0bad6a0915c8f9477d33a5823d5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 9 Feb 2022 16:41:22 +0000 Subject: [PATCH 02/14] minor --- flowtorch/bijectors/ops/affine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flowtorch/bijectors/ops/affine.py b/flowtorch/bijectors/ops/affine.py index 4cee536c..06e008fe 100644 --- a/flowtorch/bijectors/ops/affine.py +++ b/flowtorch/bijectors/ops/affine.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional, Sequence, Tuple, Dict, Callable +from typing import Callable, Dict, Optional, Sequence, Tuple import flowtorch From 5d35d7e65c315a740abe3c7d5a65c5120cabcaf2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 16 Feb 2022 16:16:14 +0000 Subject: [PATCH 03/14] Parameters read 'input' and not x and/or y anymore --- flowtorch/bijectors/autoregressive.py | 2 +- flowtorch/bijectors/base.py | 10 +++- flowtorch/parameters/base.py | 10 ++-- flowtorch/parameters/coupling.py | 57 ++++++-------------- flowtorch/parameters/dense_autoregressive.py | 17 +++--- flowtorch/parameters/tensor.py | 4 +- 6 files changed, 40 insertions(+), 60 deletions(-) diff --git a/flowtorch/bijectors/autoregressive.py b/flowtorch/bijectors/autoregressive.py index 4d8c2d7a..8a8371f9 100644 --- a/flowtorch/bijectors/autoregressive.py +++ b/flowtorch/bijectors/autoregressive.py @@ -60,7 +60,7 @@ def inverse( # TODO: Make permutation, inverse work for other event shapes log_detJ: Optional[torch.Tensor] = None for idx in cast(torch.LongTensor, permutation): - _params = self._params_fn(x_new.clone(), None, context=context) + _params = self._params_fn(x_new.clone(), inverse=False, context=context) x_temp, log_detJ = self._inverse(y, params=_params) x_new[..., idx] = x_temp[..., idx] # _log_detJ = out[1] diff --git a/flowtorch/bijectors/base.py b/flowtorch/bijectors/base.py index 55867640..322e4a3d 100644 --- a/flowtorch/bijectors/base.py +++ b/flowtorch/bijectors/base.py @@ -76,7 +76,9 @@ def forward( return x.get_parent_from_bijector(self) params = ( - self._params_fn(x, None, context) if self._params_fn is not None else None + self._params_fn(x, inverse=False, context=context) + if self._params_fn is not None + else None ) y, log_detJ = self._forward(x, params) if ( @@ -121,7 +123,11 @@ def inverse( return y.get_parent_from_bijector(self) # TODO: What to do in this line? - params = self._params_fn(x, y, context) if self._params_fn is not None else None + params = ( + self._params_fn(y, inverse=True, context=context) + if self._params_fn is not None + else None + ) x, log_detJ = self._inverse(y, params) if ( diff --git a/flowtorch/parameters/base.py b/flowtorch/parameters/base.py index 62391825..f0526847 100644 --- a/flowtorch/parameters/base.py +++ b/flowtorch/parameters/base.py @@ -24,17 +24,17 @@ def __init__( def forward( self, - x: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, + input: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: # TODO: Caching etc. - return self._forward(x, y, context) + return self._forward(input, inverse, context) def _forward( self, - x: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, + input: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: # I raise an exception rather than using @abstractmethod and diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py index a4e3413c..92c7beac 100644 --- a/flowtorch/parameters/coupling.py +++ b/flowtorch/parameters/coupling.py @@ -1,11 +1,11 @@ # Copyright (c) Meta Platforms, Inc -import warnings from typing import Callable, Optional, Sequence import torch import torch.nn as nn -from flowtorch.nn.made import create_mask, MaskedLinear + +from flowtorch.nn.made import MaskedLinear from flowtorch.parameters.base import Parameters @@ -69,33 +69,10 @@ def _build( self.output_multiplier = sum(self.param_dims) if input_dims == 1: - warnings.warn( - "DenseAutoregressive input_dim = 1. " - "Consider using an affine transformation instead." + raise ValueError( + "Coupling input_dim = 1. Coupling transforms require at least two features." ) - # Calculate the indices on the output corresponding to each parameter - # TODO: Is this logic correct??? - # ends = torch.cumsum( - # torch.tensor( - # [max(torch.prod(torch.tensor(s)).item(), 1) for s in param_shapes_] - # ), - # dim=0, - # ) - # starts = torch.cat((torch.zeros(1).type_as(ends), ends[:-1])) - # self.param_slices = [slice(s.item(), e.item()) for s, e in zip(starts, ends)] - - # Hidden dimension must be not less than the input otherwise it isn't - # possible to connect to the outputs correctly - for h in self.hidden_dims: - if h < input_dims: - raise ValueError( - "Hidden dimension must not be less than input dimension." - ) - - # TODO: Check that the permutation is valid for the input dimension! - # Implement ispermutation() that sorts permutation and checks whether it - # has all integers from 0, 1, ..., self.input_dims - 1 self.register_buffer("permutation", permutation) self.register_buffer("inv_permutation", permutation.argsort()) @@ -171,32 +148,30 @@ def bias(self) -> torch.Tensor: def _forward( self, - x: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, + input: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: - if (x is None) and (y is None): - raise RuntimeError("Either x or y must be provided.") - elif x is None: - x = y[..., self.inv_permutation] # type: ignore - inverse = True + if inverse: + input = input[..., self.inv_permutation] # type: ignore else: - x = x[..., self.permutation] # type: ignore - inverse = False + input = input[..., self.permutation] # type: ignore if context is not None: - x_aug = torch.cat([context.expand((*x.shape[:-1], -1)), x], dim=-1) + input_aug = torch.cat( + [context.expand((*input.shape[:-1], -1)), input], dim=-1 + ) else: - x_aug = x + input_aug = input - h = self.layers(x_aug) + self.bias + h = self.layers(input_aug) + self.bias # TODO: Get skip_layers working again! if self.skip_connections: - h = h + self.skip_layer(x_aug) + h = h + self.skip_layer(input_aug) # Shape the output - h = h.view(*x.shape[:-1], self.output_multiplier, -1) + h = h.view(*input.shape[:-1], self.output_multiplier, -1) result = h.unbind(-2) perm = self.inv_permutation if inverse else self.permutation diff --git a/flowtorch/parameters/dense_autoregressive.py b/flowtorch/parameters/dense_autoregressive.py index e9043f23..70a432b9 100644 --- a/flowtorch/parameters/dense_autoregressive.py +++ b/flowtorch/parameters/dense_autoregressive.py @@ -148,22 +148,21 @@ def _build( def _forward( self, - x: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, + input: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: - assert x is not None - # Flatten x - batch_shape = x.shape[: len(x.shape) - len(self.input_shape)] + # Flatten input + batch_shape = input.shape[: len(input.shape) - len(self.input_shape)] if len(batch_shape) > 0: - x = x.reshape(batch_shape + (-1,)) + input = input.reshape(batch_shape + (-1,)) if context is not None: # TODO: Fix the following! - h = torch.cat([context.expand((x.shape[0], -1)), x], dim=-1) + h = torch.cat([context.expand((input.shape[0], -1)), input], dim=-1) else: - h = x + h = input # Why not using regular sequential? for idx in range(len(self.layers) // 2): @@ -172,7 +171,7 @@ def _forward( # TODO: Get skip_layers working again! # if self.skip_layer is not None: - # h = h + self.skip_layer(x) + # h = h + self.skip_layer(input) # Shape the output # h ~ (batch_dims * input_dims, total_params_per_dim) diff --git a/flowtorch/parameters/tensor.py b/flowtorch/parameters/tensor.py index 4eba7de0..213188cc 100644 --- a/flowtorch/parameters/tensor.py +++ b/flowtorch/parameters/tensor.py @@ -23,8 +23,8 @@ def __init__( def _forward( self, - x: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, + input: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: return list(self.params) From a6645710770a95160db330a6ed17de94f1753956 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 23 Feb 2022 21:19:58 +0000 Subject: [PATCH 04/14] bugfix --- flowtorch/bijectors/coupling.py | 4 ---- flowtorch/bijectors/ops/affine.py | 2 +- flowtorch/parameters/coupling.py | 30 +++++++++++++++++++++--------- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/flowtorch/bijectors/coupling.py b/flowtorch/bijectors/coupling.py index c9a0e3d2..97f486e3 100644 --- a/flowtorch/bijectors/coupling.py +++ b/flowtorch/bijectors/coupling.py @@ -45,9 +45,7 @@ def _forward( ) -> Tuple[torch.Tensor, torch.Tensor]: assert self._params_fn is not None - x = x[..., self._params_fn.permutation] y, ldj = super()._forward(x, params) - y = y[..., self._params_fn.inv_permutation] return y, ldj def _inverse( @@ -55,7 +53,5 @@ def _inverse( ) -> Tuple[torch.Tensor, torch.Tensor]: assert self._params_fn is not None - y = y[..., self._params_fn.inv_permutation] x, ldj = super()._inverse(y, params) - x = x[..., self._params_fn.permutation] return x, ldj diff --git a/flowtorch/bijectors/ops/affine.py b/flowtorch/bijectors/ops/affine.py index 06e008fe..642e89e2 100644 --- a/flowtorch/bijectors/ops/affine.py +++ b/flowtorch/bijectors/ops/affine.py @@ -13,6 +13,7 @@ "softplus": torch.expm1(torch.ones(1)).log().item(), "exp": 0.0, } + _POSITIVE_MAPS: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = { "softplus": torch.nn.functional.softplus, "sigmoid": torch.sigmoid, @@ -90,7 +91,6 @@ def _inverse( else: inverse_scale = torch.exp(-unbounded_scale) log_scale = unbounded_scale - x_new = (y - mean) * inverse_scale return x_new, _sum_rightmost(log_scale, self.domain.event_dim) diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py index 92c7beac..1c638091 100644 --- a/flowtorch/parameters/coupling.py +++ b/flowtorch/parameters/coupling.py @@ -87,8 +87,9 @@ def _build( mask_input[:, x1_dim:] = 0.0 out_dims = input_dims * self.output_multiplier - mask_output = torch.ones(self.output_multiplier, input_dims, hidden_dims[-1]) - mask_output[:x1_dim] = 0.0 + mask_output = torch.ones(self.output_multiplier, input_dims, hidden_dims[-1], dtype=torch.bool) + mask_output[:, :x1_dim] = 0.0 + mask_output_buffer = mask_output[0, :, 0] mask_output = mask_output.view(-1, hidden_dims[-1]) self._bias = nn.Parameter( torch.zeros(self.output_multiplier, x1_dim, requires_grad=True) @@ -135,6 +136,15 @@ def _build( ) self.layers = nn.Sequential(*layers) + self.register_buffer('mask_output', mask_output_buffer) + self._init_weights() + + def _init_weights(self) -> None: + for layer in self.modules(): + if hasattr(layer, 'weight'): + layer.weight.data.normal_(0.0, 1e-3) + if hasattr(layer, 'bias') and layer.bias is not None: + layer.bias.data.fill_(0.0) @property def bias(self) -> torch.Tensor: @@ -152,17 +162,18 @@ def _forward( inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: - if inverse: - input = input[..., self.inv_permutation] # type: ignore - else: - input = input[..., self.permutation] # type: ignore + # if inverse: + # input = input[..., self.inv_permutation] # type: ignore + # else: + input = input[..., self.permutation] # type: ignore + input_masked = input.masked_fill(self.mask_output, 0.0) if context is not None: input_aug = torch.cat( - [context.expand((*input.shape[:-1], -1)), input], dim=-1 + [context.expand((*input.shape[:-1], -1)), input_masked], dim=-1 ) else: - input_aug = input + input_aug = input_masked h = self.layers(input_aug) + self.bias @@ -174,6 +185,7 @@ def _forward( h = h.view(*input.shape[:-1], self.output_multiplier, -1) result = h.unbind(-2) - perm = self.inv_permutation if inverse else self.permutation + perm = self.inv_permutation + result = tuple(r.masked_fill(~self.mask_output.expand_as(r), 0.0) for r in result) result = tuple(r[..., perm] for r in result) return result From a85a0d205ace63ced6dfe341e59b0479508eb903 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 24 Feb 2022 16:23:53 +0000 Subject: [PATCH 05/14] simplify coupling params --- flowtorch/parameters/coupling.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py index 1c638091..067b7e1a 100644 --- a/flowtorch/parameters/coupling.py +++ b/flowtorch/parameters/coupling.py @@ -85,12 +85,15 @@ def _build( mask_input = torch.ones(hidden_dims[0], input_dims) self.x1_dim = x1_dim = input_dims // 2 mask_input[:, x1_dim:] = 0.0 + mask_input = mask_input[:, self.permutation] out_dims = input_dims * self.output_multiplier mask_output = torch.ones(self.output_multiplier, input_dims, hidden_dims[-1], dtype=torch.bool) mask_output[:, :x1_dim] = 0.0 - mask_output_buffer = mask_output[0, :, 0] + mask_output = mask_output[:, self.permutation] + mask_output_reg = mask_output[0, :, 0] mask_output = mask_output.view(-1, hidden_dims[-1]) + self._bias = nn.Parameter( torch.zeros(self.output_multiplier, x1_dim, requires_grad=True) ) @@ -125,18 +128,15 @@ def _build( l.bias.data.fill_(0.0) # type: ignore if self.skip_connections: - mask_skip = torch.ones(out_dims, input_dims) - mask_skip[:, input_dims // 2 :] = 0.0 - mask_skip[: mask_output // 2] = 0.0 self.skip_layer = MaskedLinear( input_dims, # + context_dims, out_dims, - mask_skip, + mask_output, bias=False, ) self.layers = nn.Sequential(*layers) - self.register_buffer('mask_output', mask_output_buffer) + self.register_buffer('mask_output', mask_output_reg.to(torch.bool)) self._init_weights() def _init_weights(self) -> None: @@ -162,10 +162,6 @@ def _forward( inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: - # if inverse: - # input = input[..., self.inv_permutation] # type: ignore - # else: - input = input[..., self.permutation] # type: ignore input_masked = input.masked_fill(self.mask_output, 0.0) if context is not None: @@ -185,7 +181,5 @@ def _forward( h = h.view(*input.shape[:-1], self.output_multiplier, -1) result = h.unbind(-2) - perm = self.inv_permutation result = tuple(r.masked_fill(~self.mask_output.expand_as(r), 0.0) for r in result) - result = tuple(r[..., perm] for r in result) return result From 7d4d9dc6836dbee30561effd7277a3fd185bdfbf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 24 Feb 2022 16:25:43 +0000 Subject: [PATCH 06/14] make mypy happy --- flowtorch/parameters/coupling.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py index 067b7e1a..694f5971 100644 --- a/flowtorch/parameters/coupling.py +++ b/flowtorch/parameters/coupling.py @@ -88,7 +88,9 @@ def _build( mask_input = mask_input[:, self.permutation] out_dims = input_dims * self.output_multiplier - mask_output = torch.ones(self.output_multiplier, input_dims, hidden_dims[-1], dtype=torch.bool) + mask_output = torch.ones( + self.output_multiplier, input_dims, hidden_dims[-1], dtype=torch.bool + ) mask_output[:, :x1_dim] = 0.0 mask_output = mask_output[:, self.permutation] mask_output_reg = mask_output[0, :, 0] @@ -136,15 +138,15 @@ def _build( ) self.layers = nn.Sequential(*layers) - self.register_buffer('mask_output', mask_output_reg.to(torch.bool)) + self.register_buffer("mask_output", mask_output_reg.to(torch.bool)) self._init_weights() def _init_weights(self) -> None: for layer in self.modules(): - if hasattr(layer, 'weight'): - layer.weight.data.normal_(0.0, 1e-3) - if hasattr(layer, 'bias') and layer.bias is not None: - layer.bias.data.fill_(0.0) + if hasattr(layer, "weight"): + layer.weight.data.normal_(0.0, 1e-3) # type: ignore + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data.fill_(0.0) # type: ignore @property def bias(self) -> torch.Tensor: @@ -163,7 +165,7 @@ def _forward( context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: - input_masked = input.masked_fill(self.mask_output, 0.0) + input_masked = input.masked_fill(self.mask_output, 0.0) # type: ignore if context is not None: input_aug = torch.cat( [context.expand((*input.shape[:-1], -1)), input_masked], dim=-1 @@ -181,5 +183,7 @@ def _forward( h = h.view(*input.shape[:-1], self.output_multiplier, -1) result = h.unbind(-2) - result = tuple(r.masked_fill(~self.mask_output.expand_as(r), 0.0) for r in result) + result = tuple( + r.masked_fill(~self.mask_output.expand_as(r), 0.0) for r in result # type: ignore + ) return result From 8855f964faba6d50da08cfeec99f12a564ce7f8a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 24 Feb 2022 18:12:44 +0000 Subject: [PATCH 07/14] some exampled --- flowtorch/bijectors/__init__.py | 5 +- flowtorch/bijectors/coupling.py | 76 +++++++++++++- flowtorch/parameters/__init__.py | 10 +- flowtorch/parameters/coupling.py | 174 +++++++++++++++++++++++++++++-- 4 files changed, 248 insertions(+), 17 deletions(-) diff --git a/flowtorch/bijectors/__init__.py b/flowtorch/bijectors/__init__.py index c4c4d6c5..1ffd19b3 100644 --- a/flowtorch/bijectors/__init__.py +++ b/flowtorch/bijectors/__init__.py @@ -16,7 +16,7 @@ from flowtorch.bijectors.autoregressive import Autoregressive from flowtorch.bijectors.base import Bijector from flowtorch.bijectors.compose import Compose -from flowtorch.bijectors.coupling import Coupling +from flowtorch.bijectors.coupling import ConvCouplingBijector, CouplingBijector from flowtorch.bijectors.elementwise import Elementwise from flowtorch.bijectors.elu import ELU from flowtorch.bijectors.exp import Exp @@ -34,7 +34,8 @@ standard_bijectors = [ ("Affine", Affine), ("AffineAutoregressive", AffineAutoregressive), - ("Coupling", Coupling), + ("CouplingBijector", CouplingBijector), + ("ConvCouplingBijector", ConvCouplingBijector), ("AffineFixed", AffineFixed), ("ELU", ELU), ("Exp", Exp), diff --git a/flowtorch/bijectors/coupling.py b/flowtorch/bijectors/coupling.py index 97f486e3..dcac0861 100644 --- a/flowtorch/bijectors/coupling.py +++ b/flowtorch/bijectors/coupling.py @@ -1,15 +1,34 @@ # Copyright (c) Meta Platforms, Inc - +from copy import deepcopy from typing import Optional, Sequence, Tuple import flowtorch.parameters import torch from flowtorch.bijectors.ops.affine import Affine as AffineOp -from flowtorch.parameters import DenseCoupling +from flowtorch.parameters import ConvCoupling, DenseCoupling + +from torch.distributions import constraints + + +_REAL3d = deepcopy(constraints.real) +_REAL3d.event_dim = 3 + +class CouplingBijector(AffineOp): + """ + Examples: + >>> params = DenseCoupling() + >>> bij = CouplingBijector(params) + >>> bij = bij(shape=torch.Size([32,])) + >>> for p in bij.parameters(): + ... p.data += torch.randn_like(p)/10 + >>> x = torch.randn(1, 32,requires_grad=True) + >>> y = bij.forward(x).detach_from_flow() + >>> x_bis = bij.inverse(y) + >>> torch.testing.assert_allclose(x, x_bis) + """ -class Coupling(AffineOp): def __init__( self, params_fn: Optional[flowtorch.Lazy] = None, @@ -55,3 +74,54 @@ def _inverse( x, ldj = super()._inverse(y, params) return x, ldj + + +class ConvCouplingBijector(CouplingBijector): + """ + Examples: + >>> params = ConvCoupling() + >>> bij = ConvCouplingBijector(params) + >>> bij = bij(shape=torch.Size([3,16,16])) + >>> for p in bij.parameters(): + ... p.data += torch.randn_like(p)/10 + >>> x = torch.randn(4, 3, 16, 16) + >>> y = bij.forward(x) + >>> x_bis = bij.inverse(y.detach_from_flow()) + >>> torch.testing.assert_allclose(x, x_bis) + """ + + domain: constraints.Constraint = _REAL3d + codomain: constraints.Constraint = _REAL3d + + def __init__( + self, + params_fn: Optional[flowtorch.Lazy] = None, + *, + shape: torch.Size, + context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, + log_scale_min_clip: float = -5.0, + log_scale_max_clip: float = 3.0, + sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, + ) -> None: + + if not len(shape) == 3: + raise ValueError(f"Expected a 3d-tensor shape, got {shape}") + + if params_fn is None: + params_fn = ConvCoupling() # type: ignore + + AffineOp.__init__( + self, + params_fn, + shape=shape, + context_shape=context_shape, + clamp_values=clamp_values, + log_scale_min_clip=log_scale_min_clip, + log_scale_max_clip=log_scale_max_clip, + sigmoid_bias=sigmoid_bias, + positive_map=positive_map, + positive_bias=positive_bias, + ) diff --git a/flowtorch/parameters/__init__.py b/flowtorch/parameters/__init__.py index 6ac0d040..09c81fd3 100644 --- a/flowtorch/parameters/__init__.py +++ b/flowtorch/parameters/__init__.py @@ -7,8 +7,14 @@ """ from flowtorch.parameters.base import Parameters -from flowtorch.parameters.coupling import DenseCoupling +from flowtorch.parameters.coupling import ConvCoupling, DenseCoupling from flowtorch.parameters.dense_autoregressive import DenseAutoregressive from flowtorch.parameters.tensor import Tensor -__all__ = ["Parameters", "DenseAutoregressive", "Tensor", "DenseCoupling"] +__all__ = [ + "Parameters", + "DenseAutoregressive", + "Tensor", + "DenseCoupling", + "ConvCoupling", +] diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py index 694f5971..6e27d83a 100644 --- a/flowtorch/parameters/coupling.py +++ b/flowtorch/parameters/coupling.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Callable, Optional, Sequence +from typing import Callable, Iterable, Optional, Sequence import torch import torch.nn as nn @@ -9,8 +9,25 @@ from flowtorch.parameters.base import Parameters +def _make_mask(shape: torch.Size, mask_type: str) -> torch.Tensor: + if mask_type.startswith("neg_"): + return _make_mask(shape, mask_type[4:]) + elif mask_type == "chessboard": + z = torch.zeros(shape, dtype=torch.bool) + z[:, ::2, ::2] = 1 + z[:, 1::2, 1::2] = 1 + return z + elif mask_type == "quadrant": + z = torch.zeros(shape, dtype=torch.bool) + z[:, shape[1] // 2 :, : shape[2] // 2] = 1 + z[:, : shape[1] // 2, shape[2] // 2 :] = 1 + return z + else: + raise NotImplementedError(shape) + + class DenseCoupling(Parameters): - autoregressive = True + autoregressive = False def __init__( self, @@ -43,15 +60,15 @@ def _build( context_shape: Optional[torch.Size], permutation: Optional[torch.LongTensor], ) -> None: + # Work out flattened input and output shapes param_shapes_ = list(param_shapes) - input_dims = int(torch.sum(torch.tensor(input_shape)).int().item()) + input_dims = sum(input_shape) self.input_dims = input_dims if input_dims == 0: input_dims = 1 # scalars represented by torch.Size([]) if permutation is None: - # By default set a random permutation of variables, which is - # important for performance with multiple steps + # permutation will define the split of the input permutation = torch.LongTensor( torch.randperm(input_dims, device="cpu").to( torch.LongTensor((1,)).device @@ -124,11 +141,6 @@ def _build( ) ) - for l in layers[::2]: - l.weight.data.normal_(0.0, 1e-3) # type: ignore - if l.bias is not None: - l.bias.data.fill_(0.0) # type: ignore - if self.skip_connections: self.skip_layer = MaskedLinear( input_dims, # + context_dims, @@ -187,3 +199,145 @@ def _forward( r.masked_fill(~self.mask_output.expand_as(r), 0.0) for r in result # type: ignore ) return result + + +class ConvCoupling(Parameters): + autoregressive = False + _mask_types = ["chessboard", "quadrants"] + + def __init__( + self, + param_shapes: Sequence[torch.Size], + input_shape: torch.Size, + context_shape: Optional[torch.Size], + *, + cnn_activate_input: bool = True, + cnn_channels: int = 256, + cnn_kernel: Sequence[int] = None, + cnn_padding: Sequence[int] = None, + cnn_stride: Sequence[int] = None, + nonlinearity: Callable[[], nn.Module] = nn.ReLU, + skip_connections: bool = False, + mask_type: str = "chessboard", + ) -> None: + super().__init__(param_shapes, input_shape, context_shape) + + # Check consistency of input_shape with param_shapes + # We need each param_shapes to match input_shape in + # its leftmost dimensions + for s in param_shapes: + assert len(s) >= len(input_shape) and s[: len(input_shape)] == input_shape + + if cnn_kernel is None: + cnn_kernel = [3, 1, 3] + if cnn_padding is None: + cnn_padding = [1, 0, 1] + if cnn_stride is None: + cnn_stride = [1, 1, 1] + + self.cnn_channels = cnn_channels + self.cnn_activate_input = cnn_activate_input + self.cnn_kernel = cnn_kernel + self.cnn_padding = cnn_padding + self.cnn_stride = cnn_stride + + self.nonlinearity = nonlinearity + self.skip_connections = skip_connections + self._build(input_shape, param_shapes, context_shape, mask_type) + + def _build( + self, + input_shape: torch.Size, # something like [C, W, H] + param_shapes: Sequence[torch.Size], # something like [[C, W, H], [C, W, H]] + context_shape: Optional[torch.Size], + mask_type: str, + ) -> None: + + mask = _make_mask(input_shape, mask_type) + self.register_buffer("mask", mask) + self.output_multiplier = len(param_shapes) + + out_channels, width, height = input_shape + + layers = [] + if self.cnn_activate_input: + layers.append(self.nonlinearity()) + layers.append( + nn.LazyConv2d( + out_channels=self.cnn_channels, + kernel_size=self.cnn_kernel[0], + padding=self.cnn_padding[0], + stride=self.cnn_stride[0], + ) + ) + layers.append(self.nonlinearity()) + layers.append( + nn.Conv2d( + in_channels=self.cnn_channels, + out_channels=self.cnn_channels, + kernel_size=self.cnn_kernel[1], + padding=self.cnn_padding[1], + stride=self.cnn_stride[1], + ) + ) + layers.append(self.nonlinearity()) + layers.append( + nn.Conv2d( + in_channels=self.cnn_channels, + out_channels=out_channels * self.output_multiplier, + kernel_size=self.cnn_kernel[2], + padding=self.cnn_padding[2], + stride=self.cnn_stride[2], + ) + ) + + self.layers = nn.Sequential(*layers) + self._init_weights() + + def _init_weights(self) -> None: + for layer in self.modules(): + if hasattr(layer, "weight"): + layer.weight.data.normal_(0.0, 1e-3) # type: ignore + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data.fill_(0.0) # type: ignore + + def _forward( + self, + input: torch.Tensor, + inverse: bool, + context: Optional[torch.Tensor] = None, + ) -> Optional[Sequence[torch.Tensor]]: + + unsqueeze = False + if input.ndimension() == 3: + # mostly for initialization + unsqueeze = True + input = input.unsqueeze(0) + + input_masked = input.masked_fill(self.mask, 0.0) # type: ignore + if context is not None: + context_shape = [shape for shape in input_masked.shape] + context_shape[-3] = context.shape[-3] + input_aug = torch.cat( + [context.expand(*context_shape), input_masked], dim=-1 + ) + else: + input_aug = input_masked + + print(self.layers) + h = self.layers(input_aug) + + if self.skip_connections: + h = h + input_masked + + # Shape the output + + if unsqueeze: + h = h.squeeze(0) + result = h.chunk(2, -3) + + result = tuple( + r.masked_fill(~self.mask.expand_as(r), 0.0) for r in result # type: ignore + ) + + return result From ab0cc9e8cdea11ccb7013403d6d513ec1a0a33dc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 2 Mar 2022 14:18:16 +0000 Subject: [PATCH 08/14] usort --- flowtorch/parameters/coupling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py index 6e27d83a..5dc7d70e 100644 --- a/flowtorch/parameters/coupling.py +++ b/flowtorch/parameters/coupling.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn - from flowtorch.nn.made import MaskedLinear from flowtorch.parameters.base import Parameters @@ -203,7 +202,7 @@ def _forward( class ConvCoupling(Parameters): autoregressive = False - _mask_types = ["chessboard", "quadrants"] + _mask_types = ["chessboard", "quadrants", "inv_chessboard", "inv_quadrants"] def __init__( self, From 9eec8550c41b9bd54b34de587080976feee4a58a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 21 Mar 2022 11:41:28 +0000 Subject: [PATCH 09/14] minor --- flowtorch/bijectors/coupling.py | 2 -- flowtorch/bijectors/ops/affine.py | 1 - 2 files changed, 3 deletions(-) diff --git a/flowtorch/bijectors/coupling.py b/flowtorch/bijectors/coupling.py index dcac0861..43c4128c 100644 --- a/flowtorch/bijectors/coupling.py +++ b/flowtorch/bijectors/coupling.py @@ -3,11 +3,9 @@ from typing import Optional, Sequence, Tuple import flowtorch.parameters - import torch from flowtorch.bijectors.ops.affine import Affine as AffineOp from flowtorch.parameters import ConvCoupling, DenseCoupling - from torch.distributions import constraints diff --git a/flowtorch/bijectors/ops/affine.py b/flowtorch/bijectors/ops/affine.py index 642e89e2..5684fd18 100644 --- a/flowtorch/bijectors/ops/affine.py +++ b/flowtorch/bijectors/ops/affine.py @@ -3,7 +3,6 @@ from typing import Callable, Dict, Optional, Sequence, Tuple import flowtorch - import torch from flowtorch.bijectors.base import Bijector from flowtorch.ops import clamp_preserve_gradients From b7a2722c92b727707a57e17af63ed2b1d139ae8a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 21 Apr 2022 15:24:57 +0100 Subject: [PATCH 10/14] amend --- flowtorch/bijectors/__init__.py | 7 +++--- flowtorch/bijectors/coupling.py | 6 +++++ flowtorch/parameters/__init__.py | 7 +++--- flowtorch/parameters/base.py | 2 +- flowtorch/parameters/coupling.py | 41 ++++++++++++++++++++++---------- 5 files changed, 44 insertions(+), 19 deletions(-) diff --git a/flowtorch/bijectors/__init__.py b/flowtorch/bijectors/__init__.py index 1ffd19b3..b68d9992 100644 --- a/flowtorch/bijectors/__init__.py +++ b/flowtorch/bijectors/__init__.py @@ -16,7 +16,8 @@ from flowtorch.bijectors.autoregressive import Autoregressive from flowtorch.bijectors.base import Bijector from flowtorch.bijectors.compose import Compose -from flowtorch.bijectors.coupling import ConvCouplingBijector, CouplingBijector +from flowtorch.bijectors.coupling import ConvCouplingBijector +from flowtorch.bijectors.coupling import CouplingBijector from flowtorch.bijectors.elementwise import Elementwise from flowtorch.bijectors.elu import ELU from flowtorch.bijectors.exp import Exp @@ -34,9 +35,9 @@ standard_bijectors = [ ("Affine", Affine), ("AffineAutoregressive", AffineAutoregressive), - ("CouplingBijector", CouplingBijector), - ("ConvCouplingBijector", ConvCouplingBijector), ("AffineFixed", AffineFixed), + ("ConvCouplingBijector", ConvCouplingBijector), + ("CouplingBijector", CouplingBijector), ("ELU", ELU), ("Exp", Exp), ("LeakyReLU", LeakyReLU), diff --git a/flowtorch/bijectors/coupling.py b/flowtorch/bijectors/coupling.py index 43c4128c..95b8f2f0 100644 --- a/flowtorch/bijectors/coupling.py +++ b/flowtorch/bijectors/coupling.py @@ -12,6 +12,9 @@ _REAL3d = deepcopy(constraints.real) _REAL3d.event_dim = 3 +_REAL1d = deepcopy(constraints.real) +_REAL1d.event_dim = 1 + class CouplingBijector(AffineOp): """ @@ -27,6 +30,9 @@ class CouplingBijector(AffineOp): >>> torch.testing.assert_allclose(x, x_bis) """ + domain: constraints.Constraint = _REAL1d + codomain: constraints.Constraint = _REAL1d + def __init__( self, params_fn: Optional[flowtorch.Lazy] = None, diff --git a/flowtorch/parameters/__init__.py b/flowtorch/parameters/__init__.py index 09c81fd3..f1c097ab 100644 --- a/flowtorch/parameters/__init__.py +++ b/flowtorch/parameters/__init__.py @@ -7,14 +7,15 @@ """ from flowtorch.parameters.base import Parameters -from flowtorch.parameters.coupling import ConvCoupling, DenseCoupling +from flowtorch.parameters.coupling import ConvCoupling +from flowtorch.parameters.coupling import DenseCoupling from flowtorch.parameters.dense_autoregressive import DenseAutoregressive from flowtorch.parameters.tensor import Tensor __all__ = [ "Parameters", + "ConvCoupling", + "DenseCoupling", "DenseAutoregressive", "Tensor", - "DenseCoupling", - "ConvCoupling", ] diff --git a/flowtorch/parameters/base.py b/flowtorch/parameters/base.py index f0526847..5d0a74cc 100644 --- a/flowtorch/parameters/base.py +++ b/flowtorch/parameters/base.py @@ -29,7 +29,7 @@ def forward( context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: # TODO: Caching etc. - return self._forward(input, inverse, context) + return self._forward(input, inverse=inverse, context=context) def _forward( self, diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py index 5dc7d70e..d38ecb7d 100644 --- a/flowtorch/parameters/coupling.py +++ b/flowtorch/parameters/coupling.py @@ -1,9 +1,10 @@ # Copyright (c) Meta Platforms, Inc -from typing import Callable, Iterable, Optional, Sequence +from typing import Callable, Optional, Sequence import torch import torch.nn as nn + from flowtorch.nn.made import MaskedLinear from flowtorch.parameters.base import Parameters @@ -45,7 +46,9 @@ def __init__( # We need each param_shapes to match input_shape in # its leftmost dimensions for s in param_shapes: - assert len(s) >= len(input_shape) and s[: len(input_shape)] == input_shape + assert (len(s) >= len(input_shape)) and ( + s[: len(input_shape)] == input_shape + ) self.hidden_dims = hidden_dims self.nonlinearity = nonlinearity @@ -86,7 +89,8 @@ def _build( if input_dims == 1: raise ValueError( - "Coupling input_dim = 1. Coupling transforms require at least two features." + "Coupling input_dim = 1. Coupling transforms require at least " + "two features." ) self.register_buffer("permutation", permutation) @@ -105,7 +109,10 @@ def _build( out_dims = input_dims * self.output_multiplier mask_output = torch.ones( - self.output_multiplier, input_dims, hidden_dims[-1], dtype=torch.bool + self.output_multiplier, + input_dims, + hidden_dims[-1], + dtype=torch.bool, ) mask_output[:, :x1_dim] = 0.0 mask_output = mask_output[:, self.permutation] @@ -171,11 +178,12 @@ def bias(self) -> torch.Tensor: def _forward( self, - input: torch.Tensor, + *input: torch.Tensor, inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: + input = input[0] input_masked = input.masked_fill(self.mask_output, 0.0) # type: ignore if context is not None: input_aug = torch.cat( @@ -195,14 +203,20 @@ def _forward( result = h.unbind(-2) result = tuple( - r.masked_fill(~self.mask_output.expand_as(r), 0.0) for r in result # type: ignore + r.masked_fill(~self.mask_output.expand_as(r), 0.0) + for r in result # type: ignore ) return result class ConvCoupling(Parameters): autoregressive = False - _mask_types = ["chessboard", "quadrants", "inv_chessboard", "inv_quadrants"] + _mask_types = [ + "chessboard", + "quadrants", + "inv_chessboard", + "inv_quadrants", + ] def __init__( self, @@ -225,7 +239,9 @@ def __init__( # We need each param_shapes to match input_shape in # its leftmost dimensions for s in param_shapes: - assert len(s) >= len(input_shape) and s[: len(input_shape)] == input_shape + assert (len(s) >= len(input_shape)) and ( + s[: len(input_shape)] == input_shape + ) if cnn_kernel is None: cnn_kernel = [3, 1, 3] @@ -247,7 +263,7 @@ def __init__( def _build( self, input_shape: torch.Size, # something like [C, W, H] - param_shapes: Sequence[torch.Size], # something like [[C, W, H], [C, W, H]] + param_shapes: Sequence[torch.Size], # [[C, W, H], [C, W, H]] context_shape: Optional[torch.Size], mask_type: str, ) -> None: @@ -302,11 +318,12 @@ def _init_weights(self) -> None: def _forward( self, - input: torch.Tensor, + *input: torch.Tensor, inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: + input = input[0] unsqueeze = False if input.ndimension() == 3: # mostly for initialization @@ -323,7 +340,6 @@ def _forward( else: input_aug = input_masked - print(self.layers) h = self.layers(input_aug) if self.skip_connections: @@ -336,7 +352,8 @@ def _forward( result = h.chunk(2, -3) result = tuple( - r.masked_fill(~self.mask.expand_as(r), 0.0) for r in result # type: ignore + r.masked_fill(~self.mask.expand_as(r), 0.0) + for r in result # type: ignore ) return result From 60fa3679b50532976adbe9a9269f6117da6db78e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Apr 2022 13:03:20 +0100 Subject: [PATCH 11/14] fix tests and others --- flowtorch/bijectors/coupling.py | 4 ++-- flowtorch/bijectors/ops/affine.py | 8 ++++---- flowtorch/bijectors/ops/spline.py | 2 +- flowtorch/parameters/coupling.py | 12 ++++-------- tests/test_bijectivetensor.py | 1 - tests/test_bijector.py | 26 +++++++++++++++++++------- tests/test_distribution.py | 7 ++++--- 7 files changed, 34 insertions(+), 26 deletions(-) diff --git a/flowtorch/bijectors/coupling.py b/flowtorch/bijectors/coupling.py index 95b8f2f0..73b47f50 100644 --- a/flowtorch/bijectors/coupling.py +++ b/flowtorch/bijectors/coupling.py @@ -65,7 +65,7 @@ def __init__( def _forward( self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: assert self._params_fn is not None y, ldj = super()._forward(x, params) @@ -73,7 +73,7 @@ def _forward( def _inverse( self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: assert self._params_fn is not None x, ldj = super()._inverse(y, params) diff --git a/flowtorch/bijectors/ops/affine.py b/flowtorch/bijectors/ops/affine.py index 5684fd18..196cfc20 100644 --- a/flowtorch/bijectors/ops/affine.py +++ b/flowtorch/bijectors/ops/affine.py @@ -9,7 +9,7 @@ from torch.distributions.utils import _sum_rightmost _DEFAULT_POSITIVE_BIASES = { - "softplus": torch.expm1(torch.ones(1)).log().item(), + "softplus": 0.5413248538970947, "exp": 0.0, } @@ -58,7 +58,7 @@ def positive_map(self, x: torch.Tensor) -> torch.Tensor: def _forward( self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: assert params is not None mean, unbounded_scale = params @@ -73,7 +73,7 @@ def _forward( def _inverse( self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: assert ( params is not None ), f"{self.__class__.__name__}._inverse got no parameters" @@ -86,7 +86,7 @@ def _inverse( if not self._exp_map: inverse_scale = self.positive_map(unbounded_scale).reciprocal() - log_scale = inverse_scale.log() + log_scale = -inverse_scale.log() else: inverse_scale = torch.exp(-unbounded_scale) log_scale = unbounded_scale diff --git a/flowtorch/bijectors/ops/spline.py b/flowtorch/bijectors/ops/spline.py index 687d1bac..31d308a9 100644 --- a/flowtorch/bijectors/ops/spline.py +++ b/flowtorch/bijectors/ops/spline.py @@ -56,7 +56,7 @@ def _inverse( # TODO: Should I invert the sign of log_detJ? # TODO: A unit test that compares log_detJ from _forward and _inverse - return x_new, _sum_rightmost(log_detJ, self.domain.event_dim) + return x_new, _sum_rightmost(-log_detJ, self.domain.event_dim) def _log_abs_det_jacobian( self, diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py index d38ecb7d..f8179c55 100644 --- a/flowtorch/parameters/coupling.py +++ b/flowtorch/parameters/coupling.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn - from flowtorch.nn.made import MaskedLinear from flowtorch.parameters.base import Parameters @@ -178,12 +177,11 @@ def bias(self) -> torch.Tensor: def _forward( self, - *input: torch.Tensor, + input: torch.Tensor, inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: - input = input[0] input_masked = input.masked_fill(self.mask_output, 0.0) # type: ignore if context is not None: input_aug = torch.cat( @@ -203,7 +201,7 @@ def _forward( result = h.unbind(-2) result = tuple( - r.masked_fill(~self.mask_output.expand_as(r), 0.0) + r.masked_fill(~self.mask_output.expand_as(r), 0.0) # type: ignore for r in result # type: ignore ) return result @@ -318,12 +316,11 @@ def _init_weights(self) -> None: def _forward( self, - *input: torch.Tensor, + input: torch.Tensor, inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: - input = input[0] unsqueeze = False if input.ndimension() == 3: # mostly for initialization @@ -352,8 +349,7 @@ def _forward( result = h.chunk(2, -3) result = tuple( - r.masked_fill(~self.mask.expand_as(r), 0.0) - for r in result # type: ignore + r.masked_fill(~self.mask.expand_as(r), 0.0) for r in result # type: ignore ) return result diff --git a/tests/test_bijectivetensor.py b/tests/test_bijectivetensor.py index 72bbdf70..fa340f57 100644 --- a/tests/test_bijectivetensor.py +++ b/tests/test_bijectivetensor.py @@ -15,7 +15,6 @@ def get_net() -> AffineAutoregressive: [ AffineAutoregressive(params.DenseAutoregressive()), AffineAutoregressive(params.DenseAutoregressive()), - AffineAutoregressive(params.DenseAutoregressive()), ] ) ar = ar( diff --git a/tests/test_bijector.py b/tests/test_bijector.py index adb4b68f..d81a4fb0 100644 --- a/tests/test_bijector.py +++ b/tests/test_bijector.py @@ -1,4 +1,6 @@ # Copyright (c) Meta Platforms, Inc +import math + import flowtorch.bijectors as bijectors import numpy as np import pytest @@ -17,11 +19,13 @@ def test_bijector_constructor(): @pytest.fixture(params=[bij_name for _, bij_name in bijectors.standard_bijectors]) def flow(request): + torch.set_default_dtype(torch.double) bij = request.param event_dim = max(bij.domain.event_dim, 1) event_shape = event_dim * [3] base_dist = dist.Independent( - dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)), event_dim + dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)), + event_dim, ) flow = Flow(base_dist, bij) @@ -37,10 +41,12 @@ def test_jacobian(flow, epsilon=1e-2): x = torch.randn(*flow.event_shape) x = torch.distributions.transform_to(bij.domain)(x) y = bij.forward(x) - if bij.domain.event_dim == 1: - analytic_ldt = bij.log_abs_det_jacobian(x, y).data + if bij.domain.event_dim == 0: + analytic_ldt = bij.log_abs_det_jacobian(x, y).data.sum(-1) else: - analytic_ldt = bij.log_abs_det_jacobian(x, y).sum(-1).data + analytic_ldt = bij.log_abs_det_jacobian(x, y).data + for _ in range(bij.domain.event_dim - 1): + analytic_ldt = analytic_ldt.sum(-1) # Calculate numerical Jacobian # TODO: Better way to get all indices of array/tensor? @@ -82,7 +88,8 @@ def test_jacobian(flow, epsilon=1e-2): if hasattr(params, "permutation"): numeric_ldt = torch.sum(torch.log(torch.diag(jacobian))) else: - numeric_ldt = torch.log(torch.abs(jacobian.det())) + jacobian = jacobian.view(int(math.sqrt(jacobian.numel())), -1) + numeric_ldt = torch.log(torch.abs(jacobian.det())).sum() ldt_discrepancy = (analytic_ldt - numeric_ldt).abs() assert ldt_discrepancy < epsilon @@ -105,15 +112,20 @@ def test_inverse(flow, epsilon=1e-5): # Test g^{-1}(g(x)) = x x_true = base_dist.sample(torch.Size([10])) + assert x_true.dtype is torch.double x_true = torch.distributions.transform_to(bij.domain)(x_true) y = bij.forward(x_true) + J_1 = y.log_detJ + y = y.detach_from_flow() + x_calculated = bij.inverse(y) + J_2 = x_calculated.log_detJ + x_calculated = x_calculated.detach_from_flow() + assert (x_true - x_calculated).abs().max().item() < epsilon # Test that Jacobian after inverse op is same as after forward - J_1 = bij.log_abs_det_jacobian(x_true, y) - J_2 = bij.log_abs_det_jacobian(x_calculated, y) assert (J_1 - J_2).abs().max().item() < epsilon diff --git a/tests/test_distribution.py b/tests/test_distribution.py index db7c9095..25c065a5 100644 --- a/tests/test_distribution.py +++ b/tests/test_distribution.py @@ -15,7 +15,8 @@ def test_tdist_standalone(): def make_tdist(): # train a flow here base_dist = torch.distributions.Independent( - torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim)), 1 + torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim)), + 1, ) bijector = bijs.AffineAutoregressive() tdist = dist.Flow(base_dist, bijector) @@ -37,9 +38,9 @@ def test_neals_funnel_vi(): flow = dist.Flow(base_dist, bijector) bijector = flow.bijector - opt = torch.optim.Adam(flow.parameters(), lr=2e-3) + opt = torch.optim.Adam(flow.parameters(), lr=1e-2) num_elbo_mc_samples = 200 - for _ in range(100): + for _ in range(500): z0 = flow.base_dist.rsample(sample_shape=(num_elbo_mc_samples,)) zk = bijector.forward(z0) ldj = zk._log_detJ From 98491b83445cd1c51ff232a0e11c4192e8ad26a2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 26 Apr 2022 16:26:05 +0100 Subject: [PATCH 12/14] merge main --- flowtorch/bijectors/__init__.py | 8 ++++---- flowtorch/bijectors/base.py | 16 ++++++++++++---- tests/test_bijector.py | 17 ++++++++++++----- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/flowtorch/bijectors/__init__.py b/flowtorch/bijectors/__init__.py index 7313dcfb..6c98c025 100644 --- a/flowtorch/bijectors/__init__.py +++ b/flowtorch/bijectors/__init__.py @@ -37,28 +37,28 @@ ("Affine", Affine), ("AffineAutoregressive", AffineAutoregressive), ("AffineFixed", AffineFixed), + ("Fixed", Fixed), ("ConvCouplingBijector", ConvCouplingBijector), ("CouplingBijector", CouplingBijector), ("ELU", ELU), ("Exp", Exp), ("LeakyReLU", LeakyReLU), - ("Permute", Permute), + ("VolumePreserving", VolumePreserving), ("Power", Power), ("Sigmoid", Sigmoid), ("Softplus", Softplus), ("Spline", Spline), - ("SplineAutoregressive", SplineAutoregressive), ("Tanh", Tanh), ] meta_bijectors = [ ("Elementwise", Elementwise), ("Autoregressive", Autoregressive), - ("Fixed", Fixed), ("Bijector", Bijector), ("Compose", Compose), ("Invert", Invert), - ("VolumePreserving", VolumePreserving), + ("Permute", Permute), + ("SplineAutoregressive", SplineAutoregressive), ] diff --git a/flowtorch/bijectors/base.py b/flowtorch/bijectors/base.py index 2a3d0f01..a9621cca 100644 --- a/flowtorch/bijectors/base.py +++ b/flowtorch/bijectors/base.py @@ -71,7 +71,11 @@ def forward( assert isinstance(x, BijectiveTensor) return x.get_parent_from_bijector(self) - params = self._params_fn(x, context) if self._params_fn is not None else None + params = ( + self._params_fn(x, inverse=False, context=context) + if self._params_fn is not None + else None + ) y, log_detJ = self._forward(x, params) if ( is_record_flow_graph_enabled() @@ -117,7 +121,11 @@ def inverse( return y.get_parent_from_bijector(self) # TODO: What to do in this line? - params = self._params_fn(x, context) if self._params_fn is not None else None + params = ( + self._params_fn(y, inverse=True, context=context) + if self._params_fn is not None + else None + ) x, log_detJ = self._inverse(y, params) if ( @@ -170,10 +178,10 @@ def log_abs_det_jacobian( if ladj is None: if is_record_flow_graph_enabled(): warnings.warn( - "Computing _log_abs_det_jacobian from values and not " "from cache." + "Computing _log_abs_det_jacobian from values and not from cache." ) params = ( - self._params_fn(x, context) if self._params_fn is not None else None + self._params_fn(x, y, context) if self._params_fn is not None else None ) return self._log_abs_det_jacobian(x, y, params) return ladj diff --git a/tests/test_bijector.py b/tests/test_bijector.py index e9344ef6..eec9f949 100644 --- a/tests/test_bijector.py +++ b/tests/test_bijector.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc +import math import warnings import flowtorch.bijectors as bijectors @@ -21,11 +22,13 @@ def test_bijector_constructor(): @pytest.fixture(params=[bij_name for _, bij_name in bijectors.standard_bijectors]) def flow(request): + torch.set_default_dtype(torch.double) bij = request.param event_dim = max(bij.domain.event_dim, 1) event_shape = event_dim * [3] base_dist = dist.Independent( - dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)), event_dim + dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)), + event_dim, ) flow = Flow(base_dist, bij) @@ -41,10 +44,12 @@ def test_jacobian(flow, epsilon=1e-2): x = torch.randn(*flow.event_shape) x = torch.distributions.transform_to(bij.domain)(x) y = bij.forward(x) - if bij.domain.event_dim == 1: - analytic_ldt = bij.log_abs_det_jacobian(x, y).data + if bij.domain.event_dim == 0: + analytic_ldt = bij.log_abs_det_jacobian(x, y).data.sum(-1) else: - analytic_ldt = bij.log_abs_det_jacobian(x, y).sum(-1).data + analytic_ldt = bij.log_abs_det_jacobian(x, y).data + for _ in range(bij.domain.event_dim - 1): + analytic_ldt = analytic_ldt.sum(-1) # Calculate numerical Jacobian # TODO: Better way to get all indices of array/tensor? @@ -86,7 +91,8 @@ def test_jacobian(flow, epsilon=1e-2): if hasattr(params, "permutation"): numeric_ldt = torch.sum(torch.log(torch.diag(jacobian))) else: - numeric_ldt = torch.log(torch.abs(jacobian.det())) + jacobian = jacobian.view(int(math.sqrt(jacobian.numel())), -1) + numeric_ldt = torch.log(torch.abs(jacobian.det())).sum() ldt_discrepancy = (analytic_ldt - numeric_ldt).abs() assert ldt_discrepancy < epsilon @@ -109,6 +115,7 @@ def test_inverse(flow, epsilon=1e-5): # Test g^{-1}(g(x)) = x x_true = base_dist.sample(torch.Size([10])) + assert x_true.dtype is torch.double x_true = torch.distributions.transform_to(bij.domain)(x_true) y = bij.forward(x_true) From fb690c22d876cc029e3429df59b2c047f3877cf6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 26 Apr 2022 16:37:04 +0100 Subject: [PATCH 13/14] optimize imports --- flowtorch/bijectors/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flowtorch/bijectors/__init__.py b/flowtorch/bijectors/__init__.py index 6c98c025..7313dcfb 100644 --- a/flowtorch/bijectors/__init__.py +++ b/flowtorch/bijectors/__init__.py @@ -37,28 +37,28 @@ ("Affine", Affine), ("AffineAutoregressive", AffineAutoregressive), ("AffineFixed", AffineFixed), - ("Fixed", Fixed), ("ConvCouplingBijector", ConvCouplingBijector), ("CouplingBijector", CouplingBijector), ("ELU", ELU), ("Exp", Exp), ("LeakyReLU", LeakyReLU), - ("VolumePreserving", VolumePreserving), + ("Permute", Permute), ("Power", Power), ("Sigmoid", Sigmoid), ("Softplus", Softplus), ("Spline", Spline), + ("SplineAutoregressive", SplineAutoregressive), ("Tanh", Tanh), ] meta_bijectors = [ ("Elementwise", Elementwise), ("Autoregressive", Autoregressive), + ("Fixed", Fixed), ("Bijector", Bijector), ("Compose", Compose), ("Invert", Invert), - ("Permute", Permute), - ("SplineAutoregressive", SplineAutoregressive), + ("VolumePreserving", VolumePreserving), ] From 5ec166a0596c5079ecb2c601356aa95880af39b9 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Wed, 4 May 2022 17:13:38 -0700 Subject: [PATCH 14/14] Reverted changes to AffineOp since they're in separate PR --- flowtorch/bijectors/affine_autoregressive.py | 21 ++---- flowtorch/bijectors/ops/affine.py | 76 +++++--------------- 2 files changed, 21 insertions(+), 76 deletions(-) diff --git a/flowtorch/bijectors/affine_autoregressive.py b/flowtorch/bijectors/affine_autoregressive.py index a855cf5d..610e5477 100644 --- a/flowtorch/bijectors/affine_autoregressive.py +++ b/flowtorch/bijectors/affine_autoregressive.py @@ -16,28 +16,15 @@ def __init__( *, shape: torch.Size, context_shape: Optional[torch.Size] = None, - clamp_values: bool = False, log_scale_min_clip: float = -5.0, log_scale_max_clip: float = 3.0, sigmoid_bias: float = 2.0, - positive_map: str = "softplus", - positive_bias: Optional[float] = None, ) -> None: - AffineOp.__init__( - self, - params_fn, - shape=shape, - context_shape=context_shape, - clamp_values=clamp_values, - log_scale_min_clip=log_scale_min_clip, - log_scale_max_clip=log_scale_max_clip, - sigmoid_bias=sigmoid_bias, - positive_map=positive_map, - positive_bias=positive_bias, - ) - Autoregressive.__init__( - self, + super().__init__( params_fn, shape=shape, context_shape=context_shape, ) + self.log_scale_min_clip = log_scale_min_clip + self.log_scale_max_clip = log_scale_max_clip + self.sigmoid_bias = sigmoid_bias diff --git a/flowtorch/bijectors/ops/affine.py b/flowtorch/bijectors/ops/affine.py index 196cfc20..d9cdf56f 100644 --- a/flowtorch/bijectors/ops/affine.py +++ b/flowtorch/bijectors/ops/affine.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Callable, Dict, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple import flowtorch import torch @@ -8,17 +8,6 @@ from flowtorch.ops import clamp_preserve_gradients from torch.distributions.utils import _sum_rightmost -_DEFAULT_POSITIVE_BIASES = { - "softplus": 0.5413248538970947, - "exp": 0.0, -} - -_POSITIVE_MAPS: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = { - "softplus": torch.nn.functional.softplus, - "sigmoid": torch.sigmoid, - "exp": torch.exp, -} - class Affine(Bijector): r""" @@ -33,63 +22,38 @@ def __init__( *, shape: torch.Size, context_shape: Optional[torch.Size] = None, - clamp_values: bool = False, log_scale_min_clip: float = -5.0, log_scale_max_clip: float = 3.0, sigmoid_bias: float = 2.0, - positive_map: str = "softplus", - positive_bias: Optional[float] = None, ) -> None: super().__init__(params_fn, shape=shape, context_shape=context_shape) - self.clamp_values = clamp_values self.log_scale_min_clip = log_scale_min_clip self.log_scale_max_clip = log_scale_max_clip self.sigmoid_bias = sigmoid_bias - if positive_bias is None: - positive_bias = _DEFAULT_POSITIVE_BIASES[positive_map] - self.positive_bias = positive_bias - if positive_map not in _POSITIVE_MAPS: - raise RuntimeError(f"Unknwon positive map {positive_map}") - self._positive_map = _POSITIVE_MAPS[positive_map] - self._exp_map = self._positive_map is torch.exp and self.positive_bias == 0 - - def positive_map(self, x: torch.Tensor) -> torch.Tensor: - return self._positive_map(x + self.positive_bias) def _forward( self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor]: assert params is not None - mean, unbounded_scale = params - if self.clamp_values: - unbounded_scale = clamp_preserve_gradients( - unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) - scale = self.positive_map(unbounded_scale) - log_scale = scale.log() if not self._exp_map else unbounded_scale + mean, log_scale = params + log_scale = clamp_preserve_gradients( + log_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + scale = torch.exp(log_scale) y = scale * x + mean return y, _sum_rightmost(log_scale, self.domain.event_dim) def _inverse( self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - assert ( - params is not None - ), f"{self.__class__.__name__}._inverse got no parameters" - - mean, unbounded_scale = params - if self.clamp_values: - unbounded_scale = clamp_preserve_gradients( - unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert params is not None - if not self._exp_map: - inverse_scale = self.positive_map(unbounded_scale).reciprocal() - log_scale = -inverse_scale.log() - else: - inverse_scale = torch.exp(-unbounded_scale) - log_scale = unbounded_scale + mean, log_scale = params + log_scale = clamp_preserve_gradients( + log_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + inverse_scale = torch.exp(-log_scale) x_new = (y - mean) * inverse_scale return x_new, _sum_rightmost(log_scale, self.domain.event_dim) @@ -101,15 +65,9 @@ def _log_abs_det_jacobian( ) -> torch.Tensor: assert params is not None - _, unbounded_scale = params - if self.clamp_values: - unbounded_scale = clamp_preserve_gradients( - unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) - log_scale = ( - self.positive_map(unbounded_scale).log() - if not self._exp_map - else unbounded_scale + _, log_scale = params + log_scale = clamp_preserve_gradients( + log_scale, self.log_scale_min_clip, self.log_scale_max_clip ) return _sum_rightmost(log_scale, self.domain.event_dim)