Skip to content

INN.Nonlinear

Zhang Yanbo edited this page Oct 26, 2022 · 3 revisions

CLASS Nonlinear(dim, method='NICE', **kwargs) [source]

A nonlinear INN layer for one-dimensional vector transformations.

Common parameters

  • dim: dimension
  • method: This can be 'NICE', 'RealNVP' and 'ResFlow'
  • activation_fn: Activation function for the coupling function. If the function is given, this argument will be ignored

RealNVP and NICE method

  • k: number of hidden layer for coupling models;
  • mask: Mask for splitting input vectors

NICE

  • m: Addition function, it should be a neural network maps vector with dimension dim // 2 to dim - dim // 2. If m=None, it will be generated automatically by INN.utils.default_net(dim, k, activation_fn);

RealNVP

  • f_log_s: Multiplication function. It has the same dimension requirements as m;
  • f_t: Addition function. It has the same dimension requirements as m;
  • clip: (default: clip=1) Clipping the output of f_log_s to avoid extreme numbers between [-clip, clip]. The clipping is using tanh to keep the gradient;
  • scale: Scale for initialize the weights of coupling function. Large number of scale may lead to NaN results due to the exponential number;

ResFlow

  • hidden: Dimension of hidden layers
  • n_hidden: Number of hidden layers
  • lipschitz_constrain: Lipschitz constrain number, it should be lower than 1. Low value may decrease the computation power of the neural network;
  • mem_efficient: Using memory-efficient back-propagation if it is True;
  • est_steps: Number of iterations for estimating gradients and Jacobians

Methods

forward(x, log_p0=0, log_det_J_=0)

Compute the transformed input y. If compute_p=True, return y, log_p0, logdet.

The logdet term is the log-determinate of the Jacobian matrix. This is essential for controlling the distribution of output.

inverse(y, **args)

Compute the inverse of y. The **args here is a placeholder for consistent format.

Examples

import torch
import INN

model = INN.Nonlinear(dim=4, method='RealNVP')
x = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])

# Forward pass
y, log_p, log_det = model(x)

# Inverse pass
x_recon = model.inverse(y)

The outputs are:

# y

tensor([[0.9963, 2.0015, 3.0014, 3.9964],
        [4.9938, 6.0042, 7.0035, 7.9923]], grad_fn=<AddBackward0>)

# x_recon

tensor([[1.0000, 2.0000, 3.0000, 4.0000],
        [5.0000, 6.0000, 7.0000, 8.0000]], grad_fn=<AddBackward0>)
Clone this wiki locally