Skip to content

Implementing Bayesian neural networks to close the amortization gap in VAEs in PyTorch

License

Notifications You must be signed in to change notification settings

jordandeklerk/Amortized-Bayes

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

29 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Closing the Amortization Gap in Bayesian Deep Generative Models

Amortized variational inference (A-VI) has emerged as a promising approach to enhance the efficiency of Bayesian deep generative models. In this project, we aim to investigate the effectiveness of A-VI in closing the amortization gap between A-VI and traditional variational inference methods, such as factorized variational inference (F-VI), or mean-field variational inference. We conduct numerical experiments on benchmark imaging datasets to compare the performance of A-VI with varying neural network architectures against F-VI and constant-VI.

Our findings demonstrate that A-VI, when implemented with sufficiently deep neural networks, can achieve the same evidence lower bound (ELBO) and reconstruction mean squared error (MSE) as F-VI while being 2 to 3 times computationally faster. These results highlight the potential of A-VI in addressing the amortization interpolation problem and suggest that a deep encoder-decoder linear neural network with full Bayesian inference over the latent variables can effectively approximate an ideal inference function. This work paves the way for more efficient and scalable Bayesian deep generative models.

Overview

In the Bayesian paradigm, statistical inference regarding unknown variables is predicated on computations involving posterior probability densities. Due to the often intractable nature of these densities, which typically lack an analytic form, estimation becomes crucial. Classical methods for estimating the posterior distribution in Bayesian inference such as MCMC are known to be computationally expensive at test time as they rely on repeated evaluations of the likelihood function and, therefore, require a new set of likelihood evaluations for each observation. In contrast, Variational Inference (VI) offers a compelling solution by recasting the difficult task of estimating complex posterior densities into a more manageable optimization problem. The essence of VI lies in selecting a parameterized distribution family, $\mathcal{Q}$, and identifying the member that minimizes the Kullback-Leibler (KL) divergence from the posterior

$$ \begin{equation} q^* = \arg \min _{q \in \mathcal{Q}} \mathrm{KL}(q(\theta, \mathbf{z}) | p(\theta, \mathbf{z} \mid \mathbf{x})). \end{equation} $$

This process enables the approximation of the posterior with $q^*$, thereby delineating the VI objective to entail the selection of an appropriate variational family $\mathcal{Q}$ for optimization. Common practice in VI applications involves the adoption of the factorized, or mean-field, family. This family is characterized by the independence of the variables

$$\begin{equation} \mathcal{Q}_{\mathrm{F}}=\left\{q: q(\theta, \mathbf{z})=q_0(\theta) \prod_{n=1}^N q_n\left(z_n\right)\right\}, \end{equation}$$

wherein each latent variable is represented by a distinct factor $q_n$.

Contrary to the VI framework, the amortized family leverages a stochastic inference function to dictate the variational distribution of each latent variable $z_n$, typically instantiated through a neural network, facilitating the parameter mapping for each latent variable's approximating factor $q_n(z_n)$:

$$\begin{equation} \mathcal{Q}_{\mathrm{A}}=\left\{q: q(\theta, \mathbf{z})=q_0(\theta) \prod_{n=1}^N q\left(z_n ; f_\phi\left(x_n\right)\right)\right\}. \end{equation}$$

This paradigm, known as amortized variational inference (A-VI), optimizes the approximation of the posterior and the inference function simultaneously. Therefore, inference on a single observation can be performed efficiently through a single forward pass through the neural network, framing Bayesian inference as a prediction problem: for any observation, the neural network is trained to predict the posterior distribution, or a quantity that allows the network to infer the posterior without any further simulations.

Key Findings

  • A-VI, when implemented with sufficiently deep neural networks, can achieve the same evidence lower bound (ELBO) and reconstruction mean squared error (MSE) as F-VI while being 2 to 3 times computationally faster.
  • These results highlight the potential of A-VI in addressing the amortization interpolation problem and suggest that a deep encoder-decoder linear neural network with full Bayesian inference over the latent variables can effectively approximate an ideal inference function.

Getting Started

To get started with this project, clone the repository and install the required dependencies:

git clone https://github.com/jordandeklerk/Amortized-Bayes.git
cd Amortized-Bayes
pip install -r requirements.txt

Then run the main.py script:

python main.py

Project Structure

├── experiment.py
├── images
│   ├── fmnist_comp.png
│   ├── fmnist_elbo.png
│   ├── fmnist_mse.png
│   ├── fmnist_mse_test.png
│   ├── index.md
│   ├── mnist_comp.png
│   ├── mnist_elbo.png
│   ├── mnist_mse.png
│   ├── mnist_mse_test.png
│   ├── re1.png
│   ├── re2.png
│   ├── reparm.png
│   ├── reparm4.png
│   ├── vae.png
│   └── variational.png
├── main.py
├── src
│   ├── model
│   │   └── model.py
│   └── utils
│       ├── config.py
│       ├── optimizer.py
│       └── parser.py
└── train.py

Main Results

MNIST

Our results, presented in Figure 3 for the MNIST dataset, examine the effects of different network widths and configurations. After 5,000 epochs, our amortized variational inference (A-VI) achieved comparable ELBO values to fixed variational inference (F-VI) with sufficiently deep networks (k ≥ 64). We also evaluated the mean squared error (MSE) for image reconstruction on both the training and testing sets and noted that A-VI effectively bridged the performance gap here too.

Image 1 Image 2 Image 3

Figure 3: Results for the MNIST dataset

Moreover, A-VI proved to be 2 to 3 times faster computationally than F-VI, as seen in Figure 4, underscoring its efficiency in leveraging shared inference computations across data, thus negating the need to estimate unique latent factors $q_n$ for each $z_n$.

Computation Time MNIST Figure 4: Computational efficiency of A-VI on MNIST

FashionMNIST

Our results for the FashionMNIST experiments are presented in Figure 4 and show the same conclusions as the MNIST experiments.

Image 1 Image 2 Image 3

Figure 4: Results for the FashionMNIST dataset

We also see a similar increase in computational speed on the FashionMNIST dataset as shown in Figure 5.

Computation Time FashionMNIST Figure 5: Computational efficiency of A-VI on FashionMNIST

In Figure 6, we present reconstructed images for a sample of five original images from the MNIST and FashionMNIST datasets. It’s important to note that these reconstructions, produced using a linear neural network, exhibit lower visual quality. This outcome, while noticeable, was not the primary focus of our project. Implementing a convolutional neural network for both the encoder and decoder could significantly enhance the aesthetic quality of these images.

Image 1 Image 2

Figure 6: Reconstructed images for MNIST and FashionMNIST

About

Implementing Bayesian neural networks to close the amortization gap in VAEs in PyTorch

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages