From 8f187cd4fd777ea6d6bc252da034dccb86585745 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 11 Oct 2024 16:13:18 +0000 Subject: [PATCH] Documentation: Consistency Model example notebook --- examples/ConsistencyModel.ipynb | 681 ++++++++++++++++++++++++++++++++ 1 file changed, 681 insertions(+) create mode 100644 examples/ConsistencyModel.ipynb diff --git a/examples/ConsistencyModel.ipynb b/examples/ConsistencyModel.ipynb new file mode 100644 index 000000000..00bdba4b4 --- /dev/null +++ b/examples/ConsistencyModel.ipynb @@ -0,0 +1,681 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "009b6adf", + "metadata": {}, + "source": [ + "# Consistency Models for Posterior Estimation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d5f88a59", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.551814Z", + "start_time": "2024-09-23T14:39:46.032170Z" + } + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "\n", + "# ensure the backend is set\n", + "import os\n", + "if \"KERAS_BACKEND\" not in os.environ:\n", + " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", + " os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "\n", + "import keras\n", + "\n", + "# for BayesFlow devs: this ensures that the latest dev version can be found\n", + "import sys\n", + "sys.path.append('../')\n", + "\n", + "import bayesflow as bf" + ] + }, + { + "cell_type": "markdown", + "id": "eadaf793-ab63-4f69-b962-178e343ca21b", + "metadata": {}, + "source": [ + "In this notebook, we use Consistency Models (CMs) as a plug-in replacement to obtain posterior samples with fewer sampling steps.\n", + "\n", + "CMs can be trained in two ways: First, they can be used to _distill_ an existing score-based diffusion model, thereby massively decreasing the sampling time at the expense of an additional training phase. Second, they can be trained from scratch using a procedure named _Consistency Training_. For now, we only support the latter." + ] + }, + { + "cell_type": "markdown", + "id": "6286c800-460a-4881-87d8-c3aca7aeec70", + "metadata": {}, + "source": [ + "## Background" + ] + }, + { + "cell_type": "markdown", + "id": "fdff817a-6321-4af0-9d41-7ec80097f93b", + "metadata": {}, + "source": [ + "Consistency Models [1] leverage some nice properties of score-based diffusion to enable few-step sampling. Score-based diffusion initially relied on a stochastic differential equation (SDE) for sampling, but there is also a ordinary (non-stochastic) differential equation (ODE) has the same _marginal_ distribution at each time step $t$ [2]. This means that even though SDE and ODE produce different paths from the noise distribution to the target distribution, the resulting distributions when looking at many paths at time $t$ is the same. The ODE is also called Probability Flow ODE." + ] + }, + { + "cell_type": "markdown", + "id": "4a2e996d-355a-4fab-8347-728e563c6014", + "metadata": {}, + "source": [ + "CMs now leverage the fact that there is no randomness in the ODE formulation. That means, if you start at a certain point in the latent space, you will always take the same path and always end up at the same point in the data space. The same is true for every point on the path: if you integrate to get to time $t=0$, you will end up at the same point as well. In short: for each path, there is exactly one corresponding point in latent space (at $t=T$) and one corresponding point in data space (at $t=0$). The goal of CMs is now the following: each point at a time point $t$ belongs to exactly one path, and we want to predict where this path will end up at $t=0$. The function that does this is called the _consistency function_ $f$. If we have the correct function for all $t\\in(0,T]$, we can just sample from the latent distribution ($t=T$) and use $f$ to directly map to the corresponding point at $t=0$, which is in the target distribution. So for sampling from the target distribution, we avoid any integration and only need one evaluation of the consistency function. In practice, the one-step sampling does not work very well. Instead, we leverage a multi-step sampling method where we call $f$ multiple times. Please check out the [1] for more background on this sampling procedure." + ] + }, + { + "cell_type": "markdown", + "id": "25023294-3096-4ebc-83a6-a372208e0504", + "metadata": {}, + "source": [ + "When only reading the above you might wonder why we also learn the mapping to $t=0$ of all intermediate time steps, and not only for $T=0$. The main answer is that for efficient training, we do not want to actually compute the two associated points explicitly. Doing so would require to do a precise integration at training time, which is often not feasible as it is too computationally costly. Learning all time steps opens up the possibility for a different training approach where we can avoid this." + ] + }, + { + "cell_type": "markdown", + "id": "9e6eb121-f89f-4268-9a0d-6fa733c758ff", + "metadata": {}, + "source": [ + "The details of this become a bit more complicated, and we advise you to take a look at [1] if you are interested in a more thorough and mathematical discussion. Here we will give a rough description of the underlying concepts." + ] + }, + { + "cell_type": "markdown", + "id": "9b201899-b946-4432-bcac-106b9d580d32", + "metadata": {}, + "source": [ + "First, we know that at $t=0$, it holds that $f(x,t=0)=x$, as $x$ is part of the path that ends at $x$. This _boundary condition_ serves as an \"anchor\" for our training, this is the information that the network knows at the start of the training procedure (we encode it with a time-dependent skip-connection, so the network is forced to be the identity function at $t=0$)." + ] + }, + { + "cell_type": "markdown", + "id": "a11f7ac7-9c11-49c1-b29c-3f0d44c4bf49", + "metadata": {}, + "source": [ + "For training, we now somehow have to propagate this information to the rest of the part. The basic idea for this is simple. We just take a point $x_1$ closer to the data distribution (smaller time $t_1$) and intergrate for a small time step $dt$ to a point $x_2$ on the same path that is closer to the latent distribution (larger time $t_2=t_1+dt$). As we know that for $t=0$ our network provides the correct output for our path, we want to propagate the information from smaller times to larger times. Our training goal is to move the output of $f(x_2, t=t_2)$ towards the output of $f(x_1, t=t_1)$. How to choose $x_1$, $t_1$ and $dt$ is an empirical question, see the [1] for some reasoning on what works." + ] + }, + { + "cell_type": "markdown", + "id": "73ff5ed5-dbd0-423b-b4f4-8d209947ed87", + "metadata": {}, + "source": [ + "In the case of _distillation_, we start with a trained score-based diffusion model. We can use it to integrate the Probability Flow ODE to get from $x_1$ to $x_2$. If we do not have such a model, it seems as if we were stuck. We do not know which points lie on the same path, so we do not know which outputs to make similar. Fortunately, it turns out that there is an _unbiased approximator_ that, if averaged over many samples (check out the paper for the exact description), will also give us the correct score. If we use this approximator instead of the score model, and use only a single Euler step to move along the path, we get an algorithm similar to the one described for distillation. It is called Consistency Training (CT) and allows us to train a consistency model using only samples from the data distribution. The algorithm for this was improved a lot in [3], and we have incorporated those improvements into our implementation." + ] + }, + { + "cell_type": "markdown", + "id": "283e8fea-9a36-4f8a-80f5-19e2fe98bd09", + "metadata": {}, + "source": [ + "As we have made several approximations to get there, how we have to choose our hyperparameters unfortunately becomes quite unintuitive. We have to rely on empirical observations to see what works. This was done in [4], we encourage you to use the values provided there as starting points. If you happen to find hyperparameters that work significantly better, please let us know. This will help others to find the correct region in the hyperparameter space." + ] + }, + { + "cell_type": "markdown", + "id": "8da3535f-0354-40a4-991f-845c33ff75b8", + "metadata": {}, + "source": [ + "To make this work for simulation-based inference, we can make the whole process conditional, so we can produce conditional distributions as well. Below, you can see a conceptual visualization of posterior estimation with CMs." + ] + }, + { + "cell_type": "markdown", + "id": "6c105be4-c490-4100-9502-60dd7405cfb4", + "metadata": {}, + "source": [ + "![Visualization of the way consistency models map from the path to the end point in the data distribution. Depicts the concepts described in the main text.](https://arxiv.org/html/2312.05440v2/extracted/5435837/figures/cmpe_main.png)" + ] + }, + { + "cell_type": "markdown", + "id": "baafb8fd-5b14-4ddf-a6dd-8272f706deaa", + "metadata": {}, + "source": [ + "### References\n", + "\n", + "[1] Song, Y., Dhariwal, P., Chen, M., & Sutskever, I. (2023). Consistency Models. *arXiv preprint*. [https://doi.org/10.48550/arXiv.2303.01469](https://doi.org/10.48550/arXiv.2303.01469)\n", + "\n", + "[2] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. In _International Conference on Learning Representations_. [https://openreview.net/forum?id=PxTIG12RRHS](https://openreview.net/forum?id=PxTIG12RRHS)\n", + "\n", + "[3] Song, Y., & Dhariwal, P. (2023). Improved Techniques for Training Consistency Models. *arXiv preprint*. [https://doi.org/10.48550/arXiv.2310.14189](https://doi.org/10.48550/arXiv.2310.14189)\n", + "\n", + "[4] Schmitt, M., Pratz, V., Köthe, U., Bürkner, P.-C., & Radev, S. T. (2024). Consistency Models for Scalable and Fast Simulation-Based Inference. *arXiv preprint*. [https://doi.org/10.48550/arXiv.2312.05440](https://doi.org/10.48550/arXiv.2312.05440)" + ] + }, + { + "cell_type": "markdown", + "id": "c63b26ba", + "metadata": {}, + "source": [ + "## Simulator: Two Moons" + ] + }, + { + "cell_type": "markdown", + "id": "9525ffd7", + "metadata": {}, + "source": [ + "We will use the Concistency Model as a plug-in replacement for Flow Matching. Refer to the tutorial \"Two moons toy example with flow matching\" for more details on this simulator." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4b89c861527c13b8", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.747091Z", + "start_time": "2024-09-23T14:39:46.744830Z" + } + }, + "outputs": [], + "source": [ + "simulator = bf.benchmarks.simulators.TwoMoons()" + ] + }, + { + "cell_type": "markdown", + "id": "f6e1eb5777c59eba", + "metadata": {}, + "source": [ + "We generate some data to see what the simulator does:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e6218e61d529e357", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.798575Z", + "start_time": "2024-09-23T14:39:46.790581Z" + } + }, + "outputs": [], + "source": [ + "# generate 64 random draws from the joint distribution p(r, alpha, theta, x)\n", + "sample_data = simulator.sample((64,))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "46174ccb0167026c", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.854911Z", + "start_time": "2024-09-23T14:39:46.852129Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of sample_data:\n", + "\t \n", + "Keys of sample_data:\n", + "\t dict_keys(['parameters', 'observables'])\n", + "Types of sample_data values:\n", + "\t {'parameters': , 'observables': }\n", + "Shapes of sample_data values:\n", + "\t {'parameters': (64, 2), 'observables': (64, 2)}\n" + ] + } + ], + "source": [ + "print(\"Type of sample_data:\\n\\t\", type(sample_data))\n", + "print(\"Keys of sample_data:\\n\\t\", sample_data.keys())\n", + "print(\"Types of sample_data values:\\n\\t\", {k: type(v) for k, v in sample_data.items()})\n", + "print(\"Shapes of sample_data values:\\n\\t\", {k: v.shape for k, v in sample_data.items()})" + ] + }, + { + "cell_type": "markdown", + "id": "fee88fcfd7a373b0", + "metadata": {}, + "source": [ + "## Data Adapter\n", + "\n", + "The next step is to tell BayesFlow how to deal with the simulated variables. You may also think of this as informing BayesFlow about the data flow, i.e., which variables go into which network.\n", + "\n", + "For this example, we want to learn the posterior distribution $p(\\theta | x)$, so we **infer** $\\theta$, **conditioning** on $x$. In the output from the last command, we see that the simulator provides $\\theta$ as `\"parameters\"` and $x$ as `\"observables\"`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c9637c576d4ad4e5", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.905081Z", + "start_time": "2024-09-23T14:39:46.903091Z" + } + }, + "outputs": [], + "source": [ + "data_adapter = bf.ContinuousApproximator.build_data_adapter(\n", + " inference_variables=[\"parameters\"],\n", + " inference_conditions=[\"observables\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "254e287b2bccdad", + "metadata": {}, + "source": [ + "## Dataset\n", + "For this example, we will sample our training data ahead of time and use offline training with a `bf.datasets.OfflineDataset`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "39cb5a1c9824246f", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.950573Z", + "start_time": "2024-09-23T14:39:46.948624Z" + } + }, + "outputs": [], + "source": [ + "batch_size = 64\n", + "num_training_batches = 512\n", + "num_validation_batches = 128" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9dee7252ef99affa", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.268860Z", + "start_time": "2024-09-23T14:39:46.994697Z" + } + }, + "outputs": [], + "source": [ + "training_samples = simulator.sample((num_training_batches * batch_size,))\n", + "validation_samples = simulator.sample((num_validation_batches * batch_size,))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "51045bbed88cb5c2", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.281170Z", + "start_time": "2024-09-23T14:39:53.275921Z" + } + }, + "outputs": [], + "source": [ + "training_dataset = bf.datasets.OfflineDataset(training_samples, batch_size=batch_size, data_adapter=data_adapter)\n", + "validation_dataset = bf.datasets.OfflineDataset(validation_samples, batch_size=batch_size, data_adapter=data_adapter)" + ] + }, + { + "cell_type": "markdown", + "id": "2d4c6eb0", + "metadata": {}, + "source": [ + "## Training a neural network to approximate all posteriors\n", + "\n", + "The next step is to set up the neural network that will approximate the posterior $p(\\theta|x)$.\n", + "\n", + "Consistency models use _scheduling functions_ to adjust some of the hyperparameters, for example the time discretization during training. Consequently, we have to specify the total number of training steps (_gradient updates_) before the start of the training.\n", + "For offline training with a given number of epochs, we can calculate it as below:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "516ac3c4-b66f-4cf0-a443-b00705a6ace5", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 30\n", + "total_steps = epochs * num_training_batches" + ] + }, + { + "cell_type": "markdown", + "id": "2c7980ec-5623-43e3-847e-16a5b6eb8777", + "metadata": {}, + "source": [ + "Apart from the usual parameters like learning rate and batch size, CMs come with a number of different hyperparameters. Unfortunately, they can heavily interact, so they can be hard to tune. The main hyperparameters are:\n", + "\n", + "- Maximum time `max_time`: This also serves as the standard deviation of the latent distribution. You can experiment with this, values from 10-200 seem to work well. In any case, it should be larger than the standard deviation of the target distribution.\n", + "- Minimum/maximum number of discretization steps during training `s0`/`s1`: The effect of those is hard to grasp. 10 works well for `s0`. Intuitively, increasing `s1` along with the number of epochs should lead to better result, but in practice we sometimes observe a breakdown for high values of `s1`. This seems to be problem-dependent, so just try it out.\n", + "- `sigma2` modifies the time-dependency of the skip connection. Its effect on the training is unclear, we recommend leaving it at 1.0 or setting it to the approximate variance of the target distribution.\n", + "- Smallest time value `eps` ($t=\\epsilon$ is used instead of $t=0$ for numerical reasons): No large effect, as long as it is kept small enough. Probably not worth tuning." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "09206e6f", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.339590Z", + "start_time": "2024-09-23T14:39:53.319852Z" + } + }, + "outputs": [], + "source": [ + "inference_network = bf.networks.ConsistencyModel(\n", + " subnet=\"mlp\",\n", + " subnet_kwargs=dict(\n", + " depth=6,\n", + " width=256,\n", + " ),\n", + " total_steps = total_steps\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "851e522f", + "metadata": {}, + "source": [ + "This inference network is just a general CM architecture, not yet adapted to the specific inference task at hand (i.e., posterior appproximation). To achieve this adaptation, we combine the network with our data adapter, which together form an `approximator`. In this case, we need a `ContinuousApproximator` since the target we want to approximate is the posterior of the *continuous* parameter vector $\\theta$." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "96ca6ffa", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.371691Z", + "start_time": "2024-09-23T14:39:53.369375Z" + } + }, + "outputs": [], + "source": [ + "approximator = bf.ContinuousApproximator(\n", + " inference_network=inference_network,\n", + " data_adapter=data_adapter,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "566264eadc76c2c", + "metadata": {}, + "source": [ + "### Optimizer and Learning Rate\n", + "\n", + "We use an Adam optimizer with [cosine decay](https://keras.io/api/optimizers/learning_rate_schedules/cosine_decay/) to decrease the learning rate towards zero over the training time. " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e8d7e053", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.433012Z", + "start_time": "2024-09-23T14:39:53.415903Z" + } + }, + "outputs": [], + "source": [ + "initial_learning_rate = 5e-4\n", + "scheduled_lr = keras.optimizers.schedules.CosineDecay(\n", + " initial_learning_rate,\n", + " total_steps,\n", + ")\n", + "\n", + "optimizer = keras.optimizers.Adam(learning_rate=scheduled_lr)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "51808fcd560489ac", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.476089Z", + "start_time": "2024-09-23T14:39:53.466001Z" + } + }, + "outputs": [], + "source": [ + "approximator.compile(optimizer=optimizer)" + ] + }, + { + "cell_type": "markdown", + "id": "708b1303", + "metadata": {}, + "source": [ + "### Training\n", + "\n", + "We are ready to train our deep posterior approximator on the two moons example. We pass the dataset object to the `fit` method and watch as Bayesflow trains." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0f496bda", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:42:36.067393Z", + "start_time": "2024-09-23T14:39:53.513436Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:bayesflow:Fitting on dataset instance of OfflineDataset.\n", + "INFO:bayesflow:Building on a test batch.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 8ms/step - loss: 0.3707 - loss/inference_loss: 0.3707 - val_loss: 0.3482 - val_loss/inference_loss: 0.3482\n", + "Epoch 2/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.3199 - loss/inference_loss: 0.3199 - val_loss: 0.3060 - val_loss/inference_loss: 0.3060\n", + "Epoch 3/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.3031 - loss/inference_loss: 0.3031 - val_loss: 0.2138 - val_loss/inference_loss: 0.2138\n", + "Epoch 4/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2939 - loss/inference_loss: 0.2939 - val_loss: 0.2869 - val_loss/inference_loss: 0.2869\n", + "Epoch 5/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2921 - loss/inference_loss: 0.2921 - val_loss: 0.3400 - val_loss/inference_loss: 0.3400\n", + "Epoch 6/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2898 - loss/inference_loss: 0.2898 - val_loss: 0.2698 - val_loss/inference_loss: 0.2698\n", + "Epoch 7/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2862 - loss/inference_loss: 0.2862 - val_loss: 0.2233 - val_loss/inference_loss: 0.2233\n", + "Epoch 8/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2798 - loss/inference_loss: 0.2798 - val_loss: 0.2343 - val_loss/inference_loss: 0.2343\n", + "Epoch 9/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2768 - loss/inference_loss: 0.2768 - val_loss: 0.2516 - val_loss/inference_loss: 0.2516\n", + "Epoch 10/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2802 - loss/inference_loss: 0.2802 - val_loss: 0.2083 - val_loss/inference_loss: 0.2083\n", + "Epoch 11/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2744 - loss/inference_loss: 0.2744 - val_loss: 0.2741 - val_loss/inference_loss: 0.2741\n", + "Epoch 12/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2732 - loss/inference_loss: 0.2732 - val_loss: 0.2556 - val_loss/inference_loss: 0.2556\n", + "Epoch 13/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2697 - loss/inference_loss: 0.2697 - val_loss: 0.3025 - val_loss/inference_loss: 0.3025\n", + "Epoch 14/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2691 - loss/inference_loss: 0.2691 - val_loss: 0.2102 - val_loss/inference_loss: 0.2102\n", + "Epoch 15/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2672 - loss/inference_loss: 0.2672 - val_loss: 0.2959 - val_loss/inference_loss: 0.2959\n", + "Epoch 16/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2624 - loss/inference_loss: 0.2624 - val_loss: 0.2902 - val_loss/inference_loss: 0.2902\n", + "Epoch 17/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2592 - loss/inference_loss: 0.2592 - val_loss: 0.2846 - val_loss/inference_loss: 0.2846\n", + "Epoch 18/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2604 - loss/inference_loss: 0.2604 - val_loss: 0.3606 - val_loss/inference_loss: 0.3606\n", + "Epoch 19/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2564 - loss/inference_loss: 0.2564 - val_loss: 0.2005 - val_loss/inference_loss: 0.2005\n", + "Epoch 20/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2608 - loss/inference_loss: 0.2608 - val_loss: 0.2988 - val_loss/inference_loss: 0.2988\n", + "Epoch 21/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2532 - loss/inference_loss: 0.2532 - val_loss: 0.2182 - val_loss/inference_loss: 0.2182\n", + "Epoch 22/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2553 - loss/inference_loss: 0.2553 - val_loss: 0.3509 - val_loss/inference_loss: 0.3509\n", + "Epoch 23/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2544 - loss/inference_loss: 0.2544 - val_loss: 0.2090 - val_loss/inference_loss: 0.2090\n", + "Epoch 24/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2503 - loss/inference_loss: 0.2503 - val_loss: 0.2571 - val_loss/inference_loss: 0.2571\n", + "Epoch 25/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2532 - loss/inference_loss: 0.2532 - val_loss: 0.3162 - val_loss/inference_loss: 0.3162\n", + "Epoch 26/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2545 - loss/inference_loss: 0.2545 - val_loss: 0.1789 - val_loss/inference_loss: 0.1789\n", + "Epoch 27/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2511 - loss/inference_loss: 0.2511 - val_loss: 0.2579 - val_loss/inference_loss: 0.2579\n", + "Epoch 28/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2481 - loss/inference_loss: 0.2481 - val_loss: 0.2809 - val_loss/inference_loss: 0.2809\n", + "Epoch 29/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2512 - loss/inference_loss: 0.2512 - val_loss: 0.2382 - val_loss/inference_loss: 0.2382\n", + "Epoch 30/30\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.2498 - loss/inference_loss: 0.2498 - val_loss: 0.2691 - val_loss/inference_loss: 0.2691\n" + ] + } + ], + "source": [ + "history = approximator.fit(\n", + " epochs=epochs,\n", + " dataset=training_dataset,\n", + " validation_data=validation_dataset,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b90a6062", + "metadata": {}, + "source": [ + "## Validation" + ] + }, + { + "cell_type": "markdown", + "id": "ca62b21d", + "metadata": {}, + "source": [ + "### Two Moons Posterior\n", + "\n", + "The two moons posterior at point $x = (0, 0)$ should resemble two crescent shapes. Below, we plot the corresponding posterior samples and posterior density. \n", + "\n", + "These results suggest that our flow matching setup can approximate the expected analytical posterior well. You can achieve an even better fit if you use online training, more epochs, or better hyperparameters." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "8562caeb", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:42:38.584554Z", + "start_time": "2024-09-23T14:42:36.076923Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Set the number of posterior draws you want to get\n", + "num_samples = 5000\n", + "\n", + "# Obtain samples from amortized posterior\n", + "conditions = {\"observables\": np.array([[0.0, 0.0]]).astype(\"float32\")}\n", + "samples_at_origin = approximator.sample(conditions=conditions, num_samples=num_samples)[\"parameters\"]\n", + "\n", + "# Prepare figure\n", + "f, axes = plt.subplots(1, figsize=(6, 6))\n", + "\n", + "# Plot samples\n", + "axes.scatter(samples_at_origin[0, :, 0], samples_at_origin[0, :, 1], color=\"#153c7a\", alpha=0.75, s=0.5)\n", + "sns.despine(ax=axes)\n", + "axes.set_title(r\"Posterior samples at origin $x=(0, 0)$\")\n", + "axes.grid(alpha=0.3)\n", + "axes.set_aspect(\"equal\", adjustable=\"box\")\n", + "axes.set_xlim([-0.5, 0.5])\n", + "_ = axes.set_ylim([-0.5, 0.5])" + ] + }, + { + "cell_type": "markdown", + "id": "01821d24", + "metadata": {}, + "source": [ + "The posterior looks as we have expected in this case. However, in general, we do not know how the posterior is supposed to look like for any specific dataset. As such, we need diagnostics that validate the correctness of the inferred posterior. One such diagnostic is simulation-based calibration(SBC), which we can apply for free due to amortization. For more details on SBC and diagnostic plots, see:\n", + "\n", + "1. Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian inference algorithms with simulation-based calibration. *arXiv preprint*.\n", + "2. Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and its applications in goodness-of-fit evaluation and multiple sample comparison. *Statistics and Computing*." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": true, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "165px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}