Skip to content

Commit

Permalink
add theoretical capability for condition handling
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jun 3, 2024
1 parent 00fe1d6 commit 80f6a98
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 53 deletions.
8 changes: 4 additions & 4 deletions bayesflow/experimental/networks/coupling_flow/actnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def build(self, input_shape):

def call(self, xz: Tensor, inverse: bool = False, **kwargs):
if inverse:
return self._inverse(xz)
return self._forward(xz)
return self._inverse(xz, **kwargs)
return self._forward(xz, **kwargs)

def _forward(self, x: Tensor) -> (Tensor, Tensor):
def _forward(self, x: Tensor, **kwargs) -> (Tensor, Tensor):
z = self.scale * x + self.bias
log_det = ops.sum(ops.log(ops.abs(self.scale)), axis=-1)
log_det = ops.broadcast_to(log_det, ops.shape(x)[:-1])
return z, log_det

def _inverse(self, z: Tensor) -> (Tensor, Tensor):
def _inverse(self, z: Tensor, **kwargs) -> (Tensor, Tensor):
x = (z - self.bias) / self.scale
log_det = -ops.sum(ops.log(ops.abs(self.scale)), axis=-1)
log_det = ops.broadcast_to(log_det, ops.shape(z)[:-1])
Expand Down
34 changes: 19 additions & 15 deletions bayesflow/experimental/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,50 +49,54 @@ def __init__(
super().__init__(**kwargs)

self._layers = []
for _ in range(depth):
for i in range(depth):
if use_actnorm:
self._layers.append(ActNorm())
self._layers.append(DualCoupling(subnet, transform))
self._layers.append(ActNorm(name=f"ActNorm{i}"))
self._layers.append(DualCoupling(subnet, transform, name=f"DualCoupling{i}"))
if permutation.lower() == "random":
self._layers.append(RandomPermutation())
elif permutation.lower() == "swap":
self._layers.append(Swap())
elif permutation.lower() == "learnable":
self._layers.append(OrthogonalPermutation())

def build(self, input_shape):
super().build(input_shape)
self.call(keras.KerasTensor(input_shape))
# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
if conditions_shape is None:
self.call(keras.KerasTensor(xz_shape))
else:
self.call(keras.KerasTensor(xz_shape), conditions=keras.KerasTensor(conditions_shape))

def call(self, xz: Tensor, inverse: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
def call(self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if inverse:
return self._inverse(xz, **kwargs)
return self._forward(xz, **kwargs)
return self._inverse(xz, conditions=conditions, **kwargs)
return self._forward(xz, conditions=conditions, **kwargs)

def _forward(self, x: Tensor, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
def _forward(self, x: Tensor, conditions: Tensor = None, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
z = x
log_det = keras.ops.zeros(keras.ops.shape(x)[:-1])
for layer in self._layers:
z, det = layer(z, inverse=False, **kwargs)
z, det = layer(z, conditions=conditions, inverse=False, **kwargs)
log_det += det

if jacobian:
return z, log_det
return z

def _inverse(self, z: Tensor, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
def _inverse(self, z: Tensor, conditions: Tensor = None, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
x = z
log_det = keras.ops.zeros(keras.ops.shape(z)[:-1])
for layer in reversed(self._layers):
x, det = layer(x, inverse=True, **kwargs)
x, det = layer(x, conditions=conditions, inverse=True, **kwargs)
log_det += det

if jacobian:
return x, log_det
return x

def compute_loss(self, x: Tensor = None, **kwargs):
z, log_det = self(x, inverse=False, jacobian=True, **kwargs)
def compute_loss(self, x: Tensor = None, conditions: Tensor = None, **kwargs):
z, log_det = self(x, conditions=conditions, inverse=False, jacobian=True, **kwargs)
log_prob = self.base_distribution.log_prob(z)
nll = -keras.ops.mean(log_prob + log_det)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,43 +11,36 @@

@register_keras_serializable(package="bayesflow.networks.coupling_flow")
class DualCoupling(InvertibleLayer):
def __init__(self, subnet: str = "resnet", transform: str = "affine"):
super().__init__()
self.coupling1 = SingleCoupling(subnet, transform)
self.coupling2 = SingleCoupling(subnet, transform)
def __init__(self, subnet: str = "resnet", transform: str = "affine", **kwargs):
super().__init__(**kwargs)
self.coupling1 = SingleCoupling(subnet, transform, name=f"CouplingA")
self.coupling2 = SingleCoupling(subnet, transform, name=f"CouplingB")
self.pivot = None

def build(self, input_shape):
self.pivot = input_shape[-1] // 2

def call(
self,
xz: Tensor,
conditions: any = None,
inverse: bool = False,
training: bool = False
) -> (Tensor, Tensor):

def call(self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs) -> (Tensor, Tensor):
if inverse:
return self._inverse(xz, conditions=conditions)
return self._forward(xz, conditions=conditions, training=training)
return self._inverse(xz, conditions=conditions, **kwargs)
return self._forward(xz, conditions=conditions, **kwargs)

def _forward(self, x: Tensor, conditions: any = None, training: bool = False) -> (Tensor, Tensor):
def _forward(self, x: Tensor, conditions: Tensor = None, **kwargs) -> (Tensor, Tensor):
""" Transform (x1, x2) -> (g(x1; f(x2; x1)), f(x2; x1)) """
x1, x2 = x[..., :self.pivot], x[..., self.pivot:]
(z1, z2), log_det1 = self.coupling1(x1, x2, conditions=conditions, training=training)
(z2, z1), log_det2 = self.coupling2(z2, z1, conditions=conditions, training=training)
(z1, z2), log_det1 = self.coupling1(x1, x2, conditions=conditions, **kwargs)
(z2, z1), log_det2 = self.coupling2(z2, z1, conditions=conditions, **kwargs)

z = keras.ops.concatenate([z1, z2], axis=-1)
log_det = log_det1 + log_det2

return z, log_det

def _inverse(self, z: Tensor, conditions: any = None) -> (Tensor, Tensor):
def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> (Tensor, Tensor):
""" Transform (g(x1; f(x2; x1)), f(x2; x1)) -> (x1, x2) """
z1, z2 = z[..., :self.pivot], z[..., self.pivot:]
(z2, z1), log_det2 = self.coupling2(z2, z1, conditions=conditions, inverse=True)
(x1, x2), log_det1 = self.coupling1(z1, z2, conditions=conditions, inverse=True)
(z2, z1), log_det2 = self.coupling2(z2, z1, conditions=conditions, inverse=True, **kwargs)
(x1, x2), log_det1 = self.coupling1(z1, z2, conditions=conditions, inverse=True, **kwargs)

x = keras.ops.concatenate([x1, x2], axis=-1)
log_det = log_det1 + log_det2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,29 @@ def __init__(self, network: str = "resnet", transform: str = "affine", **kwargs)
def build(self, x1_shape, x2_shape):
self.output_projector.units = self.transform.params_per_dim * x2_shape[-1]

def call(self, x1: Tensor, x2: Tensor, conditions: any = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor):
def call(self, x1: Tensor, x2: Tensor, conditions: Tensor = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor):
if inverse:
return self._inverse(x1, x2, conditions=conditions)
return self._forward(x1, x2, conditions=conditions)

def _forward(self, x1: Tensor, x2: Tensor, conditions: any = None) -> ((Tensor, Tensor), Tensor):
def _forward(self, x1: Tensor, x2: Tensor, conditions: Tensor = None) -> ((Tensor, Tensor), Tensor):
""" Transform (x1, x2) -> (x1, f(x2; x1)) """
z1 = x1
parameters = self.get_parameters(x1, conditions)
parameters = self.get_parameters(x1, conditions=conditions)
z2, log_det = self.transform(x2, parameters=parameters)

return (z1, z2), log_det

def _inverse(self, z1: Tensor, z2: Tensor, conditions: any = None) -> ((Tensor, Tensor), Tensor):
def _inverse(self, z1: Tensor, z2: Tensor, conditions: Tensor = None) -> ((Tensor, Tensor), Tensor):
""" Transform (x1, f(x2; x1)) -> (x1, x2) """
x1 = z1
parameters = self.get_parameters(x1, conditions)
parameters = self.get_parameters(x1, conditions=conditions)
x2, log_det = self.transform(z2, parameters=parameters, inverse=True)

return (x1, x2), log_det

def get_parameters(self, x, conditions: any = None) -> dict[str, Tensor]:
# TODO: pass conditions to subnet via kwarg if possible
if keras.ops.is_tensor(conditions):
def get_parameters(self, x: Tensor, conditions: Tensor = None) -> dict[str, Tensor]:
if conditions is not None:
x = keras.ops.concatenate([x, conditions], axis=-1)

parameters = self.output_projector(self.network(x))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def get_config(self):

return base_config | config

def compute_subnet_output_shape(self, input_shape):
return *input_shape[:-1], 2 * input_shape[-1]

def split_parameters(self, parameters: Tensor) -> dict[str, Tensor]:
scale, shift = ops.split(parameters, 2, axis=-1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ class Transform(InvertibleLayer):
def params_per_dim(self) -> int:
raise NotImplementedError

def compute_subnet_output_shape(self, input_shape):
raise NotImplementedError

def split_parameters(self, parameters: Tensor) -> dict[str, Tensor]:
raise NotImplementedError

Expand Down

0 comments on commit 80f6a98

Please sign in to comment.