Skip to content

INN.BatchNorm1d

Zhang Yanbo edited this page Oct 26, 2022 · 1 revision

CLASS INN.BatchNorm1d(dim, requires_grad=True)

Implement batch normalization as it did in PyTorch. The INN.BatchNorm1d is doing the same thing in forward as nn.BatchNorm1d(*, affine=False).

  • dim: dimension of the input feature
  • requires_grad: The var will have gradient if requires_grad=True

Methods

forward(input, log_p0=0, log_det_J_=0)

Compute the batch-normalized result y. If compute_p=True, it will return y, logp and log_detJ.

import INN
import torch

model = INN.BatchNorm1d(3)

x = torch.Tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])

y, logp, logdet = model(x)

inverse(y, **args)

Compute the inverse of y. **args is only a place-holder for consistency.

The inverse dose not work when it is in training mode. So, we need to set model.eval() before using inverse:

import INN
import torch

model = INN.BatchNorm1d(3)
model.eval()
model.running_var = torch.abs(torch.randn(3))
model.running_mean = torch.abs(torch.randn(3))

x = torch.Tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])

y, logp, logdet = model(x)
print(y)

x_hat = model.inverse(y)
print(x_hat)

'''Results
# y = model(x)
tensor([[-1.3197,  1.3629,  0.8760],
        [ 2.3425,  4.4854,  3.7336],
        [ 6.0047,  7.6079,  6.5912]])
# x_hat = model.inverse(y)
tensor([[1.0000, 2.0000, 3.0000],
        [4.0000, 5.0000, 6.0000],
        [7.0000, 8.0000, 9.0000]])
'''
Clone this wiki locally