Skip to content

Commit

Permalink
[ADD] K-FAC for BatchNormNd (f-dangel#259)
Browse files Browse the repository at this point in the history
Extension to support BatchNormNd (eval) K-FAC

Resolves f-dangel/issues/259

Auxiliary:

- The kfac quantity contains only one element, and represents the GGN
  approximation.
- It only supports the evaluation mode.
- A test script (test_kfac_bn.py) checks these two properties.

Signed-off-by: pyun <[email protected]>
  • Loading branch information
pyun-ram committed Sep 3, 2022
1 parent 0ab9421 commit 41a37a6
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
7 changes: 7 additions & 0 deletions backpack/extensions/secondorder/hbp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from torch import Tensor
from torch.nn import (
AvgPool2d,
BatchNorm1d,
BatchNorm2d,
BatchNorm3d,
Conv2d,
CrossEntropyLoss,
Dropout,
Expand Down Expand Up @@ -35,6 +38,7 @@
losses,
padding,
pooling,
batchnorm_nd
)


Expand Down Expand Up @@ -71,6 +75,9 @@ def __init__(
SumModule: custom_module.HBPSumModule(),
ScaleModule: custom_module.HBPScaleModule(),
Identity: custom_module.HBPScaleModule(),
BatchNorm1d: batchnorm_nd.HBPBatchNormNd(),
BatchNorm2d: batchnorm_nd.HBPBatchNormNd(),
BatchNorm3d: batchnorm_nd.HBPBatchNormNd(),
},
)

Expand Down
35 changes: 35 additions & 0 deletions backpack/extensions/secondorder/hbp/batchnorm_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from torch import einsum
from torch import Tensor
from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
from typing import Tuple, Union

from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule
from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
from backpack.extensions.backprop_extension import BackpropExtension
from backpack.utils.errors import batch_norm_raise_error_if_train

class HBPBatchNormNd(HBPBaseModule):
def __init__(self):
super().__init__(BatchNormNdDerivatives(), params=["weight", "bias"])

def weight(self, ext, module, grad_inp, grad_out, backproped):
x_hat, _ = self.derivatives._get_normalized_input_and_var(module)
v = backproped
JTv = einsum("mnc...,nc...->mnc", v, x_hat)
kfac_gamma = einsum("mnc...,mnd...->cd", JTv, JTv)
return [kfac_gamma]

def bias(self, ext, module, grad_inp, grad_out, backproped):
v = backproped
JTv = v
kfac_beta = einsum("mnc...,mnd...->cd", JTv, JTv)
return [kfac_beta]

def check_hyperparameters_module_extension(
self,
ext: BackpropExtension,
module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
) -> None: # noqa: D102
batch_norm_raise_error_if_train(module)
61 changes: 61 additions & 0 deletions test/test_kfac_bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential, BatchNorm1d
from matplotlib import pyplot as plt

from backpack import backpack, extend
from backpack.extensions import KFAC, SqrtGGNExact
from backpack.utils.examples import load_one_batch_mnist

def visualize_hessian(H, param_names, param_length, fig_path, vmin=None, vmax=None):
'''
Args:
H(torch.Tensor): Hessian matrix ([M, M])
param_names(List[str]): list of param names
param_length(List[int]): list of param lengths
fig_path(str): path to save the figure
Returns:
H_min(float): min of H
H_max(float): max of H
'''
plt.figure(figsize=(10,10))
plt.imshow(H.cpu().numpy(), vmin=vmin, vmax=vmax, origin='upper')
acc = -0.5
all_ = H.shape[0]
for name, l in zip(param_names, param_length):
plt.plot([0-0.5, all_], [acc, acc], 'b-', linewidth=2)
plt.plot([acc, acc], [0-0.5, all_], 'b-', linewidth=2)
acc+= l
plt.xlim([-0.5, all_-0.5])
plt.ylim([all_-0.5, -0.5])
plt.colorbar()
plt.savefig(fig_path, bbox_inches='tight')
return H.min(), H.max()

X, y = load_one_batch_mnist(batch_size=512)
model = Sequential(Flatten(), Linear(784, 3), BatchNorm1d(3), Linear(3, 10))
lossfunc = CrossEntropyLoss()
model = extend(model.eval())
lossfunc = extend(lossfunc)

loss = lossfunc(model(X), y)
with backpack(KFAC(mc_samples=1000), SqrtGGNExact()):
loss.backward()

for name, param in model.named_parameters():
GGN_VT = param.sqrt_ggn_exact.reshape(-1, param.numel())
GGN = GGN_VT.t() @ GGN_VT
KFAC_ = torch.kron(param.kfac[0], param.kfac[1]) if len(param.kfac) == 2 \
else param.kfac[0]
visualize_hessian(GGN, [name], [param.numel()], f"./{name}_GGN.png")
visualize_hessian(KFAC_, [name], [param.numel()], f"./{name}_KFAC.png")
print(name, torch.norm(GGN-KFAC_, 2).item())

# Check handeling the train mode situation
model = extend(model.train())
loss = lossfunc(model(X), y)
try:
with backpack(KFAC(mc_samples=1000), SqrtGGNExact()):
loss.backward()
except NotImplementedError:
print("PASS. It raises NotImplementedError when model is in the training mode.")

0 comments on commit 41a37a6

Please sign in to comment.