diff --git a/bayesflow/networks/__init__.py b/bayesflow/networks/__init__.py index cef2ea0e7..858cfbf13 100644 --- a/bayesflow/networks/__init__.py +++ b/bayesflow/networks/__init__.py @@ -1,4 +1,5 @@ from .cif import CIF +from .consistency_models import ConsistencyModel from .coupling_flow import CouplingFlow from .deep_set import DeepSet from .flow_matching import FlowMatching diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 68bf6944f..d316709b8 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -15,13 +15,24 @@ @register_keras_serializable(package="bayesflow.networks") class ConsistencyModel(InferenceNetwork): - """Implements a consistency model according to https://arxiv.org/abs/2303.01469""" + """Implements a Consistency Model with Consistency Training (CT) as + described in [1-2]. The adaptations to CT described in [2] were taken + into account in this implementation. + + [1] Song, Y., Dhariwal, P., Chen, M. & Sutskever, I. (2023). + Consistency Models. + arXiv preprint arXiv:2303.01469 + + [2] Song, Y., & Dhariwal, P. (2023). + Improved Techniques for Training Consistency Models: + arXiv preprint arXiv:2310.14189 + Discussion: https://openreview.net/forum?id=WNzy9bRDvG + """ def __init__( self, total_steps: int | float, subnet: str | type = "mlp", - base_distribution: str = "normal", max_time: int | float = 200, sigma2: float = 1.0, eps: float = 0.001, @@ -29,14 +40,38 @@ def __init__( s1: int | float = 50, **kwargs, ): - super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs)) + """Creates an instance of a consistency model (CM) to be used + for standalone consistency training (CT). + + Parameters: + ----------- + total_steps : int + The total number of training steps, can be calculate as + number of epochs * number of batches + subnet : str or type, optional, default: "mlp" + A neural network type for the consistency model, will be + instantiated using subnet_kwargs. + max_time : int or float, optional, default: 200.0 + The maximum time of the diffusion + sigma2 : float or Tensor of dimension (input_dim, 1), + optional, default: 1.0 + Controls the shape of the skip-function + eps : float, optional, default: 0.001 + The minimum time + s0 : int or float, optional, default: 10 + Initial number of discretization steps + s1 : int or float, optional, default: 50 + Final number of discretization steps + **kwargs : dict, optional, default: {} + Additional keyword arguments + """ + # Normal is the only supported base distribution for CMs + super().__init__(base_distribution="normal", **keras_kwargs(kwargs)) self.total_steps = float(total_steps) self.student = find_network(subnet, **kwargs.get("subnet_kwargs", {})) self.student_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") - self.teacher = None - self.teacher_projector = None self.sigma2 = ops.convert_to_tensor(sigma2) self.sigma = ops.sqrt(sigma2) @@ -49,20 +84,22 @@ def __init__( self.s1 = float(s1) self.current_step = 0.0 + self.seed_generator = keras.random.SeedGenerator() + def _schedule_discretization(self) -> int: - """Schedule function for adjusting the discretization level `N` during the course - of training. + """Schedule function for adjusting the discretization level `N` during + the course of training. - Implements the function N(k) from https://arxiv.org/abs/2310.14189, Section 3.4. + Implements the function N(k) from [2], Section 3.4. """ k_ = math.floor(self.total_steps / (math.log(self.s1 / self.s0) / math.log(2.0) + 1.0)) out = min(self.s0 * math.pow(2.0, math.floor(self.current_step / k_)), self.s1) + 1.0 return int(out) - def discretize_time(self, num_steps, rho=7.0): - """Function for obtaining the discretized time according to - https://arxiv.org/pdf/2310.14189.pdf, Section 2, bottom of page 2. + def _discretize_time(self, num_steps, rho=7.0): + """Function for obtaining the discretized time according to [2], + Section 2, bottom of page 2. """ N = num_steps + 1.0 @@ -93,15 +130,7 @@ def build(self, xz_shape, conditions_shape=None): input_shape = self.student.compute_output_shape(input_shape) self.student_projector.build(input_shape) - # Clone - self.teacher = keras.models.clone_model(self.student) - self.teacher_projector = keras.models.clone_model(self.student_projector) - self.teacher.set_weights(self.student.weights) - self.teacher_projector.set_weights(self.student_projector) - self.teacher.trainable = False - self.student_projector.trainable = False - - # Choose coefficient according to https://arxiv.org/pdf/2310.14189.pdf, Section 3.3 + # Choose coefficient according to [2] Section 3.3 self.c_huber = 0.00054 * math.sqrt(xz_shape[-1]) self.c_huber2 = self.c_huber**2 @@ -116,27 +145,70 @@ def call( return self._inverse(xz, conditions=conditions, **kwargs) return self._forward(xz, conditions=conditions, **kwargs) + def _forward_train(self, x: Tensor, noise: Tensor, t: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: + """Forward function for training. Calls consistency function with + noisy input + """ + inp = x + t * noise + return self.consistency_function(inp, t, conditions=conditions, **kwargs) + def _forward(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: - pass + # Consistency Models only learn the direction from noise distribution + # to target distribution, so we cannot implement this function. + raise NotImplementedError("Consistency Models are not invertible") def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: - pass - - def consistency_function( - self, x: Tensor, t: Tensor, conditions: Tensor = None, student: bool = True, **kwargs - ) -> Tensor: - """Compute consistency function with either the student or the teacher network.""" + """Generate random draws from the approximate target distribution + using the multistep sampling algorithm from [1], Algorithm 1. + + Parameters + ---------- + z : Tensor + Samples from a standard normal distribution + conditions : Tensor, optional, default: None + Conditions for a approximate conditional distribution + **kwargs : dict, optional, default: {} + Additional keyword arguments. Include `steps` (default: 10) to + adjust the number of sampling steps. + + Returns + ------- + x : Tensor + The approximate samples + """ + steps = kwargs.get("steps", 10) + x = keras.ops.copy(z) * self.max_time + discretized_time = keras.ops.flip(self._discretize_time(steps), axis=-1) + t = keras.ops.full((*keras.ops.shape(x)[:-1], 1), discretized_time[0], dtype=x.dtype) + x = self.consistency_function(x, t, conditions=conditions) + for n in range(1, steps): + noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator) + x_n = x + keras.ops.sqrt(keras.ops.square(discretized_time[n]) - self.eps**2) * noise + t = keras.ops.full_like(t, discretized_time[n]) + x = self.consistency_function(x_n, t, conditions=conditions) + return x + + def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: + """Compute consistency function. + + Parameters + ---------- + x : Tensor + Input vector + t : Tensor + Vector of time samples in [eps, T] + conditions : Tensor + The conditioning vector + **kwargs : dict, optional, default: {} + Additional keyword arguments passed to the network. + """ if conditions is not None: xtc = ops.concatenate([x, t, conditions], axis=-1) else: xtc = ops.concatenate([x, t], axis=-1) - # Compute either student or teacher output (no grads for teacher during training) - if student: - f = self.student_projector(self.student(xtc, **kwargs)) - else: - f = self.teacher_projector(self.teacher(xtc, **kwargs)) + f = self.student_projector(self.student(xtc, **kwargs)) # Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim) # Thus, we can do a cross product with the time vector which is (batch_size, 1) for @@ -147,53 +219,41 @@ def consistency_function( out = skip * x + out * f return out - def update_teacher(self): - """ - Update function for copying student network weights to teacher network weights. - Should be called after the optimizer update of the student. EMA was dropped, - see https://arxiv.org/pdf/2310.14189.pdf, Section 3.2. - """ - - for w_teacher, w_student in zip(self.teacher.weights, self.student.weights): - w_teacher.assign(keras.ops.stop_gradient(w_student)) - - for w_teacher, w_student in zip(self.teacher_projector.weights, self.student_projector.weights): - w_teacher.assign(keras.ops.stop_gradient(w_student)) - - def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> dict[str, Tensor]: - base_metrics = super().compute_metrics(data, stage=stage) + def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: + base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) + # The discretization schedule requires the number of passed training steps. + # To be independent of external information, we track it here. self.current_step += 1 - x = data["inference_variables"] - c = data.get("inference_conditions") - - # z = self.base_distribution.sample((ops.shape(x)[0],)) - current_num_steps = self._schedule_discretization() - discretized_time = self.discretize_time(current_num_steps) + discretized_time = self._discretize_time(current_num_steps) # Randomly sample t_n and t_[n+1] and reshape to (batch_size, 1) - # adapted noise schedule from https://arxiv.org/pdf/2310.14189.pdf, - # Section 3.5 + # adapted noise schedule from [2], Section 3.5 p_mean = -1.1 p_std = 2.0 log_p = ops.log( ops.erf((ops.log(discretized_time[1:]) - p_mean) / (ops.sqrt(2.0) * p_std)) - ops.erf((ops.log(discretized_time[:-1]) - p_mean) / (ops.sqrt(2.0) * p_std)) ) - times = keras.random.categorical([log_p], ops.shape(x)[0])[0] + times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0] t1 = ops.take(discretized_time, times)[..., None] t2 = ops.take(discretized_time, times + 1)[..., None] - teacher_out = self._forward(x, conditions=c, student=False, training=stage == "training") + # generate noise vector + noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator) + + teacher_out = self._forward_train(x, noise, t1, conditions=conditions, training=stage == "training") + # difference between teacher and student: different time, + # and no gradient for the teacher teacher_out = ops.stop_gradient(teacher_out) - student_out = self._forward(x, conditions=c, student=True, training=stage == "training") + student_out = self._forward_train(x, noise, t2, conditions=conditions, training=stage == "training") - # weighting function, see https://arxiv.org/pdf/2310.14189.pdf, Section 3.1 + # weighting function, see [2], Section 3.1 lam = 1 / (t2 - t1) - # Pseudo-huber loss, see https://arxiv.org/pdf/2310.14189.pdf, Section 3.3 + # Pseudo-huber loss, see [2], Section 3.3 loss = ops.mean(lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber)) return base_metrics | {"loss": loss} diff --git a/examples/TwoMoons_ConsistencyModel.ipynb b/examples/TwoMoons_ConsistencyModel.ipynb new file mode 100644 index 000000000..7b1efa7ce --- /dev/null +++ b/examples/TwoMoons_ConsistencyModel.ipynb @@ -0,0 +1,704 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "009b6adf", + "metadata": {}, + "source": [ + "# Consistency Models for Posterior Estimation\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d5f88a59", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.551814Z", + "start_time": "2024-09-23T14:39:46.032170Z" + } + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "\n", + "# ensure the backend is set\n", + "import os\n", + "if \"KERAS_BACKEND\" not in os.environ:\n", + " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", + " os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "\n", + "import keras\n", + "\n", + "# for bayesflow devs: this ensures that the latest dev version can be found\n", + "import sys\n", + "sys.path.append('../')\n", + "\n", + "import bayesflow as bf" + ] + }, + { + "cell_type": "markdown", + "id": "eadaf793-ab63-4f69-b962-178e343ca21b", + "metadata": {}, + "source": [ + "In this notebook, we use Consistency Models (CMs) as a plug-in replacement to obtain posterior samples with fewer sampling steps.\n", + "\n", + "CMs can be trained in two ways: First, they can be used to _distill_ an existing score-based diffusion model, thereby massively decreasing the sampling time at the expense of an additional training phase. Second, they can be trained from scratch using a procedure named _Consistency Training_. For now, we only support the latter.\n" + ] + }, + { + "cell_type": "markdown", + "id": "6286c800-460a-4881-87d8-c3aca7aeec70", + "metadata": {}, + "source": [ + "## Background\n" + ] + }, + { + "cell_type": "markdown", + "id": "fdff817a-6321-4af0-9d41-7ec80097f93b", + "metadata": {}, + "source": [ + "Consistency Models [1] leverage some nice properties of score-based diffusion to enable few-step sampling. Score-based diffusion initially relied on a stochastic differential equation (SDE) for sampling, but there is also a ordinary (non-stochastic) differential equation (ODE) has the same _marginal_ distribution at each time step $t$ [2]. This means that even though SDE and ODE produce different paths from the noise distribution to the target distribution, the resulting distributions when looking at many paths at time $t$ is the same. The ODE is also called Probability Flow ODE.\n" + ] + }, + { + "cell_type": "markdown", + "id": "4a2e996d-355a-4fab-8347-728e563c6014", + "metadata": {}, + "source": [ + "CMs now leverage the fact that there is no randomness in the ODE formulation. That means, if you start at a certain point in the latent space, you will always take the same path and always end up at the same point in the data space. The same is true for every point on the path: if you integrate to get to time $t=0$, you will end up at the same point as well. In short: for each path, there is exactly one corresponding point in latent space (at $t=T$) and one corresponding point in data space (at $t=0$). The goal of CMs is now the following: each point at a time point $t$ belongs to exactly one path, and we want to predict where this path will end up at $t=0$. The function that does this is called the _consistency function_ $f$. If we have the correct function for all $t\\in(0,T]$, we can just sample from the latent distribution ($t=T$) and use $f$ to directly map to the corresponding point at $t=0$, which is in the target distribution. So for sampling from the target distribution, we avoid any integration and only need one evaluation of the consistency function. In practice, the one-step sampling does not work very well. Instead, we leverage a multi-step sampling method where we call $f$ multiple times. Please check out the [1] for more background on this sampling procedure.\n" + ] + }, + { + "cell_type": "markdown", + "id": "25023294-3096-4ebc-83a6-a372208e0504", + "metadata": {}, + "source": [ + "When only reading the above you might wonder why we also learn the mapping to $t=0$ of all intermediate time steps $t\\in[0, T]$, and not only for $t=T$. The main answer is that for efficient training, we do not want to actually compute the two associated points explicitly. Doing so would require to do a precise integration at training time, which is often not feasible as it is too computationally costly. Learning all time steps opens up the possibility for a different training approach where we can avoid this.\n" + ] + }, + { + "cell_type": "markdown", + "id": "9e6eb121-f89f-4268-9a0d-6fa733c758ff", + "metadata": {}, + "source": [ + "The details of this become a bit more complicated, and we advise you to take a look at [1] if you are interested in a more thorough and mathematical discussion. Here we will give a rough description of the underlying concepts.\n" + ] + }, + { + "cell_type": "markdown", + "id": "9b201899-b946-4432-bcac-106b9d580d32", + "metadata": {}, + "source": [ + "First, we know that at $t=0$, it holds that $f(\\theta,t=0)=\\theta$, as $\\theta$ is part of the path that ends at $\\theta$. This _boundary condition_ serves as an \"anchor\" for our training, this is the information that the network knows at the start of the training procedure (we encode it with a time-dependent skip-connection, so the network is forced to be the identity function at $t=0$).\n" + ] + }, + { + "cell_type": "markdown", + "id": "a11f7ac7-9c11-49c1-b29c-3f0d44c4bf49", + "metadata": {}, + "source": [ + "For training, we now somehow have to propagate this information to the rest of the part. The basic idea for this is simple. We just take a point $\\theta_1$ closer to the data distribution (smaller time $t_1$) and integrate for a small time step $dt$ to a point $\\theta_2$ on the same path that is closer to the latent distribution (larger time $t_2=t_1+dt$). As we know that for $t=0$ our network provides the correct output for our path, we want to propagate the information from smaller times to larger times. Our training goal is to move the output of $f(\\theta_2, t=t_2)$ towards the output of $f(\\theta_1, t=t_1)$. How to choose $\\theta_1$, $t_1$ and $dt$ is an empirical question, see the [1] for some thoughts on what works well.\n" + ] + }, + { + "cell_type": "markdown", + "id": "73ff5ed5-dbd0-423b-b4f4-8d209947ed87", + "metadata": {}, + "source": [ + "In the case of _distillation_, we start with a trained score-based diffusion model. We can use it to integrate the Probability Flow ODE to get from $\\theta_1$ to $\\theta_2$. If we do not have such a model, it seems as if we were stuck. We do not know which points lie on the same path, so we do not know which outputs to make similar. Fortunately, it turns out that there is an _unbiased approximator_ that, if averaged over many samples (check out the paper for the exact description), will also give us the correct score. If we use this approximator instead of the score model, and use only a single Euler step to move along the path, we get an algorithm similar to the one described for distillation. It is called Consistency Training (CT) and allows us to train a consistency model using only _samples_ from the data distribution. The algorithm for this was improved a lot in [3], and we have incorporated those improvements into our implementation.\n" + ] + }, + { + "cell_type": "markdown", + "id": "283e8fea-9a36-4f8a-80f5-19e2fe98bd09", + "metadata": {}, + "source": [ + "We have made several approximations to get to a standalone Consistency Training algorithm. As a consequence, the introduced hyperparameters and their choice unfortunately becomes somewhat unintuitive. We have to rely on empirical observations and heuristics to see what works. This was done in [4], we encourage you to use the values provided there as starting points. If you happen to find hyperparameters that work significantly better, please let us know (e.g., by opening an issue or sending an email). This will help others to find the correct region in the hyperparameter space.\n" + ] + }, + { + "cell_type": "markdown", + "id": "8da3535f-0354-40a4-991f-845c33ff75b8", + "metadata": {}, + "source": [ + "To make this work for Bayesian inverse problems in simulation-based inference, we can make the whole process conditional on some quantity $x$, so we can produce conditional distributions as well. Below, you can see a conceptual visualization of posterior estimation with CMs.\n" + ] + }, + { + "cell_type": "markdown", + "id": "6c105be4-c490-4100-9502-60dd7405cfb4", + "metadata": {}, + "source": [ + "![Visualization of the way consistency models map from the path to the end point in the data distribution. Depicts the concepts described in the main text.](https://arxiv.org/html/2312.05440v2/extracted/5435837/figures/cmpe_main.png)\n" + ] + }, + { + "cell_type": "markdown", + "id": "baafb8fd-5b14-4ddf-a6dd-8272f706deaa", + "metadata": {}, + "source": [ + "### References\n", + "\n", + "[1] Song, Y., Dhariwal, P., Chen, M., & Sutskever, I. (2023). Consistency Models. _arXiv preprint_. [https://doi.org/10.48550/arXiv.2303.01469](https://doi.org/10.48550/arXiv.2303.01469)\n", + "\n", + "[2] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. In _International Conference on Learning Representations_. [https://openreview.net/forum?id=PxTIG12RRHS](https://openreview.net/forum?id=PxTIG12RRHS)\n", + "\n", + "[3] Song, Y., & Dhariwal, P. (2023). Improved Techniques for Training Consistency Models. _arXiv preprint_. [https://doi.org/10.48550/arXiv.2310.14189](https://doi.org/10.48550/arXiv.2310.14189)\n", + "\n", + "[4] Schmitt, M., Pratz, V., Köthe, U., Bürkner, P.-C., & Radev, S. T. (2024). Consistency Models for Scalable and Fast Simulation-Based Inference. _arXiv preprint_. [https://doi.org/10.48550/arXiv.2312.05440](https://doi.org/10.48550/arXiv.2312.05440)\n" + ] + }, + { + "cell_type": "markdown", + "id": "c63b26ba", + "metadata": {}, + "source": [ + "## Simulator: Two Moons\n" + ] + }, + { + "cell_type": "markdown", + "id": "9525ffd7", + "metadata": {}, + "source": [ + "We will use the Concistency Model as a plug-in replacement for Flow Matching. Check out the tutorial \"Two moons toy example with flow matching\" for more details on the simulator and setting.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4b89c861527c13b8", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.747091Z", + "start_time": "2024-09-23T14:39:46.744830Z" + } + }, + "outputs": [], + "source": [ + "simulator = bf.benchmarks.simulators.TwoMoons()" + ] + }, + { + "cell_type": "markdown", + "id": "f6e1eb5777c59eba", + "metadata": {}, + "source": [ + "We generate some data to see what the simulator does:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e6218e61d529e357", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.798575Z", + "start_time": "2024-09-23T14:39:46.790581Z" + } + }, + "outputs": [], + "source": [ + "# generate 64 random draws from the joint distribution p(r, alpha, theta, x)\n", + "sample_data = simulator.sample((64,))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "46174ccb0167026c", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.854911Z", + "start_time": "2024-09-23T14:39:46.852129Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of sample_data:\n", + "\t \n", + "Keys of sample_data:\n", + "\t dict_keys(['parameters', 'observables'])\n", + "Types of sample_data values:\n", + "\t {'parameters': , 'observables': }\n", + "Shapes of sample_data values:\n", + "\t {'parameters': (64, 2), 'observables': (64, 2)}\n" + ] + } + ], + "source": [ + "print(\"Type of sample_data:\\n\\t\", type(sample_data))\n", + "print(\"Keys of sample_data:\\n\\t\", sample_data.keys())\n", + "print(\"Types of sample_data values:\\n\\t\", {k: type(v) for k, v in sample_data.items()})\n", + "print(\"Shapes of sample_data values:\\n\\t\", {k: v.shape for k, v in sample_data.items()})" + ] + }, + { + "cell_type": "markdown", + "id": "fee88fcfd7a373b0", + "metadata": {}, + "source": [ + "## Data Adapter\n", + "\n", + "The next step is to tell BayesFlow how to deal with the simulated variables. You may also think of this as informing BayesFlow about the data flow, i.e., which variables go into which network.\n", + "\n", + "For this example, we want to learn the posterior distribution $p(\\theta\\,|\\,x)$, so we **infer** $\\theta$, **conditioning** on $x$. In the output from the last command, we see that the simulator provides $\\theta$ as `\"parameters\"` and $x$ as `\"observables\"`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c9637c576d4ad4e5", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.905081Z", + "start_time": "2024-09-23T14:39:46.903091Z" + } + }, + "outputs": [], + "source": [ + "data_adapter = bf.ContinuousApproximator.build_data_adapter(\n", + " inference_variables=[\"parameters\"],\n", + " inference_conditions=[\"observables\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "254e287b2bccdad", + "metadata": {}, + "source": [ + "## Dataset\n", + "\n", + "For this example, we will sample our training data ahead of time and use offline training with a `bf.datasets.OfflineDataset`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "39cb5a1c9824246f", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.950573Z", + "start_time": "2024-09-23T14:39:46.948624Z" + } + }, + "outputs": [], + "source": [ + "batch_size = 64\n", + "num_training_batches = 512\n", + "num_validation_batches = 128" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9dee7252ef99affa", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.268860Z", + "start_time": "2024-09-23T14:39:46.994697Z" + } + }, + "outputs": [], + "source": [ + "training_samples = simulator.sample((num_training_batches * batch_size,))\n", + "validation_samples = simulator.sample((num_validation_batches * batch_size,))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "51045bbed88cb5c2", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.281170Z", + "start_time": "2024-09-23T14:39:53.275921Z" + } + }, + "outputs": [], + "source": [ + "training_dataset = bf.datasets.OfflineDataset(training_samples, batch_size=batch_size, data_adapter=data_adapter)\n", + "validation_dataset = bf.datasets.OfflineDataset(validation_samples, batch_size=batch_size, data_adapter=data_adapter)" + ] + }, + { + "cell_type": "markdown", + "id": "2d4c6eb0", + "metadata": {}, + "source": [ + "## Training a neural network to approximate all posteriors\n", + "\n", + "The next step is to set up the neural network that will approximate the posterior $p(\\theta\\,|\\,x)$.\n", + "\n", + "Consistency models use _scheduling functions_ to adjust some of the hyperparameters, for example the time discretization during training. Consequently, we have to specify the total number of training steps (_gradient updates_) before the start of the training.\n", + "For offline training with a given number of epochs, we can calculate it as below:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "516ac3c4-b66f-4cf0-a443-b00705a6ace5", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 30\n", + "total_steps = epochs * num_training_batches" + ] + }, + { + "cell_type": "markdown", + "id": "2c7980ec-5623-43e3-847e-16a5b6eb8777", + "metadata": {}, + "source": [ + "Apart from the usual parameters like learning rate and batch size, CMs come with a number of different hyperparameters. Unfortunately, they can heavily interact, so they can be hard to tune. The main hyperparameters are:\n", + "\n", + "- Maximum time `max_time`: This also serves as the standard deviation of the latent distribution. You can experiment with this, values from 10-200 seem to work well. In any case, it should be larger than the standard deviation of the target distribution.\n", + "- Minimum/maximum number of discretization steps during training `s0`/`s1`: The effect of those is hard to grasp. 10 works well for `s0`. Intuitively, increasing `s1` along with the number of epochs should lead to better result, but in practice we sometimes observe a breakdown for high values of `s1`. This seems to be problem-dependent, so just try it out.\n", + "- `sigma2` modifies the time-dependency of the skip connection. Its effect on the training is unclear, we recommend leaving it at 1.0 or setting it to the approximate variance of the target distribution.\n", + "- Smallest time value `eps` ($t=\\epsilon$ is used instead of $t=0$ for numerical reasons): No large effect in our experiments, as long as it is kept small enough. Probably not worth tuning.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "09206e6f", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.339590Z", + "start_time": "2024-09-23T14:39:53.319852Z" + } + }, + "outputs": [], + "source": [ + "# Compute the empirical variance of the draws from the prior θ ~ p(θ)\n", + "sigma2 = keras.ops.var(training_samples[\"parameters\"], axis=0, keepdims=True)\n", + "\n", + "inference_network = bf.networks.ConsistencyModel(\n", + " subnet=\"mlp\",\n", + " subnet_kwargs=dict(\n", + " depth=6,\n", + " width=256,\n", + " ),\n", + " total_steps = total_steps,\n", + " max_time=10, # works well for this task\n", + " sigma2=sigma2, # pass the empirical variance to the network\n", + " # the remaining hyperparameters (s0, s1, eps) are the default values\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "851e522f", + "metadata": {}, + "source": [ + "This inference network is just a general consistency model architecture, not yet adapted to the specific inference task at hand (i.e., posterior appproximation). To achieve this adaptation, we combine the network with our data adapter, which together form an `approximator`. In this case, we need a `ContinuousApproximator` since the target we want to approximate is the posterior of the _continuous_ parameter vector $\\theta$.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "96ca6ffa", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.371691Z", + "start_time": "2024-09-23T14:39:53.369375Z" + } + }, + "outputs": [], + "source": [ + "approximator = bf.ContinuousApproximator(\n", + " inference_network=inference_network,\n", + " data_adapter=data_adapter,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "566264eadc76c2c", + "metadata": {}, + "source": [ + "### Optimizer and Learning Rate\n", + "\n", + "We use an Adam optimizer with [cosine decay](https://keras.io/api/optimizers/learning_rate_schedules/cosine_decay/) to decrease the learning rate towards zero over the training time.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e8d7e053", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.433012Z", + "start_time": "2024-09-23T14:39:53.415903Z" + } + }, + "outputs": [], + "source": [ + "initial_learning_rate = 5e-4\n", + "scheduled_lr = keras.optimizers.schedules.CosineDecay(\n", + " initial_learning_rate,\n", + " total_steps,\n", + ")\n", + "\n", + "optimizer = keras.optimizers.Adam(learning_rate=scheduled_lr)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "51808fcd560489ac", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.476089Z", + "start_time": "2024-09-23T14:39:53.466001Z" + } + }, + "outputs": [], + "source": [ + "approximator.compile(optimizer=optimizer)" + ] + }, + { + "cell_type": "markdown", + "id": "708b1303", + "metadata": {}, + "source": [ + "### Training\n", + "\n", + "We are ready to train our deep posterior approximator on the two moons example. We pass the dataset object to the `fit` method and watch as bayesflow trains. This notebook is being executed on a consumer-grade CPU and training is still reasonably fast. If you have a GPU available, training will be even faster, especially for larger networks.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0f496bda", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:42:36.067393Z", + "start_time": "2024-09-23T14:39:53.513436Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:bayesflow:Fitting on dataset instance of OfflineDataset.\n", + "INFO:bayesflow:Building on a test batch.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 6ms/step - loss: 0.4117 - loss/inference_loss: 0.4117 - val_loss: 0.3387 - val_loss/inference_loss: 0.3387\n", + "Epoch 2/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3568 - loss/inference_loss: 0.3568 - val_loss: 0.3603 - val_loss/inference_loss: 0.3603\n", + "Epoch 3/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3503 - loss/inference_loss: 0.3503 - val_loss: 0.2898 - val_loss/inference_loss: 0.2898\n", + "Epoch 4/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3362 - loss/inference_loss: 0.3362 - val_loss: 0.4429 - val_loss/inference_loss: 0.4429\n", + "Epoch 5/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3343 - loss/inference_loss: 0.3343 - val_loss: 0.3929 - val_loss/inference_loss: 0.3929\n", + "Epoch 6/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3311 - loss/inference_loss: 0.3311 - val_loss: 0.2825 - val_loss/inference_loss: 0.2825\n", + "Epoch 7/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3264 - loss/inference_loss: 0.3264 - val_loss: 0.3029 - val_loss/inference_loss: 0.3029\n", + "Epoch 8/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3222 - loss/inference_loss: 0.3222 - val_loss: 0.3447 - val_loss/inference_loss: 0.3447\n", + "Epoch 9/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3222 - loss/inference_loss: 0.3222 - val_loss: 0.3209 - val_loss/inference_loss: 0.3209\n", + "Epoch 10/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3141 - loss/inference_loss: 0.3141 - val_loss: 0.2195 - val_loss/inference_loss: 0.2195\n", + "Epoch 11/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3123 - loss/inference_loss: 0.3123 - val_loss: 0.3043 - val_loss/inference_loss: 0.3043\n", + "Epoch 12/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3121 - loss/inference_loss: 0.3121 - val_loss: 0.3225 - val_loss/inference_loss: 0.3225\n", + "Epoch 13/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3089 - loss/inference_loss: 0.3089 - val_loss: 0.2082 - val_loss/inference_loss: 0.2082\n", + "Epoch 14/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3028 - loss/inference_loss: 0.3028 - val_loss: 0.2394 - val_loss/inference_loss: 0.2394\n", + "Epoch 15/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2996 - loss/inference_loss: 0.2996 - val_loss: 0.3735 - val_loss/inference_loss: 0.3735\n", + "Epoch 16/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2949 - loss/inference_loss: 0.2949 - val_loss: 0.2624 - val_loss/inference_loss: 0.2624\n", + "Epoch 17/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2956 - loss/inference_loss: 0.2956 - val_loss: 0.3925 - val_loss/inference_loss: 0.3925\n", + "Epoch 18/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2904 - loss/inference_loss: 0.2904 - val_loss: 0.2991 - val_loss/inference_loss: 0.2991\n", + "Epoch 19/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2953 - loss/inference_loss: 0.2953 - val_loss: 0.2517 - val_loss/inference_loss: 0.2517\n", + "Epoch 20/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2867 - loss/inference_loss: 0.2867 - val_loss: 0.3187 - val_loss/inference_loss: 0.3187\n", + "Epoch 21/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2880 - loss/inference_loss: 0.2880 - val_loss: 0.3218 - val_loss/inference_loss: 0.3218\n", + "Epoch 22/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2819 - loss/inference_loss: 0.2819 - val_loss: 0.2689 - val_loss/inference_loss: 0.2689\n", + "Epoch 23/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2775 - loss/inference_loss: 0.2775 - val_loss: 0.2354 - val_loss/inference_loss: 0.2354\n", + "Epoch 24/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2848 - loss/inference_loss: 0.2848 - val_loss: 0.2992 - val_loss/inference_loss: 0.2992\n", + "Epoch 25/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2699 - loss/inference_loss: 0.2699 - val_loss: 0.1976 - val_loss/inference_loss: 0.1976\n", + "Epoch 26/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2760 - loss/inference_loss: 0.2760 - val_loss: 0.3003 - val_loss/inference_loss: 0.3003\n", + "Epoch 27/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2774 - loss/inference_loss: 0.2774 - val_loss: 0.3333 - val_loss/inference_loss: 0.3333\n", + "Epoch 28/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2745 - loss/inference_loss: 0.2745 - val_loss: 0.2938 - val_loss/inference_loss: 0.2938\n", + "Epoch 29/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2795 - loss/inference_loss: 0.2795 - val_loss: 0.2968 - val_loss/inference_loss: 0.2968\n", + "Epoch 30/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2720 - loss/inference_loss: 0.2720 - val_loss: 0.2105 - val_loss/inference_loss: 0.2105\n" + ] + } + ], + "source": [ + "history = approximator.fit(\n", + " epochs=epochs,\n", + " dataset=training_dataset,\n", + " validation_data=validation_dataset,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b90a6062", + "metadata": {}, + "source": [ + "## Validation\n" + ] + }, + { + "cell_type": "markdown", + "id": "ca62b21d", + "metadata": {}, + "source": [ + "### Two Moons Posterior\n", + "\n", + "By design, the two moons posterior at point $x = (0, 0)$ should resemble two crescent moons, hence the name. Below, we plot the corresponding posterior samples.\n", + "\n", + "These results suggest that our consistency model posterior estimation setup can approximate the target posterior well. You can achieve an even better fit if you use online training, more epochs, or better hyperparameters. We won't do that here because this tutorial shall only illustrate the basic setup for consistency models in amortized inference with bayesflow.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8562caeb", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:42:38.584554Z", + "start_time": "2024-09-23T14:42:36.076923Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(-0.4, 0.4)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Set the number of posterior draws you want to get\n", + "num_samples = 3000\n", + "\n", + "# Obtain samples from amortized posterior\n", + "conditions = {\"observables\": np.array([[0.0, 0.0]]).astype(\"float32\")}\n", + "samples_at_origin = approximator.sample(conditions=conditions, num_samples=num_samples)[\"parameters\"]\n", + "\n", + "# Prepare figure\n", + "f, axes = plt.subplots(1, figsize=(6, 6))\n", + "\n", + "# Plot samples\n", + "axes.scatter(samples_at_origin[0, :, 0], samples_at_origin[0, :, 1], color=\"#153c7a\", alpha=0.75, s=0.5)\n", + "sns.despine(ax=axes)\n", + "axes.set_title(r\"Posterior samples at origin $x=(0, 0)$\")\n", + "axes.grid(alpha=0.3)\n", + "axes.set_aspect(\"equal\", adjustable=\"box\")\n", + "axes.set_xlim([-0.4, 0.4])\n", + "axes.set_ylim([-0.4, 0.4])" + ] + }, + { + "cell_type": "markdown", + "id": "01821d24", + "metadata": {}, + "source": [ + "The posterior looks as we have expected in this case. However, in general, we do not know how the posterior is supposed to look like for any specific dataset. As such, we need diagnostics that validate the correctness of the inferred posterior. One such diagnostic is simulation-based calibration (SBC), which we can compute essentially for free due to amortization. For more details on SBC and diagnostic plots, see:\n", + "\n", + "1. Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian inference algorithms with simulation-based calibration. _arXiv preprint_.\n", + "2. Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and its applications in goodness-of-fit evaluation and multiple sample comparison. _Statistics and Computing_.\n" + ] + }, + { + "cell_type": "markdown", + "id": "cb38d0c8", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bayesflow", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": true, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "165px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}