diff --git a/bayesflow/networks/__init__.py b/bayesflow/networks/__init__.py index fce9b27fa..9a915572b 100644 --- a/bayesflow/networks/__init__.py +++ b/bayesflow/networks/__init__.py @@ -3,6 +3,7 @@ from .coupling_flow import CouplingFlow from .deep_set import DeepSet from .flow_matching import FlowMatching +from .free_form_flow import FreeFormFlow from .inference_network import InferenceNetwork from .mlp import MLP from .lstnet import LSTNet diff --git a/bayesflow/networks/free_form_flow/__init__.py b/bayesflow/networks/free_form_flow/__init__.py new file mode 100644 index 000000000..803280523 --- /dev/null +++ b/bayesflow/networks/free_form_flow/__init__.py @@ -0,0 +1 @@ +from .free_form_flow import FreeFormFlow diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py new file mode 100644 index 000000000..c893d7df8 --- /dev/null +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -0,0 +1,183 @@ +import keras +from keras import ops +from keras.saving import register_keras_serializable as serializable + +from bayesflow.types import Tensor +from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp + +from ..inference_network import InferenceNetwork + + +@serializable(package="networks.free_form_flow") +class FreeFormFlow(InferenceNetwork): + """Implements a dimensionality-preserving Free-form Flow. + Incorporates ideas from [1-2]. + + [1] Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024).F + ree-form flows: Make Any Architecture a Normalizing Flow. + In International Conference on Artificial Intelligence and Statistics. + + [2] Sorrenson, P., Draxler, F., Rousselot, A., Hummerich, S., Zimmermann, L., & + Köthe, U. (2024). Lifting Architectural Constraints of Injective Flows. + In International Conference on Learning Representations. + """ + + def __init__( + self, + beta: float = 50.0, + encoder_subnet: str | type = "mlp", + decoder_subnet: str | type = "mlp", + base_distribution: str = "normal", + hutchinson_sampling: str = "qr", + **kwargs, + ): + """Creates an instance of a Free-form Flow. + + Parameters: + ----------- + beta : float, optional, default: 50.0 + encoder_subnet : str or type, optional, default: "mlp" + A neural network type for the flow, will be instantiated using + encoder_subnet_kwargs. Will be equipped with a projector to ensure + the correct output dimension and a global skip connection. + decoder_subnet : str or type, optional, default: "mlp" + A neural network type for the flow, will be instantiated using + decoder_subnet_kwargs. Will be equipped with a projector to ensure + the correct output dimension and a global skip connection. + base_distribution : str, optional, default: "normal" + The latent distribution + hutchinson_sampling : str, optional, default: "qr + One of `["sphere", "qr"]`. Select the sampling scheme for the + vectors of the Hutchinson trace estimator. + **kwargs : dict, optional, default: {} + Additional keyword arguments + """ + super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs)) + self.encoder_subnet = find_network(encoder_subnet, **kwargs.get("encoder_subnet_kwargs", {})) + self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") + self.decoder_subnet = find_network(decoder_subnet, **kwargs.get("decoder_subnet_kwargs", {})) + self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") + + self.hutchinson_sampling = hutchinson_sampling + self.beta = beta + + self.seed_generator = keras.random.SeedGenerator() + + # noinspection PyMethodOverriding + def build(self, xz_shape, conditions_shape=None): + super().build(xz_shape) + self.encoder_projector.units = xz_shape[-1] + self.decoder_projector.units = xz_shape[-1] + + # construct input shape for subnet and subnet projector + input_shape = list(xz_shape) + + if conditions_shape is not None: + input_shape[-1] += conditions_shape[-1] + + input_shape = tuple(input_shape) + + self.encoder_subnet.build(input_shape) + self.decoder_subnet.build(input_shape) + + input_shape = self.encoder_subnet.compute_output_shape(input_shape) + self.encoder_projector.build(input_shape) + + input_shape = self.decoder_subnet.compute_output_shape(input_shape) + self.decoder_projector.build(input_shape) + + def _forward( + self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + ) -> Tensor | tuple[Tensor, Tensor]: + if density: + if conditions is None: + # None cannot be batched, so supply as keyword argument + z, log_det = log_jacobian_determinant(x, self.encode, conditions=None, training=training, **kwargs) + else: + # conditions should be batched, supply as positional argument + z, log_det = log_jacobian_determinant(x, self.encode, conditions, training=training, **kwargs) + + log_density = self.base_distribution.log_prob(z) + log_det + return z, log_density + + z = self.encode(x, conditions, training=training, **kwargs) + return z + + def _inverse( + self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + ) -> Tensor | tuple[Tensor, Tensor]: + if density: + if conditions is None: + # None cannot be batched, so supply as keyword argument + x, log_det = log_jacobian_determinant(z, self.decode, conditions=None, training=training, **kwargs) + else: + # conditions should be batched, supply as positional argument + x, log_det = log_jacobian_determinant(z, self.decode, conditions, training=training, **kwargs) + log_density = self.base_distribution.log_prob(z) - log_det + return x, log_density + + x = self.decode(z, conditions, training=training, **kwargs) + return x + + def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: + if conditions is None: + inp = x + else: + inp = concatenate(x, conditions, axis=-1) + network_out = self.encoder_projector( + self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs + ) + return network_out + x + + def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: + if conditions is None: + inp = z + else: + inp = concatenate(z, conditions, axis=-1) + network_out = self.decoder_projector( + self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs + ) + return network_out + z + + def _sample_v(self, x): + batch_size = ops.shape(x)[0] + total_dim = ops.shape(x)[-1] + match self.hutchinson_sampling: + case "qr": + # Use QR decomposition as described in [2] + v_raw = keras.random.normal((batch_size, total_dim, 1), dtype=ops.dtype(x), seed=self.seed_generator) + q = ops.reshape(ops.qr(v_raw)[0], ops.shape(x)) + v = q * ops.sqrt(total_dim) + case "sphere": + # Sample from sphere with radius sqrt(total_dim), as implemented in [1] + v_raw = keras.random.normal((batch_size, total_dim), dtype=ops.dtype(x), seed=self.seed_generator) + v = v_raw * ops.sqrt(total_dim) / ops.sqrt(ops.sum(v_raw**2, axis=-1, keepdims=True)) + case _: + raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.") + return v + + def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: + base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) + # sample random vector + v = self._sample_v(x) + + def encode(x): + return self.encode(x, conditions, training=stage == "training") + + def decode(z): + return self.decode(z, conditions, training=stage == "training") + + # VJP computation + z, vjp_fn = vjp(encode, x) + v1 = vjp_fn(v)[0] + # JVP computation + x_pred, v2 = jvp(decode, (z,), (v,)) + + # equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0] + surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1) + nll = -self.base_distribution.log_prob(z) + maximum_likelihood_loss = nll - surrogate + reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1) + loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss) + + return base_metrics | {"loss": loss} diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 669ad33de..78f692c0a 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -25,7 +25,9 @@ parse_bytes, ) from .jacobian_trace import jacobian_trace +from .jacobian import compute_jacobian, log_jacobian_determinant from .jvp import jvp +from .vjp import vjp from .optimal_transport import optimal_transport from .tensor_utils import ( expand_left, diff --git a/bayesflow/utils/jacobian.py b/bayesflow/utils/jacobian.py new file mode 100644 index 000000000..830ef6e01 --- /dev/null +++ b/bayesflow/utils/jacobian.py @@ -0,0 +1,129 @@ +from collections.abc import Callable +import keras +from keras import ops +from bayesflow.types import Tensor + +from functools import partial, wraps + + +def compute_jacobian( + x_in: Tensor, + fn: Callable, + *func_args: any, + grad_type: str = "backward", + **func_kwargs: any, +) -> tuple[Tensor, Tensor]: + """Computes the Jacobian of a function with respect to its input. + + :param x_in: The input tensor to compute the jacobian at. + Shape: (batch_size, in_dim). + :param fn: The function to compute the jacobian of, which transforms + `x` to `fn(x)` of shape (batch_size, out_dim). + :param func_args: The positional arguments to pass to the function. + func_args are batched over the first dimension. + :param grad_type: The type of gradient to use. Either 'backward' or + 'forward'. + :param func_kwargs: The keyword arguments to pass to the function. + func_kwargs are not batched. + :return: The output of the function `fn(x)` and the jacobian + of the function with respect to its input `x` of shape + (batch_size, out_dim, in_dim).""" + + def batch_wrap(fn: Callable) -> Callable: + """Add a batch dimension to each tensor argument. + + :param fn: + :return: wrapped function""" + + def deep_unsqueeze(arg): + if ops.is_tensor(arg): + return arg[None, ...] + elif isinstance(arg, dict): + return {key: deep_unsqueeze(value) for key, value in arg.items()} + elif isinstance(arg, (list, tuple)): + return [deep_unsqueeze(value) for value in arg] + raise ValueError(f"Argument cannot be batched: {arg}") + + @wraps(fn) + def wrapper(*args, **kwargs): + args = deep_unsqueeze(args) + return fn(*args, **kwargs)[0] + + return wrapper + + def double_output(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + out = fn(*args, **kwargs) + return out, out + + return wrapper + + match keras.backend.backend(): + case "torch": + import torch + from torch.func import jacrev, jacfwd, vmap + + jacfn = jacrev if grad_type == "backward" else jacfwd + with torch.inference_mode(False): + with torch.no_grad(): + fn_kwargs_prefilled = partial(fn, **func_kwargs) + fn_batch_expanded = batch_wrap(fn_kwargs_prefilled) + fn_return_val = double_output(fn_batch_expanded) + fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True)) + jac, x_out = fn_jac_batched(x_in, *func_args) + case "jax": + from jax import jacrev, jacfwd, vmap + + jacfn = jacrev if grad_type == "backward" else jacfwd + fn_kwargs_prefilled = partial(fn, **func_kwargs) + fn_batch_expanded = batch_wrap(fn_kwargs_prefilled) + fn_return_val = double_output(fn_batch_expanded) + fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True)) + jac, x_out = fn_jac_batched(x_in, *func_args) + case "tensorflow": + if grad_type == "forward": + raise NotImplementedError("For TensorFlow, only backward mode Jacobian computation is available.") + import tensorflow as tf + + with tf.GradientTape() as tape: + tape.watch(x_in) + x_out = fn(x_in, *func_args, **func_kwargs) + jac = tape.batch_jacobian(x_out, x_in) + + case _: + raise NotImplementedError(f"compute_jacobian not implemented for {keras.backend.backend()}.") + return x_out, jac + + +def log_jacobian_determinant( + x_in: Tensor, + fn: Callable, + *func_args: any, + grad_type: str = "backward", + **func_kwargs: any, +) -> tuple[Tensor, Tensor]: + """Computes the log Jacobian determinant of a function + with respect to its input. + + :param x_in: The input tensor to compute the jacobian at. + Shape: (batch_size, in_dim). + :param fn: The function to compute the jacobian of, which transforms + `x` to `fn(x)` of shape (batch_size, out_dim). + :param func_args: The positional arguments to pass to the function. + func_args are batched over the first dimension. + :param grad_type: The type of gradient to use. Either 'backward' or + 'forward'. + :param func_kwargs: The keyword arguments to pass to the function. + func_kwargs are not batched. + :return: The output of the function `fn(x)` and the log jacobian determinant + of the function with respect to its input `x` of shape + (batch_size, out_dim, in_dim).""" + + x_out, jac = compute_jacobian(x_in, fn, *func_args, grad_type=grad_type, **func_kwargs) + jac = ops.reshape( + jac, (ops.shape(x_in)[0], ops.prod(list(ops.shape(x_out)[1:])), ops.prod(list(ops.shape(x_in)[1:]))) + ) + log_det = ops.slogdet(jac)[1] + + return x_out, log_det diff --git a/bayesflow/utils/jacobian_trace/_vjp.py b/bayesflow/utils/jacobian_trace/_vjp.py deleted file mode 100644 index b7e71e494..000000000 --- a/bayesflow/utils/jacobian_trace/_vjp.py +++ /dev/null @@ -1,38 +0,0 @@ -import keras - -from bayesflow.types import Tensor - - -def _make_vjp_fn(f: callable, x: Tensor) -> (Tensor, callable): - match keras.backend.backend(): - case "jax": - import jax - - fx, _vjp_fn = jax.vjp(f, x) - - def vjp_fn(projector): - return _vjp_fn(projector)[0] - case "tensorflow": - import tensorflow as tf - - with tf.GradientTape(persistent=True) as tape: - tape.watch(x) - fx = f(x) - - def vjp_fn(projector): - return tape.gradient(fx, x, projector) - case "torch": - import torch - - x = keras.ops.copy(x) - x.requires_grad_(True) - - with torch.enable_grad(): - fx = f(x) - - def vjp_fn(projector): - return torch.autograd.grad(fx, x, projector, retain_graph=True)[0] - case other: - raise NotImplementedError(f"Cannot build a vjp function for backend '{other}'.") - - return fx, vjp_fn diff --git a/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py b/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py index de03baa0a..98f3d9697 100644 --- a/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py +++ b/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py @@ -3,15 +3,13 @@ import numpy as np from bayesflow.types import Tensor +from ..vjp import vjp -from ._vjp import _make_vjp_fn - - -def compute_jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor) -> (Tensor, Tensor): +def compute_jacobian_trace(fn: Callable[[Tensor], Tensor], x: Tensor) -> (Tensor, Tensor): """Compute the exact trace of the Jacobian matrix of f by projection on each axis. - :param f: The function to be differentiated. + :param fn: The function to be differentiated. :param x: Tensor of shape (n, ..., d) The input tensor to f. @@ -24,15 +22,15 @@ def compute_jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor) -> (Tensor, shape = keras.ops.shape(x) trace = keras.ops.zeros(shape[:-1]) - fx, vjp_fn = _make_vjp_fn(f, x) + fx, vjp_fn = vjp(fn, x) for dim in range(shape[-1]): projector = np.zeros(shape, dtype="float32") projector[..., dim] = 1.0 projector = keras.ops.convert_to_tensor(projector) - vjp = vjp_fn(projector) + vjp_value = vjp_fn(projector)[0] - trace += vjp[..., dim] + trace += vjp_value[..., dim] return fx, trace diff --git a/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py b/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py index c0a867d19..0f612f1e9 100644 --- a/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py +++ b/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py @@ -1,11 +1,12 @@ +from collections.abc import Callable import keras from bayesflow.types import Tensor -from ._vjp import _make_vjp_fn +from ..vjp import vjp -def estimate_jacobian_trace(f: callable, x: Tensor, steps: int = 1) -> (Tensor, Tensor): +def estimate_jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor, steps: int = 1) -> (Tensor, Tensor): """Estimate the trace of the Jacobian matrix of f using Hutchinson's algorithm. :param f: The function to be differentiated. @@ -25,13 +26,13 @@ def estimate_jacobian_trace(f: callable, x: Tensor, steps: int = 1) -> (Tensor, shape = keras.ops.shape(x) trace = keras.ops.zeros(shape[:-1]) - fx, vjp_fn = _make_vjp_fn(f, x) + fx, vjp_fn = vjp(f, x) for _ in range(steps): projector = keras.random.normal(shape) - vjp = vjp_fn(projector) + vjp_val = vjp_fn(projector) - trace += keras.ops.sum(vjp * projector, axis=-1) + trace += keras.ops.sum(vjp_val * projector, axis=-1) return fx, trace diff --git a/bayesflow/utils/jvp.py b/bayesflow/utils/jvp.py index d086cfdc1..1f25fb3a2 100644 --- a/bayesflow/utils/jvp.py +++ b/bayesflow/utils/jvp.py @@ -1,31 +1,42 @@ +from collections.abc import Callable import keras from bayesflow.types import Tensor -def jvp(fn: callable, primals: tuple[Tensor] | Tensor, tangents: tuple[Tensor] | Tensor): - """Compute the dot product between the Jacobian of the given function at the point given by - the input (primals) and vectors in tangents.""" +def jvp(fn: Callable, primals: Tensor | tuple[Tensor, ...], tangents: Tensor | tuple[Tensor, ...]) -> (any, Tensor): + """ + Backend-agnostic version of the Jacobian-vector product (jvp). + Compute the Jacobian-vector product of the given function at the point given by the input (primals). + + :param fn: The function to differentiate. + Signature and return value must be compatible with the vjp method of the backend in use. + + :param primals: Input tensors to `fn`. + + :param tangents: Tangent vectors to differentiate `fn` with respect to. + + :return: The output of `fn(*primals)` and the Jacobian-vector product of `fn` evaluated at `primals` with respect to + `tangents`. + """ match keras.backend.backend(): + case "jax": + import jax + + fx, _jvp = jax.jvp(fn, primals, tangents) case "torch": import torch - fn_output, _jvp = torch.autograd.functional.jvp(fn, primals, tangents) + fx, _jvp = torch.func.jvp(fn, primals, tangents) case "tensorflow": import tensorflow as tf - with tf.autodiff.ForwardAccumulator(primals=primals, tangents=tangents) as acc: - fn_output = fn(*primals) - _jvp = acc.jvp(fn_output) - case "jax": - import jax + with tf.autodiff.ForwardAccumulator(primals, tangents) as acc: + fx = fn(*primals) - fn_output, _jvp = jax.jvp( - fn, - primals, - tangents, - ) + _jvp = acc.jvp(fx) case _: raise NotImplementedError(f"JVP not implemented for backend {keras.backend.backend()}") - return fn_output, _jvp + + return fx, _jvp diff --git a/bayesflow/utils/vjp.py b/bayesflow/utils/vjp.py new file mode 100644 index 000000000..435c46334 --- /dev/null +++ b/bayesflow/utils/vjp.py @@ -0,0 +1,42 @@ +from collections.abc import Callable +import keras +from functools import partial + +from bayesflow.types import Tensor + + +def vjp(fn: Callable, *primals: Tensor) -> (any, Callable[[Tensor], tuple[Tensor, ...]]): + """ + Backend-agnostic version of the vector-Jacobian product (vjp). + Computes the vector-Jacobian product of the given function at the point given by the input (primals). + + :param fn: The function to differentiate. + Signature and return value must be compatible with the vjp method of the backend in use. + + :param primals: Input tensors to `fn`. + + :return: The output of `fn(*primals)` and a vjp function. + The vjp function takes a single tensor argument, and returns the vector-Jacobian product of this argument with + `fn` as evaluated at `primals`. + """ + match keras.backend.backend(): + case "jax": + import jax + + fx, vjp_fn = jax.vjp(fn, *primals) + case "torch": + import torch + + fx, vjp_fn = torch.func.vjp(fn, *primals) + case "tensorflow": + import tensorflow as tf + + with tf.GradientTape(persistent=True) as tape: + for p in primals: + tape.watch(p) + fx = fn(*primals) + vjp_fn = partial(tape.gradient, fx, primals) + case _: + raise NotImplementedError(f"VJP not implemented for backend {keras.backend.backend()}") + + return fx, vjp_fn diff --git a/tests/conftest.py b/tests/conftest.py index 738fbd539..d7933545e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,14 +46,7 @@ def feature_size(request): return request.param -@pytest.fixture(scope="function") -def flow_matching(): - from bayesflow.networks import FlowMatching - - return FlowMatching(subnet="mlp", subnet_kwargs=dict(widths=(32, 32))) - - -@pytest.fixture(params=["coupling_flow", "flow_matching"], scope="function") +@pytest.fixture(params=["coupling_flow"], scope="function") def inference_network(request): return request.getfixturevalue(request.param) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 6eff398f5..62796f11b 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -8,6 +8,32 @@ def deep_set(): return DeepSet() +@pytest.fixture() +def coupling_flow(): + from bayesflow.networks import CouplingFlow + + return CouplingFlow() + + +@pytest.fixture() +def flow_matching(): + from bayesflow.networks import FlowMatching + + return FlowMatching() + + +@pytest.fixture() +def free_form_flow(): + from bayesflow.networks import FreeFormFlow + + return FreeFormFlow() + + +@pytest.fixture(params=["coupling_flow", "flow_matching", "free_form_flow"], scope="function") +def inference_network(request): + return request.getfixturevalue(request.param) + + @pytest.fixture() def lst_net(): from bayesflow.networks import LSTNet