Skip to content

Commit

Permalink
Merge pull request #248 from bayesflow-org/forward-kwargs
Browse files Browse the repository at this point in the history
Forward Keyword-Arguments in Dispatch and Sampling
  • Loading branch information
vpratz authored Dec 20, 2024
2 parents 770244e + a4b6558 commit 8a870c0
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from bayesflow.adapters import Adapter
from bayesflow.networks import InferenceNetwork, SummaryNetwork
from bayesflow.types import Tensor
from bayesflow.utils import logging, split_arrays
from bayesflow.utils import filter_kwargs, logging, split_arrays
from .approximator import Approximator


Expand Down Expand Up @@ -141,7 +141,7 @@ def sample(
) -> dict[str, np.ndarray]:
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions)}
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)}
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)

Expand All @@ -154,6 +154,7 @@ def _sample(
num_samples: int,
inference_conditions: Tensor = None,
summary_variables: Tensor = None,
**kwargs,
) -> Tensor:
if self.summary_network is None:
if summary_variables is not None:
Expand All @@ -162,7 +163,9 @@ def _sample(
if summary_variables is None:
raise ValueError("Summary variables are required when a summary network is present.")

summary_outputs = self.summary_network(summary_variables)
summary_outputs = self.summary_network(
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
)

if inference_conditions is None:
inference_conditions = summary_outputs
Expand All @@ -180,18 +183,26 @@ def _sample(
else:
batch_shape = (num_samples,)

return self.inference_network.sample(batch_shape, conditions=inference_conditions)
return self.inference_network.sample(
batch_shape,
conditions=inference_conditions,
**filter_kwargs(kwargs, self.inference_network.sample),
)

def log_prob(self, data: dict[str, np.ndarray]) -> np.ndarray:
data = self.adapter(data, strict=False, stage="inference")
def log_prob(self, data: dict[str, np.ndarray], **kwargs) -> np.ndarray:
data = self.adapter(data, strict=False, stage="inference", **kwargs)
data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
log_prob = self._log_prob(**data)
log_prob = self._log_prob(**data, **kwargs)
log_prob = keras.ops.convert_to_numpy(log_prob)

return log_prob

def _log_prob(
self, inference_variables: Tensor, inference_conditions: Tensor = None, summary_variables: Tensor = None
self,
inference_variables: Tensor,
inference_conditions: Tensor = None,
summary_variables: Tensor = None,
**kwargs,
) -> Tensor:
if self.summary_network is None:
if summary_variables is not None:
Expand All @@ -200,11 +211,17 @@ def _log_prob(
if summary_variables is None:
raise ValueError("Summary variables are required when a summary network is present.")

summary_outputs = self.summary_network(summary_variables)
summary_outputs = self.summary_network(
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
)

if inference_conditions is None:
inference_conditions = summary_outputs
else:
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)

return self.inference_network.log_prob(inference_variables, conditions=inference_conditions)
return self.inference_network.log_prob(
inference_variables,
conditions=inference_conditions,
**filter_kwargs(kwargs, self.inference_network.log_prob),
)

0 comments on commit 8a870c0

Please sign in to comment.