diff --git a/flowtorch/bijectors/affine_autoregressive.py b/flowtorch/bijectors/affine_autoregressive.py index 610e5477..fa90fd68 100644 --- a/flowtorch/bijectors/affine_autoregressive.py +++ b/flowtorch/bijectors/affine_autoregressive.py @@ -16,15 +16,17 @@ def __init__( *, shape: torch.Size, context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, log_scale_min_clip: float = -5.0, log_scale_max_clip: float = 3.0, - sigmoid_bias: float = 2.0, + scale_fn: str = "softplus", ) -> None: super().__init__( params_fn, shape=shape, context_shape=context_shape, ) + self.clamp_values = clamp_values self.log_scale_min_clip = log_scale_min_clip self.log_scale_max_clip = log_scale_max_clip - self.sigmoid_bias = sigmoid_bias + self.scale_fn = scale_fn diff --git a/flowtorch/bijectors/ops/affine.py b/flowtorch/bijectors/ops/affine.py index d9cdf56f..304d83c6 100644 --- a/flowtorch/bijectors/ops/affine.py +++ b/flowtorch/bijectors/ops/affine.py @@ -4,6 +4,7 @@ import flowtorch import torch +import torch.nn.functional as F from flowtorch.bijectors.base import Bijector from flowtorch.ops import clamp_preserve_gradients from torch.distributions.utils import _sum_rightmost @@ -22,25 +23,66 @@ def __init__( *, shape: torch.Size, context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, log_scale_min_clip: float = -5.0, log_scale_max_clip: float = 3.0, - sigmoid_bias: float = 2.0, + scale_fn: str = "softplus", ) -> None: super().__init__(params_fn, shape=shape, context_shape=context_shape) + self.clamp_values = clamp_values self.log_scale_min_clip = log_scale_min_clip self.log_scale_max_clip = log_scale_max_clip - self.sigmoid_bias = sigmoid_bias + self.scale_fn = scale_fn + + def _scale_fn( + self, unbounded_scale: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # NOTE: Need to hardcode log(f(x)) for numerical stability + if self.scale_fn == "softplus": + scale = F.softplus(unbounded_scale) + log_scale = torch.log(scale) + elif self.scale_fn == "exp": + scale = torch.exp(unbounded_scale) + log_scale = unbounded_scale + elif self.scale_fn == "sigmoid": + scale = torch.sigmoid(unbounded_scale) + log_scale = F.logsigmoid(unbounded_scale) + else: + raise ValueError(f"Unknown scale function: {self.scale_fn}") + + return scale, log_scale + + def _inv_scale_fn( + self, unbounded_scale: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # NOTE: Need to hardcode 1./log(f(x)) for numerical stability + if self.scale_fn == "softplus": + scale = F.softplus(unbounded_scale) + inverse_scale = F.softplus(unbounded_scale).reciprocal() + log_scale = torch.log(scale) + elif self.scale_fn == "exp": + inverse_scale = torch.exp(-unbounded_scale) + log_scale = unbounded_scale + elif self.scale_fn == "sigmoid": + inverse_scale = torch.exp(-unbounded_scale) + 1.0 + log_scale = F.logsigmoid(unbounded_scale) + else: + raise ValueError(f"Unknown scale function: {self.scale_fn}") + + return inverse_scale, log_scale def _forward( self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor]: assert params is not None - mean, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) - scale = torch.exp(log_scale) + mean, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + + scale, log_scale = self._scale_fn(unbounded_scale) y = scale * x + mean return y, _sum_rightmost(log_scale, self.domain.event_dim) @@ -49,11 +91,13 @@ def _inverse( ) -> Tuple[torch.Tensor, torch.Tensor]: assert params is not None - mean, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) - inverse_scale = torch.exp(-log_scale) + mean, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + + inverse_scale, log_scale = self._inv_scale_fn(unbounded_scale) x_new = (y - mean) * inverse_scale return x_new, _sum_rightmost(log_scale, self.domain.event_dim) @@ -65,10 +109,13 @@ def _log_abs_det_jacobian( ) -> torch.Tensor: assert params is not None - _, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) + _, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + _, log_scale = self._scale_fn(unbounded_scale) + return _sum_rightmost(log_scale, self.domain.event_dim) def param_shapes(self, shape: torch.Size) -> Tuple[torch.Size, torch.Size]: