Skip to content

Commit

Permalink
Add free-form flows as inference networks (#251)
Browse files Browse the repository at this point in the history
* feat: add free-form flows as inference networks

* implements the fff loss
* still missing: calculation of the log probability

* fff: add log jacobian determinant computation

* util: make vjp globally accessible

Change `torch.autograd.functional.vjp` to `torch.func.vjp` as the former
implementation broke gradient flow. It then also uses the same API as
Jax, making the code easier to parse.

* utils: change autograd backend for torch jvp

Change from `torch.autograd.functional.jvp` to
`torch.func.jvp`, as recommended in the documentation.
https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html

Using autograd.functional seems to break the gradient flow, while `func`
does not produce problems in this regard.

* fff: use vjp and jvp from utils

* improve docs and type hints

* fix vjp call in fff

* add fff to tests, remove flow matching from global tests

* fix default kwargs for fff subnets

* remove double source attribution

* improve type hints

* adjust batch_wrap to handle non-iterable arguments

* fff: handle conditions=None

---------

Co-authored-by: LarsKue <[email protected]>
  • Loading branch information
vpratz and LarsKue authored Dec 2, 2024
1 parent 7f37aef commit 0537f2a
Show file tree
Hide file tree
Showing 12 changed files with 423 additions and 74 deletions.
1 change: 1 addition & 0 deletions bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .flow_matching import FlowMatching
from .free_form_flow import FreeFormFlow
from .inference_network import InferenceNetwork
from .mlp import MLP
from .lstnet import LSTNet
Expand Down
1 change: 1 addition & 0 deletions bayesflow/networks/free_form_flow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .free_form_flow import FreeFormFlow
183 changes: 183 additions & 0 deletions bayesflow/networks/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import keras
from keras import ops
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp

from ..inference_network import InferenceNetwork


@serializable(package="networks.free_form_flow")
class FreeFormFlow(InferenceNetwork):
"""Implements a dimensionality-preserving Free-form Flow.
Incorporates ideas from [1-2].
[1] Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024).F
ree-form flows: Make Any Architecture a Normalizing Flow.
In International Conference on Artificial Intelligence and Statistics.
[2] Sorrenson, P., Draxler, F., Rousselot, A., Hummerich, S., Zimmermann, L., &
Köthe, U. (2024). Lifting Architectural Constraints of Injective Flows.
In International Conference on Learning Representations.
"""

def __init__(
self,
beta: float = 50.0,
encoder_subnet: str | type = "mlp",
decoder_subnet: str | type = "mlp",
base_distribution: str = "normal",
hutchinson_sampling: str = "qr",
**kwargs,
):
"""Creates an instance of a Free-form Flow.
Parameters:
-----------
beta : float, optional, default: 50.0
encoder_subnet : str or type, optional, default: "mlp"
A neural network type for the flow, will be instantiated using
encoder_subnet_kwargs. Will be equipped with a projector to ensure
the correct output dimension and a global skip connection.
decoder_subnet : str or type, optional, default: "mlp"
A neural network type for the flow, will be instantiated using
decoder_subnet_kwargs. Will be equipped with a projector to ensure
the correct output dimension and a global skip connection.
base_distribution : str, optional, default: "normal"
The latent distribution
hutchinson_sampling : str, optional, default: "qr
One of `["sphere", "qr"]`. Select the sampling scheme for the
vectors of the Hutchinson trace estimator.
**kwargs : dict, optional, default: {}
Additional keyword arguments
"""
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))
self.encoder_subnet = find_network(encoder_subnet, **kwargs.get("encoder_subnet_kwargs", {}))
self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
self.decoder_subnet = find_network(decoder_subnet, **kwargs.get("decoder_subnet_kwargs", {}))
self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")

self.hutchinson_sampling = hutchinson_sampling
self.beta = beta

self.seed_generator = keras.random.SeedGenerator()

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
self.encoder_projector.units = xz_shape[-1]
self.decoder_projector.units = xz_shape[-1]

# construct input shape for subnet and subnet projector
input_shape = list(xz_shape)

if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]

input_shape = tuple(input_shape)

self.encoder_subnet.build(input_shape)
self.decoder_subnet.build(input_shape)

input_shape = self.encoder_subnet.compute_output_shape(input_shape)
self.encoder_projector.build(input_shape)

input_shape = self.decoder_subnet.compute_output_shape(input_shape)
self.decoder_projector.build(input_shape)

def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
if density:
if conditions is None:
# None cannot be batched, so supply as keyword argument
z, log_det = log_jacobian_determinant(x, self.encode, conditions=None, training=training, **kwargs)
else:
# conditions should be batched, supply as positional argument
z, log_det = log_jacobian_determinant(x, self.encode, conditions, training=training, **kwargs)

log_density = self.base_distribution.log_prob(z) + log_det
return z, log_density

z = self.encode(x, conditions, training=training, **kwargs)
return z

def _inverse(
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
if density:
if conditions is None:
# None cannot be batched, so supply as keyword argument
x, log_det = log_jacobian_determinant(z, self.decode, conditions=None, training=training, **kwargs)
else:
# conditions should be batched, supply as positional argument
x, log_det = log_jacobian_determinant(z, self.decode, conditions, training=training, **kwargs)
log_density = self.base_distribution.log_prob(z) - log_det
return x, log_density

x = self.decode(z, conditions, training=training, **kwargs)
return x

def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor:
if conditions is None:
inp = x
else:
inp = concatenate(x, conditions, axis=-1)
network_out = self.encoder_projector(
self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
)
return network_out + x

def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor:
if conditions is None:
inp = z
else:
inp = concatenate(z, conditions, axis=-1)
network_out = self.decoder_projector(
self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
)
return network_out + z

def _sample_v(self, x):
batch_size = ops.shape(x)[0]
total_dim = ops.shape(x)[-1]
match self.hutchinson_sampling:
case "qr":
# Use QR decomposition as described in [2]
v_raw = keras.random.normal((batch_size, total_dim, 1), dtype=ops.dtype(x), seed=self.seed_generator)
q = ops.reshape(ops.qr(v_raw)[0], ops.shape(x))
v = q * ops.sqrt(total_dim)
case "sphere":
# Sample from sphere with radius sqrt(total_dim), as implemented in [1]
v_raw = keras.random.normal((batch_size, total_dim), dtype=ops.dtype(x), seed=self.seed_generator)
v = v_raw * ops.sqrt(total_dim) / ops.sqrt(ops.sum(v_raw**2, axis=-1, keepdims=True))
case _:
raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.")
return v

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
# sample random vector
v = self._sample_v(x)

def encode(x):
return self.encode(x, conditions, training=stage == "training")

def decode(z):
return self.decode(z, conditions, training=stage == "training")

# VJP computation
z, vjp_fn = vjp(encode, x)
v1 = vjp_fn(v)[0]
# JVP computation
x_pred, v2 = jvp(decode, (z,), (v,))

# equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0]
surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1)
nll = -self.base_distribution.log_prob(z)
maximum_likelihood_loss = nll - surrogate
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)
loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss)

return base_metrics | {"loss": loss}
2 changes: 2 additions & 0 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
parse_bytes,
)
from .jacobian_trace import jacobian_trace
from .jacobian import compute_jacobian, log_jacobian_determinant
from .jvp import jvp
from .vjp import vjp
from .optimal_transport import optimal_transport
from .tensor_utils import (
expand_left,
Expand Down
129 changes: 129 additions & 0 deletions bayesflow/utils/jacobian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from collections.abc import Callable
import keras
from keras import ops
from bayesflow.types import Tensor

from functools import partial, wraps


def compute_jacobian(
x_in: Tensor,
fn: Callable,
*func_args: any,
grad_type: str = "backward",
**func_kwargs: any,
) -> tuple[Tensor, Tensor]:
"""Computes the Jacobian of a function with respect to its input.
:param x_in: The input tensor to compute the jacobian at.
Shape: (batch_size, in_dim).
:param fn: The function to compute the jacobian of, which transforms
`x` to `fn(x)` of shape (batch_size, out_dim).
:param func_args: The positional arguments to pass to the function.
func_args are batched over the first dimension.
:param grad_type: The type of gradient to use. Either 'backward' or
'forward'.
:param func_kwargs: The keyword arguments to pass to the function.
func_kwargs are not batched.
:return: The output of the function `fn(x)` and the jacobian
of the function with respect to its input `x` of shape
(batch_size, out_dim, in_dim)."""

def batch_wrap(fn: Callable) -> Callable:
"""Add a batch dimension to each tensor argument.
:param fn:
:return: wrapped function"""

def deep_unsqueeze(arg):
if ops.is_tensor(arg):
return arg[None, ...]
elif isinstance(arg, dict):
return {key: deep_unsqueeze(value) for key, value in arg.items()}
elif isinstance(arg, (list, tuple)):
return [deep_unsqueeze(value) for value in arg]
raise ValueError(f"Argument cannot be batched: {arg}")

@wraps(fn)
def wrapper(*args, **kwargs):
args = deep_unsqueeze(args)
return fn(*args, **kwargs)[0]

return wrapper

def double_output(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
out = fn(*args, **kwargs)
return out, out

return wrapper

match keras.backend.backend():
case "torch":
import torch
from torch.func import jacrev, jacfwd, vmap

jacfn = jacrev if grad_type == "backward" else jacfwd
with torch.inference_mode(False):
with torch.no_grad():
fn_kwargs_prefilled = partial(fn, **func_kwargs)
fn_batch_expanded = batch_wrap(fn_kwargs_prefilled)
fn_return_val = double_output(fn_batch_expanded)
fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True))
jac, x_out = fn_jac_batched(x_in, *func_args)
case "jax":
from jax import jacrev, jacfwd, vmap

jacfn = jacrev if grad_type == "backward" else jacfwd
fn_kwargs_prefilled = partial(fn, **func_kwargs)
fn_batch_expanded = batch_wrap(fn_kwargs_prefilled)
fn_return_val = double_output(fn_batch_expanded)
fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True))
jac, x_out = fn_jac_batched(x_in, *func_args)
case "tensorflow":
if grad_type == "forward":
raise NotImplementedError("For TensorFlow, only backward mode Jacobian computation is available.")
import tensorflow as tf

with tf.GradientTape() as tape:
tape.watch(x_in)
x_out = fn(x_in, *func_args, **func_kwargs)
jac = tape.batch_jacobian(x_out, x_in)

case _:
raise NotImplementedError(f"compute_jacobian not implemented for {keras.backend.backend()}.")
return x_out, jac


def log_jacobian_determinant(
x_in: Tensor,
fn: Callable,
*func_args: any,
grad_type: str = "backward",
**func_kwargs: any,
) -> tuple[Tensor, Tensor]:
"""Computes the log Jacobian determinant of a function
with respect to its input.
:param x_in: The input tensor to compute the jacobian at.
Shape: (batch_size, in_dim).
:param fn: The function to compute the jacobian of, which transforms
`x` to `fn(x)` of shape (batch_size, out_dim).
:param func_args: The positional arguments to pass to the function.
func_args are batched over the first dimension.
:param grad_type: The type of gradient to use. Either 'backward' or
'forward'.
:param func_kwargs: The keyword arguments to pass to the function.
func_kwargs are not batched.
:return: The output of the function `fn(x)` and the log jacobian determinant
of the function with respect to its input `x` of shape
(batch_size, out_dim, in_dim)."""

x_out, jac = compute_jacobian(x_in, fn, *func_args, grad_type=grad_type, **func_kwargs)
jac = ops.reshape(
jac, (ops.shape(x_in)[0], ops.prod(list(ops.shape(x_out)[1:])), ops.prod(list(ops.shape(x_in)[1:])))
)
log_det = ops.slogdet(jac)[1]

return x_out, log_det
38 changes: 0 additions & 38 deletions bayesflow/utils/jacobian_trace/_vjp.py

This file was deleted.

Loading

0 comments on commit 0537f2a

Please sign in to comment.