Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KFAC support in BatchNorm (eval mode) #259

Open
pyun-ram opened this issue Sep 1, 2022 · 16 comments · May be fixed by #260
Open

KFAC support in BatchNorm (eval mode) #259

pyun-ram opened this issue Sep 1, 2022 · 16 comments · May be fixed by #260
Labels
good first issue Good for newcomers

Comments

@pyun-ram
Copy link

pyun-ram commented Sep 1, 2022

Hi,

Thanks for the repo! This is really a nice work.
I am planning to calculate the KFAC with backpack. But it raises the following error:

NotImplementedError: Extension saving to kfac does not have an extension for Module <class 'torch.nn.modules.batchnorm.BatchNorm2d'>

My network is as follows:

model = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=3),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Conv2d(8, 4, 3, stride=3),
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(36, 10))
loss = nn.CrossEntropyLoss()

When calculating the KFAC with:

    model_ = extend(model.eval())
    logits = model_(X)
    loss = extend(loss_func)(logits, Y)
    with backpack(KFAC(mc_samples=1000)):
        loss.backward()

It raises the not implemented error. I am wondering whether calculating KFAC in a network with BN layers in the middle is supported by backpack? It seems like it should be supported, since it successfully works in ResNet.

Thanks

@f-dangel
Copy link
Owner

f-dangel commented Sep 1, 2022

Hi, thanks for your question!

Just to make sure I'm getting it right: You want to compute KFAC for the Conv2d and Linear layers in your network, or do you want to compute KFAC for the parameters of the BatchNorm2d layer? (For the latter, I'm not sure if KFAC is defined)

@pyun-ram
Copy link
Author

pyun-ram commented Sep 1, 2022

Indeed, I want to compute the KFAC for the parameters of Conv2d, Linear and BatchNorm2d layer.
Is it possible to achieve this?

@f-dangel
Copy link
Owner

f-dangel commented Sep 1, 2022

BackPACK can compute KFAC for Linear and Conv2d layers, but not for BatchNorm2d. I don't know how the KFAC papers deal with batch normalization. Do you know? If so, one could implement this missing feature

Sadly, there is no easy way to tell BackPACK to ignore the parameters of batch norm layers, because it tries to compute its quantities on all parameters that have requires_grad=True.

If you want to get KFAC for the supported layers, you will have to set p.requires_grad=False for the BN parameters. But then you also won't get their gradient.

@pyun-ram
Copy link
Author

pyun-ram commented Sep 2, 2022

Thanks for the prompt reply! Yep, I have not seen some paper discussing the KFAC calculation for BN either...

To get the KFAC for the supported layers (Linear and Conv2d), I found a way to bypass the NotImplementedError.

# Extend hbp/__init__.py by
class HBP(SecondOrderBackpropExtension):
    def __init__(
        self,
        curv_type,
        loss_hessian_strategy,
        backprop_strategy,
        ea_strategy,
        savefield="hbp",
    ):
        ...
        super().__init__(
            savefield=savefield,
            fail_mode="ERROR",
            module_exts={
                ...
                Identity: custom_module.HBPScaleModule(),
                BatchNorm2d: batchnorm_nd.HBPBatchNormNd(),
            },
        )

# The HBPBatchNormNd is defined as
from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule

class HBPBatchNormNd(HBPBaseModule):
    def __init__(self):
        super().__init__(BatchNormNdDerivatives(), params=None)

With such modification, it works without raising the error. I am not quite sure whether it is the right manner. Do you have any advice?

@f-dangel
Copy link
Owner

f-dangel commented Sep 2, 2022

Hi,

that workaround looks good! Indeed, this will ignore the BN parameters, while keeping BackPACK's backpropagation through the layer for KFAC intact.

@pyun-ram
Copy link
Author

pyun-ram commented Sep 2, 2022

That's a great relief! :)

@pyun-ram pyun-ram closed this as completed Sep 2, 2022
@f-dangel
Copy link
Owner

f-dangel commented Sep 2, 2022

One way to get started on this would be to add support for KFAC in BatchNorm in evaluation mode.

I will outline in the following what needs to be done (this may be technically not 100% accurate).

Pull requests welcome.


Let's assume a BatchNorm1d layer that takes an input X of shape [N, C] and maps it to an output Z of shape [N, C]. The parameters γ and β are both of shape [C]. The forward pass (in evaluation mode) is

Z[n, :] = γ ⊙ X[n, :] + β        n = 1, ... , N

(where is elementwise multiplication). This looks a bit like a Linear layer with weights W = diag(γ) and bias b = β.

We don't really need a Kronecker factorization here, because the curvature blocks for γ and β are both of shape [C, C]. So instead we compute the MC-sampled Fisher/GGN block:

  • Computing the KFAC for β is like computing the MC-approximated Fisher block for β.

    backpropagated_grads = ... # from BackPACK, has shape [M, N, C] where M denotes the number of MC samples
    
    v = backpropagated_grads
    JTv = v # apply transpose Jacobian (identity in this case)
    
    # square the result to get the GGN block
    kfac_beta = einsum("mnc,mnd->cd", JTv, JTv) # [C, C]
    
    return [kfac_beta] # The KFAC extension returns lists with Kronecker factors
  • Computing the KFAC for ̱γ is like computing the MC-approximated Fisher block for γ.

    backpropagated_grads = ... # from BackPACK, has shape [M, N, C] where M denotes the number of MC samples
    X = module.input0
    
    v = backpropagated_grads
    JTv = einsum("mnc,nc->mnc", v, X) # apply transpose Jacobian
    
    # square the result to get the GGN block
    kfac_gamma = einsum("mnc,mnd->cd", JTv, JTv) # [C, C]
    
    return [kfac_gamma] # The KFAC extension returns lists with Kronecker factors
  • One can test this by setting N=1 and checking that the kfac_gamma → GGN(gamma) and kfac_beta → GGN(beta) as the number of samples grows (M → ∞).

  • To generalize this to BatchNormNd, simply replace "mnc" by "mnc...", "mnd" by "mnd...", and "nc" by "nc..." in the above einsums

  • There should be error messages if the module is not in evaluation mode

@f-dangel f-dangel reopened this Sep 2, 2022
@f-dangel f-dangel added the good first issue Good for newcomers label Sep 2, 2022
pyun-ram added a commit to pyun-ram/backpack that referenced this issue Sep 3, 2022
Extension to support BatchNormNd (eval) K-FAC

Resolves f-dangel/issues/259

Auxiliary:

- The kfac quantity contains only one element, and represents the GGN
  approximation.
- It only supports the evaluation mode.
- A test script (test_kfac_bn.py) checks these two properties.

Signed-off-by: pyun <[email protected]>
@pyun-ram pyun-ram linked a pull request Sep 3, 2022 that will close this issue
@pyun-ram
Copy link
Author

pyun-ram commented Sep 3, 2022

Thanks for your guidance! :)
A pull request has been raised. A simple test case is also added to test the result and mode checking.
I want to ask one more question why it is not needed to divide the kfac_gamma by JTv.shape[0], which is the number of MC samples in calculating kfac_gamma?

kfac_gamma = einsum("mnc,mnd->cd", JTv, JTv) # [C, C]

@f-dangel
Copy link
Owner

f-dangel commented Sep 24, 2022

Hi,

thanks for the PR; apologies you will have to be patient with my review.

Regarding your question: Good point! The factor 1 / sqrt(C) where C is the number of MC samples is inserted by the loss function, which creates the MC-approximated Hessian square root that is then backpropagated through all layers. Squaring that results in the desired 1 / C.

For CrossEntropyLoss this happens here in the code (M denotes the number of MC samples).

Best,
Felix

@f-dangel f-dangel changed the title KFAC support in BatchNorm KFAC support in BatchNorm (eval mode) Sep 24, 2022
@fredguth
Copy link

fredguth commented Jun 5, 2024

Is there a way to Backpack ignore the modules it does not support?
I want to use it with models I did not implement my self (timm models, for example).

@f-dangel
Copy link
Owner

f-dangel commented Jun 6, 2024

Hi,

are you asking this question w.r.t. KFAC?
If you want to use a first-order extension, you can simply extend the layers that are supported by BackPACK.
If you want to use a second-order extension, all layers must be supported by BackPACK, as otherwise it cannot backpropagate the additional information through the compute graph.

Best,
Felix

@fredguth
Copy link

I mean second-order extension. I want just to have an estimation and I don't care if the value is not precise, I just wanted to check its change during training. Can I simply remove the module if it does not change the dimensions?

@f-dangel
Copy link
Owner

Hey,

not sure if I'm following what you exactly want to do. If you remove the BN layers and all layers are supported by BackPACK, you can use second-order extensions. But also, your network will behave differently because you eliminated the BN layers.

Best,
Felix

@fredguth
Copy link

fredguth commented Jun 13, 2024

I only use BackPack second-order extension to measure the Fisher Information of the weights during training, but I will not use them in the training. Every n steps before I do the next training step, I use hessian to measure the information, save the result and, clean the gradients for the next step. It is not a problem for me not measuring the information in the batchnorm layers.

@fKunstner
Copy link
Collaborator

If I got your setup right, that will still be difficult without implementing the batchnorm operation. There is no option to disable the extension on the batchnorm parameters only, because backpack still needs to backpropagate the second-order information through the batchnorm layer to compute the information for the parameters of the earlier layers.

Here's a workaround that could work without having to code the batchnorm extension.

Say we start with the network

net1 = Sequential(
  Linear(1,1),
  Batchnorm1d(1),
  Linear(1,1),
)

We can make a second network that does the same operation as net1 (if net1 is in eval mode) using only Linear layers,

net2 = Sequential(
  Linear(1,1), 
  Linear(1,1), 
  Linear(1,1), 
  Linear(1,1),
)

To make them the same, we need to map the weights from net1 to net2.
For the linear layers, we just copy the data

net2[0].weight.data = net1[0].weight.data
net2[0].bias.data = net1[0].bias.data
net2[3].weight.data = net1[2].weight.data
net2[3].bias.data = net1[2].bias.data

And we should be able to implement the batchnorm operation with 2 linear layers by remapping them as follows (needs a double check)

# Implement the normalization 
# x -> (x - running_mean) / sqrt(running_var + eps) = (1 / sqrt(running_var + eps)) * x - running_mean / sqrt(running_var + eps)

bnlayer = net1[1]

net2[1].weight.data = 1/torch.sqrt(bnlayer._buffers["running_var"].data + bnlayer.eps)
net2[1].bias.data = bnlayer._buffers["running_mean"].data * net2[1].weight.data 
net2[2].weight.data = bnlayer._parameters["weight"].data
net2[2].bias.data = bnlayer._parameters["bias"].data

Now we can extend net2 using backpack to compute kfac.

Instead of doing

extend(net1)

...

with Backpack("KFAC"):
    loss(net1).backward()
extend(net2)

...

map_weights(net1, net2)
with Backpack("KFAC"):
    loss(net2).backward()
inverse_map_grad_and_kfac(net2, net1)

where inverse_map_grad_and_kfac would map the _.grad and _.kfac attributes of the parameters of net2 to the right parameters of net1

(Although operations on _.data shouldn't be tracked by autodiff, maybe put all this in a torch.nograd() block to make sure gradients don't get propagated from one network to the rest?)

@fredguth
Copy link

THanks a lot for your thoughts! I will study this... :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants