-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KFAC support in BatchNorm (eval mode) #259
Comments
Hi, thanks for your question! Just to make sure I'm getting it right: You want to compute KFAC for the |
Indeed, I want to compute the KFAC for the parameters of Conv2d, Linear and BatchNorm2d layer. |
BackPACK can compute Sadly, there is no easy way to tell BackPACK to ignore the parameters of batch norm layers, because it tries to compute its quantities on all parameters that have If you want to get KFAC for the supported layers, you will have to set |
Thanks for the prompt reply! Yep, I have not seen some paper discussing the KFAC calculation for BN either... To get the KFAC for the supported layers (Linear and Conv2d), I found a way to bypass the NotImplementedError. # Extend hbp/__init__.py by
class HBP(SecondOrderBackpropExtension):
def __init__(
self,
curv_type,
loss_hessian_strategy,
backprop_strategy,
ea_strategy,
savefield="hbp",
):
...
super().__init__(
savefield=savefield,
fail_mode="ERROR",
module_exts={
...
Identity: custom_module.HBPScaleModule(),
BatchNorm2d: batchnorm_nd.HBPBatchNormNd(),
},
)
# The HBPBatchNormNd is defined as
from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule
class HBPBatchNormNd(HBPBaseModule):
def __init__(self):
super().__init__(BatchNormNdDerivatives(), params=None) With such modification, it works without raising the error. I am not quite sure whether it is the right manner. Do you have any advice? |
Hi, that workaround looks good! Indeed, this will ignore the BN parameters, while keeping BackPACK's backpropagation through the layer for KFAC intact. |
That's a great relief! :) |
One way to get started on this would be to add support for KFAC in BatchNorm in evaluation mode. I will outline in the following what needs to be done (this may be technically not 100% accurate). Pull requests welcome. Let's assume a
(where We don't really need a Kronecker factorization here, because the curvature blocks for
|
Extension to support BatchNormNd (eval) K-FAC Resolves f-dangel/issues/259 Auxiliary: - The kfac quantity contains only one element, and represents the GGN approximation. - It only supports the evaluation mode. - A test script (test_kfac_bn.py) checks these two properties. Signed-off-by: pyun <[email protected]>
Thanks for your guidance! :) kfac_gamma = einsum("mnc,mnd->cd", JTv, JTv) # [C, C] |
Hi, thanks for the PR; apologies you will have to be patient with my review. Regarding your question: Good point! The factor For Best, |
Is there a way to Backpack ignore the modules it does not support? |
Hi, are you asking this question w.r.t. KFAC? Best, |
I mean second-order extension. I want just to have an estimation and I don't care if the value is not precise, I just wanted to check its change during training. Can I simply remove the module if it does not change the dimensions? |
Hey, not sure if I'm following what you exactly want to do. If you remove the BN layers and all layers are supported by BackPACK, you can use second-order extensions. But also, your network will behave differently because you eliminated the BN layers. Best, |
I only use BackPack second-order extension to measure the Fisher Information of the weights during training, but I will not use them in the training. Every n steps before I do the next training step, I use hessian to measure the information, save the result and, clean the gradients for the next step. It is not a problem for me not measuring the information in the batchnorm layers. |
If I got your setup right, that will still be difficult without implementing the batchnorm operation. There is no option to disable the extension on the batchnorm parameters only, because backpack still needs to backpropagate the second-order information through the batchnorm layer to compute the information for the parameters of the earlier layers. Here's a workaround that could work without having to code the batchnorm extension. Say we start with the network
We can make a second network that does the same operation as
To make them the same, we need to map the weights from
And we should be able to implement the batchnorm operation with 2 linear layers by remapping them as follows (needs a double check)
Now we can extend Instead of doing
where (Although operations on |
THanks a lot for your thoughts! I will study this... :-) |
Hi,
Thanks for the repo! This is really a nice work.
I am planning to calculate the KFAC with backpack. But it raises the following error:
My network is as follows:
When calculating the KFAC with:
It raises the not implemented error. I am wondering whether calculating KFAC in a network with BN layers in the middle is supported by backpack? It seems like it should be supported, since it successfully works in ResNet.
Thanks
The text was updated successfully, but these errors were encountered: