diff --git a/bayesflow/networks/coupling_flow/transforms/affine_transform.py b/bayesflow/networks/coupling_flow/transforms/affine_transform.py index 3c83393f4..38d8b8569 100644 --- a/bayesflow/networks/coupling_flow/transforms/affine_transform.py +++ b/bayesflow/networks/coupling_flow/transforms/affine_transform.py @@ -1,15 +1,14 @@ -import math - import keras.ops as ops from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor +from bayesflow.utils.keras_utils import shifted_softplus from .transform import Transform @serializable(package="networks.coupling_flow") class AffineTransform(Transform): - def __init__(self, clamp: float | None = 1.9, **kwargs): + def __init__(self, clamp: bool = True, **kwargs): super().__init__(**kwargs) self.clamp = clamp @@ -25,12 +24,12 @@ def split_parameters(self, parameters: Tensor) -> dict[str, Tensor]: def constrain_parameters(self, parameters: dict[str, Tensor]) -> dict[str, Tensor]: scale = parameters["scale"] - # soft clamp - if self.clamp is not None: - (2.0 * self.clamp / math.pi) * ops.arctan(scale / self.clamp) - # constrain to positive values - scale = ops.exp(scale) + scale = shifted_softplus(scale) + + # soft clamp + if self.clamp: + scale = ops.arcsinh(scale) parameters["scale"] = scale return parameters