-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add free-form flows as inference networks (#251)
* feat: add free-form flows as inference networks * implements the fff loss * still missing: calculation of the log probability * fff: add log jacobian determinant computation * util: make vjp globally accessible Change `torch.autograd.functional.vjp` to `torch.func.vjp` as the former implementation broke gradient flow. It then also uses the same API as Jax, making the code easier to parse. * utils: change autograd backend for torch jvp Change from `torch.autograd.functional.jvp` to `torch.func.jvp`, as recommended in the documentation. https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html Using autograd.functional seems to break the gradient flow, while `func` does not produce problems in this regard. * fff: use vjp and jvp from utils * improve docs and type hints * fix vjp call in fff * add fff to tests, remove flow matching from global tests * fix default kwargs for fff subnets * remove double source attribution * improve type hints * adjust batch_wrap to handle non-iterable arguments * fff: handle conditions=None --------- Co-authored-by: LarsKue <[email protected]>
- Loading branch information
Showing
12 changed files
with
423 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .free_form_flow import FreeFormFlow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import keras | ||
from keras import ops | ||
from keras.saving import register_keras_serializable as serializable | ||
|
||
from bayesflow.types import Tensor | ||
from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp | ||
|
||
from ..inference_network import InferenceNetwork | ||
|
||
|
||
@serializable(package="networks.free_form_flow") | ||
class FreeFormFlow(InferenceNetwork): | ||
"""Implements a dimensionality-preserving Free-form Flow. | ||
Incorporates ideas from [1-2]. | ||
[1] Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024).F | ||
ree-form flows: Make Any Architecture a Normalizing Flow. | ||
In International Conference on Artificial Intelligence and Statistics. | ||
[2] Sorrenson, P., Draxler, F., Rousselot, A., Hummerich, S., Zimmermann, L., & | ||
Köthe, U. (2024). Lifting Architectural Constraints of Injective Flows. | ||
In International Conference on Learning Representations. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
beta: float = 50.0, | ||
encoder_subnet: str | type = "mlp", | ||
decoder_subnet: str | type = "mlp", | ||
base_distribution: str = "normal", | ||
hutchinson_sampling: str = "qr", | ||
**kwargs, | ||
): | ||
"""Creates an instance of a Free-form Flow. | ||
Parameters: | ||
----------- | ||
beta : float, optional, default: 50.0 | ||
encoder_subnet : str or type, optional, default: "mlp" | ||
A neural network type for the flow, will be instantiated using | ||
encoder_subnet_kwargs. Will be equipped with a projector to ensure | ||
the correct output dimension and a global skip connection. | ||
decoder_subnet : str or type, optional, default: "mlp" | ||
A neural network type for the flow, will be instantiated using | ||
decoder_subnet_kwargs. Will be equipped with a projector to ensure | ||
the correct output dimension and a global skip connection. | ||
base_distribution : str, optional, default: "normal" | ||
The latent distribution | ||
hutchinson_sampling : str, optional, default: "qr | ||
One of `["sphere", "qr"]`. Select the sampling scheme for the | ||
vectors of the Hutchinson trace estimator. | ||
**kwargs : dict, optional, default: {} | ||
Additional keyword arguments | ||
""" | ||
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs)) | ||
self.encoder_subnet = find_network(encoder_subnet, **kwargs.get("encoder_subnet_kwargs", {})) | ||
self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") | ||
self.decoder_subnet = find_network(decoder_subnet, **kwargs.get("decoder_subnet_kwargs", {})) | ||
self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") | ||
|
||
self.hutchinson_sampling = hutchinson_sampling | ||
self.beta = beta | ||
|
||
self.seed_generator = keras.random.SeedGenerator() | ||
|
||
# noinspection PyMethodOverriding | ||
def build(self, xz_shape, conditions_shape=None): | ||
super().build(xz_shape) | ||
self.encoder_projector.units = xz_shape[-1] | ||
self.decoder_projector.units = xz_shape[-1] | ||
|
||
# construct input shape for subnet and subnet projector | ||
input_shape = list(xz_shape) | ||
|
||
if conditions_shape is not None: | ||
input_shape[-1] += conditions_shape[-1] | ||
|
||
input_shape = tuple(input_shape) | ||
|
||
self.encoder_subnet.build(input_shape) | ||
self.decoder_subnet.build(input_shape) | ||
|
||
input_shape = self.encoder_subnet.compute_output_shape(input_shape) | ||
self.encoder_projector.build(input_shape) | ||
|
||
input_shape = self.decoder_subnet.compute_output_shape(input_shape) | ||
self.decoder_projector.build(input_shape) | ||
|
||
def _forward( | ||
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs | ||
) -> Tensor | tuple[Tensor, Tensor]: | ||
if density: | ||
if conditions is None: | ||
# None cannot be batched, so supply as keyword argument | ||
z, log_det = log_jacobian_determinant(x, self.encode, conditions=None, training=training, **kwargs) | ||
else: | ||
# conditions should be batched, supply as positional argument | ||
z, log_det = log_jacobian_determinant(x, self.encode, conditions, training=training, **kwargs) | ||
|
||
log_density = self.base_distribution.log_prob(z) + log_det | ||
return z, log_density | ||
|
||
z = self.encode(x, conditions, training=training, **kwargs) | ||
return z | ||
|
||
def _inverse( | ||
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs | ||
) -> Tensor | tuple[Tensor, Tensor]: | ||
if density: | ||
if conditions is None: | ||
# None cannot be batched, so supply as keyword argument | ||
x, log_det = log_jacobian_determinant(z, self.decode, conditions=None, training=training, **kwargs) | ||
else: | ||
# conditions should be batched, supply as positional argument | ||
x, log_det = log_jacobian_determinant(z, self.decode, conditions, training=training, **kwargs) | ||
log_density = self.base_distribution.log_prob(z) - log_det | ||
return x, log_density | ||
|
||
x = self.decode(z, conditions, training=training, **kwargs) | ||
return x | ||
|
||
def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: | ||
if conditions is None: | ||
inp = x | ||
else: | ||
inp = concatenate(x, conditions, axis=-1) | ||
network_out = self.encoder_projector( | ||
self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs | ||
) | ||
return network_out + x | ||
|
||
def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: | ||
if conditions is None: | ||
inp = z | ||
else: | ||
inp = concatenate(z, conditions, axis=-1) | ||
network_out = self.decoder_projector( | ||
self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs | ||
) | ||
return network_out + z | ||
|
||
def _sample_v(self, x): | ||
batch_size = ops.shape(x)[0] | ||
total_dim = ops.shape(x)[-1] | ||
match self.hutchinson_sampling: | ||
case "qr": | ||
# Use QR decomposition as described in [2] | ||
v_raw = keras.random.normal((batch_size, total_dim, 1), dtype=ops.dtype(x), seed=self.seed_generator) | ||
q = ops.reshape(ops.qr(v_raw)[0], ops.shape(x)) | ||
v = q * ops.sqrt(total_dim) | ||
case "sphere": | ||
# Sample from sphere with radius sqrt(total_dim), as implemented in [1] | ||
v_raw = keras.random.normal((batch_size, total_dim), dtype=ops.dtype(x), seed=self.seed_generator) | ||
v = v_raw * ops.sqrt(total_dim) / ops.sqrt(ops.sum(v_raw**2, axis=-1, keepdims=True)) | ||
case _: | ||
raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.") | ||
return v | ||
|
||
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: | ||
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) | ||
# sample random vector | ||
v = self._sample_v(x) | ||
|
||
def encode(x): | ||
return self.encode(x, conditions, training=stage == "training") | ||
|
||
def decode(z): | ||
return self.decode(z, conditions, training=stage == "training") | ||
|
||
# VJP computation | ||
z, vjp_fn = vjp(encode, x) | ||
v1 = vjp_fn(v)[0] | ||
# JVP computation | ||
x_pred, v2 = jvp(decode, (z,), (v,)) | ||
|
||
# equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0] | ||
surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1) | ||
nll = -self.base_distribution.log_prob(z) | ||
maximum_likelihood_loss = nll - surrogate | ||
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1) | ||
loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss) | ||
|
||
return base_metrics | {"loss": loss} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from collections.abc import Callable | ||
import keras | ||
from keras import ops | ||
from bayesflow.types import Tensor | ||
|
||
from functools import partial, wraps | ||
|
||
|
||
def compute_jacobian( | ||
x_in: Tensor, | ||
fn: Callable, | ||
*func_args: any, | ||
grad_type: str = "backward", | ||
**func_kwargs: any, | ||
) -> tuple[Tensor, Tensor]: | ||
"""Computes the Jacobian of a function with respect to its input. | ||
:param x_in: The input tensor to compute the jacobian at. | ||
Shape: (batch_size, in_dim). | ||
:param fn: The function to compute the jacobian of, which transforms | ||
`x` to `fn(x)` of shape (batch_size, out_dim). | ||
:param func_args: The positional arguments to pass to the function. | ||
func_args are batched over the first dimension. | ||
:param grad_type: The type of gradient to use. Either 'backward' or | ||
'forward'. | ||
:param func_kwargs: The keyword arguments to pass to the function. | ||
func_kwargs are not batched. | ||
:return: The output of the function `fn(x)` and the jacobian | ||
of the function with respect to its input `x` of shape | ||
(batch_size, out_dim, in_dim).""" | ||
|
||
def batch_wrap(fn: Callable) -> Callable: | ||
"""Add a batch dimension to each tensor argument. | ||
:param fn: | ||
:return: wrapped function""" | ||
|
||
def deep_unsqueeze(arg): | ||
if ops.is_tensor(arg): | ||
return arg[None, ...] | ||
elif isinstance(arg, dict): | ||
return {key: deep_unsqueeze(value) for key, value in arg.items()} | ||
elif isinstance(arg, (list, tuple)): | ||
return [deep_unsqueeze(value) for value in arg] | ||
raise ValueError(f"Argument cannot be batched: {arg}") | ||
|
||
@wraps(fn) | ||
def wrapper(*args, **kwargs): | ||
args = deep_unsqueeze(args) | ||
return fn(*args, **kwargs)[0] | ||
|
||
return wrapper | ||
|
||
def double_output(fn): | ||
@wraps(fn) | ||
def wrapper(*args, **kwargs): | ||
out = fn(*args, **kwargs) | ||
return out, out | ||
|
||
return wrapper | ||
|
||
match keras.backend.backend(): | ||
case "torch": | ||
import torch | ||
from torch.func import jacrev, jacfwd, vmap | ||
|
||
jacfn = jacrev if grad_type == "backward" else jacfwd | ||
with torch.inference_mode(False): | ||
with torch.no_grad(): | ||
fn_kwargs_prefilled = partial(fn, **func_kwargs) | ||
fn_batch_expanded = batch_wrap(fn_kwargs_prefilled) | ||
fn_return_val = double_output(fn_batch_expanded) | ||
fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True)) | ||
jac, x_out = fn_jac_batched(x_in, *func_args) | ||
case "jax": | ||
from jax import jacrev, jacfwd, vmap | ||
|
||
jacfn = jacrev if grad_type == "backward" else jacfwd | ||
fn_kwargs_prefilled = partial(fn, **func_kwargs) | ||
fn_batch_expanded = batch_wrap(fn_kwargs_prefilled) | ||
fn_return_val = double_output(fn_batch_expanded) | ||
fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True)) | ||
jac, x_out = fn_jac_batched(x_in, *func_args) | ||
case "tensorflow": | ||
if grad_type == "forward": | ||
raise NotImplementedError("For TensorFlow, only backward mode Jacobian computation is available.") | ||
import tensorflow as tf | ||
|
||
with tf.GradientTape() as tape: | ||
tape.watch(x_in) | ||
x_out = fn(x_in, *func_args, **func_kwargs) | ||
jac = tape.batch_jacobian(x_out, x_in) | ||
|
||
case _: | ||
raise NotImplementedError(f"compute_jacobian not implemented for {keras.backend.backend()}.") | ||
return x_out, jac | ||
|
||
|
||
def log_jacobian_determinant( | ||
x_in: Tensor, | ||
fn: Callable, | ||
*func_args: any, | ||
grad_type: str = "backward", | ||
**func_kwargs: any, | ||
) -> tuple[Tensor, Tensor]: | ||
"""Computes the log Jacobian determinant of a function | ||
with respect to its input. | ||
:param x_in: The input tensor to compute the jacobian at. | ||
Shape: (batch_size, in_dim). | ||
:param fn: The function to compute the jacobian of, which transforms | ||
`x` to `fn(x)` of shape (batch_size, out_dim). | ||
:param func_args: The positional arguments to pass to the function. | ||
func_args are batched over the first dimension. | ||
:param grad_type: The type of gradient to use. Either 'backward' or | ||
'forward'. | ||
:param func_kwargs: The keyword arguments to pass to the function. | ||
func_kwargs are not batched. | ||
:return: The output of the function `fn(x)` and the log jacobian determinant | ||
of the function with respect to its input `x` of shape | ||
(batch_size, out_dim, in_dim).""" | ||
|
||
x_out, jac = compute_jacobian(x_in, fn, *func_args, grad_type=grad_type, **func_kwargs) | ||
jac = ops.reshape( | ||
jac, (ops.shape(x_in)[0], ops.prod(list(ops.shape(x_out)[1:])), ops.prod(list(ops.shape(x_in)[1:]))) | ||
) | ||
log_det = ops.slogdet(jac)[1] | ||
|
||
return x_out, log_det |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.