Skip to content

Commit

Permalink
Add working state - fitting benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jun 3, 2024
1 parent 7e52075 commit 142d61f
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 51 deletions.
33 changes: 5 additions & 28 deletions bayesflow/experimental/amortizers/amortizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,40 +64,17 @@ def compute_loss(self, x: dict = None, y: dict = None, y_pred: dict = None, **kw

inference_conditions = self.configure_inference_conditions(x, y_pred.get("summary_outputs"))
inference_loss = self.inference_network.compute_loss(
x=(inferred_variables, inference_conditions),
y=y.get("inference_targets"),
y_pred=y_pred.get("inference_outputs")
x=inferred_variables,
conditions=inference_conditions,
**kwargs
)

return inference_loss + summary_loss

def compute_metrics(self, x: dict, y: dict, y_pred: dict, **kwargs):
base_metrics = super().compute_metrics(x, y, y_pred, **kwargs)

inferred_variables = self.configure_inferred_variables(x)
observed_variables = self.configure_observed_variables(x)

if self.summary_network:
summary_conditions = self.configure_summary_conditions(x)
summary_metrics = self.summary_network.compute_metrics(
x=(observed_variables, summary_conditions),
y=y.get("summary_targets"),
y_pred=y_pred.get("summary_outputs")
)
else:
summary_metrics = {}

inference_conditions = self.configure_inference_conditions(x, y_pred.get("summary_outputs"))
inference_metrics = self.inference_network.compute_metrics(
x=(inferred_variables, inference_conditions),
y=y.get("inference_targets"),
y_pred=y_pred.get("inference_outputs")
)

summary_metrics = {f"summary/{key}": value for key, value in summary_metrics.items()}
inference_metrics = {f"inference/{key}": value for key, value in inference_metrics.items()}

return base_metrics | summary_metrics | inference_metrics
#TODO - add back metrics
return base_metrics

def sample(self, data: dict, num_samples: int, sample_summaries=False, **kwargs):

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/experimental/distributions/diagonal_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DiagonalNormal(Distribution):
- ``_log_unnormalized_prob`` method is used as a loss function
- ``log_prob`` is used for density computation
"""
def __init__(self, mean: float = 0.0, std: float = 1.0, **kwargs):
def __init__(self, mean: float | Tensor = 0.0, std: float | Tensor = 1.0, **kwargs):
super().__init__(**kwargs)
self.mean = mean
self.std = std
Expand Down
17 changes: 12 additions & 5 deletions bayesflow/experimental/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
use_actnorm: bool = True,
**kwargs
):
# TODO - propagate optional keyword arguments to find_network and ResNet respectively
super().__init__(**kwargs)

self._layers = []
Expand All @@ -54,11 +55,11 @@ def __init__(
self._layers.append(ActNorm(name=f"ActNorm{i}"))
self._layers.append(DualCoupling(subnet, transform, name=f"DualCoupling{i}"))
if permutation.lower() == "random":
self._layers.append(RandomPermutation())
self._layers.append(RandomPermutation(name=f"RandomPermutation{i}"))
elif permutation.lower() == "swap":
self._layers.append(Swap())
self._layers.append(Swap(name=f"Swap{i}"))
elif permutation.lower() == "learnable":
self._layers.append(OrthogonalPermutation())
self._layers.append(OrthogonalPermutation(name=f"OrthogonalPermutation{i}"))

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
Expand All @@ -77,7 +78,10 @@ def _forward(self, x: Tensor, conditions: Tensor = None, jacobian: bool = False,
z = x
log_det = keras.ops.zeros(keras.ops.shape(x)[:-1])
for layer in self._layers:
z, det = layer(z, conditions=conditions, inverse=False, **kwargs)
if isinstance(layer, DualCoupling):
z, det = layer(z, conditions=conditions, inverse=False, **kwargs)
else:
z, det = layer(z, inverse=False, **kwargs)
log_det += det

if jacobian:
Expand All @@ -88,7 +92,10 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, jacobian: bool = False,
x = z
log_det = keras.ops.zeros(keras.ops.shape(z)[:-1])
for layer in reversed(self._layers):
x, det = layer(x, conditions=conditions, inverse=True, **kwargs)
if isinstance(layer, DualCoupling):
x, det = layer(x, conditions=conditions, inverse=True, **kwargs)
else:
x, det = layer(x, inverse=True, **kwargs)
log_det += det

if jacobian:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,32 @@ def __init__(self, network: str = "resnet", transform: str = "affine", **kwargs)
def build(self, x1_shape, x2_shape):
self.output_projector.units = self.transform.params_per_dim * x2_shape[-1]

def call(self, x1: Tensor, x2: Tensor, conditions: Tensor = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor):
def call(self, x1: Tensor, x2: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs) -> ((Tensor, Tensor), Tensor):
if inverse:
return self._inverse(x1, x2, conditions=conditions)
return self._forward(x1, x2, conditions=conditions)
return self._inverse(x1, x2, conditions=conditions, **kwargs)
return self._forward(x1, x2, conditions=conditions, **kwargs)

def _forward(self, x1: Tensor, x2: Tensor, conditions: Tensor = None) -> ((Tensor, Tensor), Tensor):
def _forward(self, x1: Tensor, x2: Tensor, conditions: Tensor = None, **kwargs) -> ((Tensor, Tensor), Tensor):
""" Transform (x1, x2) -> (x1, f(x2; x1)) """
z1 = x1
parameters = self.get_parameters(x1, conditions=conditions)
parameters = self.get_parameters(x1, conditions=conditions, **kwargs)
z2, log_det = self.transform(x2, parameters=parameters)

return (z1, z2), log_det

def _inverse(self, z1: Tensor, z2: Tensor, conditions: Tensor = None) -> ((Tensor, Tensor), Tensor):
def _inverse(self, z1: Tensor, z2: Tensor, conditions: Tensor = None, **kwargs) -> ((Tensor, Tensor), Tensor):
""" Transform (x1, f(x2; x1)) -> (x1, x2) """
x1 = z1
parameters = self.get_parameters(x1, conditions=conditions)
parameters = self.get_parameters(x1, conditions=conditions, **kwargs)
x2, log_det = self.transform(z2, parameters=parameters, inverse=True)

return (x1, x2), log_det

def get_parameters(self, x: Tensor, conditions: Tensor = None) -> dict[str, Tensor]:
def get_parameters(self, x: Tensor, conditions: Tensor = None, **kwargs) -> dict[str, Tensor]:
if conditions is not None:
x = keras.ops.concatenate([x, conditions], axis=-1)

parameters = self.output_projector(self.network(x))
parameters = self.output_projector(self.network(x, **kwargs))
parameters = self.transform.split_parameters(parameters)
parameters = self.transform.constrain_parameters(parameters)

Expand Down
8 changes: 4 additions & 4 deletions bayesflow/experimental/networks/inference_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def _forward(self, x: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
def _inverse(self, z: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
raise NotImplementedError

def sample(self, num_samples: int, **kwargs) -> Tensor:
def sample(self, num_samples: int, conditions: Tensor = None, **kwargs) -> Tensor:
samples = self.base_distribution.sample((num_samples,))
return self(samples, inverse=True, jacobian=False, **kwargs)
return self(samples, conditions=conditions, inverse=True, jacobian=False, **kwargs)

def log_prob(self, x: Tensor, **kwargs) -> Tensor:
samples, log_det = self(x, inverse=False, jacobian=True, **kwargs)
def log_prob(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
samples, log_det = self(x, conditions=conditions, inverse=False, jacobian=True, **kwargs)
log_prob = self.base_distribution.log_prob(samples)
return log_prob + log_det

Expand Down
6 changes: 3 additions & 3 deletions bayesflow/experimental/networks/resnet/hidden_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(
self.dense = layers.SpectralNormalization(self.dense)
self.dropout = keras.layers.Dropout(dropout_rate)

def call(self, inputs: Tensor, **kwargs):
x = self.dense(inputs, **kwargs)
x = self.dropout(x, **kwargs)
def call(self, inputs: Tensor, training=False):
x = self.dense(inputs, training=training)
x = self.dropout(x, training=training)
if self.residual:
x = x + inputs
return self.activation_fn(x)
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/experimental/networks/resnet/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ def __init__(
)

def call(self, inputs: Tensor, **kwargs):
return self.res_blocks(inputs, **kwargs)
return self.res_blocks(inputs, training=kwargs.get("training", False))

Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@


class SetTransformer:
#TODO - whole module
pass

0 comments on commit 142d61f

Please sign in to comment.