Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing implementation of supported layers for DiagHessian and BatchDiagHessian #316

Open
3 tasks
hlzl opened this issue Sep 5, 2023 · 0 comments
Open
3 tasks

Comments

@hlzl
Copy link

hlzl commented Sep 5, 2023

There are multiple layers which are specified as being supported for second order derivatives that actually do not work when trying to calculate the Hessian diagonal using backpack-for-pytorch<=1.6.0.

So far, I've run into this problem with the following layers:

  • backpack.custom_module.branching.ScaleModule, torch.nn.Identity
  • torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d
  • backpack.custom_module.branching.SumModule

This can be tested with a script such as the following:

import torch
from backpack import backpack, extend
from backpack.extensions import DiagHessian, BatchDiagHessian
from backpack.custom_module.branching import Parallel, SumModule

model = extend(
    torch.nn.Sequential(
        *[
            torch.nn.Conv2d(3, 16, kernel_size=(3, 3)),
            Parallel(
                torch.nn.Identity(), torch.nn.BatchNorm2d(16), merge_module=SumModule()
            ),
            torch.nn.AdaptiveAvgPool2d(output_size=1),
            torch.nn.Flatten(),
            torch.nn.Linear(16, 2),
        ]
    ).cuda()
)
criterion = extend(torch.nn.CrossEntropyLoss())

batch = torch.randn((2, 3, 8, 8)).cuda()
target = torch.tensor([[1.0, 0.0], [0.0, 1.0]]).cuda()

model.eval()
model.zero_grad()
loss = criterion(model(batch), target)

with backpack(DiagHessian(), BatchDiagHessian()):
    loss.backward()

hessian_diag = torch.cat(
    [p.diag_h.view(-1) for p in model.parameters()], dim=1
)
hessian_diag_batch = torch.cat(
    [p.diag_h_batch.view(batch.shape[0], -1) for p in model.parameters()], dim=1
)

I'm guessing that these require independent fixes, but think it is a good idea to collect all layers with missing support summarised here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant