Skip to content

Commit

Permalink
Simplify and improve clamp
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Nov 8, 2024
1 parent 0602071 commit 04e25d6
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions bayesflow/networks/coupling_flow/transforms/affine_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import keras
import keras.ops as ops
from keras.saving import register_keras_serializable as serializable

Expand All @@ -9,19 +8,9 @@

@serializable(package="networks.coupling_flow")
class AffineTransform(Transform):
def __init__(self, clamp: bool | int | float | None = 3.0, **kwargs):
def __init__(self, clamp: bool = True, **kwargs):
super().__init__(**kwargs)
match clamp:
case True:
self.clamp_factor = 3.0
case False:
self.clamp_factor = None
case int() | float():
self.clamp_factor = float(clamp)
case None:
self.clamp_factor = None
case _:
raise ValueError(f"Invalid value for 'clamp': {clamp}")
self.clamp = clamp

@property
def params_per_dim(self):
Expand All @@ -39,8 +28,8 @@ def constrain_parameters(self, parameters: dict[str, Tensor]) -> dict[str, Tenso
scale = shifted_softplus(scale)

# soft clamp
if self.clamp_factor is not None:
scale = self.clamp_factor * keras.ops.tanh(scale)
if self.clamp:
scale = ops.arcsinh(scale)

parameters["scale"] = scale
return parameters
Expand Down

0 comments on commit 04e25d6

Please sign in to comment.