-
Notifications
You must be signed in to change notification settings - Fork 8
INN.BatchNorm1d
Zhang Yanbo edited this page Oct 26, 2022
·
1 revision
Implement batch normalization as it did in PyTorch. The INN.BatchNorm1d is doing the same thing in forward as nn.BatchNorm1d(*, affine=False)
.
-
dim
: dimension of the input feature -
requires_grad
: Thevar
will have gradient ifrequires_grad=True
Compute the batch-normalized result y
. If compute_p=True
, it will return y
, logp
and log_detJ
.
import INN
import torch
model = INN.BatchNorm1d(3)
x = torch.Tensor([[1,2,3],
[4,5,6],
[7,8,9]])
y, logp, logdet = model(x)
Compute the inverse of y
. **args
is only a place-holder for consistency.
The inverse dose not work when it is in training mode. So, we need to set model.eval()
before using inverse:
import INN
import torch
model = INN.BatchNorm1d(3)
model.eval()
model.running_var = torch.abs(torch.randn(3))
model.running_mean = torch.abs(torch.randn(3))
x = torch.Tensor([[1,2,3],
[4,5,6],
[7,8,9]])
y, logp, logdet = model(x)
print(y)
x_hat = model.inverse(y)
print(x_hat)
'''Results
# y = model(x)
tensor([[-1.3197, 1.3629, 0.8760],
[ 2.3425, 4.4854, 3.7336],
[ 6.0047, 7.6079, 6.5912]])
# x_hat = model.inverse(y)
tensor([[1.0000, 2.0000, 3.0000],
[4.0000, 5.0000, 6.0000],
[7.0000, 8.0000, 9.0000]])
'''