Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] K-FAC for BatchNormNd (#259) #260

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

pyun-ram
Copy link

@pyun-ram pyun-ram commented Sep 3, 2022

Extension to support BatchNormNd (eval) K-FAC

Resolves /issues/259

Auxiliary:

  • The kfac quantity contains only one element, and represents the GGN
    approximation.
  • It only supports the evaluation mode.
  • A test script (test_kfac_bn.py) checks these two properties.

Signed-off-by: pyun [email protected]

Extension to support BatchNormNd (eval) K-FAC

Resolves f-dangel/issues/259

Auxiliary:

- The kfac quantity contains only one element, and represents the GGN
  approximation.
- It only supports the evaluation mode.
- A test script (test_kfac_bn.py) checks these two properties.

Signed-off-by: pyun <[email protected]>
@f-dangel
Copy link
Owner

f-dangel commented Sep 5, 2022

Hi, thanks a lot for submitting this PR!

I summarized the next steps to merge this here:

first round: boring parts, i.e. fixing CI:

  • Fix import sorting. Install isort, then run make isort in the repository root
  • Fix code formatting. Install black, then run make black in the repository root
  • Fix linter. Install flake8, then run make flake8 in the repository root and fix the issues.

second round: Integrate the tests into BackPACK's test suite:

  • Add a test setting to the KFAC tests. You can use initialize_batch_norm_eval to set up the BN layer.
    You can then run the test with pytest -vx test -k 'BatchNorm and kfac' from the repository root
  • Add a test setting with BN in train mode to LOCAL_NOT_SUPPORTED_SETTINGS for KFAC.
  • Remove test_kfac_bn.py (requires matplotlib as dependency). Please feel free to post the example as a comment in this PR as it provides a nice visualization

Auxiliary:

- [FIX] Fix import sorting, cod formatting, and linter.
- [ADD] Add a test setting with BN in eval mode.
- [ADD] Add a test setting with BN in train mode (not supported setting).
- [FIX] Remove the test_kfac_bn.py.
@pyun-ram
Copy link
Author

pyun-ram commented Sep 8, 2022

A visualization example to compare this feature with GGN:

import torch
from matplotlib import pyplot as plt
from torch.nn import BatchNorm1d, CrossEntropyLoss, Flatten, Linear, Sequential

from backpack import backpack, extend
from backpack.extensions import KFAC, SqrtGGNExact
from backpack.utils.examples import load_one_batch_mnist


def visualize_hessian(H, param_names, param_length, fig_path, vmin=None, vmax=None):
    """
    Args:
        H(torch.Tensor): Hessian matrix ([M, M])
        param_names(List[str]): list of param names
        param_length(List[int]): list of param lengths
        fig_path(str): path to save the figure

    Returns:
        H_min(float): min of H
        H_max(float): max of H
    """
    plt.figure(figsize=(10, 10))
    plt.imshow(H.cpu().numpy(), vmin=vmin, vmax=vmax, origin="upper")
    acc = -0.5
    all_ = H.shape[0]
    for name, l in zip(param_names, param_length):
        plt.plot([0 - 0.5, all_], [acc, acc], "b-", linewidth=2)
        plt.plot([acc, acc], [0 - 0.5, all_], "b-", linewidth=2)
        acc += l
    plt.xlim([-0.5, all_ - 0.5])
    plt.ylim([all_ - 0.5, -0.5])
    plt.colorbar()
    plt.savefig(fig_path, bbox_inches="tight")
    return H.min(), H.max()


X, y = load_one_batch_mnist(batch_size=512)
model = Sequential(Flatten(), Linear(784, 3), BatchNorm1d(3), Linear(3, 10))
lossfunc = CrossEntropyLoss()
model = extend(model.eval())
lossfunc = extend(lossfunc)

loss = lossfunc(model(X), y)
with backpack(KFAC(mc_samples=1000), SqrtGGNExact()):
    loss.backward()

for name, param in model.named_parameters():
    GGN_VT = param.sqrt_ggn_exact.reshape(-1, param.numel())
    GGN = GGN_VT.t() @ GGN_VT
    KFAC_ = (
        torch.kron(param.kfac[0], param.kfac[1])
        if len(param.kfac) == 2
        else param.kfac[0]
    )
    visualize_hessian(GGN, [name], [param.numel()], f"./{name}_GGN.png")
    visualize_hessian(KFAC_, [name], [param.numel()], f"./{name}_KFAC.png")
    print(name, torch.norm(GGN - KFAC_, 2).item())

# Check handeling the train mode situation
model = extend(model.train())
loss = lossfunc(model(X), y)
try:
    with backpack(KFAC(mc_samples=1000), SqrtGGNExact()):
        loss.backward()
except NotImplementedError:
    print("PASS. It raises NotImplementedError when model is in the training mode.")

Auxiliary:

- Make BaseDerivatives, BaseParameterDerivatives, BaseLossDerivatives
  not abstract base classes, since they has no abstract methods.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

KFAC support in BatchNorm (eval mode)
2 participants