From 142d61f3ed0ecca1139e7d34b27560e954e997d6 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Mon, 3 Jun 2024 16:50:30 -0400 Subject: [PATCH] Add working state - fitting benchmarks --- .../experimental/amortizers/amortizer.py | 33 +++---------------- .../distributions/diagonal_normal.py | 2 +- .../networks/coupling_flow/coupling_flow.py | 17 +++++++--- .../couplings/single_coupling.py | 18 +++++----- .../networks/inference_network.py | 8 ++--- .../networks/resnet/hidden_block.py | 6 ++-- .../experimental/networks/resnet/resnet.py | 2 +- .../set_transformer/set_transformer.py | 1 + 8 files changed, 36 insertions(+), 51 deletions(-) diff --git a/bayesflow/experimental/amortizers/amortizer.py b/bayesflow/experimental/amortizers/amortizer.py index f8cabd1e2..9962b9626 100644 --- a/bayesflow/experimental/amortizers/amortizer.py +++ b/bayesflow/experimental/amortizers/amortizer.py @@ -64,40 +64,17 @@ def compute_loss(self, x: dict = None, y: dict = None, y_pred: dict = None, **kw inference_conditions = self.configure_inference_conditions(x, y_pred.get("summary_outputs")) inference_loss = self.inference_network.compute_loss( - x=(inferred_variables, inference_conditions), - y=y.get("inference_targets"), - y_pred=y_pred.get("inference_outputs") + x=inferred_variables, + conditions=inference_conditions, + **kwargs ) return inference_loss + summary_loss def compute_metrics(self, x: dict, y: dict, y_pred: dict, **kwargs): base_metrics = super().compute_metrics(x, y, y_pred, **kwargs) - - inferred_variables = self.configure_inferred_variables(x) - observed_variables = self.configure_observed_variables(x) - - if self.summary_network: - summary_conditions = self.configure_summary_conditions(x) - summary_metrics = self.summary_network.compute_metrics( - x=(observed_variables, summary_conditions), - y=y.get("summary_targets"), - y_pred=y_pred.get("summary_outputs") - ) - else: - summary_metrics = {} - - inference_conditions = self.configure_inference_conditions(x, y_pred.get("summary_outputs")) - inference_metrics = self.inference_network.compute_metrics( - x=(inferred_variables, inference_conditions), - y=y.get("inference_targets"), - y_pred=y_pred.get("inference_outputs") - ) - - summary_metrics = {f"summary/{key}": value for key, value in summary_metrics.items()} - inference_metrics = {f"inference/{key}": value for key, value in inference_metrics.items()} - - return base_metrics | summary_metrics | inference_metrics + #TODO - add back metrics + return base_metrics def sample(self, data: dict, num_samples: int, sample_summaries=False, **kwargs): diff --git a/bayesflow/experimental/distributions/diagonal_normal.py b/bayesflow/experimental/distributions/diagonal_normal.py index 1947f4d82..f861550cd 100644 --- a/bayesflow/experimental/distributions/diagonal_normal.py +++ b/bayesflow/experimental/distributions/diagonal_normal.py @@ -16,7 +16,7 @@ class DiagonalNormal(Distribution): - ``_log_unnormalized_prob`` method is used as a loss function - ``log_prob`` is used for density computation """ - def __init__(self, mean: float = 0.0, std: float = 1.0, **kwargs): + def __init__(self, mean: float | Tensor = 0.0, std: float | Tensor = 1.0, **kwargs): super().__init__(**kwargs) self.mean = mean self.std = std diff --git a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py index a95c17578..cc1b7354c 100644 --- a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py @@ -46,6 +46,7 @@ def __init__( use_actnorm: bool = True, **kwargs ): + # TODO - propagate optional keyword arguments to find_network and ResNet respectively super().__init__(**kwargs) self._layers = [] @@ -54,11 +55,11 @@ def __init__( self._layers.append(ActNorm(name=f"ActNorm{i}")) self._layers.append(DualCoupling(subnet, transform, name=f"DualCoupling{i}")) if permutation.lower() == "random": - self._layers.append(RandomPermutation()) + self._layers.append(RandomPermutation(name=f"RandomPermutation{i}")) elif permutation.lower() == "swap": - self._layers.append(Swap()) + self._layers.append(Swap(name=f"Swap{i}")) elif permutation.lower() == "learnable": - self._layers.append(OrthogonalPermutation()) + self._layers.append(OrthogonalPermutation(name=f"OrthogonalPermutation{i}")) # noinspection PyMethodOverriding def build(self, xz_shape, conditions_shape=None): @@ -77,7 +78,10 @@ def _forward(self, x: Tensor, conditions: Tensor = None, jacobian: bool = False, z = x log_det = keras.ops.zeros(keras.ops.shape(x)[:-1]) for layer in self._layers: - z, det = layer(z, conditions=conditions, inverse=False, **kwargs) + if isinstance(layer, DualCoupling): + z, det = layer(z, conditions=conditions, inverse=False, **kwargs) + else: + z, det = layer(z, inverse=False, **kwargs) log_det += det if jacobian: @@ -88,7 +92,10 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, jacobian: bool = False, x = z log_det = keras.ops.zeros(keras.ops.shape(z)[:-1]) for layer in reversed(self._layers): - x, det = layer(x, conditions=conditions, inverse=True, **kwargs) + if isinstance(layer, DualCoupling): + x, det = layer(x, conditions=conditions, inverse=True, **kwargs) + else: + x, det = layer(x, inverse=True, **kwargs) log_det += det if jacobian: diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py index 862ac45bb..adde6dd31 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py @@ -27,32 +27,32 @@ def __init__(self, network: str = "resnet", transform: str = "affine", **kwargs) def build(self, x1_shape, x2_shape): self.output_projector.units = self.transform.params_per_dim * x2_shape[-1] - def call(self, x1: Tensor, x2: Tensor, conditions: Tensor = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor): + def call(self, x1: Tensor, x2: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs) -> ((Tensor, Tensor), Tensor): if inverse: - return self._inverse(x1, x2, conditions=conditions) - return self._forward(x1, x2, conditions=conditions) + return self._inverse(x1, x2, conditions=conditions, **kwargs) + return self._forward(x1, x2, conditions=conditions, **kwargs) - def _forward(self, x1: Tensor, x2: Tensor, conditions: Tensor = None) -> ((Tensor, Tensor), Tensor): + def _forward(self, x1: Tensor, x2: Tensor, conditions: Tensor = None, **kwargs) -> ((Tensor, Tensor), Tensor): """ Transform (x1, x2) -> (x1, f(x2; x1)) """ z1 = x1 - parameters = self.get_parameters(x1, conditions=conditions) + parameters = self.get_parameters(x1, conditions=conditions, **kwargs) z2, log_det = self.transform(x2, parameters=parameters) return (z1, z2), log_det - def _inverse(self, z1: Tensor, z2: Tensor, conditions: Tensor = None) -> ((Tensor, Tensor), Tensor): + def _inverse(self, z1: Tensor, z2: Tensor, conditions: Tensor = None, **kwargs) -> ((Tensor, Tensor), Tensor): """ Transform (x1, f(x2; x1)) -> (x1, x2) """ x1 = z1 - parameters = self.get_parameters(x1, conditions=conditions) + parameters = self.get_parameters(x1, conditions=conditions, **kwargs) x2, log_det = self.transform(z2, parameters=parameters, inverse=True) return (x1, x2), log_det - def get_parameters(self, x: Tensor, conditions: Tensor = None) -> dict[str, Tensor]: + def get_parameters(self, x: Tensor, conditions: Tensor = None, **kwargs) -> dict[str, Tensor]: if conditions is not None: x = keras.ops.concatenate([x, conditions], axis=-1) - parameters = self.output_projector(self.network(x)) + parameters = self.output_projector(self.network(x, **kwargs)) parameters = self.transform.split_parameters(parameters) parameters = self.transform.constrain_parameters(parameters) diff --git a/bayesflow/experimental/networks/inference_network.py b/bayesflow/experimental/networks/inference_network.py index 3f29a7186..7a0667cbe 100644 --- a/bayesflow/experimental/networks/inference_network.py +++ b/bayesflow/experimental/networks/inference_network.py @@ -31,12 +31,12 @@ def _forward(self, x: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: def _inverse(self, z: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: raise NotImplementedError - def sample(self, num_samples: int, **kwargs) -> Tensor: + def sample(self, num_samples: int, conditions: Tensor = None, **kwargs) -> Tensor: samples = self.base_distribution.sample((num_samples,)) - return self(samples, inverse=True, jacobian=False, **kwargs) + return self(samples, conditions=conditions, inverse=True, jacobian=False, **kwargs) - def log_prob(self, x: Tensor, **kwargs) -> Tensor: - samples, log_det = self(x, inverse=False, jacobian=True, **kwargs) + def log_prob(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: + samples, log_det = self(x, conditions=conditions, inverse=False, jacobian=True, **kwargs) log_prob = self.base_distribution.log_prob(samples) return log_prob + log_det diff --git a/bayesflow/experimental/networks/resnet/hidden_block.py b/bayesflow/experimental/networks/resnet/hidden_block.py index 89780fa64..da52eca23 100644 --- a/bayesflow/experimental/networks/resnet/hidden_block.py +++ b/bayesflow/experimental/networks/resnet/hidden_block.py @@ -36,9 +36,9 @@ def __init__( self.dense = layers.SpectralNormalization(self.dense) self.dropout = keras.layers.Dropout(dropout_rate) - def call(self, inputs: Tensor, **kwargs): - x = self.dense(inputs, **kwargs) - x = self.dropout(x, **kwargs) + def call(self, inputs: Tensor, training=False): + x = self.dense(inputs, training=training) + x = self.dropout(x, training=training) if self.residual: x = x + inputs return self.activation_fn(x) diff --git a/bayesflow/experimental/networks/resnet/resnet.py b/bayesflow/experimental/networks/resnet/resnet.py index b20a7acc9..4cc467536 100644 --- a/bayesflow/experimental/networks/resnet/resnet.py +++ b/bayesflow/experimental/networks/resnet/resnet.py @@ -80,5 +80,5 @@ def __init__( ) def call(self, inputs: Tensor, **kwargs): - return self.res_blocks(inputs, **kwargs) + return self.res_blocks(inputs, training=kwargs.get("training", False)) diff --git a/bayesflow/experimental/networks/set_transformer/set_transformer.py b/bayesflow/experimental/networks/set_transformer/set_transformer.py index 530ba664f..0d8e7b4a5 100644 --- a/bayesflow/experimental/networks/set_transformer/set_transformer.py +++ b/bayesflow/experimental/networks/set_transformer/set_transformer.py @@ -1,4 +1,5 @@ class SetTransformer: + #TODO - whole module pass