diff --git a/backpack/extensions/secondorder/hbp/__init__.py b/backpack/extensions/secondorder/hbp/__init__.py index 642ac172f..3d7a8b958 100644 --- a/backpack/extensions/secondorder/hbp/__init__.py +++ b/backpack/extensions/secondorder/hbp/__init__.py @@ -1,6 +1,9 @@ from torch import Tensor from torch.nn import ( AvgPool2d, + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, Conv2d, CrossEntropyLoss, Dropout, @@ -35,6 +38,7 @@ losses, padding, pooling, + batchnorm_nd ) @@ -71,6 +75,9 @@ def __init__( SumModule: custom_module.HBPSumModule(), ScaleModule: custom_module.HBPScaleModule(), Identity: custom_module.HBPScaleModule(), + BatchNorm1d: batchnorm_nd.HBPBatchNormNd(), + BatchNorm2d: batchnorm_nd.HBPBatchNormNd(), + BatchNorm3d: batchnorm_nd.HBPBatchNormNd(), }, ) diff --git a/backpack/extensions/secondorder/hbp/batchnorm_nd.py b/backpack/extensions/secondorder/hbp/batchnorm_nd.py new file mode 100644 index 000000000..f9e313100 --- /dev/null +++ b/backpack/extensions/secondorder/hbp/batchnorm_nd.py @@ -0,0 +1,35 @@ +from torch import einsum +from torch import Tensor +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d +from typing import Tuple, Union + +from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.extensions.backprop_extension import BackpropExtension +from backpack.utils.errors import batch_norm_raise_error_if_train + +class HBPBatchNormNd(HBPBaseModule): + def __init__(self): + super().__init__(BatchNormNdDerivatives(), params=["weight", "bias"]) + + def weight(self, ext, module, grad_inp, grad_out, backproped): + x_hat, _ = self.derivatives._get_normalized_input_and_var(module) + v = backproped + JTv = einsum("mnc...,nc...->mnc", v, x_hat) + kfac_gamma = einsum("mnc...,mnd...->cd", JTv, JTv) + return [kfac_gamma] + + def bias(self, ext, module, grad_inp, grad_out, backproped): + v = backproped + JTv = v + kfac_beta = einsum("mnc...,mnd...->cd", JTv, JTv) + return [kfac_beta] + + def check_hyperparameters_module_extension( + self, + ext: BackpropExtension, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + ) -> None: # noqa: D102 + batch_norm_raise_error_if_train(module) diff --git a/test/test_kfac_bn.py b/test/test_kfac_bn.py new file mode 100644 index 000000000..4f29a71cd --- /dev/null +++ b/test/test_kfac_bn.py @@ -0,0 +1,61 @@ +import torch +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential, BatchNorm1d +from matplotlib import pyplot as plt + +from backpack import backpack, extend +from backpack.extensions import KFAC, SqrtGGNExact +from backpack.utils.examples import load_one_batch_mnist + +def visualize_hessian(H, param_names, param_length, fig_path, vmin=None, vmax=None): + ''' + Args: + H(torch.Tensor): Hessian matrix ([M, M]) + param_names(List[str]): list of param names + param_length(List[int]): list of param lengths + fig_path(str): path to save the figure + + Returns: + H_min(float): min of H + H_max(float): max of H + ''' + plt.figure(figsize=(10,10)) + plt.imshow(H.cpu().numpy(), vmin=vmin, vmax=vmax, origin='upper') + acc = -0.5 + all_ = H.shape[0] + for name, l in zip(param_names, param_length): + plt.plot([0-0.5, all_], [acc, acc], 'b-', linewidth=2) + plt.plot([acc, acc], [0-0.5, all_], 'b-', linewidth=2) + acc+= l + plt.xlim([-0.5, all_-0.5]) + plt.ylim([all_-0.5, -0.5]) + plt.colorbar() + plt.savefig(fig_path, bbox_inches='tight') + return H.min(), H.max() + +X, y = load_one_batch_mnist(batch_size=512) +model = Sequential(Flatten(), Linear(784, 3), BatchNorm1d(3), Linear(3, 10)) +lossfunc = CrossEntropyLoss() +model = extend(model.eval()) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) +with backpack(KFAC(mc_samples=1000), SqrtGGNExact()): + loss.backward() + +for name, param in model.named_parameters(): + GGN_VT = param.sqrt_ggn_exact.reshape(-1, param.numel()) + GGN = GGN_VT.t() @ GGN_VT + KFAC_ = torch.kron(param.kfac[0], param.kfac[1]) if len(param.kfac) == 2 \ + else param.kfac[0] + visualize_hessian(GGN, [name], [param.numel()], f"./{name}_GGN.png") + visualize_hessian(KFAC_, [name], [param.numel()], f"./{name}_KFAC.png") + print(name, torch.norm(GGN-KFAC_, 2).item()) + +# Check handeling the train mode situation +model = extend(model.train()) +loss = lossfunc(model(X), y) +try: + with backpack(KFAC(mc_samples=1000), SqrtGGNExact()): + loss.backward() +except NotImplementedError: + print("PASS. It raises NotImplementedError when model is in the training mode.") \ No newline at end of file