Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] K-FAC for BatchNormNd (#259) #260

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions backpack/core/derivatives/basederivatives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Base classes for more flexible Jacobians and second-order information."""
import warnings
from abc import ABC
from typing import Callable, List, Tuple

from torch import Tensor
Expand All @@ -9,7 +8,7 @@
from backpack.core.derivatives import shape_check


class BaseDerivatives(ABC):
class BaseDerivatives:
"""First- and second-order partial derivatives of unparameterized module.

Note:
Expand Down Expand Up @@ -306,7 +305,7 @@ def reshape_like_output(cls, mat: Tensor, module: Module) -> Tensor:
return cls._reshape_like(mat, module.output.shape)


class BaseParameterDerivatives(BaseDerivatives, ABC):
class BaseParameterDerivatives(BaseDerivatives):
"""First- and second order partial derivatives of a module with parameters.

Assumptions (true for `nn.Linear`, `nn.Conv(Transpose)Nd`, `nn.BatchNormNd`):
Expand Down Expand Up @@ -435,7 +434,7 @@ def _weight_jac_mat_prod(
raise NotImplementedError


class BaseLossDerivatives(BaseDerivatives, ABC):
class BaseLossDerivatives(BaseDerivatives):
"""Second- order partial derivatives of loss functions."""

# TODO Add shape check
Expand Down
7 changes: 7 additions & 0 deletions backpack/extensions/secondorder/hbp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from torch import Tensor
from torch.nn import (
AvgPool2d,
BatchNorm1d,
BatchNorm2d,
BatchNorm3d,
Conv2d,
CrossEntropyLoss,
Dropout,
Expand All @@ -27,6 +30,7 @@

from . import (
activations,
batchnorm_nd,
conv2d,
custom_module,
dropout,
Expand Down Expand Up @@ -71,6 +75,9 @@ def __init__(
SumModule: custom_module.HBPSumModule(),
ScaleModule: custom_module.HBPScaleModule(),
Identity: custom_module.HBPScaleModule(),
BatchNorm1d: batchnorm_nd.HBPBatchNormNd(),
BatchNorm2d: batchnorm_nd.HBPBatchNormNd(),
BatchNorm3d: batchnorm_nd.HBPBatchNormNd(),
},
)

Expand Down
36 changes: 36 additions & 0 deletions backpack/extensions/secondorder/hbp/batchnorm_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Tuple, Union

from torch import Tensor, einsum
from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d

from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
from backpack.extensions.backprop_extension import BackpropExtension
from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule
from backpack.utils.errors import batch_norm_raise_error_if_train


class HBPBatchNormNd(HBPBaseModule):
def __init__(self):
super().__init__(BatchNormNdDerivatives(), params=["weight", "bias"])

def weight(self, ext, module, grad_inp, grad_out, backproped):
x_hat, _ = self.derivatives._get_normalized_input_and_var(module)
v = backproped
JTv = einsum("mnc...,nc...->mnc", v, x_hat)
kfac_gamma = einsum("mnc...,mnd...->cd", JTv, JTv)
return [kfac_gamma]

def bias(self, ext, module, grad_inp, grad_out, backproped):
v = backproped
JTv = v
kfac_beta = einsum("mnc...,mnd...->cd", JTv, JTv)
return [kfac_beta]

def check_hyperparameters_module_extension(
self,
ext: BackpropExtension,
module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
) -> None: # noqa: D102
batch_norm_raise_error_if_train(module)
76 changes: 74 additions & 2 deletions test/extensions/secondorder/hbp/kfac_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
GROUP_CONV_SETTINGS,
LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS,
)
from test.utils.evaluation_mode import initialize_batch_norm_eval

from torch import rand
from torch.nn import (
BatchNorm1d,
CrossEntropyLoss,
Flatten,
Identity,
Expand All @@ -26,8 +28,6 @@
)
LOCAL_NOT_SUPPORTED_SETTINGS = []

NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS

BATCH_SIZE_1_SETTINGS = [
{
"input_fn": lambda: rand(1, 7),
Expand Down Expand Up @@ -75,3 +75,75 @@
"id_prefix": "branching-scalar",
},
]

BATCH_SIZE_1_SETTINGS += [
{
"input_fn": lambda: rand(1, 7),
"module_fn": lambda: Sequential(
Linear(7, 3),
initialize_batch_norm_eval(BatchNorm1d(3)),
ReLU(),
Flatten(start_dim=1, end_dim=-1),
Linear(3, 1),
),
"loss_function_fn": lambda: MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((1, 1)),
"id_prefix": "one-additional(bn)",
},
{
"input_fn": lambda: rand(3, 10),
"module_fn": lambda: Sequential(
Linear(10, 5),
initialize_batch_norm_eval(BatchNorm1d(5)),
ReLU(),
# skip connection
Parallel(
Identity(), Linear(5, 5), initialize_batch_norm_eval(BatchNorm1d(5))
),
# end of skip connection
Sigmoid(),
Linear(5, 4),
),
"loss_function_fn": lambda: CrossEntropyLoss(),
"target_fn": lambda: classification_targets((3,), 4),
"id_prefix": "branching-linear(bn)",
},
{
"input_fn": lambda: rand(3, 10),
"module_fn": lambda: Sequential(
Linear(10, 5),
initialize_batch_norm_eval(BatchNorm1d(5)),
ReLU(),
# skip connection
Parallel(
ScaleModule(weight=3.0),
Linear(5, 5),
initialize_batch_norm_eval(BatchNorm1d(5)),
),
# end of skip connection
Sigmoid(),
Linear(5, 4),
),
"loss_function_fn": lambda: CrossEntropyLoss(),
"target_fn": lambda: classification_targets((3,), 4),
"id_prefix": "branching-scalar(bn)",
},
]

LOCAL_NOT_SUPPORTED_SETTINGS += [
{
"input_fn": lambda: rand(3, 7),
"module_fn": lambda: Sequential(
Linear(7, 3),
initialize_batch_norm_eval(BatchNorm1d(3)).train(),
ReLU(),
Flatten(start_dim=1, end_dim=-1),
Linear(3, 1),
),
"loss_function_fn": lambda: MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((3, 1)),
"id_prefix": "one-additional(bn-train)",
},
]

NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS